小问题:线段树不是应该开4N的大小吗,但是如果是4N会36分,8N45分,9N会73分,why?
#include<bits/stdc++.h>
using namespace std;
const int N = 100010;
struct node{
int l,r,lazy,sum;
}t[9 * N];
int n,m,root,p,a[N],size[N],dfn[N],son[N],dep[N],top[N],fa[N],dfnn;
vector<int> v[N];
void pushup(int id){
t[id].sum = (t[id * 2].sum + t[id * 2 + 1].sum) % p;
}
void pushdown(int id){
if(t[id].lazy){
t[id * 2].sum += (t[id * 2].r - t[id * 2].l + 1) * t[id].lazy;
t[id * 2 + 1].sum += (t[id * 2 + 1].r - t[id * 2 + 1].l + 1) * t[id].lazy;
t[id * 2].lazy += t[id].lazy;
t[id * 2 + 1].lazy += t[id].lazy;
t[id * 2].sum %= p;
t[id * 2 + 1].sum %= p;
t[id * 2].lazy %= p;
t[id * 2 + 1].lazy %= p;
t[id].lazy = 0;
}
}
void build(int id,int l,int r){
t[id].l = l;
t[id].r = r;
if(l == r){
// t[id].sum = a[l];
return;
}
int mid = (l + r) >> 1;
build(id * 2,l,mid);
build(id * 2 + 1,mid + 1,r);
// pushup(id);
}
void change(int id,int l,int r,int k){
if(l <= t[id].l && t[id].r <= r){
t[id].sum += k * (t[id].r - t[id].l + 1);
t[id].sum %= p;
t[id].lazy += k;
t[id].lazy %= p;
return;
}
pushdown(id);
int mid = (t[id].l + t[id].r) / 2;
if(mid >= l){
change(id * 2,l,r,k);
}
if(mid < r){
change(id * 2 + 1,l,r,k);
}
pushup(id);
}
int query(int id,int l,int r){
if(l <= t[id].l && t[id].r <= r){
return t[id].sum;
}
pushdown(id);
int ans = 0;
int mid = (t[id].l + t[id].r) / 2;
if(mid >= l){
ans += query(id * 2,l,r);
ans %= p;
}
if(mid < r){
ans += query(id * 2 + 1,l,r);
ans %= p;
}
return ans % p;
}
void dfs1(int u,int fat,int deep){
size[u] = 1;
fa[u] = fat;
dep[u] = deep;
int ma = 0;
for(int i = 0;i < v[u].size();i++){
if(v[u][i] != fat){
dfs1(v[u][i],u,deep + 1);
size[u] += size[v[u][i]];
if(size[v[u][i]] > ma){
ma = size[v[u][i]];
son[u] = v[u][i];
}
}
}
}
void dfs2(int u,int chain){
top[u] = chain;
dfn[u] = ++dfnn;
if(!son[u]){
return;
}
dfs2(son[u],chain);
for(int i = 0;i < v[u].size();i++){
if(v[u][i] != fa[u] && v[u][i] != son[u]){
dfs2(v[u][i],v[u][i]);
}
}
}
void solve1(int u,int v,int k){
k %= p;
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]){
swap(u,v);
}
change(root,dfn[top[u]],dfn[u],k);
u = fa[top[u]];
}
if(dfn[u] > dfn[v]){
swap(u,v);
}
change(root,dfn[u],dfn[v],k);
}
int solve2(int u,int v){
int ans = 0;
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]){
swap(u,v);
}
ans += query(root,dfn[top[u]],dfn[u]);
ans %= p;
u = fa[top[u]];
}
if(dfn[u] > dfn[v]){
swap(u,v);
}
ans += query(root,dfn[u],dfn[v]);
ans %= p;
return ans % p;
}
int main(){
cin >> n >> m >> root >> p;
for(int i = 1;i <= n;i++){
cin >> a[i];
}
for(int i = 1;i <= n - 1;i++){
int x,y;
cin >> x >> y;
v[x].push_back(y);
v[y].push_back(x);
}
dfs1(root,0,1);
dfs2(root,root);
build(root,1,n);
for(int i = 1;i <= n;i++){
change(root,dfn[i],dfn[i],a[i]);
}
// for(int i = 1;i <= n;i++){
// cout << query(root,dfn[i],dfn[i]) << " ";
// }
while(m--){
int op,x,y,z;
cin >> op >> x;
if(op == 1){
cin >> y >> z;
solve1(x,y,z);
}else if(op == 2){
cin >> y;
cout << solve2(x,y) << endl;
}else if(op == 3){
cin >> z;
change(root,dfn[x],dfn[x] + size[x] - 1,z);
}else{
cout << query(root,dfn[x],dfn[x] + size[x] - 1) << endl;
}
}
return 0;
}