点分治总结

@Llf  April 26, 2018

我自己感觉点分治是一种比较玄学神奇的算法,感觉自己也只是理解了一点皮毛而已。这里就谈谈点分治的实现方法和一些运用。


方法

点分治其实是将一棵树的点不断分为多棵子树,分别得到子树节点到子树根的距离来进行处理。

既然要有一个点为根节点,那么应该选取哪一个点呢?

既然要向下分治,如果把一个原来的叶节点作为分治的根节点,那么遍历每一个点的时间开销是非常大的。然而如果这个点的左右子树都达到最大,那么遍历的时间开销就要少很多。而子树的重心就能满足这个性质,所以我们选取子树的重心作为分治中心,即分治子树的根。

1.求重心

首先我们需要了解重心是什么

树的重心也叫树的质心。找到一个点,其所有的子树中最大的子树节点数最少,那么这个点就是这棵树的重心,删去重心后,生成的多棵树尽可能平衡。-- 百度百科

提取关键词:

  1. 其所有的子树中最大的子树节点数最少
  2. 生成的多棵树尽可能平衡

第1点告诉我们怎么求重心,只需要记录每一个点的子树大小,然后按着定义做就行了。

第2点告诉我们选择重心作为分治点的原因,因为多棵树尽可能平衡,就可以使复杂度变得最优秀。

这个不是重点,直接上代码(感觉和树剖差不多):

这里可以先不管vis[],后面会解释
sizenow是当前这个子树的大小,后面会讲到
mx是当前分治到的最大的子树节点最小值

void getroot(int x,int f)
{
  size[x]=1;
  son[x]=0;
  for (int i=head[x];i;i=edge[i].next)//遍历每一个子节点
  {
    int y=edge[i].to;
    if (y==f || vis[y]) continue;
    getroot(y,x);
    size[x]+=size[y];//加上子节点的子树大小
    son[x]=max(son[x],size[y]);//求最大的子树
  }
  son[x]=max(son[x],sizenow-size[x]);
  if (son[x]<mx)//使最大的子树节点最少
  {
    mx=son[x];
    root=x;//记录分治中心(根节点)
  }
}

2.进行分治

分治其实就是一个递归,不断地将子树进行处理就行了,注意要用vis[]标记已经访问防止重复。先上代码:

solve就是解决问题,每一道题有不同的作用,这里不用管。

void divide(int x)
{
  vis[x]=true;//标记已访问
  solve(x,0,1);
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (vis[y]) continue;
    solve(y,edge[i].w,0);
    mx=2147483646;root=0;sizenow=size[y];//得到当前树的大小,将mx和root初始化
    getroot(y,0);//得到子树的分治中心
    divide(root);//分治子树
  }
}

至于为什么有两次solve,主要是为了去重。原因如下:

以下内容来自https://blog.csdn.net/qq_39553725/article/details/77542223


对于以下这棵树:
点分治
显然A点是它的重心。
我们假设现在分治到了A点(当前点为A)
我们一开始求解贡献时,会有以下路径被处理出来:
A—>A
A—>B
A—>B—>C
A—>B—>D
A—>E
A—>E—>F (按照先序遍历顺序罗列)
那么我们在合并答案是会将上述6条路径两两进行合并。
这是注意到:
合并A—>B—>C 和 A—>B—>D 肯定是不合法的!!
因为这并不是一条树上(简单)路径,出现了重边,我们要想办法把这种情况处理掉。
处理方法很简单,减去每个子树的单独贡献。
例如对于以B为根的子树,就会减去:
B—>B
B—>C
B—>D
这三条路径组合的贡献
读者可能会有疑问,这与上面的6条路径并不一样啊。
我们再回过头来看一看这两句代码:
ans = ans + solve(tr,0);
ans = ans - solve(v,t[i].len);
注意到了吧,solve函数的第二个初始值并不相同。
我们在处理子树时,将初始长度设为连接边长,这样做就相当于个子树的每个组合都加上了A—>的路径,从而变得与上面一样。
个人认为这是点分治一个极其重要的地方,读者们一定要理解清楚。


好了,分治就完了。不过和树剖一样,只是剖分完是并没有什么卵用的,还需要具体问题具体分析。

具体例题

1.洛谷P3806 【模板】点分治1

题目链接:https://www.luogu.org/problemnew/show/P3806

