斜率优化dp专题小结
优化 小结 专题 DP
2023-06-13 09:15:04 时间
斜率优化dp是一种通过构造斜率表达式,用维护凸包的方法来去除多余的点以减少算法复杂度的方法。通常可以将问题规模减小一个维度,从而提高运行效率。这个算法的关键是将dp的状态转移方程进行转换,比如对于如下状态转移方程:
如果直接dp那么复杂度将会是(O(n_2)),某些情况下就会显得效率不够。这时候就可以用斜率dp进行优化,将其优化到
O(n)。
首先我们需要将状态转移方程进行变形,在计算dp[i]的时候,对于任何x和y,如果x比y更优,那么也就是说:
令:
现在从左到右,设x\lt y\lt z,如果Cmp(z,y)\lt Cmp(y,x),那么y点便永远不可能成为最优解,可以直接将它踢出我们的最优解集。同时,由于sum[i]单调增,所以如果Cmp(y,x)\lt sum[i]那么x点也不可能成为最优解。据此,我们可以便可以通过维护这样的一个队列,每加入一个元素就判断排除所有不可能是最优解的点从而进行优化。斜率优化dp的套路基本是固定的,基本上就是用数组模拟队列,然后两个while循环判断是否可以去除无用的点。
#include<stdio.h>
#include<string.h>
int a[500004];
int dp[500004];
int sum[500004];
int q[500004];
int getUp(int i, int j){
return dp[i] + sum[i] * sum[i] - dp[j] - sum[j] * sum[j];
}
int getDown(int i, int j){
return 2 * (sum[i] - sum[j]);
}
void init(){
memset(a, 0, sizeof a);
memset(dp, 0, sizeof dp);
memset(sum, 0, sizeof sum);
memset(q, 0, sizeof q);
}
int main(){
int n, m;
while (scanf("%d%d", &n, &m) == 2){
init();
for (int i = 1; i <= n; i++){
scanf("%d", &a[i]);
sum[i] = sum[i - 1] + a[i];
}
int head = 0, tail = 0;
q[tail++] = 0;
for (int i = 1; i <= n; i++){
while (head + 1 < tail){
int a1 = q[head], a2 = q[head + 1];
if (getUp(a2, a1) <= getDown(a2, a1)*sum[i]){
head++;
}
else{
break;
}
}
int k=q[head];
dp[i] = dp[k] + m + (sum[i] - sum[k])*(sum[i] - sum[k]);
q[tail++] = i;
while (head + 2 < tail){
int a1 = q[tail - 3];
int a2 = q[tail - 2];
int a3 = q[tail - 1];
if (getUp(a2, a1)*getDown(a3, a2) >= getDown(a2, a1)*getUp(a3, a2)){
tail-=2;
q[tail++]=a3;
}
else{
break;
}
}
}
printf("%dn", dp[n]);
}
}
#include<cstdio>
#include<cstring>
using namespace std;
int a[10004];
int dp[5004][10004];
int q[10004];
int getUp(int i,int x,int y){
return dp[i - 1][x] - dp[i - 1][y] +a[x+1]*a[x+1]-a[y+1]*a[y+1];
}
int getDown(int x, int y){
return a[x + 1] - a[y + 1];
}
int getRight(int i){
return 2 * a[i];
}
int main(){
int tt;
scanf("%d", &tt);
for (int t = 1; t <= tt; t++){
int n, m;
memset(a, 0, sizeof a);
memset(dp, 0, sizeof dp);
memset(q, 0, sizeof q);
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++){
scanf("%d", &a[i]);
}
sort(a + 1, a + 1 + n);
for (int i = 1; i <= n; i++){
dp[1][i] = (a[i] - a[1])*(a[i] - a[1]);
}
for (int i = 2; i <= m; i++){
int head =0,tail = 0;
q[tail++] = i - 1;
for (int j = i; j <= n; j++){
while (head + 1 < tail){
int a1 = q[head];
int a2 = q[head+1];
if (getUp(i, a2, a1) <= getDown(a2, a1)*getRight(j)){
head++;
}
else{
break;
}
}
int k = q[head];
dp[i][j] = dp[i-1][k]+(a[j]-a[k+1])*(a[j]-a[k+1]);
q[tail++] = j;
while (head + 2 < tail){
int a1 = q[tail - 3];
int a2 = q[tail - 2];
int a3 = q[tail - 1];
if (getUp(i,a2,a1)*getDown(a3,a2)>=getUp(i,a3,a2)*getDown(a2,a1)){
tail -= 2;
q[tail++] = a3;
}
else{
break;
}
}
}
}
printf("Case %d: %dn",t, dp[m][n]);
}
}
#include<cstdio>
#include<cstring>
using namespace std;
int sum[1004];
int a[1004];
int dp[1004][1004];
int q[100004];
int getUp(int i, int x, int y){
return dp[i - 1][x] - dp[i - 1][y]-dp[0][x]+dp[0][y]+sum[x]*sum[x]-sum[y]*sum[y];
}
int getDown(int x, int y){
return sum[x] - sum[y];
}
int main(){
int n, m;
while (scanf("%d%d", &n, &m), n || m){
memset(sum, 0, sizeof sum);
memset(dp, 0, sizeof dp);
memset(q, 0, sizeof q);
memset(a, 0, sizeof a);
for (int i = 1; i <= n; i++){
scanf("%d", &a[i]);
sum[i] = sum[i - 1] + a[i];
dp[0][i] = dp[0][i - 1] + a[i] * sum[i - 1];
}
for (int i = 1; i <= m; i++){
int head = 0, tail = 0;
q[tail++] = 0;
for (int j = 1; j <= n; j++){
while (head + 1 < tail){
int a1 = q[head];
int a2 = q[head + 1];
if (getUp(i, a2, a1) <= getDown(a2, a1)*sum[j]){
head++;
}
else{
break;
}
}
int k = q[head];
dp[i][j] = dp[i - 1][k] +dp[0][j]-dp[0][k]- (sum[j] - sum[k])*sum[k];
q[tail++] = j;
while (head + 2 < tail){
int a1 = q[tail - 3];
int a2 = q[tail - 2];
int a3 = q[tail - 1];
if (getUp(i, a2, a1)*getDown(a3, a2) >= getUp(i, a3, a2)*getDown(a2, a1)){
tail -= 2;
q[tail++] = a3;
}
else{
break;
}
}
}
}
printf("%dn", dp[m][n]);
}
}
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
long long a[400004];
long long sum[400004];
long long dp[400004];
long long q[400004];
long long getUp(long long x, long long y){
return dp[x] - dp[y] - sum[x] + sum[y] + a[x+1] * x - a[y+1] * y;
}
long long getDown(long long x, long long y){
return a[x + 1] - a[y + 1];
}
int main(){
long long n, m;
while (scanf("%I64d%I64d", &n, &m) == 2){
memset(dp, 0, sizeof dp);
memset(q, 0, sizeof q);
memset(sum, 0, sizeof sum);
memset(a, 0, sizeof a);
for (long long i = 1; i <= n; i++){
scanf("%I64d", &a[i]);
}
sort(a + 1, a + 1 + n);
for (long long i = 1; i <= n; i++){
sum[i] = sum[i - 1] + a[i];
}
long long head = 0, tail = 0;
q[tail++] = 0;
for (long long i = 1; i <= n; i++){
while (head + 1 < tail){
long long a1 = q[head], a2 = q[head + 1];
if (getUp(a2, a1) <= getDown(a2, a1)*i){
head++;
}
else{
break;
}
}
long long k = q[head];
dp[i] = dp[k] + sum[i] - sum[k] - a[k + 1] * (i - k);
if (i - m + 1 < m){
continue;
}
q[tail++] = i - m + 1;
while (head + 2 < tail){
long long a1 = q[tail - 3];
long long a2 = q[tail - 2];
long long a3 = q[tail - 1];
if (getUp(a2, a1)*getDown(a3, a2) >= getDown(a2, a1)*getUp(a3, a2)){
tail -= 2;
q[tail++] = a3;
}
else{
break;
}
}
}
printf("%I64dn", dp[n]);
}
}
相关文章
- SQL Server 2016 查询存储性能优化小结
- 灵活的Redis:利用索引实现性能优化(redis索引)
- Linux多线程调度优化实践(linux多线程调度)
- MySQL分页查询优化:提升查询效率(mysql分页查询优化)
- MSSQL技术的233层优化:高效提升系统性能(mssql233)
- MacOS快速优化,提升性能(macos优化)
- AVX技术带来的Linux运算优化(avxlinux)
- 优化Linux操作系统界面优化:提升体验的利器(linux操作系统的界面)
- 优化MSSQL服务器:缓存设置实践(mssql缓存设置)
- 100万数据库极致优化经验秘籍(100万mysql优化)
- 海量数据库的查询优化及分页算法方案集合2/2
- php程序效率优化的一些策略小结
- 一次失败的jQuery优化尝试小结
- SQLServer游标的使用/关闭/释放/优化小结