讨论区中WA60是k=3的矩阵在计算fi,1时没有考虑加上val[u],但是我考虑过了还是60pts
思路和第一篇题解一样,但是我的矩阵使用横向量即[fi-1,0,fi-1,1,fi-1,2] * 转移矩阵 = [fi,0,fi,1,fi,2],而不是题解中的写法。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e5 + 7, inf = 1e17;
int n, Q, k, val[N], num[N], st[22][N], dep[N];
vector<int> g[N];
struct Matrix {
int a[3][3];
void clear() {
for(int i = 0; i < 3; i++)
for(int j = 0; j < 3; j++)
a[i][j] = inf;
}
void csh() {
for(int i = 0; i < 3; i++)
for(int j = 0; j < 3; j++)
a[i][j] = inf;
for(int i = 0; i < 3; i++)
a[i][i] = 0;
}
void LZX(int p) {
if(k == 1) {
for(int i = 0; i < 3; i++)
for(int j = 0; j < 3; j++)
a[i][j] = inf;
a[0][0] = val[p];
}
if(k == 2) {
for(int i = 0; i < 3; i++)
for(int j = 0; j < 3; j++)
a[i][j] = inf;
a[0][0] = val[p]; a[1][0] = val[p];
a[0][1] = 0;
}
if(k == 3) {
a[0][0] = a[1][0] = a[2][0] = val[p];
a[0][1] = 0; a[1][1] = num[p]; a[2][1] = inf;
a[0][2] = inf; a[1][2] = 0; a[2][2] = inf;
}
}
void build(int p) {
for(int i = 0; i < 3; i++)
for(int j = 0; j < 3; j++)
a[i][j] = inf;
a[0][0] = val[p];
// a[0][1] = val[p] + num[p];
}
void print() {
printf("\n");
for(int i = 0; i < 3; i++){
for(int j = 0; j < 3; j++)
printf("%lld ", a[i][j]);
printf("\n");
}
printf("\n");
}
} base[N], mat[22][N], __mat[22][N];
Matrix operator* (Matrix x, Matrix y) { //?
Matrix res; res.clear();
for(int i = 0; i < 3; i++) {
for(int j = 0; j < 3; j++) {
for(int k = 0; k < 3; k++) {
res.a[i][j] = min(res.a[i][j], x.a[i][k] + y.a[k][j]);
}
}
}
return res;
}
void dfs(int u) { //处理st[0], base[0], num
dep[u] = dep[st[0][u]] + 1; num[u] = inf;
for(int v : g[u]) {
if(v != st[0][u]) {
num[u] = min(num[u], val[v]);
st[0][v] = u;
dfs(v);
}
}
base[u].LZX(u);
}
int LCA(int u, int v) {
if(u == v) return val[u];
if(dep[u] < dep[v]) swap(u, v); //u为较深的点
Matrix lzx, res; lzx.build(u); res.csh();
int delt = dep[u] - dep[v];
for(int i = 18; i >= 0; i--) {
if(delt & (1 << i)) {
res = res * mat[i][u];
u = st[i][u];
}
}
if(u == v){
res = lzx * res;
return res.a[0][0];
}
int tv = v;
Matrix tmep; tmep.csh();
for(int i = 18; i >= 0; i--) {
if(st[i][u] != st[i][v]) {
res = res * mat[i][u]; tmep = __mat[i][v] * tmep; // ?
u = st[i][u]; v = st[i][v];
}
}
u = st[0][u];
res = lzx * res * base[u] * tmep * base[tv];
return res.a[0][0];
}
signed main() {
//动态DP:矩阵优化树形DP
//先st表预处理出base序列逆序积和顺序积
scanf("%lld%lld%lld", &n, &Q, &k);
for(int i = 1; i <= n; i++)
scanf("%lld ", &val[i]);
for(int u, v, i = 2; i <= n; i ++) {
scanf("%lld%lld", &u, &v);
g[u].push_back(v); g[v].push_back(u);
}
for(int i = 0; i <= n; i++)
base[i].clear();
dfs(1);
// for(int i = 1; i <= n; i++) base[i].print();
for(int i = 1; i <= n; i++) __mat[0][i] = mat[0][i] = base[st[0][i]];
for(int lg = 1; lg <= 18; lg++) {
for(int i = 1; i <= n; i++) {
st[lg][i] = st[lg - 1][st[lg - 1][i]];
mat[lg][i] = mat[lg - 1][i] * mat[lg - 1][st[lg - 1][i]];
__mat[lg][i] = __mat[lg - 1][st[lg - 1][i]] * __mat[lg - 1][i];
}
}
// mat[0][2].print();
while(Q --) {
int x, y;
scanf("%lld%lld", &x, &y);
printf("%lld\n", LCA(x, y));
}
return 0;
}