zl程序教程

您现在的位置是:首页 >  其它

当前栏目

【luogu CF645E】Intellectual Inquiry(DP)(结论)(矩阵乘法)

矩阵 DP Luogu 乘法 结论
2023-09-27 14:28:28 时间

Intellectual Inquiry

题目链接:luogu CF645E

题目大意

给你一个序列,值域在 1~k,然后要你在后面再加上 m 个数,也要满足值域,然后使得本质不同的子序列个数最多,输出这个数量。

思路

首先别管那么多,如果 m = 0 m=0 m=0 怎么办。
发现 k k k 只有 100 100 100,而且因为是子序列,所以其实是跟最后一个数有关的。
s i s_i si i i i 这个数结尾的子序列个数。

那每次放进去一个新的数 x x x,那除了 s x s_x sx 别的都不会变,那就看看 s x s_x sx 怎么变。
那可以从所有的上面接,也可以自己新开,但是有一些接会导致重复,数量就是之前的 s x s_x sx,你相当于把 s x s_x sx 里面每种情况的最后的 x x x 删掉,得到的序列加上你这个都是重复的。
所以其实就是 s x = s x + ∑ i = 1 k s i − s x + 1 = ∑ i = 1 k s i + 1 s_x=s_x+\sum\limits_{i=1}^ks_i-s_x+1=\sum\limits_{i=1}^ks_i+1 sx=sx+i=1ksisx+1=i=1ksi+1
我们设 s u m = ∑ i = 1 k s i sum=\sum\limits_{i=1}^ks_i sum=i=1ksi,那就是 s x = s u m + 1 s_x=sum+1 sx=sum+1

发现每次操作之后,如果新加入的是 x x x,那 s x s_x sx 就会变成最大的。
而且无论放什么,它那个位置变成的值都是固定的,所以我们如果要在后面加数,我们肯定是选 s x s_x sx 最小的。

然后因为做的题其实后面加的 m m m 会很大,但也可以搞。

那如果后面要加很多,那我们就不停地选最小的,也就会变成一个循环。

那我们对于一次循环能得到转移矩阵,矩阵快速幂一下,剩下的 < k <k <k 个我们暴力转移即可。

但是有一个小小的问题是我们要如何找到最小的以确定循环转移的顺序,因为有取模是不能直接看。
但是我们考虑到上面的性质,每次操作之后会变成最大的,那我们用 l s t i lst_i lsti 记录 i i i 这个数最晚出现的时间,那我们 l s t lst lst 排序的结果就是我们的顺序。

至于怎么算转移矩阵,你就枚举每个数,分别弄一下, O ( n 3 ) O(n^3) O(n3) 搞好就行。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define mo 1000000007

using namespace std;

const int N = 1e6 + 100;
int n, k, a[N], lst[N], id[N];
int sum, s[101];
ll m;
char tmp[N];

bool cmp(int x, int y) {return lst[x] < lst[y];} 

int add(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int dec(int x, int y) {return x < y ? x - y + mo : x - y;}
int mul(int x, int y) {return 1ll * x * y % mo;}

struct matrix {
	int n, m;
	int a[101][101];
}A, B;

matrix operator *(matrix x, matrix y) {
	matrix re; re.n = x.n; re.m = y.m;
	for (int i = 0; i < re.n; i++)
		for (int j = 0; j < re.m; j++)
			re.a[i][j] = 0;
	for (int k = 0; k < x.m; k++)
		for (int i = 0; i < re.n; i++)
			for (int j = 0; j < re.m; j++)
				re.a[i][j] = add(re.a[i][j], mul(x.a[i][k], y.a[k][j]));
	return re; 
}

matrix ksm(matrix x, ll y) {
	if (!y) {
		matrix re; re.n = k + 1; re.m = k + 1;
		for (int i = 0; i < re.n; i++) for (int j = 0; j < re.m; j++) re.a[i][j] = 0;
		for (int i = 0; i < re.n; i++) re.a[i][i] = 1;
		return re;
	}
	matrix re = x; y--;
	while (y) {
		if (y & 1) re = re * x;
		x = x * x; y >>= 1;
	}
	return re;
} 

ll re; char c;
ll read() {
	re = 0; c = getchar();
	while (c < '0' || c > '9') c = getchar();
	while (c >= '0' && c <= '9') {
		re = (re << 3) + (re << 1) + c - '0';
		c = getchar(); 
	}
	return re;
}

int main() {
	m = read(); k = read();
//	scanf("%d %d %d", &n, &m, &k);
	scanf("%s", tmp + 1); n = strlen(tmp + 1);
	for (int i = 1; i <= n; i++) a[i] = tmp[i] - 'a' + 1;
	
	for (int i = 1; i <= n; i++) {
		ll bef = s[a[i]];
		s[a[i]] = add(sum, 1);
		sum = add(dec(sum, bef), s[a[i]]);
		lst[a[i]] = i;
	}
	
	for (int i = 1; i <= k; i++) id[i] = i;
	sort(id + 1, id + k + 1, cmp);
	
	A.n = 1; A.m = k + 1;
	for (int i = 1; i <= k; i++) A.a[0][i] = s[i]; A.a[0][0] = 1;
	B.n = k + 1; B.m = k + 1;
	for (int i = 0; i <= k; i++) B.a[i][i] = 1;
	for (int i = 1; i <= k; i++) {
		int sum = 1;//sum si
		for (int j = 1; j <= k; j++) {
			int lstsum = sum;
			sum = dec(mul(sum, 2), B.a[i][id[j]]);
			B.a[i][id[j]] = lstsum;
		}
	}
	int sum = 0;//+1
	for (int i = 1; i <= k; i++) {
		int lstsum = sum;
		sum = add(dec(mul(sum, 2), B.a[0][id[i]]), 1);
		B.a[0][id[i]] = add(lstsum, 1);
	}
	
	B = ksm(B, m / k);
	A = A * B;
	sum = 0; for (int i = 1; i <= k; i++) s[i] = A.a[0][i], sum = add(sum, s[i]);
	for (int i = 1; i <= m % k; i++) {
		int now = id[i], bef = s[now];
		s[now] = add(sum, 1);
		sum = add(dec(sum, bef), s[now]);
	}
	printf("%d", add(sum, 1));
	
	return 0;
}