线段树

May 07th, 2019
  • 在其它设备中阅读本文章

Summary

线段树相关.

Links

两道模板题:

  1. [P3374 [模板]树状数组 1](https://www.luogu.org/problemnew/show/P3374)
  2. [P3372 [模板]线段树 1](https://www.luogu.org/problemnew/show/P3372)

Text

保存一颗线段树

线段树的长相大概是

这个样子, 像一颗以 $[1,10]$ 为根的二叉树.
仔细观察可以发现:

  • 这是一颗接近满二叉树的二叉树 (只有最后一层不满);
  • 非叶子节点 $[L,R]$ 的左儿子为 $[L,\lfloor\frac{L+R}{2}\rfloor]$, 右儿子 (如果有) 为 $[\lfloor\frac{L+R}{2}\rfloor+1,R]$;
  • 叶子节点 $[L,R]$ 有 $L=R$.

所以可以方便的使用一个数组(设为sum)来保存一颗线段树. 节点 $i$ 的左儿子为 $i*2$, 右儿子为 $i*2+1$. 节点 $[L,R]$ 本来需要保存以下信息 (记 $v_i$ 为第 $i$ 个位置上的值):

  • $L$,$R$;
  • $\sum_{i=L}^Rv_i$.

但是由之前观察得到的结论

非叶子节点 $[L,R]$ 的左儿子为 $[L,\lfloor\frac{L+R}{2}\rfloor]$, 右儿子 (如果有) 为 $[\lfloor\frac{L+R}{2}\rfloor+1,R]$.

可以看出一个节点的 $L$,$R$ 是可以由父节点的 $L$,$R$ 推出来的, 而根节点 $L=1$,$=n$ 已知, 所以只需要保存 $\sum_{i=L}^Rv_i$ 就可以了.
在建树时需要用到下面的代码, 解释会在后面给出.
Code:

void build(int l, int r, int rt)
{
    if (l == r)
    {
        scanf("%d", &sum[rt]);
        return;
    }
    int m = (l + r) >> 1;
    build(l, m, rt << 1);
    build(m + 1, r, rt << 1 | 1);
    pushup(rt);
}

.

线段树单点修改与区间查询

单点 $p$ 更新时可以看作更新区间 $[p,p]$, 设当前区间为 $[l,r]$,$m=\frac{l+r}{2}$, 则

  • 令 $l=1$,$r=n$;
  • 若当前 $l\not=r$

    • 当 $l\leq p\leq m$ 时,$[p,p]$ 在 $[l,r]$ 的左儿子 $[l,m]$ 内, 递归访问 $[l,m]$;
    • 否则 ($m<p\leq r$ 时)$[p,p]$ 在 $[l,r]$ 的右儿子 $[m+1,r]$ 内, 递归访问 $[m+1,r]$.
  • 否则当前 $l=r=p$;

然后使 $v_{[p,p]的编号}$ 加上所需的值就可以了.

Code:

void update(int p, int c, int l, int r, int rt)
{
    if (l == r)
    {
        sum[rt] += c;
        return;
    }
    int m = (l + r) >> 1;
    if (p <= m)
        update(p, c, l, m, rt << 1);
    else
        update(p, c, m + 1, r, rt << 1 | 1);
    pushup(rt);
}

, 注意到里面有一个pushup()函数, 作用等会再讲.

在求和时, 显然 $\sum_{i=L}^Rv_i=\sum_{j=L}^Mv_j+\sum_{k=M+1}^Rv_k$, 线段树正是基于此种思想来查询区间和的.

比如我要查询区间 $[2,8]$ 中所有数的和, 则需要查询的区间 (黄色) 为

, 但是由之前的结论, 其实只需要查询绿色区间即可(为什么)

. 所以接下来我们应该利用黄色区间维护绿色区间. 线段树 pushup 操作.
Code:

void pushup(int rt)
{
    sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}

, 在每次单点更新后pushup()一下即可维护线段树.
现在考虑对 $[L,R]$ 求和, 则

  • 令 $l=1$,$r=n$;
  • 若 $L\leq l$,$R\geq r$($[l,r]$ 完全包含于 $[L,R]$), 则返回当前区间和;
  • 否则

    • 若 $L\leq\lfloor\frac{L+R}{2}\rfloor$(说明在左子树中还有一部分属于 $[L,R]$)则递归访问 $[l,\lfloor\frac{L+R}{2}\rfloor]$;
    • 若 $R>\lfloor\frac{L+R}{2}\rfloor$(说明在右子树中还有一部分属于 $[L,R]$)则递归访问 $[\lfloor\frac{L+R}{2}\rfloor+1,r]$;

Code:

int getsum(int L, int R, int l, int r, int rt)
{
    if (L <= l && r <= R)
        return sum[rt];
    int m = (l + r) >> 1, ret = 0;
    if (L <= m)
        ret += getsum(L, R, l, m, rt << 1);
    if (m < R)
        ret += getsum(L, R, m + 1, r, rt << 1 | 1);
    return ret;
}

.

分析一下建树的代码

建树的代码上面已经给过了, 现在来分析一下.

void build(int l, int r, int rt)
{

,emmm, 这个就算了吧.

    if (l == r)
    {
        scanf("%d", &sum[rt]);
        return;
    }

, 这一段就像单点更新, 找到了目标点则更新, 然后返回.

    int m = (l + r) >> 1;
    build(l, m, rt << 1);
    build(m + 1, r, rt << 1 | 1);

, 这一段表示如果还没有走到叶子节点则递归访问左右子树建树.

    pushup(rt);

, 在建好左右子树后更新上一层.

}

.

如何做到区间更新

简单嘛,for i:l~r update();.
显然这样会超时, 所以我们需要新的方法: 延迟更新.
具体来说, 我们在更新一个区间时不会去 真正地更新, 而是做一个标记(laz)代替. 比如: 更新 $[L,R]$, 将其中每个数加上 $c$, 当需要用到 整个 $[L,R]$ 区间 时, 这次更新会使区间和增加 $(R-L+1)*c$, 所以只需要一个标记记录 "这个区间中每个数增加了 c" 即可.
但是当我们需要用到 $[L,R]$ 当中的一部分 时则必须向下更新这个区间(没有在更新时立即向下更新从而实现了延迟), 即pushdown().
Code:

void pushdown(int rt, int m)
{
    if (!laz[rt])
        return;
    laz[rt << 1] += laz[rt];
    laz[rt << 1 | 1] += laz[rt];
    sum[rt << 1] += laz[rt] * (m - (m >> 1));
    sum[rt << 1 | 1] += laz[rt] * (m >> 1);
    laz[rt] = 0;
}

, 对于update()getsum()也要略作修改.
Code:

void update(int L, int R, int c, int l, int r, int rt)
{
    if (L <= l && r <= R)
    {
        laz[rt] += c;
        sum[rt] += c * (r - l + 1);
        return;
    }
    pushdown(rt, r - l + 1);
    int m = (l + r) >> 1;
    if (L <= m)
        update(L, R, c, l, m, rt << 1);
    if (m < R)
        update(L, R, c, m + 1, r, rt << 1 | 1);
    pushup(rt);
}
int getsum(int L, int R, int l, int r, int rt)
{
    if (L <= l && r <= R)
        return sum[rt];
    int m = (l + r) >> 1, ret = 0;
    pushdown(rt, r - l + 1);
    if (L <= m)
        ret += getsum(L, R, l, m, rt << 1);
    if (m < R)
        ret += getsum(L, R, m + 1, r, rt << 1 | 1);
    return ret;
}

完整代码

template <typename _Tp>
class SegmentTree
{
public:
    void Build(int TotalLenth, int (*ReadInitialValue)(), int LeftEnd = 1, int RightEnd = -1, int CurrentPoint = 1)
    {
        m_TotalLenth = TotalLenth;
        if (RightEnd == -1)
            RightEnd = m_TotalLenth;
        if (LeftEnd == RightEnd)
        {
            m_Summation[CurrentPoint] = ReadInitialValue();
            return;
        }
        int Midpoint = (LeftEnd + RightEnd) >> 1;
        Build(m_TotalLenth, ReadInitialValue, LeftEnd, Midpoint, CurrentPoint << 1);
        Build(m_TotalLenth, ReadInitialValue, Midpoint + 1, RightEnd, CurrentPoint << 1 | 1);
        __Pushup(CurrentPoint);
    }
    _Tp GetSum(_Tp InitialValue, int GetsumLeftEnd, int GetsumRightEnd, int LeftEnd = 1, int RightEnd = -1, int CurrentPoint = 1)
    {
        if (RightEnd == -1)
            RightEnd = m_TotalLenth;
        if (GetsumLeftEnd <= LeftEnd && RightEnd <= GetsumRightEnd)
            return InitialValue + m_Summation[CurrentPoint];
        int Midpoint = (LeftEnd + RightEnd) >> 1;
        _Tp ReturnValue = InitialValue;
        __Pushdown(CurrentPoint, RightEnd - LeftEnd + 1);
        if (GetsumLeftEnd <= Midpoint)
            ReturnValue += GetSum(InitialValue, GetsumLeftEnd, GetsumRightEnd, LeftEnd, Midpoint, CurrentPoint << 1);
        if (Midpoint < GetsumRightEnd)
            ReturnValue += GetSum(InitialValue, GetsumLeftEnd, GetsumRightEnd, Midpoint + 1, RightEnd, CurrentPoint << 1 | 1);
        return ReturnValue;
    }
    void Update(_Tp UpdateValue, int UpdateLeftEnd, int UpdateRightEnd = -1, int LeftEnd = 1, int RightEnd = -1, int CurrentPoint = 1)
    {
        if (UpdateRightEnd == -1)
            UpdateRightEnd = UpdateLeftEnd;
        if (RightEnd == -1)
            RightEnd = m_TotalLenth;
        if (UpdateLeftEnd <= LeftEnd && RightEnd <= UpdateRightEnd)
        {
            m_LazyTag[CurrentPoint] += UpdateValue;
            m_Summation[CurrentPoint] += UpdateValue * (RightEnd - LeftEnd + 1);
            return;
        }
        __Pushdown(CurrentPoint, RightEnd - LeftEnd + 1);
        int Midpoint = (LeftEnd + RightEnd) >> 1;
        if (UpdateLeftEnd <= Midpoint)
            Update(UpdateValue, UpdateLeftEnd, UpdateRightEnd, LeftEnd, Midpoint, CurrentPoint << 1);
        if (Midpoint < UpdateRightEnd)
            Update(UpdateValue, UpdateLeftEnd, UpdateRightEnd, Midpoint + 1, RightEnd, CurrentPoint << 1 | 1);
        __Pushup(CurrentPoint);
    }

private:
    _Tp m_LazyTag[400005], m_Summation[400005];
    int m_TotalLenth;
    void __Pushdown(int CurrentPoint, int Midpoint)
    {
        if (!m_LazyTag[CurrentPoint])
            return;
        m_LazyTag[CurrentPoint << 1] += m_LazyTag[CurrentPoint];
        m_LazyTag[CurrentPoint << 1 | 1] += m_LazyTag[CurrentPoint];
        m_Summation[CurrentPoint << 1] += m_LazyTag[CurrentPoint] * (Midpoint - (Midpoint >> 1));
        m_Summation[CurrentPoint << 1 | 1] += m_LazyTag[CurrentPoint] * (Midpoint >> 1);
        m_LazyTag[CurrentPoint] = 0;
    }
    void __Pushup(int CurrentPoint)
    {
        m_Summation[CurrentPoint] = m_Summation[CurrentPoint << 1] + m_Summation[CurrentPoint << 1 | 1];
    }
};

owo

mo-ha