RT,本地跟题解对拍了2500组都没有发现问题
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<stack>
using namespace std;
const int Maxn=1000010;
const int Maxm=Maxn<<2;
const int inf=0x3f3f3f3f;
int h[Maxn],rk[Maxn];
int minv[Maxm],maxv[Maxm];
int l[Maxn],r[Maxn];
int c[Maxn],f[Maxn],p[Maxn];
char a[Maxn],s[Maxn];
void work(char s[])
{
int n=strlen(s+1);
int m=max(n,256);
for(int i=1;i<=n;++i)
p[i]=i,rk[i]=s[i];
for(int cur=1;cur<n;cur<<=1)
{
memset(c,0,sizeof(c));
for(int i=1;i<=n;++i)
f[i]=p[i];
for(int i=1;i<=n;++i)
++c[rk[(f[i])+cur]];
for(int i=1;i<=m;++i)
c[i]+=c[i-1];
for(int i=n;i;--i)
p[c[rk[f[i]+cur]]--]=f[i];
memset(c,0,sizeof(c));
for(int i=1;i<=n;++i)
f[i]=p[i];
for(int i=1;i<=n;++i)
++c[rk[(f[i])]];
for(int i=1;i<=m;++i)
c[i]+=c[i-1];
for(int i=n;i;--i)
p[c[rk[f[i]]]--]=f[i];
memcpy(f,rk,sizeof(rk));
int tmp=0;
for(int i=1;i<=n;++i)
{
bool x=(f[p[i]]==f[p[i-1]]);
bool y=(f[p[i]+cur]==f[p[i-1]+cur]);
if(x && y)rk[p[i]]=tmp;
else rk[p[i]]=++tmp;
}
}
for(int i=1,k=0;i<=n;++i)
{
if(k)--k;
while(s[i+k]==s[p[rk[i]-1]+k])++k;
h[rk[i]]=k;
}
stack <int> a;
a.push(0);
for(int i=1;i<=n;++i)
{
while(a.top() && h[a.top()]>h[i])a.pop();
l[i]=a.top()+1,a.push(i);
}
while(a.size())a.pop();
a.push(n+1);
for(int i=n;i;--i)
{
while(a.top()!=n+1 && h[a.top()]>h[i])a.pop();
r[i]=a.top()-1,a.push(i);
}
}
inline void push_up(int k)
{
minv[k]=min(minv[k<<1],minv[k<<1|1]);
maxv[k]=max(maxv[k<<1],maxv[k<<1|1]);
}
void build(int k,int l,int r)
{
if(l==r)
{
minv[k]=maxv[k]=p[l];
return;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
push_up(k);
}
int query_min(int k,int l,int r,int x,int y)
{
if(x<=l && r<=y)return minv[k];
int mid=(l+r)>>1,ret=inf;
if(x<=mid)ret=query_min(k<<1,l,mid,x,y);
if(mid<y)ret=min(ret,query_min(k<<1|1,mid+1,r,x,y));
return ret;
}
int query_max(int k,int l,int r,int x,int y)
{
if(x<=l && r<=y)return maxv[k];
int mid=(l+r)>>1,ret=0;
if(x<=mid)ret=query_max(k<<1,l,mid,x,y);
if(mid<y)ret=max(ret,query_max(k<<1|1,mid+1,r,x,y));
return ret;
}
int main()
{
// freopen("in.txt","r",stdin);
int n,m,ans=0;
scanf("%s",s+1);
n=strlen(s+1);
scanf("%s",a+1);
m=strlen(a+1);
s[n+1]='$';
for(int i=1;i<=m;++i)
s[n+i+1]=a[i];
work(s);
int k=n+m+1;
build(1,1,k);
for(int i=1;i<=k;++i)
{
// printf("i = %d\n",i);
int x=l[i],y=r[i];
if(x>1)--x;
bool u=query_min(1,1,k,x,i)<=n;
bool v=query_max(1,1,k,i,y)>n+1;
if(u && v)ans=max(ans,h[i]);
u=query_min(1,1,k,i,y)<=n;
v=query_max(1,1,k,x,i)>n+1;
if(u && v)ans=max(ans,h[i]);
}
printf("%d\n",ans);
return 0;
}