求助插头dp
  • 板块学术版
  • 楼主chen_qian
  • 当前回复3
  • 已保存回复3
  • 发布时间2021/7/1 20:07
  • 上次更新2023/11/4 20:27:22
查看原帖
求助插头dp
128870
chen_qian楼主2021/7/1 20:07

这个题是 cdq 论文里的第二道例题,求矩阵的哈密顿路径有多少条,孩子人都调傻了,目前过了 6 个点,有没有大佬帮 debug ,或者 hack 也行。。

#include<bits/stdc++.h>
#define N 13
#define mod 299987
#define ll unsigned long long 
using namespace std;
int n,m;
ll head[300000],next[1<<18],val[2][1<<18],inc[N],num[2][1<<18],idx[2],now,ans;
void add(ll bit,ll v){
	int x=bit%mod+1;
	for(int i=head[x];i;i=next[i]){
		if(num[now][i]==bit){
			val[now][i]+=v;
			return ;
		}
	}
	val[now][++idx[now]]=v;
	num[now][idx[now]]=bit;
	next[idx[now]]=head[x];
	head[x]=idx[now];
}
void solve(){
	idx[now]=1,num[now][1]=0,val[now][1]=1;
	for(int i=1;i<=n;i++){
		for(int j=1;j<=idx[now];j++) num[now][j]<<=2;
		for(int j=1;j<=m;j++){
			//cout<<i<<' '<<j<<endl;
//			for(int k=1;k<=idx[now];k++){
//				ll bit=num[now][k];
//				for(int l=0;l<=m;l++){
//					cout<<(bit>>(2*l))%4<<' ';
//				}
//				cout<<val[now][k];
//				cout<<endl;
//			}
			int last=now;now^=1;
			memset(head,0,sizeof(head));
			idx[now]=0;
			for(int k=1;k<=idx[last];k++){
				ll bit=num[last][k],v=val[last][k];
				int p1=((bit>>(2*j-2))%4),p2=((bit>>(2*j))%4);
				int cnt=0;
				for(int l=0;l<=m;l++)//记录个数 
					if((bit>>(2*l))%4==3) cnt++;
				if(cnt>2) continue;
				if(!p1&&!p2){//上下没有
					if(i+1<=n&&j+1<=m) add(bit+inc[j-1]+2*inc[j],v);//普通转移 
					if(i+1<=n&&cnt<2) add(bit+inc[j-1]*3,v);//添加不匹配节点 
					if(j+1<=m&&cnt<2) add(bit+inc[j]*3,v);
				}
				else if(p1&&!p2){
					if(p1!=3){//没有一个不匹配 
						if(i+1<=n) add(bit,v);
						if(j+1<=n) add(bit-inc[j-1]*p1+inc[j]*p1,v);//常规转移 
						if(p1==1&&cnt<2){//或者将当前节点看做一个终止点,不向外额外扩展,此时对应括号变成不匹配节点, 
							int tp=1;
							for(int l=j+1;l<=m;l++){
								if((bit>>(2*l))%4==2) tp--;
								if((bit>>(2*l))%4==1) tp++;
								if(!tp){
									add(bit-inc[j-1]+inc[l],v);
									break;
								}
							}
						}
						else if(cnt<2){ 
							int tp=1;
							for(int l=j-2;l>=0;l--){
								if((bit>>(2*l))%4==2) tp++;
								if((bit>>(2*l))%4==1) tp--;
								if(!tp){
									add(bit-inc[j-1]*2+2*inc[l],v);
									break;
								}
							}
						}	
					}
					else{
						if(i==n&&j==m) ans+=v;//有一个 3 ,如果为结束节点则直接结束 
						else{//否则可以延伸这个插头 
							if(i+1<=n) add(bit,v);
							if(j+1<=m) add(bit-3*inc[j-1]+3*inc[j],v);
						}
					}
				}
				else if(!p1&&p2){//大致同上的讨论 
					if(p2!=3){
						if(i+1<=n) add(bit-inc[j]*p2+inc[j-1]*p2,v);
						if(j+1<=n) add(bit,v);
						if(p2==1&&cnt<2){
							int tp=1;
							for(int l=j+1;l<=m;l++){
								if((bit>>(2*l))%4==2) tp--;
								if((bit>>(2*l))%4==1) tp++;
								if(!tp){
									add(bit-inc[j]+inc[l],v);
									break;
								}
							}
						}
						else if(cnt<2){
							int tp=1;
							for(int l=j-2;l>=0;l--){
								if((bit>>(2*l))%4==2) tp++;
								if((bit>>(2*l))%4==1) tp--;
								if(!tp){
									add(bit-inc[j]*2+2*inc[l],v);
									break;
								}
							}
						}	
					}
					else{
						if(i==n&&j==m) ans+=v;
						else{
							if(i+1<=n) add(bit+3*inc[j-1]-3*inc[j],v); 
							if(j+1<=m) add(bit,v);
						}
					}
				}
				else if(p1==1&&p2==1){//同回路 
					int tp=1;
					for(int l=j+1;l<=m;l++){
						if((bit>>(2*l))%4==2) tp--;
						if((bit>>(2*l))%4==1) tp++;
						if(!tp){
							add(bit-inc[j]-inc[j-1]-inc[l],v);
							break;
						}
					}
				}
				else if(p1==2&&p2==2){//同回路 
					int tp=1;
					for(int l=j-2;l>=0;l--){
						if((bit>>(2*l))%4==2) tp++;
						if((bit>>(2*l))%4==1) tp--;
						if(!tp){
							add(bit-inc[j-1]*2-inc[j]*2+inc[l],v);
							break;
						}
					}
				}
				else if(p1==2&&p2==1) add(bit-inc[j-1]*2-inc[j],v);
				else if(p1==3&&p2==3){//两个终止节点,只能在最后一个匹配起来,不然后面的节点无法被遍历 
					if(i==n&&j==m) ans+=v;
				}
				else if(p1==3||p2==3){//其中一个为终止节点,只能匹配在一起,括号另一边变成终止节点 
					if(p1==3){
						if(p2==1){
							int tp=1;
							for(int l=j+1;l<=m;l++){
								if((bit>>(2*l))%4==2) tp--;
								if((bit>>(2*l))%4==1) tp++;
								if(!tp){
									add(bit+inc[l]-3*inc[j-1]-inc[j],v);
									break;
								}
							}
						}
						else{
							int tp=1;
							for(int l=j-2;l>=0;l--){
								if((bit>>(2*l))%4==2) tp++;
								if((bit>>(2*l))%4==1) tp--;
								if(!tp){
									add(bit+2*inc[l]-inc[j-1]*3-2*inc[j],v);
									break;
								}
							}
						}
					}
					if(p2==3){//同上讨论 
						if(p1==1){
							int tp=1;
							for(int l=j+1;l<=m;l++){
								if((bit>>(2*l))%4==2) tp--;
								if((bit>>(2*l))%4==1) tp++;
								if(!tp){
									add(bit+inc[l]-3*inc[j]-inc[j-1],v);
									break;
								}
							}
						}
						else{
							int tp=1;
							for(int l=j-2;l>=0;l--){
								if((bit>>(2*l))%4==2) tp++;
								if((bit>>(2*l))%4==1) tp--;
								if(!tp){
									add(bit+2*inc[l]-inc[j]*3-2*inc[j-1],v);
									break;
								}
							}
						}
					}
				}
			}
		}
	}
}
int main(){
	scanf("%d%d",&n,&m);
	if(n==1&&m==1){
		puts("1");
		return 0;
	} 
	inc[0]=1;
	for(int i=1;i<=12;i++) inc[i]=inc[i-1]<<2;
	solve();
	printf("%lld\n",ans); 
	return 0;
}
2021/7/1 20:07
加载中...