银牌换根dp。
给定n个点的树,每个点有权值ai。选两条链,将链上相交点除外,取贡献和。求最大贡献。
相交最多一个点相交,枚举所有点,取上下共四条最大单链即可;不相交时是两条独立的链,枚举所有点作为两条链的最高点,再枚举该点的所有邻点,取邻点做最高点时的最大双链后去掉邻点,上下取两条剩余最长链即可。主要难点在于邻点是父节点时,最大双链可以包含自己,要剔除这一部分。处理这种最大双链的方式是:将父节点所有子节点的双链最大值(不包含自己)和父节点下取两最长链最大值(不包含自己)取max即可。
m1表示向下第1大单链,m2表示向下第2大单链,以此类推。md[i]表示i做最高点最大双链。md1表示子节点中双链最大值,md2表示子节点中双链次大值。
总时间复杂度O(n)。
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define deb(n) cout << #n << "=" << n << " "
#define debug(n) cout << #n << "=" << n << endl
#define div() cout << "_______________" << endl
const int N = 2e5 + 10;
int n, ans;
int A[N], md[N], md1[N], md2[N], mu[N], m1[N], m2[N], m3[N], m4[N];
vector<int> v[N];
void dfs_d(int pre, int from) {
md[pre] = md1[pre] = md2[pre] = m1[pre] = m2[pre] = m3[pre] = m4[pre] = 0;
for (auto& to : v[pre]) {
if (to == from) continue;
dfs_d(to, pre);
if (m1[to] >= m1[pre]) m4[pre] = m3[pre], m3[pre] = m2[pre], m2[pre] = m1[pre], m1[pre] = m1[to];
else if (m1[to] >= m2[pre]) m4[pre] = m3[pre], m3[pre] = m2[pre], m2[pre] = m1[to];
else if (m1[to] >= m3[pre]) m4[pre] = m3[pre], m3[pre] = m1[to];
else if (m1[to] > m4[pre]) m4[pre] = m1[to];
md[pre] = max(md[pre], md[to]);
if (md[to] >= md1[pre]) md2[pre] = md1[pre], md1[pre] = md[to];
else if (md[to] > md2[pre]) md2[pre] = md[to];
}
md[pre] = max(md[pre], A[pre] + m1[pre] + m2[pre]);
m1[pre] += A[pre], m2[pre] += A[pre], m3[pre] += A[pre], m4[pre] += A[pre];
}
void dfs_u(int pre, int from) {
if (pre == 1) mu[pre] = 0;
else if (m1[pre] + A[from] == m1[from]) mu[pre] = A[pre] + max(mu[from], m2[from]);
else mu[pre] = A[pre] + max(mu[from], m1[from]);
for (auto& to : v[pre]) {
if (to == from) continue;
dfs_u(to, pre);
}
}
void dfs_j(int pre, int from) {
vector<int> save;
save.push_back(mu[pre] - A[pre]);
save.push_back(m1[pre] - A[pre]);
save.push_back(m2[pre] - A[pre]);
save.push_back(m3[pre] - A[pre]);
save.push_back(m4[pre] - A[pre]);
sort(save.begin(), save.end());
reverse(save.begin(), save.end());
int sum = 0;
for (int i = 0; i < 4; i++) sum += save[i];
ans = max(ans, sum);
for (auto& to : v[pre]) {
if (to == from) continue;
dfs_j(to, pre);
}
}
void dfs_n(int pre, int from) {
vector<int> save;
save.push_back(mu[pre] - A[pre]);
save.push_back(m1[pre] - A[pre]);
save.push_back(m2[pre] - A[pre]);
save.push_back(m3[pre] - A[pre]);
save.push_back(m4[pre] - A[pre]);
sort(save.begin(), save.end());
reverse(save.begin(), save.end());
for (auto& to : v[pre]) {
if (to == from) continue;
if (m1[to] == save[0]) ans = max(ans, A[pre] + md[to] + save[1] + save[2]);
else if (m1[to] == save[1]) ans = max(ans, A[pre] + md[to] + save[0] + save[2]);
else ans = max(ans, A[pre] + md[to] + save[0] + save[1]);
}
if (pre != 1) {
int shuanglian = 0;
if (md[pre] == md1[from]) shuanglian = md2[from];
else shuanglian = md1[from];
int xiangqie = 0;
if (m1[pre] == m1[from] - A[from]) xiangqie = A[from] + m2[from] - A[from] + m3[from] - A[from];
else if (m1[pre] == m2[from] - A[from]) xiangqie = A[from] + m1[from] - A[from] + m3[from] - A[from];
else xiangqie = A[from] + m1[from] - A[from] + m2[from] - A[from];
ans = max(ans, A[pre] + m1[pre] - A[pre] + m2[pre] - A[pre] + max(shuanglian, xiangqie));
}
for (auto& to : v[pre]) {
if (to == from) continue;
dfs_n(to, pre);
}
}
void oper() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> A[i];
for (int i = 1; i < n; i++) {
int a, b; cin >> a >> b;
v[a].push_back(b);
v[b].push_back(a);
}
if (n == 1) {
cout << "0" << endl;
return;
}
dfs_d(1, -1);
dfs_u(1, -1);
dfs_j(1, -1);
dfs_n(1, -1);
cout << ans << endl;
}
signed main() {
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
int t = 1; //cin >> t;
while (t--) oper();
return 0;
}