对于这道题询问的距离为k的点是否存在,我们可以用一种类似桶排序的笨方法,将距离作为数组下标,统计到这个距离就将数组++,最后不为0就存在,反之不存在。

要得到两点之间的距离,只需要记录两个点分别到根的距离,然后相加即可。

这里的query函数用作求当前分治到的树中点到根的距离,solve用于相加得到两点间距离并统计。

#include<bits/stdc++.h>
using namespace std;
struct Edge{
  int next,to,w;
} edge[20005];
int size[20005],son[20005],dis[20005],head[20005];
bool vis[20005];
int sum[10000005];//存距离是否存在,注意大小和n不一样
int n,m,k,a,b,c,ord,cnt=1,mx,root,sizenow;
inline int read()
{
  char ch=getchar();
  int f=1,x=0;
  while (ch<'0' || ch>'9')
  {
    if (ch=='-') f=-1;
    ch=getchar();
  }
  while (ch>='0' && ch<='9')
  {
    x=x*10+ch-'0';
    ch=getchar();
  }
  return f*x;
}
inline void add(int u,int v,int w)
{
  edge[cnt].to=v;
  edge[cnt].w=w;
  edge[cnt].next=head[u];
  head[u]=cnt++;
}
void getroot(int x,int f)
{
  size[x]=1;
  son[x]=0;
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (y==f || vis[y]) continue;
    getroot(y,x);
    size[x]+=size[y];
    son[x]=max(son[x],size[y]);
  }
  son[x]=max(son[x],sizenow-size[x]);
  if (son[x]<mx)
  {
    mx=son[x];
    root=x;
  }
}
void query(int x,int f,int dist)
{
  dis[++ord]=dist;//得到到根的距离
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (y==f || vis[y]) continue;
    query(y,x,dist+edge[i].w);//查询子节点
  }
}
void solve(int x,int length,bool s) //s=false时要去重
{
  ord=0;
  query(x,0,length);
  for (int i=1;i<=ord-1;i++)
    for (int j=i+1;j<=ord;j++)
    {
      if (s==true) sum[dis[i]+dis[j]]++;
      else sum[dis[i]+dis[j]]--; //进行统计
    }
}
void divide(int x)
{
  vis[x]=true;
  solve(x,0,1);
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (vis[y]) continue;
    solve(y,edge[i].w,0);
    mx=2147483646;root=0;sizenow=size[y];
    getroot(y,0);
    divide(root);
  }
}
int main()
{
  memset(vis,false,sizeof(vis));
  n=read();m=read();
  for (int i=1;i<n;i++)
  {
    a=read();b=read();c=read();
    add(a,b,c);
    add(b,a,c);
  }
  root=0;mx=2147483646;sizenow=n;
  getroot(1,0);
  divide(root);
  for (int i=1;i<=m;i++)
  {
    a=read();
    if (sum[a]) printf("AYE\n");
    else printf("NAY\n");
  }
  return 0;
}

2.[国家集训队]聪聪可可

题目链接:

这里使用一个mod3[3]数组来记录两点间距离%3后的值分别为0,1,2的个数。

对每个节点求出其子树内的dis,经过该点的路径数即为mod3[1]*mod3[2]*2+mod3[0]^2

因为要求答案互质,所以还需要一个求最大公约数的操作。

