定义

我们现在有一个长为 $n$ 的字符串 $s$,我们定义这个字符串的后缀 $i$ 表示 $s[i,n]$。

现在,我们要对 $s$ 所产生的 $n$ 个后缀进行排序,得到第 $i$ 个后缀是第几名,我们记其为 $rk_i$。同时,我们还能得到第 $i$ 名的是哪个后缀,记为 $sa_i$。

如何求 SA

暴力

我们有一种极其暴力的做法,把这 $n$ 个后缀存下来,再排序。总共有 $O(n\log n)$ 次比较,每次比较最坏 $O(n)$,则复杂度是 $O(n^2\log n)$,遥遥落后。

倍增法

我们换一种思路:每次计算长度为 $w$ 的所有子串的排名,这样就可以通过合并排名来统计答案,为此,我们修改一下定义:

假设当前考虑的子串长度为 $w$,对于在结尾不足 $w$ 位的子串,我们给它补上当前字符集中最小的字符(实现中是值 $0$)。这样,我们就一共有 $n$ 个长为 $w$ 的子串了。

$rk_i$ 表示 $s[i,i+w-1]$ 在这 $n$ 个长为 $w$ 的子串中的排名。$sa_i$ 类似。

我们从 $w=1$ 的情况开始考虑。此时显然我们可以轻而易举地计算出 $rk_i$,然后根据 $rk$ 来计算 $sa$。

当我们考虑到 $w=2^p$ 时,假设我们已经有了 $w’=2^{p-1}$ 时的 $rk$ 和 $sa$。因为 $s[i,i+w-1]=s[i,i+w’-1]+s[i+w’,i+w-1]$,所以我们可以根据 $rk_i$ 和 $rk_{i+w’}$ 来进行一个排序。

先放一下代码,结合代码讲解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+10;
int n;
char s[N];
int w,rk[N*2],oldrk[N*2],sa[N*2];
bool cmp(int x,int y){
if(rk[x]==rk[y])return rk[x+w]<rk[y+w];
else return rk[x]<rk[y];
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>(s+1);
n=strlen(s+1);
for(int i=1;i<=n;i++)sa[i]=i,rk[i]=s[i];
for(w=1;w<n;w<<=1){
sort(sa+1,sa+1+n,cmp);
memcpy(oldrk,rk,sizeof(rk));
for(int p=0,i=1;i<=n;i++){
if(oldrk[sa[i]]==oldrk[sa[i-1]]&&oldrk[sa[i]+w]==oldrk[sa[i-1]+w])rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
}
for(int i=1;i<=n;i++)cout<<sa[i]<<' ';
return 0;
}

注意在上面的这个实现中,最外层 for 循环开始时,$rk$ 的值就是子串长度为 $w$ 的值,而 $sa$ 的值是子串长度为 $w/2$ 的时候的值。每次 for 循环,先根据当前的 $rk$ 来把 $sa$ 更新到当前状态,再根据 $sa$ 计算出下一个 $rk$。

根据 $rk$ 来给 $sa$ 排序时(也就是 sort 函数),此时的 $sa$ 数组里什么值其实是无关紧要的,只要是任意一个 $n$ 的排列就行(因为关键字和 $sa$ 没关系),我们以 $rk_{sa_i}$ 为第一关键字,$rk_{sa_i+w}$ 为第二关键字。原因是在子串 $s[sa_i,sa_i+2w]$ 中,根据字符串比较的原则,要先比前面的 $s[sa_i,sa_i+w]$ 的部分。

后面根据 $sa$ 来更新 $rk$ 时,注意当两个子串相等时他们的 $rk$ 也要相等。

倍增 $O(\log n)$,排序 $O(n\log n)$,总时间复杂度 $O(n\log^2n)$。

基数排序优化

在给 $sa$ 排序的过程中,我们可以使用基数排序来优化到 $O(n)$ 排序。

首先我们要了解基数排序,而我们所使用的基数排序又依赖于计数排序,所以我们先讲讲计数排序。

1
2
3
for(int i=1;i<=n;i++)cnt[a[i]]++;
for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)ans[i]=cnt[a[i]]--;

十分的简洁明了。其中 $a$ 是待排序数组,$m$ 是值域,$ans_i$ 表示 $a_i$ 的排名。值得注意的是,这是一个稳定的排序算法,因为统计答案的时候我们采用了原数组倒序的方式。如果 $m=n$,那么 $ans$ 就是排序之后的答案数组。时间复杂度显然是 $O(n+m)$。

下面我们介绍一下基数排序。这是一个多关键字的排序,我们现在有 $n$ 个元素,每个元素有 $2$ 个关键字。$a_i$ 表示第 $i$ 个元素的第 $1$ 个关键字,$b_i$ 是第二个。

1
2
3
4
5
6
7
for(int i=1;i<=n;i++)cnt[b[i]]++;
for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)tmp[i]=cnt[b[i]]--;
memset(cnt,0,sizeof(cnt));
for(int i=1;i<=n;i++)cnt[tmp[i]]++;
for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)ans[i]=cnt[tmp[i]]--;

