多校1题解I/区间dp

区间 DP 豪题!

题面

这道题目我们首先看 n 的取值范围n <= 420所以我们可以用区间 dp来解决这道题目

首先,我们定义一个数据类型为pair<int, int>(也可以用 array<int, 2> 都是一个道理的 )的三维数组,pair<int, int>first代表的是不平衡度 bsecond代表的是代价,重点dp[i][j][k] 中的 k代表的是有多少个 pair<int, int> 数据,每一个数据代表的是第 i 个铁棒到第 j 个铁棒这一个长铁棒,用两个小铁棒组成当前长铁棒平衡度小于等于 dp[i][j][k].first最小代价dp[i][j][k].second是多少,所以我们每次在结束状态转移的后,我们会对 dp[i][j] 按照 dp[i][j][k].first 从小到大进行排序,然后遍历一遍将当前数据的 dp[i][j][k].second 与前一个数据进行 min 操作。

然后,我可以遍历一遍长铁棒的“长度”(就是由几个铁棒组成),长铁棒的起始点,还有长铁棒分成两个短铁棒的切割位置(前一段短铁棒到哪个位置),这个就是我们代码中的。

1
2
3
4
5
6
7
8
9
for(int len = 1; len <= n; len ++) {
for(int i = 1; i + len - 1 <= n; i ++) {
for(int k = i; k < j; k ++) {
int x = k;
//前一段短铁棒就是第i个铁棒到第x个铁棒
//后一段短铁棒就是第x + 1个铁棒到第j个铁棒
}
}
}

我们很好想到,假如 len == 1 的情况下实际上是不需要任何代价的,所以我们只需要向其中加入 {0, 0}即可。但是当 len != 1 时,我们就要找到我们的状态转移方程了,首先我们需要计算出他们的平衡度 b 我们需要求出两段短铁棒的长度,但是一个一个加很明显不是一个明智的选择,所以我们使用前缀和来进行优化,sum[i]代表前 i 个铁棒的长度和。那么我们就可以得出他们的平衡度应该是 b = (sum[j] - sum[x] - (sum[x] - sum[i - 1])) = sum[j] + sum[i -1] - sum[x] * 2,然后我们就计算他们的代价 cost = i到x这个铁棒平衡度小于等于b且最接近b的情况下的最小代价 + (x + 1)到j这个铁棒平衡度小于等于b且最接近b的情况下的最小代价 + 当前切割代价 。那么我们该如何得到”i 到 x 这个铁棒平衡度小于等于 b 且最接近 b 的情况下的最小代价“呢?这个就要追溯到我们前面讲的 dp[i][j][k] 的意义了,这个代表的是从第 i 个铁棒到第 j 个铁棒组成的长铁棒,在用两个小铁棒组成长铁棒平衡度小于等于 dp[i][j][k].first 的情况下所需要的最小代价是 dp[i][j][k].second。这里可能很多朋友回想我这两个小铁棒的平衡度 dp[i][j][k].first会不会起冲突,这个肯定是不会的,题目中要求平衡度 b1 > b2 >...... > bn - 1,但是谁先谁后又没有规定,谁大谁先呗,反正他们的平衡度肯定都是小于等于我当前两小铁棍的平衡度,所以我们不需要担心这个问题。还有一点我们为什么要找平衡度最接近 b的数据呢,因为数据越多代价数据也就越多,可能我想要的最小代价就在最近 b平衡度的位置呢,所以我为了找到这个数据,我使用二分查找来找(我的 dp[i][j]中的数据是有序的)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
auto find = [&](int i, int j, int b) -> int {
if (dp[i][j].empty()) return INF;
if (dp[i][j][0][0] > b) return INF;

int l = 0, r = (int)dp[i][j].size() - 1;
int ans = 0;
while (l <= r) {
int mid = l + (r - l) / 2;
if (dp[i][j][mid][0] <= b) {
l = mid + 1;
ans = mid;
}
else
r = mid - 1;
}
return dp[i][j][ans][1];
};

那么,我的 find(i, x, b) 代表的就是 ”i 到 x 这个铁棒平衡度小于等于 b 且最接近 b 的情况下的最小代价“,那我的代价应该是 find(i, x, b) + find(x + 1, j, b) + min(sum[j] - sum[x], sum[x] - sum[i - 1]) * log2((sum[j] - sum[i - 1]) * 2 - 1),然后将这一组 {b, cost}数据 push_backdp[i][j]中。

最后,我们对 dp[i][j]进行排序,然后从头到尾遍历 dp[i][j][k].second = min(dp[i][j][k].second, dp[i][j][k - 1].second)

其他的就没啥问题了,我们要找的答案应该是从第 1 个铁棒到第 n个铁棒,从第 k 个铁棒截断,所需要的最小价值,也就是 find(i, k, b) + find(x + 1, n, b) + log2(...) * min(...)

下面是源代码:

我用的是array<int, 2>,要是要用 pair<int, int> 的话可以自己换一下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <bits/stdc++.h>
using namespace std;
#define int long long
using ll = long long;
using ull = unsigned long long;

const int N = 2e5 + 5;
// const int M = 1e2 + 5;
const int MOD = 998244353;
const int INF = 1e15;

void solve() {
int n;
cin >> n;
vector<int> sum(n + 5, 0);
for (int i = 1; i <= n; i++) {
cin >> sum[i];
sum[i] += sum[i - 1];
}

vector dp(n + 5, vector<vector<array<int, 2>>>(n + 5));

auto find = [&](int i, int j, int b) -> int {
if (dp[i][j].empty()) return INF;
if (dp[i][j][0][0] > b) return INF;

int l = 0, r = (int)dp[i][j].size() - 1;
int ans = 0;
while (l <= r) {
int mid = l + (r - l) / 2;
if (dp[i][j][mid][0] <= b) {
l = mid + 1;
ans = mid;
}
else
r = mid - 1;
}
return dp[i][j][ans][1];
};

vector<int> ans(n + 5, INF);
for (int len = 1; len <= n; len++) {
for (int i = 1; i + len - 1 <= n; i++) {
int j = i + len - 1;
// cout << i << " " << j << "\n";
if (len == 1)
dp[i][j].push_back({0, 0});
else {
// int p = dp[i][j].size();
// cout << p << "f\n";
for (int k = i; k < j; k++) {
int x = k;
int b = abs(sum[j] + sum[i - 1] - 2 * sum[x]);
int lg = log2((sum[j] - sum[i - 1]) * 2 - 1);
int cost = find(i, x, b) + find(x + 1, j, b) +
min(sum[j] - sum[x], sum[x] - sum[i - 1]) * lg;
ans[k] = cost;
ans[k] = min(INF, ans[k]);
dp[i][j].push_back({b, cost});
}

sort(dp[i][j].begin(), dp[i][j].end());

for (int k = 1; k < dp[i][j].size(); k++)
dp[i][j][k][1] = min(dp[i][j][k][1], dp[i][j][k - 1][1]);

if (i == 1 && j == n) {
for (int k = 1; k < n; k++) {
if (ans[k] == INF)
cout << "-1 ";
else
cout << ans[k] << " ";
}
cout << "\n";
}
}
}
}
}

signed main() {
#ifdef local
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);

int __;
__ = 1;
cin >> __;
while (__--) {
solve();
}

return 0;
}