提供一个使用 Tarjan 点双连通分量算法建立圆方树的代码。
需要注意的是,当图中仅有 $2$ 个点时,此算法不再适用,特判一下即可。
C++ 代码
#include <iostream>
#include <cstring>
using namespace std;
typedef pair<int, int> PII;
const int N = 400010, M = N << 3;
int n, m, Q, cnt;
int h[N], nh[N], e[M], w[M], ne[M], idx;
int low[N], dfn[N], time_stamp;
int wgh[N], cyc[N], q[N], stk[N], top;
int dep[N], dist[N], ndist[N], fa[N][21];
inline void add(int *h, int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
void tarjan(int u, int father)
{
dfn[u] = low[u] = ++ time_stamp;
stk[ ++ top] = u;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == father) continue;
if (!dfn[j])
{
wgh[j] = w[i];
dist[j] = dist[u] + w[i];
tarjan(j, u);
low[u] = min(low[u], low[j]);
if (low[j] >= dfn[u])
{
int p = ++ cnt, y = stk[top];
cyc[p] = wgh[y] + dist[y] - dist[u];
add(nh, p, u, 0), add(nh, u, p, 0);
do
{
y = stk[top -- ];
int dis = dist[y] - dist[u];
dis = min(dis, cyc[p] - dis);
add(nh, p, y, dis), add(nh, y, p, dis);
} while (j != y);
}
}
else if (dfn[j] < dfn[u])
wgh[u] = w[i], low[u] = min(low[u], dfn[j]);
}
}
inline void bfs()
{
memset(dep, 0x3f, sizeof dep);
int hh = 0, tt = 0;
dep[0] = 0, dep[1] = 1, q[0] = 1;
while (hh <= tt)
{
int t = q[hh ++ ];
for (int i = nh[t]; ~i; i = ne[i])
{
int j = e[i];
if (dep[j] > dep[t] + 1)
{
dep[j] = dep[t] + 1;
ndist[j] = ndist[t] + w[i];
fa[j][0] = t, q[ ++ tt] = j;
for (int k = 1; k <= 20; k ++ )
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
inline PII lca(int a, int b)
{
if (dep[a] < dep[b]) swap(a, b);
for (int k = 20; k >= 0; k -- )
if (dep[fa[a][k]] >= dep[b])
a = fa[a][k];
if (a == b) return (PII){a, -1};
for (int k = 20; k >= 0; k -- )
if (fa[a][k] != fa[b][k])
a = fa[a][k], b = fa[b][k];
if (fa[a][0] <= n) return (PII){fa[a][0], -1};
return (PII){a, b};
}
int main()
{
int a, b, c, mw = 1e9;
memset(h, -1, sizeof h);
memset(nh, -1, sizeof nh);
scanf("%d%d%d", &n, &m, &Q);
while (m -- )
{
scanf("%d%d%d", &a, &b, &c);
add(h, a, b, c), add(h, b, a, c);
mw = min(mw, c);
}
if (n == 2) return printf("%d\n", mw) & 0;
cnt = n;
for (int i = 1; i <= n; i ++ )
if (!dfn[i]) tarjan(i, -1);
bfs();
while (Q -- )
{
scanf("%d%d", &a, &b);
PII t = lca(a, b);
int res = ndist[a] + ndist[b];
int r1 = t.first, r2 = t.second;
if (r2 == -1) printf("%d\n", res - 2 * ndist[r1]);
else
{
int dis = abs(dist[r1] - dist[r2]);
dis = min(dis, cyc[fa[r1][0]] - dis);
printf("%d\n", res - ndist[r1] - ndist[r2] + dis);
}
}
return 0;
}