C++ 代码
#include<bits/stdc++.h>
using namespace std;
const int N = 6e3 + 10;
//h[N]内部存的是idx,e[N], ne[N]是用于描述idx的,(idx,e[N], ne[N]共同组成节点)
//但h[N]的下标存的是题目中的点
int h[N], e[N], ne[N], idx;
int happy[N];
int f[N][2];
bool has_fa[N];
//添加有向边 u->v
void add(int u, int v)//因为只有上司有多位下属,而下属只有一位上司,因此邻接表的方向是父节点->子节点
{
e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}
void dfs(int u)
{
f[u][1] = happy[u];
for (int i = h[u]; i != -1; i = ne[i])
{
int v = e[i];
dfs(v);//计算方向为:自下而上,因此必须先计算好子节点,然后以此计算父节点
f[u][1] += f[v][0];
f[u][0] += max(f[v][0], f[v][1]);
}
}
int main()
{
int n; cin >> n;
for (int i = 1; i <= n; i++)scanf("%d", &happy[i]);
memset(h, -1, sizeof(h));
for (int i = 0; i < n - 1; i++)
{
int a, b; scanf("%d%d", &a, &b);
add(b, a);
has_fa[a] = 1;
}
int root = 1;//注意:此为题中的节点(1~N)而不是idx
while (has_fa[root])root++;
dfs(root);
cout << max(f[root][0], f[root][1]);
}