思路
本题需要同时求最短路和次短路,传统的一维dis扩展到二维,增加一维来表示最短路还是次短路,
同时需要维护最短路和次短路的数量
避免重复更新的条件是dis[u][t] < d,保证一种类型的最短路就更新其他节点一次
代码
from heapq import heappop, heappush
from math import inf
from collections import defaultdict
def dj():
dis = [[inf,inf] for _ in range(n + 1)]
cnt = [[0,0] for _ in range(n + 1)]
dis[s][0] = 0
cnt[s][0] = 1
q = [(0,s,0)] # d,u,type( 0:最短路 1:次短路)
while q:
d,u,t = heappop(q)
if dis[u][t] < d:
continue
for v,w in g[u]:
# 尝试更新最短路
if dis[v][0] > d + w:
cnt[v] = [cnt[u][t],cnt[v][0]]
dis[v][1] = dis[v][0]
dis[v][0] = d + w
heappush(q,(dis[v][0],v,0))
heappush(q,(dis[v][1],v,1))
elif dis[v][0] == d + w:
cnt[v][0] += cnt[u][t]
# 尝试更新次短路
elif dis[v][1] > d + w:
cnt[v][1] = cnt[u][t]
dis[v][1] = d + w
heappush(q,(dis[v][1],v,1))
elif dis[v][1] == d + w:
cnt[v][1] += cnt[u][t]
if dis[f][1] - dis[f][0] == 1:
return sum(cnt[f])
else:
return cnt[f][0]
for _ in range(int(input())):
n,m = map(int,input().split())
g = defaultdict(list)
for _ in range(m):
a,b,l = map(int,input().split())
g[a].append((b,l))
s,f = map(int,input().split())
print(dj())