树链剖分总结

算法竞赛 数据结构-树链剖分
编辑文章

作为一个上个月刚学完线段树的蒟蒻,看Splay又看不懂,便直接跳着来学树剖了。又在一个博客上看到说学树剖之前最好还要把LCA给学了,便去花了一天学了一个Tarjan求LCA(然而后来发现并不怎么需要),然后是几乎照着别人的代码把树剖抄懂的。在这里我就讲一下我理解的树剖。


准备工作

  1. 链表/链式前向星
  2. 线段树/树状数组/Splay等可以维护一段数据的数据结构
  3. LCA (其实只涉及到一些思想,而且几乎用不到,可以先不学)

树剖原理

我很早之前就听说过树剖,当时觉得实在是太高大上了,但现在发现只是名字比较高端,整个原理还是很简单的,只是码量比较大。

作为蒟蒻,我在图论题中几乎都只会搜索,也靠着搜索在SCOI上拿了仅有的几十分。不过我好歹还是学过前缀和差分的,如果一个图中所有的点连成一条线,那么来找两个点之间路径长只需要维护前缀和就行了。如果一棵树中的结点都连成一条线,我们就把它称作链。如果把树上的结点分为若干条链,那么很多问题就可以变得简单多了。

所谓树链剖分就是将一棵树给剖分成若干条链,再分别处理。

当然,一个节点也是可以算一条链的,不过如果这样分还不如不分。树剖这个算法的目的,便是将一棵树中每一个非叶节点分到链中,并且每一个节点都只属于一条链,这样查询起来又可以快很多。

在这里给出一棵树(图源百度百科):

树剖

在这张图中,粗线即为分成的链。要让每一个非叶节点在链上,我们就需要让一条链尽可能覆盖更多节点。所以在每一个节点的子节点中,我们选以它为根的子树节点个数最多的子节点来连成链。

比如在4号节点的子节点{8,9,10}中,以8和以10为根的子树的节点总个数为1,而以9为根的子树节点个数为3,所以我们就将9作为链上的一个节点继续向下连接。9就被称为是4号节点的重节点,其他的两个节点就被称作轻节点

继续扩展,父节点和重节点间的连线被称作重边,就是粗线;父节点和轻结点的连线被称作轻边,就是图中的细线;多条重边连接起来的路径叫重链,如路径1->4->9->13->14;多条轻边连接起来的路径叫轻链,如路径1->2->5。

通过一个表格将这些定义总结一下

定义含义
重节点以它为根的子树节点个数最多的节点
轻节点所有子节点中不是重节点的节点
重边父节点和重节点间的连线
轻边父节点和轻结点的连线
重链多条重边连接起来的路径
轻链多条轻边连接起来的路径

实现

我们使用两次dfs就能实现剖分,但是只剖分的话是并没有什么卵用的,一般题目中还会涉及到两节点间的权值和,权值最大值等问题。这里以 洛谷P3384【模板】树链剖分 为例。

剖分

首先先解释下我使用的变量

变量名意义
fa[x]x号节点的父亲
son[x]x号节点的重儿子(节点)
size[x]以x号节点为根的子树中节点个数
deep[x]x号节点的深度
top[x]x号节点所在的链顶的节点编号
w[x]x号节点的原权值
wnew[x]dfs序中第x号节点的权值
id[x]x号节点的dfs序
edge[]和head[]链式前向星数组
tree[]线段树

先进行第一次dfs,需要完成的任务是

  1. 确定这个点的深度
  2. 确定父亲节点
  3. 确定以这个节点为根的子树中节点个数
  4. 确定这个点的重儿子

具体实现方式见代码

void dfs1(int x,int f,int depth)
{
    deep[x]=depth;//深度
    fa[x]=f;//父亲节点
    size[x]=1;//子树节点个数至少有一个
    int mx=-1;
    for (int i=head[x];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if (v==f) continue;//不搜索父节点
        dfs1(v,x,depth+1);
        size[x]+=size[v];//节点个数加上子节点的
        if (size[v]>mx) //更新重儿子
        {
            mx=size[v];
            son[x]=v;
        }
    }
}

然后是第二次dfs。需要完成的任务是

  1. 确定新编号(dfs序)
  2. 赋权值到新编号上
  3. 确定这个点所在的链的顶端
  4. 处理每一条链

还是看代码吧

