有人想要来一起写道概率题的std吗
  • 板块学术版
  • 楼主封禁用户
  • 当前回复1
  • 已保存回复1
  • 发布时间2021/11/6 15:38
  • 上次更新2023/11/4 01:16:47
查看原帖
有人想要来一起写道概率题的std吗
373037
封禁用户楼主2021/11/6 15:38

题目是这样的
而这道题又是基于这个简单版本的。
我对简单版给出了下图所示的题解(虽然还没写完但是要点应该够了):
并且在简单版中和同学对拍顺利后,我交上了这样的std代码:简单版 std
(码风别喷,谢谢)

然后呢?我开始研究plus版本。
直接暴力的复杂度约为O(n^2^k)级别,不可用。
原定的正解思路是二分,先枚举提问序列中间的数字,然后对提问序列左右两边分别递归求解,这样枚举大约是O(n^k)的复杂度。(事实上原本只有k=3所以O(n^3)能过我就不想更优的了)
然而出题人手贱把3改成了k。
经过半天的挣扎,我形成了一套区间dp思路,这样的复杂度约为O(k n^3)。

所以这道题的思路:
定义dp[i][j][r]表示在区间[i,j]之间提问r次能够猜对的概率的最大值。
初始化:dp[i][i][0]=1,这是因为区间长度为1时一定能猜对。同理,dp[i][j][0]=1.0/(j-i+1),证明也可以在上图找到。
然后状态转移方程:

dp[i][j][r]=
max{ (s[a-1]-s[i-1])*dp[i][a-1][r-1] + (s[a]-s[a-1])*dp[a][a][0] + (s[j]-s[a])*dp[a+1][j][r-1] }/(s[j]-s[i-1])

事实上,上述状态转移方程只是个通常情况,具体的话……如果谁愿意就看我接下来的代码好了。
我自己手造了几个数据,用std和O(n^2^k)暴力做了一下,发现当k>1时,std给出的答案中概率总是略大于暴力给出的答案。并且,其中一个暴力给出的答案经过人工计算是正确的。
因此……啊我深感自己的无能,希望有人有兴趣来帮我一起写这个std,或者赏光查查我的std代码有什么问题。不尽感激!

std代码(待查错):

#include<bits/stdc++.h>
using namespace std;
template<typename type>
bool update(type &todo,const type &standard){
	if(todo<standard){
		todo=standard;
		return 1;
	}
	return 0;
}
int n,k;
vector<double> s;
vector<vector<vector<int> > > x;
vector<vector<vector<double> > > dp;
void putx(int l=1,int r=n,int rest=k){
	if(rest){
		int now=x[l][r][rest];
		putx(l,now-1,rest-1);
		cout<<now<<" ";
		putx(now+1,r,rest-1);
	}
}
int main(){
	cin>>n>>k;
	s.resize(n+1);
	x.resize(n+1);
	dp.resize(n+1);
	for(int i=1;i<=n;i++){
		dp[i].resize(n+1);
		x[i].resize(n+1);
		int a,b;
		scanf("%d%d",&a,&b);
		s[i]=s[i-1]+1.0*a/b;
		for(int j=i;j<=n;j++){
			dp[i][j].resize(k+1);
			x[i][j].resize(k+1);
			dp[i][j][0]=1.0/(j-i+1);
		}
	}
	for(int i=1;i<=n;i++)
		for(int j=i;j<=n;j++)
			dp[i][j][0]=1.0/(j-i+1);
	#define j i+l
	for(int l=1;l<=n-1;l++)
		for(int i=1;i<=n-l;i++){
			if(update(dp[i][j][1],((s[i]-s[i-1])*dp[i][i][0]+(s[j]-s[i])*dp[i+1][j][0])/(s[j]-s[i-1])))
				x[i][j][1]=i;
			for(int a=i+1;a<=j-1;a++)
				if(update(dp[i][j][1],((s[a-1]-s[i-1])*dp[i][a-1][0]+(s[a]-s[a-1])*dp[a][a][0]+(s[j]-s[a])*dp[a+1][j][0])/(s[j]-s[i-1])))
					x[i][j][1]=a;
			if(update(dp[i][j][1],((s[j]-s[j-1])*dp[j][j][0]+(s[j-1]-s[i-1])*dp[i][j-1][0])/(s[j]-s[i-1])))
				x[i][j][1]=j;
		}
	for(int r=2;r<=k;r++)
		for(int l=(1<<r+1)-2;l<=n-1;l++)
			for(int i=1;i<=n-l;i++)
				for(int a=i+(1<<r-1)-1;a<=j-(1<<r-1)+1;a++)
					if(update(dp[i][j][r],((s[a-1]-s[i-1])*dp[i][a-1][r-1]+(s[a]-s[a-1])*dp[a][a][0]+(s[j]-s[a])*dp[a+1][j][r-1])/(s[j]-s[i-1])))
						x[i][j][r]=a;
	#undef j
//	for(int r=0;r<=k;r++){
//		for(int i=1;i<=n;i++){
//			for(int j=i;j<=n;j++){
//				cout<<ends<<dp[i][j][r];
//			}
//			cout<<endl;
//		}
//		cout<<endl;
//	}
	putx();
	printf("\n%.4lf",dp[1][n][k]);
}

附:我自己用的暴力程序:

#include<bits/stdc++.h>
using namespace std;
#define udm unordered_map
template<typename type>
bool update(type &todo,const type &standard){
	if(todo<standard){
		todo=standard;
		return 1;
	}
	return 0;
}
int n,k;
double answer;
vector<int> ans;
vector<double> p,s;
int main(){
	ios::sync_with_stdio(0);
	cin>>n>>k;
	s.resize(n+1);
	p.resize(n+1);
	if(k^2)
		return 0;
	for(int i=1;i<=n;i++){
		int a,b;
		cin>>a>>b;
		p[i]=1.0*a/b;
		s[i]=s[i-1]+p[i];
	}
	for(int a=2;a<=n;a++)
		for(int b=a+2;b<=n;b++)
			for(int c=b+2;c<=n;c++)
//				for(int d=c+2;d<=n;d++)
//					for(int e=d+2;e<=n;e++)
//						for(int f=e+2;f<=n;f++)
//							for(int g=f+2;g<=n;g++)
								if(update(answer,s[a-1]/(a-1)+p[a]+(s[b-1]-s[a])/(b-a-1)+p[b]+(s[c-1]-s[b])/(c-b-1)+p[c]+(s[n]-s[c])/(n-c)))
									ans=vector<int>({a,b,c});
//								if(update(answer,s[a-1]/(a-1)+p[a]+(s[b-1]-s[a])/(b-a-1)+p[b]+(s[c-1]-s[b])/(c-b-1)+p[c]+(s[d-1]-s[c])/(d-c-1)+p[d]+(s[e-1]-s[d])/(e-d-1)+p[e]+(s[f-1]-s[e])/(f-e-1)+p[f]+(s[g-1]-s[f])/(g-f-1)+p[g]+(s[n]-s[g])/(n-g)))
//									ans=vector<int>({a,b,c,d,e,f,g});
	for(auto i:ans)
		cout<<ends<<i;
	cout<<endl<<answer;
	return 0;
}
2021/11/6 15:38
加载中...