195 #ifndef OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
196 #define OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
207 #include "absl/strings/str_format.h"
225 template <
typename GraphType>
248 DCHECK(graph_ ==
nullptr);
281 inline const GraphType&
Graph()
const {
return *graph_; }
292 DCHECK_EQ(0, scaled_arc_cost_[arc] % cost_scaling_factor_);
293 return scaled_arc_cost_[arc] / cost_scaling_factor_;
324 if (graph_ ==
nullptr) {
329 return graph_->num_nodes();
340 return matched_arc_[left_node];
353 DCHECK_NE(GraphType::kNilArc, matching_arc);
354 return Head(matching_arc);
357 std::string
StatsString()
const {
return total_stats_.StatsString(); }
362 : num_left_nodes_(num_left_nodes), node_iterator_(0) {}
365 : num_left_nodes_(assignment.
NumLeftNodes()), node_iterator_(0) {}
369 bool Ok()
const {
return node_iterator_ < num_left_nodes_; }
371 void Next() { ++node_iterator_; }
380 Stats() : pushes_(0), double_pushes_(0), relabelings_(0), refinements_(0) {}
387 void Add(
const Stats& that) {
388 pushes_ += that.pushes_;
389 double_pushes_ += that.double_pushes_;
390 relabelings_ += that.relabelings_;
391 refinements_ += that.refinements_;
394 return absl::StrFormat(
395 "%d refinements; %d relabelings; "
396 "%d double pushes; %d pushes",
397 refinements_, relabelings_, double_pushes_, pushes_);
400 int64 double_pushes_;
406 class ActiveNodeContainerInterface {
408 virtual ~ActiveNodeContainerInterface() {}
409 virtual bool Empty()
const = 0;
414 class ActiveNodeStack :
public ActiveNodeContainerInterface {
416 ~ActiveNodeStack()
override {}
418 bool Empty()
const override {
return v_.empty(); }
420 void Add(
NodeIndex node)
override { v_.push_back(node); }
430 std::vector<NodeIndex> v_;
433 class ActiveNodeQueue :
public ActiveNodeContainerInterface {
435 ~ActiveNodeQueue()
override {}
437 bool Empty()
const override {
return q_.empty(); }
439 void Add(
NodeIndex node)
override { q_.push_front(node); }
449 std::deque<NodeIndex> q_;
460 typedef std::pair<ArcIndex, CostValue> ImplicitPriceSummary;
464 bool EpsilonOptimal()
const;
468 bool AllMatched()
const;
475 inline ImplicitPriceSummary BestArcAndGap(
NodeIndex left_node)
const;
479 void ReportAndAccumulateStats();
490 bool UpdateEpsilon();
494 inline bool IsActive(
NodeIndex left_node)
const;
501 inline bool IsActiveForDebugging(
NodeIndex node)
const;
508 void InitializeActiveNodeContainer();
516 void SaturateNegativeArcs();
525 return scaled_arc_cost_[arc] - price_[
Head(arc)];
530 const GraphType* graph_;
539 bool incidence_precondition_satisfied_;
866 bool* in_range)
const {
879 const double result =
880 static_cast<double>(std::max<CostValue>(1, n / 2 - 1)) *
881 (
static_cast<double>(old_epsilon) +
static_cast<double>(new_epsilon));
884 if (result > limit) {
886 if (in_range !=
nullptr) *in_range =
false;
907 CostValue largest_scaled_cost_magnitude_;
920 ZVector<CostValue> price_;
925 std::vector<ArcIndex> matched_arc_;
933 ZVector<NodeIndex> matched_node_;
938 std::vector<CostValue> scaled_arc_cost_;
943 std::unique_ptr<ActiveNodeContainerInterface> active_nodes_;
951 Stats iteration_stats_;
959 template <
typename GraphType>
960 const CostValue LinearSumAssignment<GraphType>::kMinEpsilon = 1;
962 template <
typename GraphType>
964 const GraphType& graph,
const NodeIndex num_left_nodes)
966 num_left_nodes_(num_left_nodes),
968 cost_scaling_factor_(1 + num_left_nodes),
969 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
971 price_lower_bound_(0),
972 slack_relabeling_price_(0),
973 largest_scaled_cost_magnitude_(0),
975 price_(num_left_nodes, 2 * num_left_nodes - 1),
976 matched_arc_(num_left_nodes, 0),
977 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
978 scaled_arc_cost_(graph.max_end_arc_index(), 0),
979 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
980 ? static_cast<ActiveNodeContainerInterface*>(
981 new ActiveNodeStack())
982 : static_cast<ActiveNodeContainerInterface*>(
983 new ActiveNodeQueue())) {}
985 template <
typename GraphType>
989 num_left_nodes_(num_left_nodes),
991 cost_scaling_factor_(1 + num_left_nodes),
992 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
994 price_lower_bound_(0),
995 slack_relabeling_price_(0),
996 largest_scaled_cost_magnitude_(0),
998 price_(num_left_nodes, 2 * num_left_nodes - 1),
999 matched_arc_(num_left_nodes, 0),
1000 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
1001 scaled_arc_cost_(num_arcs, 0),
1002 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
1003 ? static_cast<ActiveNodeContainerInterface*>(
1004 new ActiveNodeStack())
1005 : static_cast<ActiveNodeContainerInterface*>(
1006 new ActiveNodeQueue())) {}
1008 template <
typename GraphType>
1010 if (graph_ !=
nullptr) {
1016 cost *= cost_scaling_factor_;
1018 largest_scaled_cost_magnitude_ =
1019 std::max(largest_scaled_cost_magnitude_, cost_magnitude);
1020 scaled_arc_cost_[arc] =
cost;
1023 template <
typename ArcIndexType>
1027 : temp_(0), cost_(
cost) {}
1030 temp_ = (*cost_)[source];
1034 ArcIndexType destination)
const override {
1035 (*cost_)[destination] = (*cost_)[source];
1039 (*cost_)[destination] = temp_;
1046 std::vector<CostValue>*
const cost_;
1056 template <
typename GraphType>
1066 return ((graph_.Tail(
a) < graph_.Tail(
b)) ||
1067 ((graph_.Tail(
a) == graph_.Tail(
b)) &&
1068 (graph_.Head(
a) < graph_.Head(
b))));
1072 const GraphType& graph_;
1080 template <
typename GraphType>
1081 PermutationCycleHandler<typename GraphType::ArcIndex>*
1087 template <
typename GraphType>
1098 graph->GroupForwardArcsByFunctor(compare, &cycle_handler);
1102 template <
typename GraphType>
1104 const CostValue current_epsilon)
const {
1105 return std::max(current_epsilon / alpha_, kMinEpsilon);
1108 template <
typename GraphType>
1109 bool LinearSumAssignment<GraphType>::UpdateEpsilon() {
1110 CostValue new_epsilon = NewEpsilon(epsilon_);
1111 slack_relabeling_price_ = PriceChangeBound(epsilon_, new_epsilon,
nullptr);
1112 epsilon_ = new_epsilon;
1113 VLOG(3) <<
"Updated: epsilon_ == " << epsilon_;
1114 VLOG(4) <<
"slack_relabeling_price_ == " << slack_relabeling_price_;
1123 template <
typename GraphType>
1124 inline bool LinearSumAssignment<GraphType>::IsActive(
1127 return matched_arc_[left_node] == GraphType::kNilArc;
1133 template <
typename GraphType>
1134 inline bool LinearSumAssignment<GraphType>::IsActiveForDebugging(
1136 if (node < num_left_nodes_) {
1137 return IsActive(node);
1139 return matched_node_[node] == GraphType::kNilNode;
1143 template <
typename GraphType>
1144 void LinearSumAssignment<GraphType>::InitializeActiveNodeContainer() {
1145 DCHECK(active_nodes_->Empty());
1146 for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_);
1147 node_it.Ok(); node_it.Next()) {
1149 if (IsActive(node)) {
1150 active_nodes_->Add(node);
1165 template <
typename GraphType>
1166 void LinearSumAssignment<GraphType>::SaturateNegativeArcs() {
1168 for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_);
1169 node_it.Ok(); node_it.Next()) {
1171 if (IsActive(node)) {
1179 matched_arc_[node] = GraphType::kNilArc;
1180 matched_node_[mate] = GraphType::kNilNode;
1186 template <
typename GraphType>
1187 bool LinearSumAssignment<GraphType>::DoublePush(
NodeIndex source) {
1189 DCHECK(IsActive(source)) <<
"Node " << source
1190 <<
"must be active (unmatched)!";
1191 ImplicitPriceSummary summary = BestArcAndGap(source);
1192 const ArcIndex best_arc = summary.first;
1197 if (best_arc == GraphType::kNilArc) {
1200 const NodeIndex new_mate = Head(best_arc);
1201 const NodeIndex to_unmatch = matched_node_[new_mate];
1202 if (to_unmatch != GraphType::kNilNode) {
1205 matched_arc_[to_unmatch] = GraphType::kNilArc;
1206 active_nodes_->Add(to_unmatch);
1208 iteration_stats_.double_pushes_ += 1;
1213 iteration_stats_.pushes_ += 1;
1215 matched_arc_[source] = best_arc;
1216 matched_node_[new_mate] = source;
1218 iteration_stats_.relabelings_ += 1;
1219 const CostValue new_price = price_[new_mate] - gap - epsilon_;
1220 price_[new_mate] = new_price;
1221 return new_price >= price_lower_bound_;
1224 template <
typename GraphType>
1225 bool LinearSumAssignment<GraphType>::Refine() {
1226 SaturateNegativeArcs();
1227 InitializeActiveNodeContainer();
1228 while (total_excess_ > 0) {
1231 const NodeIndex node = active_nodes_->Get();
1232 if (!DoublePush(node)) {
1240 LOG_IF(DFATAL, total_stats_.refinements_ > 0)
1241 <<
"Infeasibility detection triggered after first iteration found "
1242 <<
"a feasible assignment!";
1246 DCHECK(active_nodes_->Empty());
1247 iteration_stats_.refinements_ += 1;
1265 template <
typename GraphType>
1266 inline typename LinearSumAssignment<GraphType>::ImplicitPriceSummary
1267 LinearSumAssignment<GraphType>::BestArcAndGap(
NodeIndex left_node)
const {
1268 DCHECK(IsActive(left_node))
1269 <<
"Node " << left_node <<
" must be active (unmatched)!";
1271 typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1272 ArcIndex best_arc = arc_it.Index();
1273 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1279 const CostValue max_gap = slack_relabeling_price_ - epsilon_;
1280 CostValue second_min_partial_reduced_cost =
1281 min_partial_reduced_cost + max_gap;
1282 for (arc_it.Next(); arc_it.Ok(); arc_it.Next()) {
1283 const ArcIndex arc = arc_it.Index();
1284 const CostValue partial_reduced_cost = PartialReducedCost(arc);
1285 if (partial_reduced_cost < second_min_partial_reduced_cost) {
1286 if (partial_reduced_cost < min_partial_reduced_cost) {
1288 second_min_partial_reduced_cost = min_partial_reduced_cost;
1289 min_partial_reduced_cost = partial_reduced_cost;
1291 second_min_partial_reduced_cost = partial_reduced_cost;
1295 const CostValue gap = std::min<CostValue>(
1296 second_min_partial_reduced_cost - min_partial_reduced_cost, max_gap);
1298 return std::make_pair(best_arc, gap);
1305 template <
typename GraphType>
1306 inline CostValue LinearSumAssignment<GraphType>::ImplicitPrice(
1310 typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1313 ArcIndex best_arc = arc_it.Index();
1314 if (best_arc == matched_arc_[left_node]) {
1317 best_arc = arc_it.Index();
1320 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1326 return -(min_partial_reduced_cost + slack_relabeling_price_);
1328 for (arc_it.Next(); arc_it.Ok(); arc_it.Next()) {
1329 const ArcIndex arc = arc_it.Index();
1330 if (arc != matched_arc_[left_node]) {
1331 const CostValue partial_reduced_cost = PartialReducedCost(arc);
1332 if (partial_reduced_cost < min_partial_reduced_cost) {
1333 min_partial_reduced_cost = partial_reduced_cost;
1337 return -min_partial_reduced_cost;
1341 template <
typename GraphType>
1342 bool LinearSumAssignment<GraphType>::AllMatched()
const {
1343 for (
NodeIndex node = 0; node < graph_->num_nodes(); ++node) {
1344 if (IsActiveForDebugging(node)) {
1352 template <
typename GraphType>
1353 bool LinearSumAssignment<GraphType>::EpsilonOptimal()
const {
1354 for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_);
1355 node_it.Ok(); node_it.Next()) {
1356 const NodeIndex left_node = node_it.Index();
1359 CostValue left_node_price = ImplicitPrice(left_node);
1360 for (
typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1361 arc_it.Ok(); arc_it.Next()) {
1362 const ArcIndex arc = arc_it.Index();
1363 const CostValue reduced_cost = left_node_price + PartialReducedCost(arc);
1368 if (matched_arc_[left_node] == arc) {
1372 if (reduced_cost > epsilon_) {
1378 if (reduced_cost < 0) {
1387 template <
typename GraphType>
1389 incidence_precondition_satisfied_ =
true;
1393 epsilon_ =
std::max(largest_scaled_cost_magnitude_, kMinEpsilon + 1);
1394 VLOG(2) <<
"Largest given cost magnitude: "
1395 << largest_scaled_cost_magnitude_ / cost_scaling_factor_;
1398 for (
NodeIndex node = 0; node < num_left_nodes_; ++node) {
1399 matched_arc_[node] = GraphType::kNilArc;
1400 typename GraphType::OutgoingArcIterator arc_it(*graph_, node);
1402 incidence_precondition_satisfied_ =
false;
1407 for (
NodeIndex node = num_left_nodes_; node < graph_->num_nodes(); ++node) {
1409 matched_node_[node] = GraphType::kNilNode;
1411 bool in_range =
true;
1412 double double_price_lower_bound = 0.0;
1414 CostValue old_error_parameter = epsilon_;
1416 new_error_parameter = NewEpsilon(old_error_parameter);
1417 double_price_lower_bound -=
1418 2.0 *
static_cast<double>(PriceChangeBound(
1419 old_error_parameter, new_error_parameter, &in_range));
1420 old_error_parameter = new_error_parameter;
1421 }
while (new_error_parameter != kMinEpsilon);
1422 const double limit =
1424 if (double_price_lower_bound < limit) {
1428 price_lower_bound_ =
static_cast<CostValue>(double_price_lower_bound);
1430 VLOG(4) <<
"price_lower_bound_ == " << price_lower_bound_;
1433 LOG(
WARNING) <<
"Price change bound exceeds range of representable "
1434 <<
"costs; arithmetic overflow is not ruled out and "
1435 <<
"infeasibility might go undetected.";
1440 template <
typename GraphType>
1442 total_stats_.Add(iteration_stats_);
1443 VLOG(3) <<
"Iteration stats: " << iteration_stats_.StatsString();
1444 iteration_stats_.Clear();
1447 template <
typename GraphType>
1449 CHECK(graph_ !=
nullptr);
1450 bool ok = graph_->num_nodes() == 2 * num_left_nodes_;
1451 if (!ok)
return false;
1458 ok = ok && incidence_precondition_satisfied_;
1459 DCHECK(!ok || EpsilonOptimal());
1460 while (ok && epsilon_ > kMinEpsilon) {
1461 ok = ok && UpdateEpsilon();
1462 ok = ok && Refine();
1463 ReportAndAccumulateStats();
1464 DCHECK(!ok || EpsilonOptimal());
1465 DCHECK(!ok || AllMatched());
1468 VLOG(1) <<
"Overall stats: " << total_stats_.StatsString();
1472 template <
typename GraphType>
1479 cost += GetAssignmentCost(node_it.Index());
1486 #endif // OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_