leetcode 847 访问所有节点的最短路径

题意

给出一张无向图,求访问所有节点的最短路径。

做法

法一 状压+广搜

修改常规广搜,队列元素改为包含节点、当前路径长度、节点访问情况的三元组。用二进制数 maskmask 表示每一节点访问情况,如果 maskmask 的第 ii 位是 11,则表示节点 ii 已经访问过。初始时需将所有位于起点的状态入队。

法二 状压dp

考虑整条路径经过的点序列,虽然可能会重复访问某个点,但这个序列必须包含某个子序列,这个子序列是从 00n1n - 1 的一个排列。把这个子序列的节点称为关键节点。考虑到达某个关键节点 uu 的最短路径长度,可以拆成先到达之前某个关键节点 vv ,然后再从 vv 走最短路到达 vv 这两部分。

考虑用 f[u][mask]f[u][mask] 表示从任意一个起点开始,到节点 uu 为止,经过关键节点情况为 maskmask 的最短路径长度。有递推公式:

f[u][mask]=min(f[v][lastmask]+d(v, u))f[u][mask] = \min ( f[v][lastmask] + d(v,\ u))

其中,maskmask 的第 vv 位上为 11,表示路径已经经过了节点 vvlastmasklastmask 是将 maskmask 的第 uu 位从 11 改成 00 的结果,表示上一状态还没走到最后一个节点 uud(v, u)d(v,\ u) 表示从两点之间最短路。

首先用弗洛伊德预处理出任意两点间最短路径,维护好边界之后动规即可。

代码

法一

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution {
public:
int shortestPathLength(vector<vector<int>>& graph) {
int n = graph.size();
queue<tuple<int, int, int>> q;
vector<vector<int>> vis(n, vector<int>(1 << n));
for (int i = 0; i < n; ++i) {
q.emplace(i, 1 << i, 0);
vis[i][1 << i] = true;
}

int ans = 0;
while (!q.empty()) {
auto [u, mask, dist] = q.front();
q.pop();
if (mask == (1 << n) - 1) {
ans = dist;
break;
}
for (int v: graph[u]) {
int mask_v = mask | (1 << v);
if (!vis[v][mask_v]) {
q.emplace(v, mask_v, dist + 1);
vis[v][mask_v] = true;
}
}
}
return ans;
}
};

法二 状压dp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Solution {
public:
int shortestPathLength(vector<vector<int>>& graph) {
int n = graph.size();
// floyd
vector<vector<int>> d(n, vector<int>(n, n + 1));
for (int i = 0; i < n; ++i) {
for (int j: graph[i]) {
d[i][j] = 1;
}
}
for (int k = 0; k < n; ++k) {
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
d[i][j] = min(d[i][j], d[i][k] + d[k][j]);
}
}
}
// dp
vector<vector<int>> f(n, vector<int>(1 << n, INT_MAX / 2));
for (int mask = 1; mask < (1 << n); ++mask) {
// 如果 mask 只包含一个 1, 表示只经过一个关键节点
if ((mask & (mask - 1)) == 0) {
int u = __builtin_ctz(mask);
f[u][mask] = 0;
}
else {
// 枚举已经经过的关键节点作为中转点
for (int u = 0; u < n; ++u) {
if (mask & (1 << u)) {
for (int v = 0; v < n; ++v) {
if ((mask & (1 << v)) && u != v) {
f[u][mask] = min(f[u][mask], f[v][mask ^ (1 << u)] + d[v][u]);
}
}
}
}
}
}

int ans = INT_MAX;
for (int u = 0; u < n; ++u) {
ans = min(ans, f[u][(1 << n) - 1]);
}
return ans;
}
};