为什么遍历树时先全部dfs子树一遍,再进行背包能AC,但一边dfs一遍做背包就会WA?
AC代码:
#include<bits/stdc++.h>
#define Maxn 10005
#define Maxt 50005
using namespace std;
inline int read()
{
int x=1,f=0;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')x=-1;c=getchar();}
while(c>='0'&&c<='9'){f=f*10+c-'0';c=getchar();}
return x*f;
}
inline void out(int x)
{
if(x<0) putchar('-'),x=-x;
if(x>=10) out(x/10);
putchar(x%10+'0');
}
inline void print(int x,char c)
{
out(x),putchar(c);
}
int End[2*Maxn],Next[2*Maxn],Last[Maxn];
int f[Maxn],dp[Maxt],A[Maxn],n,cnt;
void addedge(int x,int y)
{
End[++cnt]=y;
Next[cnt]=Last[x];
Last[x]=cnt;
}
void dfs(int u,int fa)
{
// printf("u=%d fa=%d\n",u,fa);
for(int i=Last[u];i;i=Next[i])
{
int v=End[i];
if(v==fa) continue;
dfs(v,u);
}
memset(dp,0x3f,sizeof(dp));
dp[0]=0;
for(int i=Last[u];i;i=Next[i])
{
int v=End[i];
if(v==fa) continue;
for(int j=A[u];j>=0;j--)
{
int tmp=0x3f3f3f3f;
if(j>=A[v]) tmp=min(tmp,dp[j-A[v]]+f[v]);
if(j>=f[v]) tmp=min(tmp,dp[j-f[v]]+A[v]);
dp[j]=tmp;
}
}
for(int i=A[u];i>=0;i--)
{
f[u]=min(f[u],dp[i]);
}
return;
}
int main()
{
memset(f,0x3f,sizeof(f));
n=read();
for(int i=2;i<=n;i++)
{
int x=read();
addedge(x,i);
}
for(int i=1;i<=n;i++) A[i]=read();
dfs(1,0);
// for(int i=1;i<=n;i++) printf("f[%d]=%d\n",i,f[i]);
if (f[1]<0x3f3f3f3f) printf("POSSIBLE\n");
else printf("IMPOSSIBLE\n");
return 0;
}
WA代码:
#include<bits/stdc++.h>
#define Maxn 10005
#define Maxt 50005
using namespace std;
inline int read()
{
int x=1,f=0;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')x=-1;c=getchar();}
while(c>='0'&&c<='9'){f=f*10+c-'0';c=getchar();}
return x*f;
}
inline void out(int x)
{
if(x<0) putchar('-'),x=-x;
if(x>=10) out(x/10);
putchar(x%10+'0');
}
inline void print(int x,char c)
{
out(x),putchar(c);
}
int End[2*Maxn],Next[2*Maxn],Last[Maxn];
int f[Maxn],dp[Maxt],A[Maxn],n,cnt;
void addedge(int x,int y)
{
End[++cnt]=y;
Next[cnt]=Last[x];
Last[x]=cnt;
}
void dfs(int u,int fa)
{
// printf("u=%d fa=%d\n",u,fa);
memset(dp,0x3f,sizeof(dp));
dp[0]=0;
for(int i=Last[u];i;i=Next[i])
{
int v=End[i];
if(v==fa) continue;
dfs(v,u);
for(int j=A[u];j>=0;j--)
{
int tmp=0x3f3f3f3f;
if(j>=A[v]) tmp=min(tmp,dp[j-A[v]]+f[v]);
if(j>=f[v]) tmp=min(tmp,dp[j-f[v]]+A[v]);
dp[j]=tmp;
}
}
for(int i=A[u];i>=0;i--)
{
f[u]=min(f[u],dp[i]);
}
return;
}
int main()
{
memset(f,0x3f,sizeof(f));
n=read();
for(int i=2;i<=n;i++)
{
int x=read();
addedge(x,i);
}
for(int i=1;i<=n;i++) A[i]=read();
dfs(1,0);
// for(int i=1;i<=n;i++) printf("f[%d]=%d\n",i,f[i]);
if (f[1]<0x3f3f3f3f) printf("POSSIBLE\n");
else printf("IMPOSSIBLE\n");
return 0;
}