定义
我们现在有一个长为 $n$ 的字符串 $s$,我们定义这个字符串的后缀 $i$ 表示 $s[i,n]$。
现在,我们要对 $s$ 所产生的 $n$ 个后缀进行排序,得到第 $i$ 个后缀是第几名,我们记其为 $rk_i$。同时,我们还能得到第 $i$ 名的是哪个后缀,记为 $sa_i$。
如何求 SA
暴力
我们有一种极其暴力的做法,把这 $n$ 个后缀存下来,再排序。总共有 $O(n\log n)$ 次比较,每次比较最坏 $O(n)$,则复杂度是 $O(n^2\log n)$,遥遥落后。
倍增法
我们换一种思路:每次计算长度为 $w$ 的所有子串的排名,这样就可以通过合并排名来统计答案,为此,我们修改一下定义:
假设当前考虑的子串长度为 $w$,对于在结尾不足 $w$ 位的子串,我们给它补上当前字符集中最小的字符(实现中是值 $0$)。这样,我们就一共有 $n$ 个长为 $w$ 的子串了。
$rk_i$ 表示 $s[i,i+w-1]$ 在这 $n$ 个长为 $w$ 的子串中的排名。$sa_i$ 类似。
我们从 $w=1$ 的情况开始考虑。此时显然我们可以轻而易举地计算出 $rk_i$,然后根据 $rk$ 来计算 $sa$。
当我们考虑到 $w=2^p$ 时,假设我们已经有了 $w’=2^{p-1}$ 时的 $rk$ 和 $sa$。因为 $s[i,i+w-1]=s[i,i+w’-1]+s[i+w’,i+w-1]$,所以我们可以根据 $rk_i$ 和 $rk_{i+w’}$ 来进行一个排序。
先放一下代码,结合代码讲解。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| #include <bits/stdc++.h> using namespace std; const int N=1e6+10; int n; char s[N]; int w,rk[N*2],oldrk[N*2],sa[N*2]; bool cmp(int x,int y){ if(rk[x]==rk[y])return rk[x+w]<rk[y+w]; else return rk[x]<rk[y]; } signed main(){ ios::sync_with_stdio(0); cin.tie(0),cout.tie(0); cin>>(s+1); n=strlen(s+1); for(int i=1;i<=n;i++)sa[i]=i,rk[i]=s[i]; for(w=1;w<n;w<<=1){ sort(sa+1,sa+1+n,cmp); memcpy(oldrk,rk,sizeof(rk)); for(int p=0,i=1;i<=n;i++){ if(oldrk[sa[i]]==oldrk[sa[i-1]]&&oldrk[sa[i]+w]==oldrk[sa[i-1]+w])rk[sa[i]]=p; else rk[sa[i]]=++p; } } for(int i=1;i<=n;i++)cout<<sa[i]<<' '; return 0; }
|
注意在上面的这个实现中,最外层 for 循环开始时,$rk$ 的值就是子串长度为 $w$ 的值,而 $sa$ 的值是子串长度为 $w/2$ 的时候的值。每次 for 循环,先根据当前的 $rk$ 来把 $sa$ 更新到当前状态,再根据 $sa$ 计算出下一个 $rk$。
根据 $rk$ 来给 $sa$ 排序时(也就是 sort 函数),此时的 $sa$ 数组里什么值其实是无关紧要的,只要是任意一个 $n$ 的排列就行(因为关键字和 $sa$ 没关系),我们以 $rk_{sa_i}$ 为第一关键字,$rk_{sa_i+w}$ 为第二关键字。原因是在子串 $s[sa_i,sa_i+2w]$ 中,根据字符串比较的原则,要先比前面的 $s[sa_i,sa_i+w]$ 的部分。
后面根据 $sa$ 来更新 $rk$ 时,注意当两个子串相等时他们的 $rk$ 也要相等。
倍增 $O(\log n)$,排序 $O(n\log n)$,总时间复杂度 $O(n\log^2n)$。
基数排序优化
在给 $sa$ 排序的过程中,我们可以使用基数排序来优化到 $O(n)$ 排序。
首先我们要了解基数排序,而我们所使用的基数排序又依赖于计数排序,所以我们先讲讲计数排序。
1 2 3
| for(int i=1;i<=n;i++)cnt[a[i]]++; for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1]; for(int i=n;i>=1;i--)ans[i]=cnt[a[i]]--;
|
十分的简洁明了。其中 $a$ 是待排序数组,$m$ 是值域,$ans_i$ 表示 $a_i$ 的排名。值得注意的是,这是一个稳定的排序算法,因为统计答案的时候我们采用了原数组倒序的方式。如果 $m=n$,那么 $ans$ 就是排序之后的答案数组。时间复杂度显然是 $O(n+m)$。
下面我们介绍一下基数排序。这是一个多关键字的排序,我们现在有 $n$ 个元素,每个元素有 $2$ 个关键字。$a_i$ 表示第 $i$ 个元素的第 $1$ 个关键字,$b_i$ 是第二个。
1 2 3 4 5 6 7
| for(int i=1;i<=n;i++)cnt[b[i]]++; for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1]; for(int i=n;i>=1;i--)tmp[i]=cnt[b[i]]--; memset(cnt,0,sizeof(cnt)); for(int i=1;i<=n;i++)cnt[tmp[i]]++; for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1]; for(int i=n;i>=1;i--)ans[i]=cnt[tmp[i]]--;
|
为什么这样是正确的呢?第一次我们先对第二关键字排序。第二次排序的时候,当第一关键字不同,此时可以被正确排序。而当第一关键字不同的时候,因为计数排序是稳定的排序,所以我们原先保留的第二关键字的顺序不会变,就完成了排序。
把这项技术运用到我们的倍增法里,我们就得到了一个 $O(n\log n)$ 的求后缀数组的算法。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| #include <bits/stdc++.h> using namespace std; const int N=1e6+10; int n,rk[N*2],oldrk[N*2],sa[N*2],id[N*2],cnt[N]; char s[N]; signed main(){ ios::sync_with_stdio(false); cin.tie(0),cout.tie(0); cin>>(s+1); n=strlen(s+1); for(int i=1;i<=n;i++)cnt[rk[i]=s[i]]++; for(int i=1;i<128;i++)cnt[i]+=cnt[i-1]; for(int i=n;i>=1;i--)sa[cnt[rk[i]]--]=i; memcpy(oldrk,rk,sizeof(rk)); for(int p=0,i=1;i<=n;i++){ if(oldrk[sa[i-1]]==oldrk[sa[i]])rk[sa[i]]=p; else rk[sa[i]]=++p; } for(int w=1;w<n;w<<=1){ memset(cnt,0,sizeof(cnt)); for(int i=1;i<=n;i++)cnt[rk[i+w]]++; for(int i=1;i<=n;i++)cnt[i]+=cnt[i-1]; for(int i=n;i>=1;i--)id[cnt[rk[i+w]]--]=i; memset(cnt,0,sizeof(cnt)); for(int i=1;i<=n;i++)cnt[rk[i]]++; for(int i=1;i<=n;i++)cnt[i]+=cnt[i-1]; for(int i=n;i>=1;i--)sa[cnt[rk[id[i]]]--]=id[i]; memcpy(oldrk,rk,sizeof(rk)); for(int p=0,i=1;i<=n;i++){ if(oldrk[sa[i]]==oldrk[sa[i-1]]&&oldrk[sa[i]+w]==oldrk[sa[i-1]+w])rk[sa[i]]=p; else rk[sa[i]]=++p; } } for(int i=1;i<=n;i++)cout<<sa[i]<<' '; return 0; }
|
为了帮助理解,我决定还是讲解一下。
首先我们计算出了子串长度为 $1$ 时的 $sa$,再据此计算出 $rk$ 的数值版本。在 for 循环中,我们先根据第二关键字计数排序(也就是 $rk_{sa_i+w}$)。排完序后的 $id_i$ 表示 $s[j+w,j+2w-1]$ 在 $j=1,2,3,\cdots,n$ 这 $n$ 个子串中的排名为 $i$ 的 $j$ 的值。然后再根据第一关键字,以 $id$ 倒序的顺序来进行计数排序。这一次计数排序保证了在 $rk_i$ 不同的时候,较小的排在前面。而 $id$ 倒序的顺序保证了在 $rk_i$ 相同的时候,$rk_{i+w}$ 较小的能排在前面。
常数优化
我们已经写出了 $O(n\log n)$ 的算法,但这份代码的常数巨大,需要优化常数。
先优化几个比较显然的点。
- 我们每次计数排序的值域不用开到 $n$,由代码可知上一次的 $p$ 就是值域。
- 当计算完 $rk$ 后,若 $p=n$,则说明算法完成了。因为此时 $rk$ 两两不同,再往后排也不会有新的结果。
然后思考一下对第二关键字排序的实质。
我们发现,对于 $i+w>n$ 的部分,$rk_{i+w}$ 实质上是等于 $0$ 的,因此,这一部分会被放在 $id$ 的最前面。
在对第二关键字排序的时候,此时的 $sa$ 数组其实是上一次的保留结果,此时 $sa_i$ 表示所有长为 $\frac{w}{2}$ 的子串中,第 $i$ 名的起始位置。我们直接从第 $1$ 名开始,如果 $sa_i>w$,我们直接把 $sa_i-w$ 放进 $id$ 即可。此时我们发现,$sa_i$ 最大是 $n$,所以 $sa_i-w$ 最大是 $n-w$,正好跟前面对上了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| #include <bits/stdc++.h> using namespace std; const int N=1e6+100; int n; char s[N]; int p,m,rk[N*2],sa[N*2],oldrk[N*2],id[N*2],cnt[N]; inline bool cmp(int x,int y,int w){ if(oldrk[x]==oldrk[y]&&oldrk[x+w]==oldrk[y+w])return 1; return 0; } signed main(){ ios::sync_with_stdio(0); cin.tie(0),cout.tie(0); cin>>(s+1); n=strlen(s+1); m=128; for(int i=1;i<=n;i++)cnt[rk[i]=s[i]]++; for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1]; for(int i=n;i>=1;i--)sa[cnt[rk[i]]--]=i; for(int w=1;;w<<=1,m=p){ int cur=0; for(int i=n-w+1;i<=n;i++)id[++cur]=i; for(int i=1;i<=n;i++) if(sa[i]>w)id[++cur]=sa[i]-w; for(int i=0;i<=m;i++)cnt[i]=0; for(int i=1;i<=n;i++)cnt[rk[i]]++; for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1]; for(int i=n;i>=1;i--)sa[cnt[rk[id[i]]]--]=id[i]; p=0; for(int i=1;i<=n;i++)oldrk[i]=rk[i]; for(int i=1;i<=n;i++){ if(oldrk[sa[i]]==oldrk[sa[i-1]]&&oldrk[sa[i]+w]==oldrk[sa[i-1]+w])rk[sa[i]]=p; else rk[sa[i]]=++p; } if(p==n)break; } for(int i=1;i<=n;i++)cout<<sa[i]<<' '; return 0; }
|