zl程序教程

您现在的位置是:首页 >  其他

当前栏目

【UNR #6 D】小火车(折半搜索)(二分)

搜索 二分 折半 火车
2023-09-27 14:28:30 时间

小火车

题目链接:UNR #6 D

题目大意

给你一个序列,你要构造一个只有 0,1,-1 的序列,使得两个序列每一项乘起来的和为 p 的倍数。
其中保证 p 小于 2^n,n 为序列长度。

思路

首先由一个 3 n 3^n 3n 的暴搜。
考虑 p < 2 n p<2^n p<2n,那我们会发现如果你就看 0 , + 1 0,+1 0,+1,如果搜到两个结果,那我们就可以第一个结果减去第二个结果(那样就会有 0 , 1 , − 1 0,1,-1 0,1,1),那就可以构造出方案。
那如果没有答案,那 2 n 2^n 2n 个全部都要不一样的,但是只有 p p p 个位置,所以是一定有答案的。

然后考虑通过这个来搞,那问题就是如何找到重复的位置 p p p
考虑 s ( l , r ) s(l,r) s(l,r) 2 n 2^n 2n 中有多少个是在 l ∼ r l\sim r lr 之间。
然后因为 s ( l , r ) > r − l + 1 s(l,r)>r-l+1 s(l,r)>rl+1,所以 s ( l , m i d ) > m i d − l + 1 s(l,mid)>mid-l+1 s(l,mid)>midl+1 s ( m i d + 1 , r ) > r − m i d s(mid+1,r)>r-mid s(mid+1,r)>rmid 之间一定有一个成立,不然大的就不成立,那我们就可以通过这个二分出 p p p 的位置,然后就可以找到方案了。

然后考虑怎么求,直接 2 n 2^n 2n 肯定不行,考虑到 n n n 40 40 40,于是可以折半搜索把两边的情况搜出来,然后双指针就可以找到结果了。(其实一个是走的指针,另一个维护的指针是一个区间)
(然后注意是模 p p p,所以你要看 l ∼ r l\sim r lr l + p ∼ r + p l+p\sim r+p l+pr+p 的结果的和)

代码

#include<cstdio>
#include<vector>
#include<algorithm>
#define ll long long

using namespace std;

const int N = 41;
int n, B, ans[N];
ll p, a[N];
vector <pair<ll, int> > x, y;

ll get_num(ll l, ll r) {
	ll re = 0; int jl = 0, jr = -1;
	for (int i = x.size() - 1; i >= 0; i--) {
		while (jr < (int)y.size() - 1 && x[i].first + y[jr + 1].first <= r) jr++;
		while (jl < y.size() && x[i].first + y[jl].first < l) jl++;
		if (jl <= jr) re += jr - jl + 1; 
	}
	for (int i = x.size() - 1; i >= 0; i--) {
		while (jr < (int)y.size() - 1 && x[i].first + y[jr + 1].first <= r + p) jr++;
		while (jl < y.size() && x[i].first + y[jl].first < l + p) jl++;
		if (jl <= jr) re += jr - jl + 1; 
	}
	return re;
}

int main() {
	scanf("%d %lld", &n, &p);
	for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
	
	B = n / 2;
	for (int i = 0; i < 1 << B; i++) {
		ll val = 0; for (int j = 1; j <= B; j++) if ((i >> (j - 1)) & 1) (val += a[j]) %= p;
		x.push_back(make_pair(val, i));
	}
	sort(x.begin(), x.end());
	for (int i = 0; i < 1 << (n - B); i++) {
		ll val = 0; for (int j = 1; j <= (n - B); j++) if ((i >> (j - 1)) & 1) (val += a[B + j]) %= p;
		y.push_back(make_pair(val, i));
	}
	sort(y.begin(), y.end());
	
	ll L = 0, R = p - 1;
	while (L < R) {
		ll mid = (L + R) >> 1;
		if (get_num(L, mid) > mid - L + 1) R = mid;
			else L = mid + 1;
	}
	ll pl = L;
	
	int yes = 0; int j = 0;
	for (int i = x.size() - 1; i >= 0; i--) {
		while (j < y.size() - 1 && x[i].first + y[j].first < pl) j++;
		if (x[i].first + y[j].first != pl) continue;
		if (!yes) {
			yes = 1;
			for (int u = 1; u <= B; u++) if ((x[i].second >> (u - 1)) & 1) ans[u]++;
			for (int u = 1; u <= n - B; u++) if ((y[j].second >> (u - 1)) & 1) ans[B + u]++;
		}
		else {
			yes = 2;
			for (int u = 1; u <= B; u++) if ((x[i].second >> (u - 1)) & 1) ans[u]--;
			for (int u = 1; u <= n - B; u++) if ((y[j].second >> (u - 1)) & 1) ans[B + u]--;
			break;
		}
		if (j < y.size() - 1 && y[j + 1].first == y[j].first) { j++;
			if (!yes) {
				yes = 1;
				for (int u = 1; u <= B; u++) if ((x[i].second >> (u - 1)) & 1) ans[u]++;
				for (int u = 1; u <= n - B; u++) if ((y[j].second >> (u - 1)) & 1) ans[B + u]++;
			}
			else {
				yes = 2;
				for (int u = 1; u <= B; u++) if ((x[i].second >> (u - 1)) & 1) ans[u]--;
				for (int u = 1; u <= n - B; u++) if ((y[j].second >> (u - 1)) & 1) ans[B + u]--;
				break;
			}
		}
	}
	if (yes != 2) {
		for (int i = x.size() - 1; i >= 0; i--) {
			while (j < y.size() - 1 && x[i].first + y[j].first < pl + p) j++;
			if (x[i].first + y[j].first != pl + p) continue;
			if (!yes) {
				yes = 1;
				for (int u = 1; u <= B; u++) if ((x[i].second >> (u - 1)) & 1) ans[u]++;
				for (int u = 1; u <= n - B; u++) if ((y[j].second >> (u - 1)) & 1) ans[B + u]++;
			}
			else {
				for (int u = 1; u <= B; u++) if ((x[i].second >> (u - 1)) & 1) ans[u]--;
				for (int u = 1; u <= n - B; u++) if ((y[j].second >> (u - 1)) & 1) ans[B + u]--;
				break;
			}
			if (j < y.size() - 1 && y[j + 1].first == y[j].first) { j++;
				if (!yes) {
					yes = 1;
					for (int u = 1; u <= B; u++) if ((x[i].second >> (u - 1)) & 1) ans[u]++;
					for (int u = 1; u <= n - B; u++) if ((y[j].second >> (u - 1)) & 1) ans[B + u]++;
				}
				else {
					for (int u = 1; u <= B; u++) if ((x[i].second >> (u - 1)) & 1) ans[u]--;
					for (int u = 1; u <= n - B; u++) if ((y[j].second >> (u - 1)) & 1) ans[B + u]--;
					break;
				}
			}
		}
	}
	
	for (int i = 1; i <= n; i++) printf("%d ", ans[i]);
	
	return 0;
}