void dfs2(int x,int topf)
{
    id[x]=++dfsord;//标记每一个节点的dfs序
    wnew[dfsord]=w[x];//得到新编号(dfs序)
    top[x]=topf;//得到这条链的顶端
    if (!son[x]) return;//无儿子返回
    dfs2(son[x],topf);//先处理重儿子
    for (int i=head[x];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if (v==fa[x] || v==son[x]) continue;
        dfs2(v,v);//如果是轻儿子,新的链一定以自己为顶端
    }
}
先处理重儿子是为了保证每一条链都是被连续处理的

好了,剖分就结束了,是不是很简单啊

处理问题

操作1,2

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

在处理u和v号节点之间路径上所有节点时,我们一般先让u和v属于同一条链,然后因为dfs序的特点就可以直接用线段树处理了。

让u和v顶端相同的方法是将u和v中较深的节点往上跳到这条链顶端的上方,跳完一个再交换跳下一个。每次更改或查询只要更改当前跳的节点到它所在链的顶端即可,最后到一条链上了以后直接处理两点之间就行了。

代码如下:

注:我的线段树是左闭右开区间,即表示[left,right),所以处理时右边要+1

操作1

void uprange(int u,int v,int delta)
{
    delta%=p;//按题意取%
    while (top[u]!=top[v])
    {
        if (deep[top[u]]<deep[top[v]]) swap(u,v);//交换为更深的点
        change(1,id[top[u]],id[u]+1,delta);
        u=fa[top[u]];//向上跳
    }
    if (deep[u]>deep[v]) swap(u,v);//交换为更深的点保证u的dfs序在前
    change(1,id[u],id[v]+1,delta);
}

操作2

int qrange(int u,int v)//求u到v节点的路径中节点之和
{
    int ans=0;
    while (top[u]!=top[v])//不停将u向上跳,直到在一条链上
    {
        if (deep[top[u]]<deep[top[v]]) swap(u,v);//交换成更深的点
        ans+=query(1,id[top[u]],id[u]+1);
        ans%=p;
        u=fa[top[u]];//向上跳
    }
    //已经在一条链上
    if (deep[u]>deep[v]) swap(u,v);
    ans+=query(1,id[u],id[v]+1);
    ans%=p;
    return ans;
}

操作3,4

可以根据dfs序的性质直到,子树区间右端点为id[x]+siz[x]-1,直接处理即可。

代码:

操作3

inline void upson(int u,int delta)
{
    change(1,id[u],id[u]+size[u],delta);
}

操作4

inline int qson(int u)
{
    return query(1,id[u],id[u]+size[u]);
}

其它细节

  1. 在一些没有指定根的问题中其实以任意节点为根都是可以的
  2. 根节点开始dfs时可以以0作为它的根,顶端就是它本身
  3. 一定要记得先dfs在建树,因为线段树是处理dfs序的
dfs1(root,0,1);
dfs2(root,root);
build(1,1,n+1);

完整代码

