求助卡常,玄关
查看原帖
求助卡常,玄关
974277
水星湖psgqwq楼主2024/11/29 23:48

卡不了一点,最大点要跑接近三秒/kk

#include <bits/stdc++.h>
using namespace std;
namespace z {
const int BS = 1 << 15; char buf[BS]; int bi = 0, bs = 0;
inline void lb() { bs = fread(buf, 1, BS, stdin); bi = 0; }
inline char nc() { if (bi >= bs) lb(); return bs == 0 ? EOF : buf[bi++]; }
template<typename T> inline T read() { char c; T res = 0, sign = 1; do { c = nc(); } while (!isdigit(c) && c != '-'); if (c == '-') sign = -1, c = nc(); do { res = res * 10 + (c - '0'); c = nc(); } while (isdigit(c)); return sign * res; }
#define rd  read<int>
void out(long long x)
{
    if(x < 10)putchar(x + '0');
    else out(x / 10), putchar(x % 10 + '0');
}
const int N = 1e6 + 5;
int n, q, a[N], fa[N], vis[N], sz[N], rt[N], f[N];
long long mul[N];
vector<int> p[N];
inline int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
void merge(int x, int y) { if(sz[x] > sz[y]) swap(x, y); sz[find(y)] += sz[find(x)], fa[find(x)] = find(y); }
void dfs(int u) {
    for(auto v : p[u]) {
        dfs(v);
        if(a[v] == a[u]) merge(u, v);
    }
}
unordered_map<long long, vector<int>> mp, mp1;
unordered_map<long long, long long> mem;
inline pair<int, long long> dfs(int u, pair<int, int> x, int fa, int t) {
    int res = sz[u]; long long re = 1ll * sz[u] * sz[u];
    vis[u] = t;
    if(x.first == a[u]) {
        long long tmp = 1ll * u * n + x.second;
        for(int v : mp[tmp]) {
            if(v == fa) continue;
            auto [r1, r2] = dfs(v, x, u, t);
            res += r1;
            re += r2;
        }
    } else {
        long long tmp = 1ll * u * n + x.first;
        for(int v : mp[tmp]) {
            if(v == fa) continue;
            auto [r1, r2] = dfs(v, x, u, t);
            res += r1;
            re += r2;
        }
    }
    return {res, re};
}
void file() {
    freopen("sub7-1.in", "r", stdin);
    freopen("out.txt", "w", stdout);
}
void main() {
    #ifndef ONLINE_JUDGE 
        file();
    #endif
    n = rd(), q = rd();
    for(int i = 1; i <= n; i=-~i) fa[i] = i, sz[i] = 1;
    for(int i = 1; i <= n; i=-~i)
        a[i] = rd();
    for(int i = 2; i <= n; i=-~i) {
        int u; u = rd();
        p[u].push_back(i);
    }
    dfs(1); 
    for(int i = 1; i <= n; i=-~i) {
        rt[i] = find(i);
        if(i == rt[i])
            mul[a[i]] += 1ll * sz[i] * sz[i];
    }
    for(int u = 1; u <= n; u=-~u) 
        for(int v : p[u]) 
            if(rt[u] != rt[v]) {
                mp[1ll * rt[u] * n + a[v]].push_back(rt[v]);
                mp[1ll * rt[v] * n + a[u]].push_back(rt[u]);
                int t1 = a[u], t2 = a[v];
                if(t1 > t2) swap(t1, t2);
                mp1[1ll * t1 * n + t2].push_back(rt[u]);
            }        
    memset(vis, 0, sizeof(vis));
    while(q--) {
        long long ans = 0, t, t1, x, y; 
        x = rd(), y = rd();
        if(x > y) swap(x, y);
        if(mem[x * n + y]) {
            out(mem[x * n + y]);
            putchar('\n');
            continue;
        }
        long long tmp = x * n + y;
        if(mp1.count(tmp)) {
            for(int u : mp1[tmp]) 
                if(!vis[u]) tie(t, t1) = dfs(u, {x, y}, 0, 1), ans += t * t - t1;
            for(int u : mp1[tmp]) 
                if(vis[u]) dfs(u, {x, y}, 0, 0);
        }
        out(ans + mul[x] + mul[y]); putchar('\n');
        mem[x * n + y] = ans + mul[x] + mul[y];
    }
}
}

signed main()
{
    z::main();
    return 0;
}
2024/11/29 23:48
加载中...