树形DP
由于是一个有向无环图且入度为 $0$ 的点只有 $1$ 个,可以用两次 $DFS $解决:
-
第一次 $DFS$ 从入度为 $0$ 的点出发,找出每个点往下走的最长的链的长度,要么就是 $f[u] $,要么是子节点 $f[j] + 1$,取最大值。
-
找到每个节点的最长的链的长度后,第二次 $DFS$ 从入度为 $0$ 的点出发,依次找到子节点中满足 $f[j] = f[u] - 1$ 的结点的最小值(题目要求字典序最小),继续递归。
-
在第二次 $DFS$ 时,边递归边存入结果,最后输出答案。
附C++,Java,Go代码
C++代码:
#include <iostream>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 10010;
int h[N], e[N], ne[N];
int d[N], f[N];
vector<int> res;
int n, idx;
void add(int a, int b)
{
e[idx] = b;
ne[idx] = h[a];
h[a] = idx ++;
}
void dfs(int u, int fa)
{
f[u] = 1;
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
dfs(j, u);
f[u] = max(f[u], f[j] + 1);
}
}
void dfs_g(int u, int fa)
{
int minv = 1e9;
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
if (f[j] == f[u] - 1)
minv = min(minv, j);
}
if (minv != 1e9) { // 不是叶子节点时
res.push_back(minv);
dfs_g(minv, u);
}
}
int main()
{
cin >> n;
memset(h, -1, sizeof h);
for (int i = 0; i < n; i ++)
{
int k, x;
cin >> k;
while (k -- )
{
cin >> x;
add(i, x);
d[x] ++;
}
}
int g = -1;
for (int i = 0; i < n; i ++)
if (d[i] == 0)
g = i;
dfs(g, -1);
cout << f[g] << endl; // 最长的链的长度
res.push_back(g);
dfs_g(g, -1);
for (int x : res) cout << x << " ";
return 0;
}
Java代码:
import java.util.*;
public class Main {
static final int N = 10010;
static int[] h = new int[N];
static int[] e = new int[N], ne = new int[N];
static int[] d = new int[N]; // 入度
static int[] f = new int[N]; // 存每个点的最长的链的长度
static List<Integer> res = new ArrayList<>();
static int n, idx;
static void add(int a, int b) {
e[idx] = b;
ne[idx] = h[a];
h[a] = idx ++;
}
static void dfs(int u, int fa) {
f[u] = 1;
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
dfs(j, u);
f[u] = Math.max(f[u], f[j] + 1);
}
}
static void dfs_g(int u, int fa) {
int minv = 1_000_000_000;
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
if (f[j] == f[u] - 1)
minv = Math.min(minv, j);
}
if (minv != 1_000_000_000) { // 不是叶子节点时
res.add(minv);
dfs_g(minv, u);
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
Arrays.fill(h, -1);
for (int i = 0; i < n; i ++) {
int k = sc.nextInt();
while (k -- > 0) {
int x = sc.nextInt();
add(i, x);
d[x] ++;
}
}
int g = -1;
for (int i = 0; i < n; i ++)
if (d[i] == 0)
g = i;
dfs(g, -1);
System.out.println(f[g]); // 最长的链的长度
res.add(g);
dfs_g(g, -1);
for (int x : res)
System.out.printf("%d ", x);
}
}
Go代码:
package main
import (
. "fmt"
)
const N = 10010
var (
h = make([] int, N)
e = make([] int, N)
ne = make([] int, N)
d = make([] int, N)
f = make([] int, N)
res []int
n, idx int
)
func add(a, b int) {
e[idx] = b
ne[idx] = h[a]
h[a] = idx
idx ++
}
func dfs(u, fa int) {
f[u] = 1
for i := h[u]; i != -1; i = ne[i] {
j := e[i]
if j == fa {
continue
}
dfs(j, u)
f[u] = max(f[u], f[j] + 1)
}
}
func dfs_g(u, fa int) {
minv := 1_000_000_000
for i := h[u]; i != -1; i = ne[i] {
j := e[i]
if j == fa {
continue
}
if f[j] == f[u] - 1 {
minv = min(minv, j)
}
}
if minv != 1_000_000_000 {
res = append(res, minv)
dfs_g(minv, u)
}
}
func main() {
Scan(&n)
for i := 0; i < n; i ++ {
h[i] = -1
}
for i := 0; i < n; i ++ {
var k, x int
Scan(&k)
for j := 0; j < k; j ++ {
Scan(&x)
add(i, x)
d[x]++
}
}
g := -1
for i := 0; i < n; i ++ {
if d[i] == 0 {
g = i
}
}
dfs(g, -1)
Println(f[g])
res = append(res, g)
dfs_g(g, -1)
for _, x := range res {
Printf("%d ", x)
}
}
func max(a, b int) int { if a > b { return a }; return b }
func min(a, b int) int { if a < b { return a }; return b }