自己的一点理解
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=5e5;
int e[N],ne[N],h[N],idx;
int maxu,res[N],maxd;
int d1[N],d2[N],up[N],next1[N];
//d1表示向下最长的距离,d2表示次长,up表示向上最长距离,next1表示向下最长下一个点是哪个
void add(int a,int b)
{
e[idx]=b;
ne[idx]=h[a];
h[a]=idx++;
}
void dfs_d(int u,int fa)
{
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j!=fa)
{
dfs_d(j,u);
int dis=d1[j]+1;
if(dis>d1[u])
{
d2[u]=d1[u];
d1[u]=dis;
next1[u]=j;
}
else if(dis>d2[u])
d2[u]=dis;
}
}
maxd=max(maxd,d1[u]+d2[u]);
}
void dfs_up(int u,int fa)
{
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j!=fa)
{
up[j]=up[u]+1;
//如果在向下最长路径上,与经过点u再走到次长路径的路径比较
if(next1[u]==j)
up[j]=max(up[j],d2[u]+1);
else
up[j]=max(up[j],d1[u]+1);
dfs_up(j,u);
}
}
}
int main()
{
int n;
scanf("%d",&n);
memset(h,-1,sizeof h);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
dfs_d(0,-1);
dfs_up(0,-1);
for(int i=0;i<n;i++)
{
int d[3]={d1[i],d2[i],up[i]};
sort(d,d+3);
if(d[1]+d[2]==maxd)
printf("%d\n",i);
}
return 0;
}