#include<bits/stdc++.h>
using namespace std;
struct Edge{
  int next,to,w;
} edge[40005];
int head[40005],size[40005],son[40005],mod3[3];
bool vis[40005];
int n,m,a,b,c,cnt=1,ord,root,mx,sizenow,ans;
inline int read()
{
  char ch=getchar();
  int f=1,x=0;
  while (ch<'0' || ch>'9')
  {
    if (ch=='-') f=-1;
    ch=getchar();
  }
  while (ch>='0' && ch<='9')
  {
    x=x*10+ch-'0';
    ch=getchar();
  }
  return f*x;
}
inline void add(int u,int v,int w)
{
  edge[cnt].to=v;
  edge[cnt].w=w;
  edge[cnt].next=head[u];
  head[u]=cnt++;
}
inline void init(int sizen)
{
  mx=2147483647;root=0;sizenow=sizen;
}
void getroot(int x,int f)
{
  size[x]=1;
  son[x]=0;
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (vis[y] || y==f) continue;
    getroot(y,x);
    size[x]+=size[y];
    son[x]=max(son[x],size[y]);
  }
  son[x]=max(son[x],sizenow-size[x]);
  if (mx>son[x])
  {
    mx=son[x];
    root=x;
  }
}
void query(int x,int f,int dist)
{
  mod3[dist%3]++;
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (vis[y] || y==f) continue;
    query(y,x,dist+edge[i].w);
  }
}
inline int solve(int x,int length)
{
  memset(mod3,0,sizeof(mod3));
  query(x,0,length);
  return mod3[0]*mod3[0]+mod3[1]*mod3[2]*2;
}
void divide(int x)
{
  vis[x]=true;
  ans+=solve(x,0);
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (vis[y]) continue;
    ans-=solve(y,edge[i].w);
    init(size[y]);
    getroot(y,0);
    divide(root);
  }
}
int gcd(int x,int y)//求最大公约数
{
  if (y==0) return x;
  return gcd(y,x%y);
}
int main()
{
  n=read();
  for (int i=1;i<n;i++)
  {
    a=read();b=read();c=read();
    add(a,b,c);
    add(b,a,c);
  }
  init(n);
  getroot(1,0);
  divide(root);
  printf("%d%c%d",ans/gcd(ans,n*n),'/',n*n/gcd(ans,n*n));
  return 0;
}

3.[IOI2011]Race

题目链接:

BZOJ是权限题,若没交钱就用洛谷,或者用离线BZOJ题库(by ruanxingzhi):
https://bzoj.llf0703.com/p/2599.html

洛谷上这道题是@larryzhong大佬提供的,比我小3岁却吊打我,在这里先%%%为敬

因为这道题需要得到边数最小,我们对于距离为i的点建立tmp[i],表示在当前递归到的子树中,走到距离为i的顶点最少需要走多少边

点分治,每次先对每棵子树遍历,求出每个点i到root的距离dis[i],以及走过的边数d[i],那么ans=min(ans,tmp[k-dis[i]]+d[i]).

遍历完这棵子树再修改被访问了的tmp[dis[i]],然后下一棵。最后所有子树遍历完了以后,再遍历一遍所有节点,把修改到的tmp值变回inf(初始就是inf)

#include<bits/stdc++.h>
#define maxint 1e9
using namespace std;
struct Edge{
  int next,to,w;
} edge[400005];
bool vis[200005];
int size[200005],son[200005],dis[200005],head[400005],esum[200005];
//tmp[i],表示在当前递归到的子树中,走到距离为i的顶点最少需要走多少边
//每个点i到root的距离dis[i],以及走过的边数esum[i]
int tmp[1000005];
int n,m,k,a,b,c,cnt=1,sizenow,mx,ans,root;
inline int read()
{
  char ch=getchar();
  int f=1,x=0;
  while (ch<'0' || ch>'9')
  {
    if (ch=='-') f=-1;
    ch=getchar();
  }
  while (ch>='0' && ch<='9')
  {
    x=x*10+ch-'0';
    ch=getchar();
  }
  return x*f;
}
inline void add(int u,int v,int w)
{
  edge[cnt].to=v;
  edge[cnt].w=w;
  edge[cnt].next=head[u];
  head[u]=cnt++;
}
inline void init(int sizen)
{
  mx=maxint;
  root=0;
  sizenow=sizen;
}
void getroot(int x,int f)
{
  size[x]=1;
  son[x]=0;
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (vis[y] || y==f) continue;
    getroot(y,x);
    size[x]+=size[y];
    son[x]=max(son[x],size[y]);
  }
  son[x]=max(son[x],sizenow-son[x]);
  if (mx>son[x])
  {
    mx=son[x];
    root=x;
  }
}
void query(int x,int f,int dist,int edgesum)
{
  dis[x]=dist;
  esum[x]=edgesum;
  if (dis[x]<=k) ans=min(ans,tmp[k-dis[x]]+esum[x]);
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (y==f || vis[y]) continue;
    query(y,x,dist+edge[i].w,edgesum+1);
  }
}
//遍历完这棵子树再修改被访问了的tmp[dis[i]],然后下一棵
void update(int x,int f,bool s)
{
  if (dis[x]<=k)
    if (s) tmp[dis[x]]=min(tmp[dis[x]],esum[x]);
    else tmp[dis[x]]=maxint;
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (y==f || vis[y]) continue;
    update(y,x,s);
  }
}
void divide(int x)
{
  vis[x]=true;
  tmp[0]=0;//每次进入dfs_solve时tmp[0]=0,因为这个当前的根到自己距离为0,走过了0条边
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (vis[y]) continue;
    query(y,0,edge[i].w,1);
    update(y,0,true);
  }
  for (int i=head[x];i;i=edge[i].next)//所有子树遍历完了以后,再遍历一遍所有节点,把修改到的tmp值变回inf(初始就是inf)
  {
    int y=edge[i].to;
    if (vis[y]) continue;
    update(y,0,false);
  }
  for (int i=head[x];i;i=edge[i].next)
  {
    int y=edge[i].to;
    if (vis[y]) continue;
    init(size[y]);
    getroot(y,0);
    divide(root);
  }
}
int main()
{
  n=read();k=read();
  for (int i=1;i<=k;i++) tmp[i]=maxint;
  ans=maxint;
  for (int i=1;i<n;i++)
  {
    a=read();b=read();c=read();
    a++;b++;
    add(a,b,c);
    add(b,a,c);
  }
  init(n);
  getroot(1,0);
  divide(root);
  if (ans==maxint) printf("-1");
  else printf("%d",ans);
  return 0;
}