#include<bits/stdc++.h>
using namespace std;
struct Edge{
    int next,to;
} edge[200005];
int fa[200005],size[200005],deep[200005],w[200005],wnew[200005],head[200005],son[200005],id[200005],top[200005];
struct Tree{
    int left,right,sum,delta;
} tree[800005];
int cnt=1,ans,n,m,a,b,c,d,p,dfsord,root;
inline void add(int u,int v)
{
    edge[cnt].to=v;
    edge[cnt].next=head[u];
    head[u]=cnt++;
}
void build(int x,int l,int r)
{
    tree[x].left=l;
    tree[x].right=r;
    if (r-l==1) tree[x].sum=wnew[l];
    else
    {
        build(x*2,l,(l+r)/2);
        build(x*2+1,(l+r)/2,r);
        tree[x].sum=(tree[x*2].sum+tree[x*2+1].sum)%p;
    }
}
inline void update(int x)
{
    tree[x*2].sum+=tree[x].delta*(tree[x*2].right-tree[x*2].left);
    tree[x*2+1].sum+=tree[x].delta*(tree[x*2+1].right-tree[x*2+1].left);
    tree[x*2].sum%=p;
    tree[x*2+1].sum%=p;
    tree[x*2].delta+=tree[x].delta;
    tree[x*2+1].delta+=tree[x].delta;
    tree[x].delta=0;
}
void change(int x,int l,int r,int delta)
{
    if (l<=tree[x].left && r>=tree[x].right) 
    {
        tree[x].delta+=delta;
        tree[x].sum+=delta*(tree[x].right-tree[x].left);
        tree[x].sum%=p;
    }
    else
    {
        if (tree[x].delta!=0) update(x);
        if (l<(tree[x].left+tree[x].right)/2) change(x*2,l,r,delta);
        if (r>(tree[x].left+tree[x].right)/2) change(x*2+1,l,r,delta);
        tree[x].sum=(tree[x*2].sum+tree[x*2+1].sum)%p;
    }
}
int query(int x,int l,int r)
{
    if (l<=tree[x].left && r>=tree[x].right) return tree[x].sum%p;
    else
    {
        if (tree[x].delta!=0) update(x);
        int ans=0;
        if (l<(tree[x].left+tree[x].right)/2) ans+=query(x*2,l,r);
        if (r>(tree[x].left+tree[x].right)/2) ans+=query(x*2+1,l,r);
        return ans%p;
    }
}               
/*  dfs1
    标记每个点的深度dep[]   
    标记每个点的父亲fa[]
    标记每个非叶子节点的子树大小(含它自己)    
    标记每个非叶子节点的重儿子编号son[]
*/
void dfs1(int x,int f,int depth)
{
    deep[x]=depth;
    fa[x]=f;
    size[x]=1;
    int mx=-1;
    for (int i=head[x];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if (v==f) continue;//不搜索父节点
        dfs1(v,x,depth+1);
        size[x]+=size[v];
        if (size[v]>mx) //更新重儿子
        {
            mx=size[v];
            son[x]=v;
        }
    }
}
/*  dfs2
    标记每个点的新编号
    赋值每个点的初始值到新编号上
    处理每个点所在链的顶端
    处理每条链
*/
void dfs2(int x,int topf)
{
    id[x]=++dfsord;//标记每一个节点的dfs序
    wnew[dfsord]=w[x];//得到新编号(dfs序)
    top[x]=topf;//得到这条链的顶端
    if (!son[x]) return;//无儿子返回
    dfs2(son[x],topf);//先处理重儿子
    for (int i=head[x];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if (v==fa[x] || v==son[x]) continue;
        dfs2(v,v);//如果是轻儿子,新的链一定以自己为顶端
    }
}
int qrange(int u,int v)//求u到v节点的路径中节点之和
{
    int ans=0;
    while (top[u]!=top[v])//不停将u向上跳,直到在一条链上
    {
        if (deep[top[u]]<deep[top[v]]) swap(u,v);//交换成更深的点
        ans+=query(1,id[top[u]],id[u]+1);
        ans%=p;
        u=fa[top[u]];//向上跳
    }
    //已经在一条链上
    if (deep[u]>deep[v]) swap(u,v);
    ans+=query(1,id[u],id[v]+1);
    ans%=p;
    return ans;
}
void uprange(int u,int v,int delta)
{
    delta%=p;
    while (top[u]!=top[v])
    {
        if (deep[top[u]]<deep[top[v]]) swap(u,v);
        change(1,id[top[u]],id[u]+1,delta);
        u=fa[top[u]];
    }
    if (deep[u]>deep[v]) swap(u,v);
    change(1,id[u],id[v]+1,delta);
}
inline int qson(int u)
{
    return query(1,id[u],id[u]+size[u]);//子树区间右端点为id[x]+siz[x]-1 
}
inline void upson(int u,int delta)
{
    change(1,id[u],id[u]+size[u],delta);
}
int main()
{
    memset(head,-1,sizeof(head));
    scanf("%d%d%d%d",&n,&m,&root,&p);
    for (int i=1;i<=n;i++) scanf("%d",&w[i]);
    for (int i=1;i<n;i++)
    {
        scanf("%d%d",&a,&b);
        add(a,b);
        add(b,a);
    }
    dfs1(root,0,1);
    dfs2(root,root);
    build(1,1,n+1);
    for (int i=1;i<=m;i++)
    {
        scanf("%d",&a);
        if (a==1)
        {
            scanf("%d%d%d",&b,&c,&d);
            uprange(b,c,d);
        }
        if (a==2)
        {
            scanf("%d%d",&b,&c);
            printf("%d\n",qrange(b,c));
        }
        if (a==3)
        {
            scanf("%d%d",&b,&c);
            upson(b,c);
        }
        if (a==4)
        {
            scanf("%d",&b);
            printf("%d\n",qson(b));
        }
    }
    return 0;
}

相关例题

待更新……

新评论

称呼不能为空
邮箱格式不合法
网站格式不合法
内容不能为空