四个小时才过样例 (~﹃~)~zZ
#include <bits/stdc++.h>
using namespace std;
const int N=2e5;
struct edge{
int u,v,w,c,nxt;
}e[2*N+5];
struct node{
int sum,l,r;
}tr[4*N+5];
struct node1{
int sum,ls,rs,lz;
}gs[41*N+5],z[41*N+5];
int n,q;
int head[N+5],tot=1;
int sz[N+5],mxsn[N+5],fa[N+5],dep[N+5];
int dfn[N+5],top[N+5],w[N+5],sx;
int bv[N+5],rt[N+5],cnt,rtt[N+5],cntt;
void add(int u,int v,int w,int col){
e[tot].u=u;
e[tot].v=v;
e[tot].w=w;
e[tot].c=col;
e[tot].nxt=head[u];
head[u]=tot++;
return;
}
void dfs1(int x,int faa){
sz[x]=1;
for(int i=head[x];i;i=e[i].nxt){
int y=e[i].v;
if(y==faa) continue;
fa[y]=x;
dep[y]=dep[x]+1;
dfs1(y,x);
sz[x]+=sz[y];
if(sz[mxsn[x]]<sz[y]) mxsn[x]=y;
}
return;
}
void dfs2(int x,int faa){
top[x]=faa;
dfn[x]=++sx;
w[sx]=x;
if(!mxsn[x]) return;
dfs2(mxsn[x],faa);
for(int i=head[x];i;i=e[i].nxt){
int y=e[i].v;
if(y==fa[x] || y==mxsn[x]) continue;
dfs2(y,y);
}
return;
}
void updata(int p){
tr[p].sum=tr[2*p].sum+tr[2*p+1].sum;
tr[p].l=tr[2*p].l;
tr[p].r=tr[2*p+1].r;
}
void build(int p,int l,int r){
if(l==r){
tr[p].l=l;
tr[p].r=r;
tr[p].sum=bv[l];
return;
}
int mid=(l+r)>>1;
build(2*p,l,mid);
build(2*p+1,mid+1,r);
updata(p);
return;
}
int search(int p,int l,int r){
int s=0;
if(tr[p].l==l && tr[p].r==r){
return tr[p].sum;
}
if(tr[2*p].r>=r) s+=search(2*p,l,r);
else if(tr[2*p+1].l<=l) s+=search(2*p+1,l,r);
else{
s+=search(2*p,l,tr[2*p].r);
s+=search(2*p+1,tr[2*p+1].l,r);
}
return s;
}
void updatags(int p){
gs[p].sum=gs[gs[p].ls].sum+gs[gs[p].rs].sum;
return;
}
void changegs(int &p,int l,int r,int ll,int rr,int w){
if(!p) p=++cnt;
if(ll<=l && rr>=r){
gs[p].lz+=w;
gs[p].sum+=(r-l+1)*w;
return;
}
int mid=(l+r)>>1;
if(ll<=mid) changegs(gs[p].ls,l,mid,ll,rr,w);
if(rr>=mid+1) changegs(gs[p].rs,mid+1,r,ll,rr,w);
updata(p);
return;
}
void pushdowngs(int p,int l,int r){
if(gs[p].ls==0) gs[p].ls=++cnt;
if(gs[p].rs==0) gs[p].rs=++cnt;
int mid=(l+r)>>1;
gs[gs[p].ls].sum+=(mid-l+1)*gs[p].lz;
gs[gs[p].rs].sum+=(r-mid)*gs[p].lz;
gs[gs[p].ls].lz+=gs[p].lz;
gs[gs[p].rs].lz+=gs[p].lz;
return;
}
int searchgs(int p,int l,int r,int ll,int rr){
int s=0;
if(l==ll && r==rr) return gs[p].sum;
if(!p) return 0;
if(gs[p].lz!=0) pushdowngs(p,l,r);
int mid=(l+r)>>1;
if(rr<=mid){
s+=searchgs(gs[p].ls,l,mid,ll,rr);
}
else if(ll>=mid+1){
s+=searchgs(gs[p].rs,mid+1,r,ll,rr);
}
else{
s+=searchgs(gs[p].ls,l,mid,ll,mid);
s+=searchgs(gs[p].rs,mid+1,r,mid+1,rr);
}
return s;
}
void updataz(int p){
z[p].sum=z[z[p].ls].sum+z[z[p].rs].sum;
return;
}
void changez(int &p,int l,int r,int ll,int rr,int w){
if(!p) p=++cntt;
if(ll<=l && rr>=r){
z[p].lz+=w;
z[p].sum+=(r-l+1)*w;
return;
}
int mid=(l+r)>>1;
if(ll<=mid) changez(z[p].ls,l,mid,ll,rr,w);
if(rr>=mid+1) changez(z[p].rs,mid+1,r,ll,rr,w);
updata(p);
return;
}
void pushdownz(int p,int l,int r){
if(z[p].ls==0) z[p].ls=++cntt;
if(z[p].rs==0) z[p].rs=++cntt;
int mid=(l+r)>>1;
z[z[p].ls].sum+=(mid-l+1)*z[p].lz;
z[z[p].rs].sum+=(r-mid)*z[p].lz;
z[z[p].ls].lz+=z[p].lz;
z[z[p].rs].lz+=z[p].lz;
return;
}
int searchz(int p,int l,int r,int ll,int rr){
int s=0;
if(l==ll && r==rr) return z[p].sum;
if(!p) return 0;
if(z[p].lz!=0) pushdownz(p,l,r);
int mid=(l+r)>>1;
if(rr<=mid){
s+=searchz(z[p].ls,l,mid,ll,rr);
}
else if(ll>=mid+1){
s+=searchz(z[p].rs,mid+1,r,ll,rr);
}
else{
s+=searchz(z[p].ls,l,mid,ll,mid);
s+=searchz(z[p].rs,mid+1,r,mid+1,rr);
}
return s;
}
int work(int c,int w,int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
int lb=dfn[top[x]],rb=dfn[x];
if(lb<=rb) ans+=search(1,lb,rb)-searchz(rtt[c],1,n,lb,rb)+searchgs(rt[c],1,n,lb,rb)*w;
x=fa[top[x]];
}
if(dep[x]<dep[y]) swap(x,y);
ans+=search(1,dfn[y]+1,dfn[x])-searchz(rtt[c],1,n,dfn[y]+1,dfn[x])+searchgs(rt[c],1,n,dfn[y]+1,dfn[x])*w;
return ans;
}
int main(){
scanf("%d%d",&n,&q);
for(int i=1;i<n;i++){
int a,b,c,d;
scanf("%d%d%d%d",&a,&b,&c,&d);
add(a,b,d,c),add(b,a,d,c);
}
dep[1]=1;
dfs1(1,0);
dfs2(1,1);
for(int i=1;i<=(n-1)*2;i+=2){
int u=e[i].u,v=e[i].v;
if(dep[u]>dep[v]) swap(u,v);
bv[dfn[v]]=e[i].w;
changegs(rt[e[i].c],1,n,dfn[v],dfn[v],1);
changez(rtt[e[i].c],1,n,dfn[v],dfn[v],e[i].w);
}
build(1,1,n);
while(q--){
int c,w,x,y;
scanf("%d%d%d%d",&c,&w,&x,&y);
printf("%d\n",work(c,w,x,y));
}
return 0;
}