根据@ S7former佬的题解理解了一下,打了点注释,存一下思路
/*
用dijkstra算出原来每个点到 1的最短距离
然后遍历所有点,找到中转点,
使得中转点到当前点距离 + 中转点到 1距离 == 原来的最短距离
其中 d[i]表示 i号点到 1号点的距离
*/
#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
#define f first
#define s second
typedef pair<int, int> PII;
const int N = 1e4+10, M = 2e5+10;
int h[N], ne[M], e[M], idx, w[M];
int n, m;
int d[N];
bool st[N];
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
void dijkstra()
{
memset(d, 0x3f, sizeof d);
d[1] = 0;
priority_queue<PII, vector<PII>, greater<PII> > q;
q.push({0, 1});
while(q.size())
{
auto t = q.top();
q.pop();
if(st[t.s]) continue;
st[t.s] = true;
for(int i = h[t.s]; i != -1; i = ne[i])
{
int j = e[i];
if(d[j] > w[i] + t.f)
{
d[j] = w[i] + t.f;
q.push({d[j], j});
}
}
}
}
int main()
{
cin >> n >> m;
memset(h, -1, sizeof h);
while(m--)
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
dijkstra();
int res = 0;
for(int k = 2; k <= n; k++)
{
int road = 0x3f3f3f3f;
for(int i = h[k]; i != -1; i = ne[i])
{
int j = e[i];
// j作为中转点, 1-k的距离 == 1-j的距离 + j-k的距离
if(d[k] == d[j] + w[i]) road = min(road, w[i]);
}
res += road;
}
cout << res;
return 0;
}