又是一道妙题,加深了蒟蒻对 LCT\text{LCT} 的理解。

题意

给定一棵 nn 个节点的有根树,根节点为 11。最开始每个节点都有颜色,且颜色互不相同。

定义一条路径的权值为:路径上点的不同颜色数。

现在一共会有 mm 组询问,每组询问有三种:

  • 1 xxx 到根节点 11 上的所有点都染色为以前从未出现过的颜色。
  • 2 x y 询问 xyx \to y 路径的权值。
  • 3 x 询问 xx 的子树内的结点中,到根节点的路径的最大权值。

题解

思路

考虑如何维护 11 操作,容易想到:任一时刻,对于每种颜色,拥有该颜色的点在树上的联通块有且仅有一个,且一定是直上直下的链。

这个性质的存在,使得此题可以用 LCT\text{LCT} 解决。一个经典的处理方法是:

  • 对于 LCT\text{LCT} 上的一条边,若两端点的颜色相同,则该边为实边。
  • 若两端点颜色不同,该边为虚边。

容易证明这种赋实边、虚边的方法是符合 LCT\text{LCT} 的性质的。

所以操作 11 就可以转换成 LCT\text{LCT}access 操作,即:将 x1x \to 1 路径上的结点实边断开,再将路径上的边改为实边。

对于操作 22,记录一个 disxdis_x 表示 x1x \to 1 的路径上经历的虚边数量,即该路径的权值。操作 22 的答案就是 disx+disy2×disLCA+1dis_{x} + dis_{y} - 2 \times dis_{LCA} + 1

而操作 33 就是在查询 xx 子树内的 disdis 最大值,可以用线段树配合 DFS 序解决。

代码实现

这里要详细讲一下 LCT\text{LCT} 里的 access

回想一下普通 LCT\text{LCT} 里的 access


普通 LCT 里的 access:

1
2
3
4
void access(int x) {
for (int y = 0; x; y = x, x = fa[x])
splay(x), son[x][1] = y, pushup(x);
}

每次 for 循环内的流程是:1. 将 xx 旋转到当前 splay 的根。2. 将 xx 原本实边断开。3. 将 xx 现在的实边连到 yy。4. xx 的实儿子改变,故 pushup

那么对于这道题,2、3 两个步骤会对 disdis 产生影响,具体就是(有点绕):

  • 原树xx 的,子树中包含 sonx,1son_{x, 1},的儿子子树 disdis 全部加上 11
  • 原树xx 的,子树中包含 yy,的儿子子树 disdis 全部减去 11

所有为了找到那个儿子,又由于 splay 的性质:splay 中序遍历得到的序列,深度递增。所以只需找到 LCT\bold{\text{LCT}}sonx,1/yson_{x, 1} / y 子树最左的儿子就行了。


findroot 代码

1
2
3
4
int findroot(int x) {
while (son[x][0]) x = son[x][0];
return x;
}

access 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void access(int x) {
for (int y = 0; x; y = x, x = fa[x]) {
splay(x); int rt;
if (son[x][1]) {
rt = findroot(son[x][1]);
t.update(1, dfn[rt], dfn[rt] + sz[rt] - 1, 1);
}
if (y) {
rt = findroot(y);
t.update(1, dfn[rt], dfn[rt] + sz[rt] - 1, -1);
}
son[x][1] = y;
}
}

很不幸,上面两个代码是错的,原因是:findroot 时未 splay,这导致复杂度变为 O(n2)O(n^2)
但是如果在 findrootsplay,会使 access 中的 xx 无法成为当前 splay 的根节点,最终死循环。

解决办法也不难,只需将 findroot 中用到的所有 xx 存放在一个 vector 中,access 完毕后将 vector 中的所有结点都 splay 一遍即可。


findroot 和 access 的正确代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
vector <int> tosplay;
int findroot(int x) {
while (son[x][0]) x = son[x][0];
tosplay.emplace_back(x);
return x;
}
void access(int x) {
tosplay.clear();
for (int y = 0; x; y = x, x = fa[x]) {
splay(x); int rt;
if (son[x][1]) {
rt = findroot(son[x][1]);
t.update(1, dfn[rt], dfn[rt] + sz[rt] - 1, 1);
}
if (y) {
rt = findroot(y);
t.update(1, dfn[rt], dfn[rt] + sz[rt] - 1, -1);
}
son[x][1] = y;
}
for (auto p : tosplay) splay(p);
}

