代码底部有我的见解。
#include <bits/stdc++.h>
#define int long long
using namespace std;
int c, t, v[15], ans, flag[110000], tot, num[10], vis[10];
pair<int, int> dp[110000][10];
string s;
signed main() {
scanf("%lld", &c);
scanf("%lld", &t);
while (t--) {
ans = 0;
cin >> s;
s = " " + s;
for (int i = 1; i <= 9; i++) {
scanf("%lld", &v[i]);
}
int len = (int)s.size() - 1;
for (int i = 0; i <= len; i++) {
for (int j = 0; j <= 6; j++) {
dp[i][j].first = 0;
dp[i][j].second = 0;
dp[0][j+1].second=-1e18;
}
flag[i] = 1;
}
for (int i = 1; i <= len; i++) {
for (int j = 0; j <= 6; j++) {
dp[i][j] = dp[i-1][j];
if (j && dp[i][j].second < dp[i-1][j-1].second+v[s[i]-'0']-(dp[i-1][j-1].first*10+s[i]-'0')+dp[i-1][j-1].first) {
dp[i][j].second = dp[i-1][j-1].second+v[s[i]-'0']-(dp[i-1][j-1].first*10+s[i]-'0')+dp[i-1][j-1].first;
dp[i][j].first = dp[i-1][j-1].first*10+s[i]-'0';
}
}
}
int mx = 0, shu = -1, Ans = 0;
for (int i = 1; i <= 6; i++) {
if (mx < dp[len][i].second) {
shu = dp[len][i].first;
mx = dp[len][i].second;
}
}
tot = 0;
if (shu != -1) {
while (shu) {
num[++tot] = shu % 10;
shu /= 10;
}
for (int i = 1; i <= tot; i++) {
vis[i] = 0;
}
for (int i = 1; i <= len; i++) {
for (int j = 1; j <= tot; j++) {
if (s[i] - '0' == num[j] && !vis[j] && flag[i]) {
flag[i] = 0;
vis[j] = 1;
Ans += v[s[i] - '0'];
}
}
}
Ans = Ans - mx;
}
for (int i = 1; i <= len; i++) {
if (flag[i]) {
ans += v[s[i] - '0'];
}
}
printf("%lld\n", ans + Ans);
}
return 0;
}