我提交了一份 O(n∣S∣) 的代码,最慢的点跑了 276 ms。
然后我对该代码做出了修改并把时间复杂度降到了 O(n),最慢的点竟然还有 120 ms。
为什么?
代码 1:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll n,a[2000100][30],s[2000100],t[2000100];// a[i][j]表示以i结尾最短的可消除子串的第一个字符的位置且该子串前方的字符为j
char c[2000100];
ll ans;
int main(){
scanf("%lld%s",&n,c+1);
for(int i=1;i<=n;i++){
// printf("%lld %c\n",i,c[i]);
if(c[i]==c[i-1]){
memcpy(a[i],a[i-2],sizeof(a[i]));
ll u=c[i-2]-'a';
if(i==2) u=26;
a[i][u]=i-1;
}
else{
ll u=a[i-1][c[i]-'a'];
if(u==0) continue;
u-=2;
memcpy(a[i],a[u],sizeof(a[i]));
ll v=c[u]-'a';
if(u==0) v=26;
a[i][v]=u+1;
}
}
for(int i=1;i<=n;i++){
for(int j=0;j<=26;j++){
s[i]=max(s[i],a[i][j]);
}
}
for(int i=1;i<=n;i++){
ll u=s[i];
t[u]=i;
}
/* for(int i=1;i<=n;i++){
printf("%lld ",s[i]);
}
putchar('\n');
for(int i=1;i<=n;i++){
printf("%lld ",t[i]);
}
putchar('\n');*/
for(int i=1;i<=n;i++){
if(!t[i]) continue;
ll ps=i,cnt=0;
while(t[ps]){
cnt++;
ll tp=t[ps];
t[ps]=0;
ps=tp;
s[ps]=0;
ps++;
}
// printf("%lld\n",cnt);
ans+=cnt*(cnt+1)/2;
/* for(int i=1;i<=n;i++){
printf("%lld ",s[i]);
}
putchar('\n');
for(int i=1;i<=n;i++){
printf("%lld ",t[i]);
}
putchar('\n');*/
}
printf("%lld\n",ans);
/* for(int i=0;i<=26;i++){
for(int j=1;j<=n;j++){
printf("%lld ",a[j][i]);
}
putchar('\n');
}*/
}
代码 2:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll n,a[2000100][30],r[2000100],s[2000100],t[2000100];// a[i][j]表示以i结尾最短的可消除子串的第一个字符的位置且该子串前方的字符为j
char c[2000100];
int to[2000100];
ll ans;
int main(){
scanf("%lld%s",&n,c+1);
for(ll i=1;i<=n;i++){
to[i]=i;
if(c[i]==c[i-1]){
to[i]=to[i-2];
ll u=c[i-2]-'a';
if(i==2) u=26;
a[to[i]][u]=i-1;
r[to[i]]=max(r[to[i]],i-1);
s[i]=r[to[i]];
}
else{
ll u=a[to[i-1]][c[i]-'a'];
if(u==0) continue;
u-=2;
to[i]=to[u];
ll v=c[u]-'a';
if(u==0) v=26;
a[to[i]][v]=u+1;
r[to[i]]=max(r[to[i]],u+1);
s[i]=r[to[i]];
}
}
for(int i=1;i<=n;i++){
ll u=s[i];
t[u]=i;
}
/* for(int i=1;i<=n;i++){
printf("%lld ",s[i]);
}
putchar('\n');
for(int i=1;i<=n;i++){
printf("%lld ",t[i]);
}
putchar('\n');
for(int i=1;i<=n;i++){
printf("%lld ",to[i]);
}
putchar('\n');*/
for(int i=1;i<=n;i++){
if(!t[i]) continue;
ll ps=i,cnt=0;
while(t[ps]){
cnt++;
ll tp=t[ps];
t[ps]=0;
ps=tp;
s[ps]=0;
ps++;
}
ans+=cnt*(cnt+1)/2;
}
printf("%lld\n",ans);
}