zl程序教程

您现在的位置是:首页 > 

当前栏目

树上的边题解

题解 树上
2023-06-13 09:11:02 时间

题面

给一棵 n 个点的树,定义 f(l,r) 为:

\forall i \in [l,r]j \in [l,r],都存在 i \to j 的路径时,需要选择的最少树边数量。

现在你要求出:

\sum \limits _{l=1}^n \sum \limits _{r=l}^n f(l,r)

题解

枚举 l 或者 r 是没有前途的,时间复杂度的下界是 O(n^2),并且无法在左端点移动时更新整体的答案,没有进一步优化的空间。

因此换一个思路,考虑统计每条边的贡献。

一条 u \to v 的边,将原树分成两棵子树。该边会对答案产生贡献,当且仅当:

\exists i \in [l,r],j \in [l,r],满足 i \in \operatorname{subtree}(u),j \in \operatorname{subtree}(v).

对于一个编号序列 1,2,3,\ldots,n,每个编号只有两种状态:i \in \operatorname{subtree}(u)i \in \operatorname{subtree}(v).

选择的一个连续编号区间 [l,r],钦定的边会产生贡献,当且仅当该区间内同时包含 i \in \operatorname{subtree}(u)j \in \operatorname{subtree}(v).

那么有一个自然的想法:维护每一段连续的相同状态的编号,即维护若干个线段 [L,R] 满足 \forall i \in [L,R], i \in \operatorname{subtree}(u \text{ or } v),那么对于每一条线段,只要 l < L,r \ge Rl \le L,r > R

然而不难发现,这样会算重很多贡献,容斥不太可做的样子。

正难则反,统计不合法的情况数,然后用总情况数减去不合法情况统计合法情况数。

一个连续编号区间 [l,r] 不合法,当且仅当 \forall i \in [l,r]i \in \operatorname{subtree}(u \text{ or } v),可以发现当且仅当 [l,r] 是求出的线段 [L,R] 的连续子段,它就是不合法的。

那么对于每一个连续相同状态编号区间,可以计算出它对应的不合法情况数:\dbinom{R - L + 1}{2}.

而对于一条边,它不能产生贡献的总情况数是:\sum \dbinom{R - L + 1}{2},总情况数是 \dbinom{n}{2},能产生的贡献就是 \dbinom{n}{2} - \sum \dbinom{R - L + 1}{2}.

现在只剩下一个问题:如何维护连续相同状态编号区间及其对应贡献?

这个做法多多,但核心大同小异:加入一个点最多改变一条线段的贡献。

例如有一种 dsu on tree 的做法,用 set 来暴力维护线段,当加入一个点时,找到对应线段更新贡献,时间复杂度 O(n \log^2 n).

当然可以使用线段树来维护线段,这就是一个维护连续子段的变体。

假如当前搜索到 u 点,定义在 u 点子树内状态为 1,在 u 点子树外状态为 0,以编号为下标,每个子树就有一个唯一对应的 01 串。

在这个 01 串上,用线段树维护连续子段产生的贡献。唯一的难点就是上推操作。

对于左半线段与右半线段的上推合并,仅有左半线段的右端点与右半线段的左端点拼接可能产生的贡献。因此记录左右端点开始的最长连续子段,记录左右端点的值,然后更新贡献即可。

可以使用线段树合并,每个点仅会产生一次插入操作,时间复杂度 O(n \log n),使用优先重儿子搜索、垃圾回收的空间优化方法,可以比较轻松地通过。

代码

不算难码,细节也不多。为了实现方便,牺牲了一定的常数。

#include <cstdio>
#include <algorithm>
#include <ctype.h>
const int bufSize = 1e6;
inline char nc()
{
    #ifdef DEBUG
    return getchar();
    #endif
    static char buf[bufSize], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, bufSize, stdin), p1 == p2) ? EOF : *p1++;
}
template<typename T>
inline T read(T &r)
{
    static char c;
    r = 0;
    for (c = nc(); !isdigit(c); c = nc());
    for (; isdigit(c); c = nc()) r = r * 10 + c - 48;
    return r;
}
const int maxn = 1e5 +100;
struct node
{
    int to, next;
} E[maxn << 1];
int head[maxn], tot;
inline void add(const int& x, const int& y) { E[++tot].next = head[x], E[tot].to = y, head[x] = tot; }
struct segnode
{
    int llen, rlen, lval, rval, ls, rs;
    long long val;
    bool isEmpty() const { return llen == 0; }
} A[maxn << 2];
int n, root[maxn],s[maxn << 2], top, ind, size[maxn], son[maxn];
inline int newnode() { return top > 0 ? s[top--] : ++ind; }
inline void delnode(int x) { if (x) A[x] = segnode(), s[++top] = x; }
inline long long getx(int x) { return 1ll * x * (x - 1) >> 1; }
inline segnode makezero(int l, int r)
{
    segnode res;
    res.llen = res.rlen = (r - l + 1), res.lval = res.rval = 0;
    res.val = getx(res.llen);
    return res;
}
inline void pushup(int l, int r, segnode& res, segnode ls, segnode rs)
{
    int mid = (l + r) >> 1;
    if (ls.isEmpty()) ls = makezero(l, mid);
    if (rs.isEmpty()) rs = makezero(mid + 1, r);
    res.llen = ls.llen, res.rlen = rs.rlen, res.lval = ls.lval, res.rval = rs.rval;
    res.val = ls.val + rs.val;
    if (ls.rval == rs.lval)
    {
        res.val -= getx(ls.rlen) + getx(rs.llen);
        res.val += getx(ls.rlen + rs.llen);
        if (ls.llen == (mid - l + 1)) res.llen += rs.llen;
        if (rs.rlen == (r - mid)) res.rlen += ls.rlen;
    }
}
void modify(int l, int r, int& p, int pos, int k)
{
    if(!p) p = newnode();
    if (l == r) return (void)(A[p].llen = A[p].rlen = 1, A[p].lval = A[p].rval = k, A[p].val = 0);
    int mid = (l + r) >> 1;
    if (pos <= mid) modify(l, mid, A[p].ls, pos, k);
    else modify(mid + 1, r, A[p].rs, pos, k);
    pushup(l, r, A[p], A[A[p].ls], A[A[p].rs]);
}
void merge(int l, int r, int& x, int y)
{
    if (!x || !y) return (void)(x += y);
    int mid = (l + r) >> 1;
    merge(l, mid, A[x].ls, A[y].ls), merge(mid + 1, r, A[x].rs, A[y].rs), delnode(y);
    pushup(l, r, A[x], A[A[x].ls], A[A[x].rs]);
}
long long res;
void dfs1(int u, int fa)
{
    size[u] = 1;
    for (int p = head[u]; p; p = E[p].next)
    {
        int v = E[p].to;
        if (v == fa) continue;
        dfs1(v, u), size[u] += size[v];
        if (size[v] > size[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int fa)
{
    if (son[u]) dfs2(son[u], u), merge(1, n, root[u], root[son[u]]);
    for (int p = head[u]; p; p = E[p].next)
    {
        int v = E[p].to;
        if (v == fa || v == son[u]) continue;
        dfs2(v, u), merge(1, n, root[u], root[v]);
    }
    modify(1, n, root[u], u, 1);
    if (fa) res += getx(n) - A[root[u]].val;
}
int main()
{
    read(n);
    for (int i = 1, a, b; i < n; ++i) read(a), read(b), add(a, b), add(b, a);
    dfs1(1, 0), dfs2(1, 0);
    printf("%lld\n", res);
    return 0;
}