zl程序教程

您现在的位置是:首页 >  其它

当前栏目

【bzoj4182】Shopping 树的点分治+dfs序+背包dp

DP DFS 背包 分治
2023-09-11 14:22:39 时间

题目描述

给出一棵 $n$ 个点的树,每个点有物品重量 $w$ 、体积 $c$ 和数目 $d$ 。要求选出一个连通子图,使得总体积不超过背包容量 $m$ ,且总重量最大。求这个最大总重量。

输入

输入第一行一个正整数T,表示测试数据组数。

对于每组数据,
第一行两个正整数n;m;
第二行n个非负整数w1,w2...wn;
第三行n个正整数c1,c2...cn;
第四行n个正整数d1,d2...dn;
接下来n-1行每行两个正整数u;v表示u和v之间有一条道路
$n\le 500,m\le 4000$ 

输出

输出共T 行,每行一个整数,表示最大的喜爱度之和。

样例输入

1
3 2
1 2 3
1 1 1
1 2 1
1 2
1 3

样例输出

4


题解

树的点分治+dfs序+背包dp

终于get到了树形背包dp的正确姿势 = =

如果要求必须选 $x$ ,即做以 $x$ 为根的树形背包dp。

那么对于一个点,有两种情况:选和不选。

选的话即可选子节点,不选的话就不能选子树内的点。

因此使用dfs序进行dp。以 $x$ 为根进行dfs。设 $f[i][j]$ 表示使用dfs序上 $[i,n]$ 位置对应的节点,背包容量为 $j$ 时的最大重量。

对于位置 $i$ ,如果选,则从 $f[i+1][]$ 转移过来;否则子树内节点都不能选,从 $f[last[val[i]]+1][]$ 转移过来 。其中 $val[i]$ 表示dfs序上位置 $i$ 对应的节点编号,$last[i]$ 表示 $i$ 子树在dfs序上的区间右端点位置。

这样从后向前进行多重背包dp,最终的 $f[1][m]$ 即为以 $x$ 为根的树形背包dp的答案。

但是如果想本题这样,求任意一个连通块的结果呢?使用点分治,求出包含重心的答案,递归不包含重心的答案即可。

时间复杂度 $O(nm\log d\log n)$ ,如果使用单调队列优化多重背包的话即可使时间复杂度去掉一个log。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 510
using namespace std;
int m , w[N] , c[N] , d[N] , head[N] , to[N << 1] , next[N << 1] , cnt , si[N] , ms[N] , sum , root , vis[N] , val[N] , last[N] , tot , f[N][4010] , ans;
inline void add(int x , int y)
{
    to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt;
}
void getroot(int x , int fa)
{
    int i;
    si[x] = 1 , ms[x] = 0;
    for(i = head[x] ; i ; i = next[i])
        if(!vis[to[i]] && to[i] != fa)
            getroot(to[i] , x) , si[x] += si[to[i]] , ms[x] = max(ms[x] , si[to[i]]);
    ms[x] = max(ms[x] , sum - si[x]);
    if(ms[x] < ms[root]) root = x;
}
void dfs(int x , int fa)
{
    int i;
    si[x] = 1 , val[++tot] = x;
    for(i = head[x] ; i ; i = next[i])
        if(!vis[to[i]] && to[i] != fa)
            dfs(to[i] , x) , si[x] += si[to[i]];
    last[x] = tot;
}
void solve(int x)
{
    int i , j , k , t;
    vis[x] = 1 , tot = 0 , dfs(x , 0);
    for(i = 1 ; i <= tot + 1 ; i ++ )
        for(j = 0 ; j <= m ; j ++ )
            f[i][j] = 0;
    for(i = tot ; i ; i -- )
    {
        t = d[val[i]] - 1;
        for(j = m ; j >= c[val[i]] ; j -- ) f[i][j] = f[i + 1][j - c[val[i]]] + w[val[i]];
        for(j = 1 ; j <= t ; t -= j , j <<= 1)
            for(k = m ; k >= j * c[val[i]] ; k -- )
                f[i][k] = max(f[i][k] , f[i][k - j * c[val[i]]] + j * w[val[i]]);
        if(t)
            for(j = m ; j >= t * c[val[i]] ; j -- )
                f[i][j] = max(f[i][j] , f[i][j - t * c[val[i]]] + t * w[val[i]]);
        for(j = m ; ~j ; j -- ) f[i][j] = max(f[i][j] , f[last[val[i]] + 1][j]);
    }
    ans = max(ans , f[1][m]);
    for(i = head[x] ; i ; i = next[i])
        if(!vis[to[i]])
            sum = si[to[i]] , root = 0 , getroot(to[i] , 0) , solve(root);
}
int main()
{
    int T;
    scanf("%d" , &T);
    while(T -- )
    {
        memset(head , 0 , sizeof(head)) , cnt = 0;
        memset(vis , 0 , sizeof(vis)) , ans = 0;
        int n , i , x , y;
        scanf("%d%d" , &n , &m);
        for(i = 1 ; i <= n ; i ++ ) scanf("%d" , &w[i]);
        for(i = 1 ; i <= n ; i ++ ) scanf("%d" , &c[i]);
        for(i = 1 ; i <= n ; i ++ ) scanf("%d" , &d[i]);
        for(i = 1 ; i < n ; i ++ ) scanf("%d%d" , &x , &y) , add(x , y) , add(y , x);
        sum = n , ms[0] = 1 << 30 , root = 0 , getroot(1 , 0) , solve(root);
        printf("%d\n" , ans);
    }
    return 0;
}