题意
给出一张无向图,求访问所有节点的最短路径。
做法
法一 状压+广搜
修改常规广搜,队列元素改为包含节点、当前路径长度、节点访问情况的三元组。用二进制数 mask 表示每一节点访问情况,如果 mask 的第 i 位是 1,则表示节点 i 已经访问过。初始时需将所有位于起点的状态入队。
法二 状压dp
考虑整条路径经过的点序列,虽然可能会重复访问某个点,但这个序列必须包含某个子序列,这个子序列是从 0 到 n−1 的一个排列。把这个子序列的节点称为关键节点。考虑到达某个关键节点 u 的最短路径长度,可以拆成先到达之前某个关键节点 v ,然后再从 v 走最短路到达 v 这两部分。
考虑用 f[u][mask] 表示从任意一个起点开始,到节点 u 为止,经过关键节点情况为 mask 的最短路径长度。有递推公式:
f[u][mask]=min(f[v][lastmask]+d(v, u))
其中,mask 的第 v 位上为 1,表示路径已经经过了节点 v 。lastmask 是将 mask 的第 u 位从 1 改成 0 的结果,表示上一状态还没走到最后一个节点 u 。d(v, u) 表示从两点之间最短路。
首先用弗洛伊德预处理出任意两点间最短路径,维护好边界之后动规即可。
代码
法一
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| class Solution { public: int shortestPathLength(vector<vector<int>>& graph) { int n = graph.size(); queue<tuple<int, int, int>> q; vector<vector<int>> vis(n, vector<int>(1 << n)); for (int i = 0; i < n; ++i) { q.emplace(i, 1 << i, 0); vis[i][1 << i] = true; }
int ans = 0; while (!q.empty()) { auto [u, mask, dist] = q.front(); q.pop(); if (mask == (1 << n) - 1) { ans = dist; break; } for (int v: graph[u]) { int mask_v = mask | (1 << v); if (!vis[v][mask_v]) { q.emplace(v, mask_v, dist + 1); vis[v][mask_v] = true; } } } return ans; } };
|
法二 状压dp
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| class Solution { public: int shortestPathLength(vector<vector<int>>& graph) { int n = graph.size(); vector<vector<int>> d(n, vector<int>(n, n + 1)); for (int i = 0; i < n; ++i) { for (int j: graph[i]) { d[i][j] = 1; } } for (int k = 0; k < n; ++k) { for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { d[i][j] = min(d[i][j], d[i][k] + d[k][j]); } } } vector<vector<int>> f(n, vector<int>(1 << n, INT_MAX / 2)); for (int mask = 1; mask < (1 << n); ++mask) { if ((mask & (mask - 1)) == 0) { int u = __builtin_ctz(mask); f[u][mask] = 0; } else { for (int u = 0; u < n; ++u) { if (mask & (1 << u)) { for (int v = 0; v < n; ++v) { if ((mask & (1 << v)) && u != v) { f[u][mask] = min(f[u][mask], f[v][mask ^ (1 << u)] + d[v][u]); } } } } } }
int ans = INT_MAX; for (int u = 0; u < n; ++u) { ans = min(ans, f[u][(1 << n) - 1]); } return ans; } };
|