zl程序教程

您现在的位置是:首页 >  其它

当前栏目

【YBT2023寒假Day7 B】打怪兽(cdq分治)(斜率优化)

优化 分治 寒假 斜率 Day7
2023-09-27 14:28:29 时间

打怪兽

题目链接:YBT2023寒假Day7 B

题目大意

有 n 个怪,每个怪有攻击力和血量。
你每次可以选一个怪打 b 的伤害,如果一个怪的血量小于等于 0 就死了。
然后每次你打完之后所有活着的怪会各打你一次。
一开始你可以选两个怪秒杀,问你打死所有怪需要的最小血量。

思路

首先把血量 c c c 改成要打死的次数 ( c + b − 1 ) / b (c+b-1)/b (c+b1)/b(就是 ⌈ c b ⌉ \left\lceil\dfrac{c}{b}\right\rceil bc

首先不考虑秒杀,其实是一个经典的排序问题。
(你直接看谁先就直接假设只有这两个怪,谁先打掉的血少谁就放前面)
(至于原因把式子列一下就看得出来)

接下来考虑看秒杀,考虑秒掉之后减少的费用。
我们假设 i , j ( i < j ) i,j(i<j) i,j(i<j) 为秒掉的两个,那减少的费用是:
a i ( ∑ k = 1 i c k − 1 ) + c i ∑ k = i + 1 n a k + a j ( ∑ k = 1 j c k − c i − 1 ) + c j ∑ k = j + 1 n a k a_i(\sum\limits_{k=1}^ic_k-1)+c_i\sum\limits_{k=i+1}^na_k+a_j(\sum\limits_{k=1}^jc_k-c_i-1)+c_j\sum\limits_{k=j+1}^na_k ai(k=1ick1)+cik=i+1nak+aj(k=1jckci1)+cjk=j+1nak
(第一个和第三个是它打你的次数减少了,第二个和第四个是你打它的时间省了使得后面的怪打你的次数少了)

会发现可以写成 A i + B j − a j c i A_i+B_j-a_jc_i Ai+Bjajci 的形式(其实 A i , B i A_i,B_i Ai,Bi 是同一个数组,都是 a i ( ∑ k = 1 i c i − 1 ) + c i ∑ k = i + 1 n a i a_i(\sum\limits_{k=1}^ic_i-1)+c_i\sum\limits_{k=i+1}^na_i ai(k=1ici1)+cik=i+1nai,前缀后缀预处理一下就行)

那如果我们枚举一个 j j j,它就会变成要求 A i − a j c i A_i-a_jc_i Aiajci 的最大值,那这个东西是线段,可以维护上凸壳 j j j 变大,在凸壳上二分(也可以直接指针找)最大值加上点。
但是一个问题是 a , c a,c a,c 不能同时递增。

于是这种两个变量的考虑通过 cdq 分治减少一个。
大概是给前面一半按 c i c_i ci 排序,后面一半按 a i a_i ai 排序。
然后推 A i 1 − a j c i 1 > A i 2 − a j c i 2 A_{i1}-a_jc_{i1}>A_{i2}-a_jc_{i2} Ai1ajci1>Ai2ajci2 这个式子得到比较方式。(就是斜率优化)

代码

#include<cstdio>
#include<iostream>
#include<algorithm> 
#define ll long long

using namespace std;

const int N = 3e5 + 100;
struct node {
	ll a, d, A;
}a[N];
int n, b, tot;
ll ans, sta[N];

bool cmp(node x, node y) {
	return x.a * (x.d - 1) + y.a * (x.d + y.d - 1) < y.a * (y.d - 1) + x.a * (y.d + x.d - 1);
}

bool cmpd(node x, node y) {return x.d < y.d;}
bool cmpa(node x, node y) {return x.a < y.a;}

void cdq(int l, int r) {
	if (l == r) return ;
	int mid = (l + r) >> 1;
	cdq(l, mid); cdq(mid + 1, r);
	sort(a + l, a + mid + 1, cmpd); sort(a + mid + 1, a + r + 1, cmpa);
	tot = 1; int L = 1; sta[0] = sta[1] = 0;
	for (int i = l; i <= mid; i++) {
		while (L < tot && (a[i].A - a[sta[tot]].A) * (a[sta[tot]].d - a[sta[tot - 1]].d) >= (a[sta[tot]].A - a[sta[tot - 1]].A) * (a[i].d - a[sta[tot]].d)) tot--;
		sta[++tot] = i;
	}
	for (int i = mid + 1; i <= r; i++) {
		while (L < tot && (a[sta[L + 1]].A - a[sta[L]].A) >= a[i].a * (a[sta[L + 1]].d - a[sta[L]].d)) L++;
		ans = max(ans, a[i].A + a[sta[L]].A - a[i].a * a[sta[L]].d);
	}
}

int main() {
	freopen("fittest.in", "r", stdin);
	freopen("fittest.out", "w", stdout);
	
	scanf("%d %d", &n, &b);
	for (int i = 1; i <= n; i++) {
		scanf("%lld %lld", &a[i].a, &a[i].d); a[i].d = (a[i].d + b - 1) / b;
	}
	sort(a + 1, a + n + 1, cmp);
	
	ll all = 0, tmp = 0;
	for (int i = 1; i <= n; i++) {
		tmp += a[i].d;
		all += a[i].a * (tmp - 1);
		a[i].A = a[i].a * (tmp - 1);
	}
	tmp = 0;
	for (int i = n; i >= 1; i--) {
		a[i].A += a[i].d * tmp;
		tmp += a[i].a;
	}
	
	cdq(1, n);
	printf("%lld", all - ans);
	
	return 0;
}