O(n2) 暴力 dp,前 4 个点挂,很奇怪。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=5e3+5;
int T,n,a[N];
ll f[N][N][2];
int main(){
scanf("%d",&T);
while(T--){
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
if(n==3){
if(a[1]!=a[2]&&a[2]!=a[3]&&a[1]!=a[3]) puts("0");
else if(a[1]==a[2]&&a[2]==a[3]) printf("%d\n",2*a[2]);
else if(a[1]==a[2]||a[2]==a[3]) printf("%d\n",a[2]);
else printf("%d\n",a[3]);
continue;
}
memset(f,0,sizeof(f));
for(int i=1;i<=n;i++)
for(int j=1;j<i;j++){
if(j!=i-1){
f[i][j][0]=max(f[i][j][0],f[i-1][j][0]+(a[i]==a[i-1])*a[i]);
f[i][j][1]=max(f[i][j][1],f[i-1][j][1]+(a[i]==a[i-1])*a[i]);
}
else{
for(int k=1;k<j;k++)
if(k!=j-1)
f[i][j][0]=max(f[i][j][0],f[j][k][1]+(a[i]==a[k])*a[i]);
else
for(int l=1;l<k;l++)
f[i][j][0]=max(f[i][j][0],f[k][l][0]+(a[i]==a[k])*a[i]+(a[j]==a[l])*a[j]);
for(int k=1;k<j;k++)
if(k!=j-1)
f[i][j][1]=max(f[i][j][1],f[j][k][0]+(a[i]==a[k])*a[i]);
else
for(int l=1;l<k;l++)
f[i][j][1]=max(f[i][j][1],f[k][l][1]+(a[i]==a[k])*a[i]+(a[j]==a[l])*a[j]);
}
}
ll ans=0;
for(int i=1;i<n;i++)
ans=max({ans,f[n][i][0],f[n][i][1]});
printf("%lld\n",ans);
}
return 0;
}