当前栏目
树上的边题解
题面
给一棵 n 个点的树,定义 f(l,r) 为:
\forall i \in [l,r],j \in [l,r],都存在 i \to j 的路径时,需要选择的最少树边数量。
现在你要求出:
题解
枚举 l 或者 r 是没有前途的,时间复杂度的下界是 O(n^2),并且无法在左端点移动时更新整体的答案,没有进一步优化的空间。
因此换一个思路,考虑统计每条边的贡献。
一条 u \to v 的边,将原树分成两棵子树。该边会对答案产生贡献,当且仅当:
\exists i \in [l,r],j \in [l,r],满足 i \in \operatorname{subtree}(u),j \in \operatorname{subtree}(v).
对于一个编号序列 1,2,3,\ldots,n,每个编号只有两种状态:i \in \operatorname{subtree}(u) 或 i \in \operatorname{subtree}(v).
选择的一个连续编号区间 [l,r],钦定的边会产生贡献,当且仅当该区间内同时包含 i \in \operatorname{subtree}(u) 与 j \in \operatorname{subtree}(v).
那么有一个自然的想法:维护每一段连续的相同状态的编号,即维护若干个线段 [L,R] 满足 \forall i \in [L,R], i \in \operatorname{subtree}(u \text{ or } v),那么对于每一条线段,只要 l < L,r \ge Rl \le L,r > R
然而不难发现,这样会算重很多贡献,容斥不太可做的样子。
正难则反,统计不合法的情况数,然后用总情况数减去不合法情况统计合法情况数。
一个连续编号区间 [l,r] 不合法,当且仅当 \forall i \in [l,r],i \in \operatorname{subtree}(u \text{ or } v),可以发现当且仅当 [l,r] 是求出的线段 [L,R] 的连续子段,它就是不合法的。
那么对于每一个连续相同状态编号区间,可以计算出它对应的不合法情况数:\dbinom{R - L + 1}{2}.
而对于一条边,它不能产生贡献的总情况数是:\sum \dbinom{R - L + 1}{2},总情况数是 \dbinom{n}{2},能产生的贡献就是 \dbinom{n}{2} - \sum \dbinom{R - L + 1}{2}.
现在只剩下一个问题:如何维护连续相同状态编号区间及其对应贡献?
这个做法多多,但核心大同小异:加入一个点最多改变一条线段的贡献。
例如有一种 dsu on tree 的做法,用 set 来暴力维护线段,当加入一个点时,找到对应线段更新贡献,时间复杂度 O(n \log^2 n).
当然可以使用线段树来维护线段,这就是一个维护连续子段的变体。
假如当前搜索到 u 点,定义在 u 点子树内状态为 1,在 u 点子树外状态为 0,以编号为下标,每个子树就有一个唯一对应的 01 串。
在这个 01 串上,用线段树维护连续子段产生的贡献。唯一的难点就是上推操作。
对于左半线段与右半线段的上推合并,仅有左半线段的右端点与右半线段的左端点拼接可能产生的贡献。因此记录左右端点开始的最长连续子段,记录左右端点的值,然后更新贡献即可。
可以使用线段树合并,每个点仅会产生一次插入操作,时间复杂度 O(n \log n),使用优先重儿子搜索、垃圾回收的空间优化方法,可以比较轻松地通过。
代码
不算难码,细节也不多。为了实现方便,牺牲了一定的常数。
#include <cstdio>
#include <algorithm>
#include <ctype.h>
const int bufSize = 1e6;
inline char nc()
{
#ifdef DEBUG
return getchar();
#endif
static char buf[bufSize], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, bufSize, stdin), p1 == p2) ? EOF : *p1++;
}
template<typename T>
inline T read(T &r)
{
static char c;
r = 0;
for (c = nc(); !isdigit(c); c = nc());
for (; isdigit(c); c = nc()) r = r * 10 + c - 48;
return r;
}
const int maxn = 1e5 +100;
struct node
{
int to, next;
} E[maxn << 1];
int head[maxn], tot;
inline void add(const int& x, const int& y) { E[++tot].next = head[x], E[tot].to = y, head[x] = tot; }
struct segnode
{
int llen, rlen, lval, rval, ls, rs;
long long val;
bool isEmpty() const { return llen == 0; }
} A[maxn << 2];
int n, root[maxn],s[maxn << 2], top, ind, size[maxn], son[maxn];
inline int newnode() { return top > 0 ? s[top--] : ++ind; }
inline void delnode(int x) { if (x) A[x] = segnode(), s[++top] = x; }
inline long long getx(int x) { return 1ll * x * (x - 1) >> 1; }
inline segnode makezero(int l, int r)
{
segnode res;
res.llen = res.rlen = (r - l + 1), res.lval = res.rval = 0;
res.val = getx(res.llen);
return res;
}
inline void pushup(int l, int r, segnode& res, segnode ls, segnode rs)
{
int mid = (l + r) >> 1;
if (ls.isEmpty()) ls = makezero(l, mid);
if (rs.isEmpty()) rs = makezero(mid + 1, r);
res.llen = ls.llen, res.rlen = rs.rlen, res.lval = ls.lval, res.rval = rs.rval;
res.val = ls.val + rs.val;
if (ls.rval == rs.lval)
{
res.val -= getx(ls.rlen) + getx(rs.llen);
res.val += getx(ls.rlen + rs.llen);
if (ls.llen == (mid - l + 1)) res.llen += rs.llen;
if (rs.rlen == (r - mid)) res.rlen += ls.rlen;
}
}
void modify(int l, int r, int& p, int pos, int k)
{
if(!p) p = newnode();
if (l == r) return (void)(A[p].llen = A[p].rlen = 1, A[p].lval = A[p].rval = k, A[p].val = 0);
int mid = (l + r) >> 1;
if (pos <= mid) modify(l, mid, A[p].ls, pos, k);
else modify(mid + 1, r, A[p].rs, pos, k);
pushup(l, r, A[p], A[A[p].ls], A[A[p].rs]);
}
void merge(int l, int r, int& x, int y)
{
if (!x || !y) return (void)(x += y);
int mid = (l + r) >> 1;
merge(l, mid, A[x].ls, A[y].ls), merge(mid + 1, r, A[x].rs, A[y].rs), delnode(y);
pushup(l, r, A[x], A[A[x].ls], A[A[x].rs]);
}
long long res;
void dfs1(int u, int fa)
{
size[u] = 1;
for (int p = head[u]; p; p = E[p].next)
{
int v = E[p].to;
if (v == fa) continue;
dfs1(v, u), size[u] += size[v];
if (size[v] > size[son[u]]) son[u] = v;
}
}
void dfs2(int u, int fa)
{
if (son[u]) dfs2(son[u], u), merge(1, n, root[u], root[son[u]]);
for (int p = head[u]; p; p = E[p].next)
{
int v = E[p].to;
if (v == fa || v == son[u]) continue;
dfs2(v, u), merge(1, n, root[u], root[v]);
}
modify(1, n, root[u], u, 1);
if (fa) res += getx(n) - A[root[u]].val;
}
int main()
{
read(n);
for (int i = 1, a, b; i < n; ++i) read(a), read(b), add(a, b), add(b, a);
dfs1(1, 0), dfs2(1, 0);
printf("%lld\n", res);
return 0;
}