洛谷3193 [HNOI2008]GT考试

算法竞赛 字符串 数学 数学-矩阵 字符串-kmp
编辑文章

题意

求不包含子串 $A$ 的长度为 $N$ 的数字 $X$ 的个数。答案对 $K$ 取模。

其中 $A$ 的长度 $M\le 20$,$N\le 10^9$,$K\le 1000$ 。

题解

矩阵乘法优化dp。

用 $\text{f[i][j]}$ 表示在 $X$ 中做到第 $i$ 位,匹配到 $A$ 中第 $j$ 位的方案个数。最终的答案即为:

$$\sum_{i=0}^{M-1} \text{f[n][i]} \tag{1}$$

枚举上一个状态匹配到 $A$ 中的第 $k$ 位,可以得到转移方程为:

$$\text{f[i][j]}=\sum_{k=0}^{M-1} \text{f[i-1][k]}\times \text{g[k][j]} \tag{2}$$

其中 $\text{g[k][j]}$ 表示当前匹配到 $A$ 中的第 $k$ 位,要求匹配到第 $j$ 位的方案数。显然 $\text{g[k][j]}$ 可以通过暴力枚举 $k$ ,然后用 $\text{kmp}$ 得到能转移到的 $j$ 预处理出来。

$(2)$ 式已经很像矩阵乘法了,且 $\text{f[i][j]}$ 是通过不断乘 $\text{g[k][j]}$ 得到的。如果直接递推 $N$ 次肯定会爆,考虑优化。

将 $g[k][j]$ 看成矩阵 $G$ ,可以发现 $G$ 是不变的。而 $\text{f[0][0]}=1$ ,所以答案就是 $G^n$ 。

我矩阵下标是从 $1$ 开始,所以我给下标都加了 $1$: x.s[i+1][j+1]++;

#include<bits/stdc++.h>

using namespace std;

inline int read()
{
    char ch=getchar();
    int f=1,x=0;
    while (ch<'0' || ch>'9')
    {
        if (ch=='-') f=-1;
        ch=getchar();
    }
    while (ch>='0' && ch<='9')
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
    return f*x;
}

char ch[25];
int n,m,ha,ans,nxt[25];
struct matrix {
    int n,s[25][25];
    matrix(int len=0) { n=len; memset(s,0,sizeof(s)); }
    matrix operator * (const matrix &x) const {
        matrix y;
        y.n=n;
        for (int i=1;i<=n;i++)
            for (int j=1;j<=n;j++)
                for (int k=1;k<=n;k++)
                    y.s[i][j]=(y.s[i][j]+s[i][k]*x.s[k][j]%ha)%ha;
        return y;
    }
};

inline void init(matrix &x);
inline matrix qpow(matrix x,int y);
inline void get_next();
inline matrix kmp();

signed main()
{
    n=read(); m=read(); ha=read();
    scanf("%s",ch+1);
    get_next();
    matrix x=qpow(kmp(),n);
    for (int i=1;i<=m;i++) ans=(ans+x.s[1][i])%ha;
    return !printf("%d",ans);
}

inline void init(matrix &x) { for (int i=1;i<=x.n;i++) x.s[i][i]=1; }

inline matrix qpow(matrix x,int y)
{
    matrix ans(x.n); init(ans);
    while (y)
    {
        if (y&1) ans=ans*x;
        y>>=1;
        x=x*x;
    }
    return ans;
}

inline void get_next()
{
    int j=0;
    nxt[1]=0;
    for (int i=2;i<=m;i++)
    {
        while (j && ch[i]!=ch[j+1]) j=nxt[j];
        if (ch[i]==ch[j+1]) j++;
        nxt[i]=j;
    }
}

inline matrix kmp()
{
    matrix x(m);
    for (int i=0;i<m;i++)
    {
        for (char k='0';k<='9';k++)
        {
            int j=i;
            while (j && ch[j+1]!=k) j=nxt[j];
            if (ch[j+1]==k) j++;
            x.s[i+1][j+1]++;
        }
    }
    return x;
}

新评论

称呼不能为空
邮箱格式不合法
网站格式不合法
内容不能为空