下面这份代码在找最优点时是暴力找的,然后就水过了,但是我把找最优点换成找重心之后,它就挂了,我只修改了这个地方。
// by longdie
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e5 + 5;
inline int read(int s = 0, int f = 1, char ch = getchar()) {
while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar(); }
while(isdigit(ch)) { s = s*10 + ch - '0', ch = getchar(); }
return s * f;
}
int n, Q, head[N], cnt, dis[N], rt, rt0, Max, totsize, vis[N], siz[N], fa[N], xin[N];
ll f1[N], f2[N], num[N];
struct edge { int to, next, w; } e[N<<1];
inline void add(int x, int y, int z) {
e[++cnt] = (edge){y, head[x], z}, head[x] = cnt;
}
struct LCA {
int fa[N], siz[N], dep[N], top[N], son[N];
void dfs0(int u, int f) {
fa[u] = f, siz[u] = 1;
for(register int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == f) continue;
dep[v] = dep[u] + 1;
dis[v] = dis[u] + e[i].w;
dfs0(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs1(int u, int k) {
top[u] = k;
if(son[u]) dfs1(son[u], k);
for(register int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == fa[u] || v == son[u]) continue;
dfs1(v, v);
}
}
int get_lca(int x, int y) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
return dep[x] < dep[y] ? x : y;
}
} a;
inline int get_dis(int x, int y) {
int lca = a.get_lca(x, y);
return dis[x] + dis[y] - 2 * dis[lca];
}
void dfs0(int u, int f) {
siz[u] = 1;
int res = 0;
for(register int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == f || vis[v]) continue;
dfs0(v, u);
siz[u] += siz[v];
res = max(siz[v], res);
}
res = max(res, totsize - siz[u]);
if(res < Max) { Max = res, rt = u; }
}
void dfs1(int u, int f) {
fa[u] = f, vis[u] = 1;
int now = totsize;
for(register int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == f || vis[v]) continue;
if(siz[v] < now) totsize = siz[v];
else totsize = now - siz[u];
Max = totsize, rt = 0, dfs0(v, 0);
xin[v] = rt;
dfs1(rt, u);
}
}
void change(int x, ll y) {
num[x] += y;
for(register int i = x; fa[i]; i = fa[i]) {
int tmp = get_dis(x, fa[i]);
num[fa[i]] += y;
f1[fa[i]] += y * tmp;
f2[i] += y * tmp;
}
}
ll query(int u) {
ll res = 0;
res += f1[u];
for(register int i = u; fa[i]; i = fa[i]) {
res += f1[fa[i]], res -= f2[i];
res += (num[fa[i]] - num[i]) * get_dis(fa[i], u);
}
return res;
}
ll get(int u, int f) {//只有这个地方不同
ll tmp = query(u);
for(register int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == f) continue;
if(query(v) < tmp) return get(v, u);
}
return tmp;
}
signed main() {
n = read(), Q = read();
for(register int i = 1, x, y, z; i < n; ++i) {
x = read(), y = read(), z = read();
add(x, y, z), add(y, x, z);
}
a.dfs0(1, 0), a.dfs1(1, 1);
Max = n, totsize = n, dfs0(1, 0);
rt0 = rt, dfs1(1, 0);
while(Q--) {
int u = read(), v = read();
change(u, v);
ll res = get(rt0, 0);
printf("%lld\n", res);
}
return 0;
}
正解 :
// by longdie
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e5 + 5;
inline int read(int s = 0, int f = 1, char ch = getchar()) {
while(!isdigit(ch)) { if(ch == '-') f = -1; ch = getchar(); }
while(isdigit(ch)) { s = s*10 + ch - '0', ch = getchar(); }
return s * f;
}
vector<int> g[N], G[N];
int n, Q, head[N], cnt=1, dis[N], rt, rt0, Max, totsize, vis[N], siz[N], fa[N], ji[N<<1];
ll f1[N], f2[N], num[N];
struct edge { int to, next, w; } e[N<<1];
inline void add(int x, int y, int z) {
e[++cnt] = (edge){y, head[x], z}, head[x] = cnt;
}
struct LCA {
int fa[N], siz[N], dep[N], top[N], son[N];
void dfs0(int u, int f) {
fa[u] = f, siz[u] = 1;
for(register int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == f) continue;
dep[v] = dep[u] + 1;
dis[v] = dis[u] + e[i].w;
dfs0(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs1(int u, int k) {
top[u] = k;
if(son[u]) dfs1(son[u], k);
for(register int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == fa[u] || v == son[u]) continue;
dfs1(v, v);
}
}
int get_lca(int x, int y) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
return dep[x] < dep[y] ? x : y;
}
} a;
inline int get_dis(int x, int y) {
int lca = a.get_lca(x, y);
return dis[x] + dis[y] - 2 * dis[lca];
}
void dfs0(int u, int f) {
siz[u] = 1;
int res = 0;
for(register int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(v == f || vis[v]) continue;
dfs0(v, u);
siz[u] += siz[v];
res = max(siz[v], res);
}
res = max(res, totsize - siz[u]);
if(res < Max) { Max = res, rt = u; }
}
void dfs1(int u, int f) {
fa[u] = f, vis[u] = 1;
int now = totsize;
for(register int i = head[u], v; i; i = e[i].next) {
v = e[i].to;
if(vis[v]) { ji[i] = ji[i^1] = 1; continue; }
if(siz[v] < now) totsize = siz[v];
else totsize = now - siz[u];
Max = totsize, rt = 0, dfs0(v, 0);
g[u].push_back(v), G[u].push_back(rt);
dfs1(rt, u);
}
}
void change(int x, ll y) {
num[x] += y;
for(register int i = x; fa[i]; i = fa[i]) {
int tmp = get_dis(x, fa[i]);
num[fa[i]] += y;
f1[fa[i]] += y * tmp;
f2[i] += y * tmp;
}
}
ll query(int u) {
ll res = 0;
res += f1[u];
for(register int i = u; fa[i]; i = fa[i]) {
res += f1[fa[i]], res -= f2[i];
res += (num[fa[i]] - num[i]) * get_dis(fa[i], u);
}
return res;
}
ll get(int u) {//这个地方换成了正确的查找方式
ll tmp = query(u);
for(register int i = 0, v; i < g[u].size(); ++i) {
v = g[u][i];
if(query(v) < tmp) return get(G[u][i]);
}
return tmp;
}
signed main() {
n = read(), Q = read();
for(register int i = 1, x, y, z; i < n; ++i) {
x = read(), y = read(), z = read();
add(x, y, z), add(y, x, z);
}
a.dfs0(1, 0), a.dfs1(1, 1);
Max = n, totsize = n, dfs0(1, 0);
rt0 = rt, dfs1(1, 0);
while(Q--) {
int u = read(), v = read();
change(u, v);
ll res = get(rt0);
printf("%lld\n", res);
}
return 0;
}
萌新求助啊。。。。。。