4.洛谷 P4178 Tree

题目链接: https://www.luogu.org/problemnew/show/P4178

这道题好像比上面那道还简单一点,我也不知道我为什么先刷的上面那一道

注意的是可以在统计有多少个路径权值<=k时可以进行一次排序,直接枚举两个端点就可以处理了,可以节省时间。

#include<bits/stdc++.h>
#define inf 0xfffffff
using namespace std;
struct Edge{
    int next,to,w;
} edge[80005];
bool vis[40005];
int dis[40005],size[40005],son[40005],head[40005];
int root,mx,sizenow,cnt=1,n,m,a,b,c,ord,k,ans;
inline int read()
{
    char ch=getchar();
    int f=1,x=0;
    while (ch<'0' || ch>'9')
    {
        if (ch=='-') f=-1;
        ch=getchar();
    }
    while (ch>='0' && ch<='9')
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
    return f*x;
}
inline void add(int u,int v,int w)
{
    edge[cnt].to=v;
    edge[cnt].w=w;
    edge[cnt].next=head[u];
    head[u]=cnt++;
}
inline void init(int sizen)
{
    mx=inf; sizenow=sizen; root=0;
}
void getroot(int x,int f)
{
    size[x]=1;
    son[x]=0;
    for (int i=head[x];i;i=edge[i].next)
    {
        int y=edge[i].to;
        if (y==f || vis[y]) continue;
        getroot(y,x);
        size[x]+=size[y];
        son[x]=max(son[x],size[y]);
    }
    son[x]=max(son[x],sizenow-size[x]);
    if (mx>son[x])
    {
        mx=son[x];
        root=x;
    }
}
void query(int x,int f,int dist)
{
    dis[++ord]=dist;
    for (int i=head[x];i;i=edge[i].next)
    {
        int y=edge[i].to;
        if (y==f || vis[y]) continue;
        query(y,x,dist+edge[i].w);
    }
}
inline int solve(int x,int length)
{
    ord=0;
    query(x,0,length);
    sort(dis+1,dis+ord+1);
    int ans=0,stat=1;
    while (stat<ord)
    {
        if (dis[stat]+dis[ord]<=k)
        {
            ans+=ord-stat;
            stat++;
        }
        else ord--;
    }
    return ans;
}
void divide(int x)
{
    ans+=solve(x,0);
    vis[x]=true;
    for (int i=head[x];i;i=edge[i].next)
    {
        int y=edge[i].to;
        if (vis[y]) continue;
        ans-=solve(y,edge[i].w);
        init(size[y]);
        getroot(y,0);
        divide(root);
    }
}
int main()
{
    n=read();
    for (int i=1;i<n;i++)
    {
        a=read();b=read();c=read();
        add(a,b,c);
        add(b,a,c);
    }
    k=read();
    init(n);
    getroot(1,0);
    divide(root);
    printf("%d",ans);
    return 0;
}

以上内容更新于2018.4.26



添加新评论