|
| 1 | +#pragma once |
| 2 | +#include <algorithm> |
| 3 | +#include <cassert> |
| 4 | +#include <optional> |
| 5 | +#include <vector> |
| 6 | + |
| 7 | +// Solve assignment problem by Hungarian algorithm |
| 8 | +// dual problem: maximize sum(f) - sum(g) s.t. f_i - g_j <= C_ij |
| 9 | +// Requires: n == m |
| 10 | +// Ccomplexity: O(n^2 m) |
| 11 | +// https://www.slideshare.net/joisino/ss-249394573 |
| 12 | +// Todo: |
| 13 | +// - generalize: n != m |
| 14 | +// - Reduce to O(nm min(n, m)) |
| 15 | +template <class T> std::pair<T, std::vector<int>> hungarian(const std::vector<std::vector<T>> &C) { |
| 16 | + if (C.empty()) return {T(), {}}; |
| 17 | + const int n = C.size(), m = C.front().size(); |
| 18 | + assert(n == m); |
| 19 | + std::vector<T> f(n, T()), g(m, T()); |
| 20 | + |
| 21 | + auto chmin = [](T &x, T y) { return (x > y ? (x = y, true) : false); }; |
| 22 | + |
| 23 | + // Make dual variables feasible |
| 24 | + for (int j = 0; j < m; ++j) { |
| 25 | + g[j] = f[0] - C[0][j]; |
| 26 | + for (int i = 0; i < n; ++i) g[j] = std::max(g[j], f[i] - C[i][j]); |
| 27 | + } |
| 28 | + |
| 29 | + std::vector<int> lmate(n, -1); |
| 30 | + std::vector<std::optional<int>> rmate(m, std::nullopt); |
| 31 | + |
| 32 | + std::vector<int> rreach(m, -1); |
| 33 | + std::vector<int> rprv(m, -1); |
| 34 | + |
| 35 | + std::vector<int> lvisited; |
| 36 | + |
| 37 | + for (int i = 0; i < n; ++i) { |
| 38 | + lvisited = {i}; |
| 39 | + int cur = 0; |
| 40 | + std::optional<int> reachable_r = std::nullopt; |
| 41 | + |
| 42 | + while (!reachable_r.has_value()) { |
| 43 | + |
| 44 | + auto check_l = [&]() -> void { |
| 45 | + int l = lvisited[cur++]; |
| 46 | + for (int j = 0; j < m; ++j) { |
| 47 | + if (rreach[j] == i) continue; |
| 48 | + if (f[l] - g[j] == C[l][j]) { |
| 49 | + rreach[j] = i; |
| 50 | + rprv[j] = l; |
| 51 | + if (rmate[j].has_value()) { |
| 52 | + lvisited.push_back(rmate[j].value()); |
| 53 | + } else { |
| 54 | + reachable_r = j; |
| 55 | + cur = lvisited.size(); |
| 56 | + return; |
| 57 | + } |
| 58 | + } |
| 59 | + } |
| 60 | + }; |
| 61 | + while (cur < int(lvisited.size())) check_l(); |
| 62 | + |
| 63 | + if (!reachable_r.has_value()) { |
| 64 | + T min_diff = T(); |
| 65 | + int min_l = -1, min_r = -1; |
| 66 | + for (int l : lvisited) { |
| 67 | + for (int j = 0; j < m; ++j) { |
| 68 | + if (rreach[j] == i) continue; |
| 69 | + T diff = C[l][j] + g[j] - f[l]; |
| 70 | + if (min_l < 0) { |
| 71 | + min_diff = diff; |
| 72 | + min_l = l; |
| 73 | + min_r = j; |
| 74 | + } else if (chmin(min_diff, diff)) { |
| 75 | + min_l = l; |
| 76 | + min_r = j; |
| 77 | + } |
| 78 | + } |
| 79 | + } |
| 80 | + for (int l : lvisited) f[l] += min_diff; |
| 81 | + for (int j = 0; j < m; ++j) { |
| 82 | + if (rreach[j] == i) g[j] += min_diff; |
| 83 | + } |
| 84 | + rreach[min_r] = i; |
| 85 | + rprv.at(min_r) = min_l; |
| 86 | + |
| 87 | + if (rmate[min_r].has_value()) { |
| 88 | + lvisited.push_back(rmate[min_r].value()); |
| 89 | + } else { |
| 90 | + reachable_r = min_r; |
| 91 | + } |
| 92 | + } |
| 93 | + } |
| 94 | + for (int h = reachable_r.value(); h >= 0;) { |
| 95 | + int l = rprv.at(h); |
| 96 | + int nxth = lmate.at(l); |
| 97 | + rmate.at(h) = l; |
| 98 | + lmate.at(l) = h; |
| 99 | + h = nxth; |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + T sol = T(); |
| 104 | + for (int i = 0; i < n; ++i) sol += C.at(i).at(lmate.at(i)); |
| 105 | + return {sol, lmate}; |
| 106 | +} |
0 commit comments