本题的所有代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
const int N = 1E5 + 5, M = 20;
int n, m, dep[N], wt[N], dfn[N], tot, sz[N];
vector <int> G[N];
struct segt {
struct node {
int l, r;
int tag, mx;
} t[N << 2];
int lson(int x) {return x << 1;}
int rson(int x) {return x << 1 | 1;}
void pushup(int x) {t[x].mx = max(t[lson(x)].mx, t[rson(x)].mx);}
void build(int x, int l, int r) {
t[x].l = l; t[x].r = r; t[x].tag = 0;
if (l == r) {
t[x].mx = wt[l];
return ;
} int mid = (l + r) >> 1;
build(lson(x), l, mid); build(rson(x), mid + 1, r);
pushup(x);
}
void upd(int x, int val) {t[x].mx += val; t[x].tag += val;}
void pushdown(int x) {
upd(lson(x), t[x].tag); upd(rson(x), t[x].tag);
t[x].tag = 0;
}
void update(int x, int L, int R, int val) {
if (t[x].l >= L && t[x].r <= R) return upd(x, val);
int mid = (t[x].l + t[x].r) >> 1; pushdown(x);
if (L <= mid) update(lson(x), L, R, val);
if (R > mid) update(rson(x), L, R, val);
pushup(x);
}
int query(int x, int L, int R) {
if (t[x].l >= L && t[x].r <= R) return t[x].mx;
int mid = (t[x].l + t[x].r) >> 1, res = 0; pushdown(x);
if (L <= mid) res = max(res, query(lson(x), L, R));
if (R > mid) res = max(res, query(rson(x), L, R));
return res;
}
int query(int x) {return query(1, x, x);}
} t;
struct lct {
int son[N][2], fa[N];
bool checkroot(int x) {return son[fa[x]][0] != x && son[fa[x]][1] != x;}
bool checkson(int x) {return son[fa[x]][1] == x;}
void rotate(int x) {
int y = fa[x], z = fa[y], chx = checkson(x), chy = checkson(y);
if (!checkroot(y)) son[z][chy] = x; fa[x] = z;
son[y][chx] = son[x][!chx]; fa[son[x][!chx]] = y;
son[x][!chx] = y; fa[y] = x;
}
void splay(int x) {
while (!checkroot(x)) {
if (!checkroot(fa[x])) rotate(checkson(x) != checkson(fa[x]) ? x : fa[x]);
rotate(x);
}
}
vector <int> tosplay;
int findroot(int x) {
while (son[x][0]) x = son[x][0];
tosplay.emplace_back(x);
return x;
}
void access(int x) {
tosplay.clear();
for (int y = 0; x; y = x, x = fa[x]) {
splay(x); int rt;
if (son[x][1]) {
rt = findroot(son[x][1]);
t.update(1, dfn[rt], dfn[rt] + sz[rt] - 1, 1);
}
if (y) {
rt = findroot(y);
t.update(1, dfn[rt], dfn[rt] + sz[rt] - 1, -1);
}
son[x][1] = y;
}
for (auto p : tosplay) splay(p);
}
} f;
struct bz {
int yf[N][M + 1];
void dfs(int x, int fa) {
wt[dfn[x] = ++tot] = dep[x]; sz[x] = 1;
for (auto v : G[x]) {
if (v == fa) continue;
dep[v] = dep[x] + 1; yf[v][0] = x; f.fa[v] = x;
for (int i = 1; i <= M; ++i)
yf[v][i] = yf[yf[v][i - 1]][i - 1];
dfs(v, x);
sz[x] += sz[v];
}
}
int getlca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
for (int i = M; ~i; --i)
if (dep[u] - (1 << i) >= dep[v]) u = yf[u][i];
if (u == v) return u;
for (int i = M; ~i; --i)
if (yf[u][i] != yf[v][i])
u = yf[u][i], v = yf[v][i];
return yf[u][0];
}
void solve() {
for (int i = 0; i <= M; ++i) yf[1][i] = 1;
dep[1] = 1; dfs(1, 0); t.build(1, 1, n);
}
} b;
signed main(void) {
ios :: sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n >> m;
for (int i = 1; i < n; ++i) {
int u, v; cin >> u >> v;
G[u].emplace_back(v);
G[v].emplace_back(u);
} b.solve();
for (int i = 1; i <= m; ++i) {
int opt, x; cin >> opt >> x;
if (opt == 1) f.access(x);
else if (opt == 2) {
int y; cin >> y;
int lca = b.getlca(x, y);
cout << t.query(dfn[x]) + t.query(dfn[y]) - 2 * t.query(dfn[lca]) + 1 << '\n';
} else cout << t.query(1, dfn[x], dfn[x] + sz[x] - 1) << '\n';
for (int i = 1; i <= n; ++i) cout << t.query(dfn[i]) << ' ';
cout << '\n';
}
return 0;
}