为什么这样是正确的呢?第一次我们先对第二关键字排序。第二次排序的时候,当第一关键字不同,此时可以被正确排序。而当第一关键字不同的时候,因为计数排序是稳定的排序,所以我们原先保留的第二关键字的顺序不会变,就完成了排序。

把这项技术运用到我们的倍增法里,我们就得到了一个 $O(n\log n)$ 的求后缀数组的算法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+10;
int n,rk[N*2],oldrk[N*2],sa[N*2],id[N*2],cnt[N];
char s[N];
signed main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>(s+1);
n=strlen(s+1);
for(int i=1;i<=n;i++)cnt[rk[i]=s[i]]++;
for(int i=1;i<128;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)sa[cnt[rk[i]]--]=i;
memcpy(oldrk,rk,sizeof(rk));
for(int p=0,i=1;i<=n;i++){
if(oldrk[sa[i-1]]==oldrk[sa[i]])rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
for(int w=1;w<n;w<<=1){
memset(cnt,0,sizeof(cnt));
for(int i=1;i<=n;i++)cnt[rk[i+w]]++;
for(int i=1;i<=n;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)id[cnt[rk[i+w]]--]=i;
memset(cnt,0,sizeof(cnt));
for(int i=1;i<=n;i++)cnt[rk[i]]++;
for(int i=1;i<=n;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)sa[cnt[rk[id[i]]]--]=id[i];
memcpy(oldrk,rk,sizeof(rk));
for(int p=0,i=1;i<=n;i++){
if(oldrk[sa[i]]==oldrk[sa[i-1]]&&oldrk[sa[i]+w]==oldrk[sa[i-1]+w])rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
}
for(int i=1;i<=n;i++)cout<<sa[i]<<' ';
return 0;
}

为了帮助理解,我决定还是讲解一下。

首先我们计算出了子串长度为 $1$ 时的 $sa$,再据此计算出 $rk$ 的数值版本。在 for 循环中,我们先根据第二关键字计数排序(也就是 $rk_{sa_i+w}$)。排完序后的 $id_i$ 表示 $s[j+w,j+2w-1]$ 在 $j=1,2,3,\cdots,n$ 这 $n$ 个子串中的排名为 $i$ 的 $j$ 的值。然后再根据第一关键字,以 $id$ 倒序的顺序来进行计数排序。这一次计数排序保证了在 $rk_i$ 不同的时候,较小的排在前面。而 $id$ 倒序的顺序保证了在 $rk_i$ 相同的时候,$rk_{i+w}$ 较小的能排在前面。

常数优化

我们已经写出了 $O(n\log n)$ 的算法,但这份代码的常数巨大,需要优化常数。

先优化几个比较显然的点。

  1. 我们每次计数排序的值域不用开到 $n$,由代码可知上一次的 $p$ 就是值域。
  2. 当计算完 $rk$ 后,若 $p=n$,则说明算法完成了。因为此时 $rk$ 两两不同,再往后排也不会有新的结果。

然后思考一下对第二关键字排序的实质。

我们发现,对于 $i+w>n$ 的部分,$rk_{i+w}$ 实质上是等于 $0$ 的,因此,这一部分会被放在 $id$ 的最前面。

在对第二关键字排序的时候,此时的 $sa$ 数组其实是上一次的保留结果,此时 $sa_i$ 表示所有长为 $\frac{w}{2}$ 的子串中,第 $i$ 名的起始位置。我们直接从第 $1$ 名开始,如果 $sa_i>w$,我们直接把 $sa_i-w$ 放进 $id$ 即可。此时我们发现,$sa_i$ 最大是 $n$,所以 $sa_i-w$ 最大是 $n-w$,正好跟前面对上了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+100;
int n;
char s[N];
int p,m,rk[N*2],sa[N*2],oldrk[N*2],id[N*2],cnt[N];
inline bool cmp(int x,int y,int w){
if(oldrk[x]==oldrk[y]&&oldrk[x+w]==oldrk[y+w])return 1;
return 0;
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>(s+1);
n=strlen(s+1);
m=128;
for(int i=1;i<=n;i++)cnt[rk[i]=s[i]]++;
for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)sa[cnt[rk[i]]--]=i;
for(int w=1;;w<<=1,m=p){
int cur=0;
for(int i=n-w+1;i<=n;i++)id[++cur]=i;
for(int i=1;i<=n;i++)
if(sa[i]>w)id[++cur]=sa[i]-w;
for(int i=0;i<=m;i++)cnt[i]=0;
for(int i=1;i<=n;i++)cnt[rk[i]]++;
for(int i=1;i<=m;i++)cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--)sa[cnt[rk[id[i]]]--]=id[i];
p=0;
for(int i=1;i<=n;i++)oldrk[i]=rk[i];
for(int i=1;i<=n;i++){
if(oldrk[sa[i]]==oldrk[sa[i-1]]&&oldrk[sa[i]+w]==oldrk[sa[i-1]+w])rk[sa[i]]=p;
else rk[sa[i]]=++p;
}
if(p==n)break;
}
for(int i=1;i<=n;i++)cout<<sa[i]<<' ';
return 0;
}