zl程序教程

您现在的位置是:首页 >  其它

当前栏目

【luogu P4229】某位歌姬的故事(离散化)(DP)

DP 故事 Luogu 离散
2023-09-27 14:28:28 时间

某位歌姬的故事

题目链接:luogu P4229

题目大意

有一个数组,告诉你长度和值域,然后有若干个限制条件是某个区间的最大值是某个数。
然后问你有多少个满足条件的数组。

思路

首先发现小的只有限制个数,那我们考虑用限制的区间来离散化。
然后因为这里是区间,所以离散化最好弄成左闭右开的。

然后发现每个位置它可以选的时候是有上限的,然后对于要求上限不同的区间,它们之间必然没有联系,所以我们可以对每种上限的要求分开考虑。

然后剩下的其实就简单了,你就直接 DP,设 f i , j f_{i,j} fi,j 为当前到位置 i i i,上一次取最大值是在位置 j j j
然后根据限制条件转移即可。
(取最大值是 w l e n − ( w − 1 ) l e n w^{len}-(w-1)^{len} wlen(w1)len,不取就是 ( w − 1 ) l e n (w-1)^{len} (w1)len

具体实现可以看看代码,主要是实现。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#define mo 998244353
#define ll long long
#define INF 0x3f3f3f3f3f3f3f3f

using namespace std;

struct node {
	int l, r, m, bl, br;
}q[501];
int n, Q, A, sz[1050], a[1050], w[1050], ws[1050];
int pl[1050], lim[1050], tmp, aw[1050];
ll ans;

ll ksm(ll x, ll y) {
	ll re = 1;
	while (y) {
		if (y & 1) re = re * x % mo;
		x = x * x % mo; y >>= 1;
	}
	return re;
}

void Init() {
	a[0] = 0; ans = 1; aw[0] = 0;
	for (int i = 1; i < 1050; i++) w[i] = INF;
}

bool check_no_ans() {
	for (int i = 1; i <= Q; i++)
		for (int j = 1; j <= Q; j++)
			if (q[i].m > q[j].m && q[j].bl <= q[i].bl && q[i].br <= q[j].br) return 1;
	return 0;
}

ll f[1050][1050];

ll clac(int W) {
	int nn = 0;
	for (int i = 1; i < a[0]; i++) if (w[i] == W) pl[++nn] = i;
	for (int i = 1; i <= Q; i++) if (q[i].m == W) {
		int l1 = pl[lower_bound(pl + 1, pl + nn + 1, q[i].bl) - pl];
		int r1 = pl[lower_bound(pl + 1, pl + nn + 1, q[i].br + 1) - pl - 1];
		lim[r1] = max(lim[r1], l1);
	}
	for (int i = 1; i < a[0]; i++) lim[i] = max(lim[i - 1], lim[i]);
	f[0][0] = 1;
	for (int x = 1; x <= nn; x++) { int i = pl[x];
		for (int y = 0; y < i; y++) { int j = pl[y];
			f[x][i] = (f[x][i] + f[x - 1][y] * (ksm(W, sz[i]) - ksm(W - 1, sz[i]) + mo) % mo) % mo;
			if (lim[i] <= y) f[x][y] = f[x - 1][y] * ksm(W - 1, sz[i]) % mo;
		}
	}
	ll re = 0; for (int i = 1; i <= pl[nn]; i++) re = (re + f[nn][i]) % mo;
	for (int i = 0; i < a[0]; i++) lim[i] = 0;
	for (int i = 0; i <= nn; i++)
		for (int j = 0; j <= pl[nn]; j++) f[i][j] = 0;
	return re;
}

int main() {
	int T; scanf("%d", &T); tmp = INF;
	while (T--) {
		scanf("%d %d %d", &n, &Q, &A);
		Init();
		for (int i = 1; i <= Q; i++) {
			scanf("%d %d %d", &q[i].l, &q[i].r, &q[i].m);
			a[++a[0]] = q[i].l; a[++a[0]] = q[i].r + 1; aw[++aw[0]] = q[i].m;
		}
		a[++a[0]] = 1; a[++a[0]] = n + 1;
		sort(a + 1, a + a[0] + 1); a[0] = unique(a + 1, a + a[0] + 1) - a - 1;
		for (int i = 1; i <= Q; i++) {
			q[i].bl = lower_bound(a + 1, a + a[0] + 1, q[i].l) - a;
			q[i].br = lower_bound(a + 1, a + a[0] + 1, q[i].r + 1) - a - 1;
			for (int j = q[i].bl; j <= q[i].br; j++) w[j] = min(w[j], q[i].m);
		}
		
		if (check_no_ans()) {printf("0\n"); continue;}
		
		sort(aw + 1, aw + aw[0] + 1); aw[0] = unique(aw + 1, aw + aw[0] + 1) - aw - 1;
		for (int i = 1; i < a[0]; i++) sz[i] = a[i + 1] - a[i];
		memcpy(ws, w, sizeof(ws));
		sort(ws + 1, ws + a[0]); ws[0] = unique(ws + 1, ws + a[0]) - ws - 1;
		for (int i = 1; i <= aw[0]; i++) {
			ans = ans * clac(aw[i]) % mo;
		}
		for (int i = 1; i < a[0]; i++)
			if (w[i] == tmp) ans = ans * ksm(A, sz[i]) % mo;
		printf("%lld\n", ans);
	}
	
	return 0;
}