23 #include "absl/strings/str_format.h"
33 class BaseAllDifferent :
public Constraint {
35 BaseAllDifferent(Solver*
const s,
const std::vector<IntVar*>& vars)
36 : Constraint(s),
vars_(vars) {}
37 ~BaseAllDifferent()
override {}
38 std::string DebugStringInternal(
const std::string&
name)
const {
43 const std::vector<IntVar*>
vars_;
50 class ValueAllDifferent :
public BaseAllDifferent {
52 ValueAllDifferent(Solver*
const s,
const std::vector<IntVar*>& vars)
53 : BaseAllDifferent(s, vars) {}
54 ~ValueAllDifferent()
override {}
57 void InitialPropagate()
override;
58 void OneMove(
int index);
61 std::string DebugString()
const override {
62 return DebugStringInternal(
"ValueAllDifferent");
64 void Accept(ModelVisitor*
const visitor)
const override {
73 RevSwitch all_instantiated_;
76 void ValueAllDifferent::Post() {
77 for (
int i = 0; i < size(); ++i) {
85 void ValueAllDifferent::InitialPropagate() {
86 for (
int i = 0; i < size(); ++i) {
87 if (
vars_[i]->Bound()) {
93 void ValueAllDifferent::OneMove(
int index) {
96 for (
int j = 0; j < size(); ++j) {
98 if (
vars_[j]->Size() < 0xFFFFFF) {
99 vars_[j]->RemoveValue(val);
101 solver()->AddConstraint(solver()->MakeNonEquality(
vars_[j], val));
108 bool ValueAllDifferent::AllMoves() {
109 if (all_instantiated_.
Switched() || size() == 0) {
112 for (
int i = 0; i < size(); ++i) {
113 if (!
vars_[i]->Bound()) {
117 std::unique_ptr<int64[]> values(
new int64[size()]);
118 for (
int i = 0; i < size(); ++i) {
119 values[i] =
vars_[i]->Value();
121 std::sort(values.get(), values.get() + size());
122 for (
int i = 0; i < size() - 1; ++i) {
123 if (values[i] == values[i + 1]) {
128 all_instantiated_.
Switch(solver());
135 class RangeBipartiteMatching {
144 RangeBipartiteMatching(Solver*
const solver,
int size)
147 intervals_(new Interval[size + 1]),
148 min_sorted_(new Interval*[size]),
149 max_sorted_(new Interval*[size]),
150 bounds_(new
int64[2 * size + 2]),
151 tree_(new int[2 * size + 2]),
152 diff_(new
int64[2 * size + 2]),
153 hall_(new int[2 * size + 2]),
155 for (
int i = 0; i < size; ++i) {
156 max_sorted_[i] = &intervals_[i];
157 min_sorted_[i] = max_sorted_[i];
162 intervals_[
index].min = imin;
163 intervals_[
index].max = imax;
169 const bool modified1 = PropagateMin();
170 const bool modified2 = PropagateMax();
171 return modified1 || modified2;
182 std::sort(min_sorted_.get(), min_sorted_.get() + size_,
183 CompareIntervalMin());
184 std::sort(max_sorted_.get(), max_sorted_.get() + size_,
185 CompareIntervalMax());
188 int64 max = max_sorted_[0]->max + 1;
196 if (i < size_ &&
min <=
max) {
199 bounds_[++nb] = last;
201 min_sorted_[i]->min_rank = nb;
203 min = min_sorted_[i]->min;
208 bounds_[++nb] = last;
210 max_sorted_[j]->max_rank = nb;
214 max = max_sorted_[j]->max + 1;
218 bounds_[nb + 1] = bounds_[nb] + 2;
222 bool PropagateMin() {
223 bool modified =
false;
225 for (
int i = 1; i <= active_size_ + 1; ++i) {
228 diff_[i] = bounds_[i] - bounds_[i - 1];
231 for (
int i = 0; i < size_; ++i) {
232 const int x = max_sorted_[i]->min_rank;
233 const int y = max_sorted_[i]->max_rank;
234 int z = PathMax(tree_.get(), x + 1);
236 if (--diff_[z] == 0) {
238 z = PathMax(tree_.get(), z + 1);
241 PathSet(x + 1, z, z, tree_.get());
242 if (diff_[z] < bounds_[z] - bounds_[y]) {
246 int w = PathMax(hall_.get(), hall_[x]);
247 max_sorted_[i]->min = bounds_[w];
248 PathSet(x, w, w, hall_.get());
251 if (diff_[z] == bounds_[z] - bounds_[y]) {
252 PathSet(hall_[y], j - 1, y, hall_.get());
259 bool PropagateMax() {
260 bool modified =
false;
262 for (
int i = 0; i <= active_size_; i++) {
265 diff_[i] = bounds_[i + 1] - bounds_[i];
268 for (
int i = size_ - 1; i >= 0; --i) {
269 const int x = min_sorted_[i]->max_rank;
270 const int y = min_sorted_[i]->min_rank;
271 int z = PathMin(tree_.get(), x - 1);
273 if (--diff_[z] == 0) {
275 z = PathMin(tree_.get(), z - 1);
278 PathSet(x - 1, z, z, tree_.get());
279 if (diff_[z] < bounds_[y] - bounds_[z]) {
284 int w = PathMin(hall_.get(), hall_[x]);
285 min_sorted_[i]->max = bounds_[w] - 1;
286 PathSet(x, w, w, hall_.get());
289 if (diff_[z] == bounds_[y] - bounds_[z]) {
290 PathSet(hall_[y], j + 1, y, hall_.get());
301 struct CompareIntervalMin {
302 bool operator()(
const Interval* i1,
const Interval* i2)
const {
303 return (i1->min < i2->min);
308 struct CompareIntervalMax {
309 bool operator()(
const Interval* i1,
const Interval* i2)
const {
310 return (i1->max < i2->max);
314 void PathSet(
int start,
int end,
int to,
int*
const tree) {
323 int PathMin(
const int*
const tree,
int index) {
325 while (tree[i] < i) {
331 int PathMax(
const int*
const tree,
int index) {
333 while (tree[i] > i) {
339 Solver*
const solver_;
341 std::unique_ptr<Interval[]> intervals_;
342 std::unique_ptr<Interval*[]> min_sorted_;
343 std::unique_ptr<Interval*[]> max_sorted_;
346 std::unique_ptr<int64[]> bounds_;
347 std::unique_ptr<int[]> tree_;
348 std::unique_ptr<int64[]> diff_;
349 std::unique_ptr<int[]> hall_;
353 class BoundsAllDifferent :
public BaseAllDifferent {
355 BoundsAllDifferent(Solver*
const s,
const std::vector<IntVar*>& vars)
356 : BaseAllDifferent(s, vars), matching_(s, vars.size()) {}
358 ~BoundsAllDifferent()
override {}
360 void Post()
override {
362 solver(),
this, &BoundsAllDifferent::IncrementalPropagate,
363 "IncrementalPropagate");
365 for (
int i = 0; i < size(); ++i) {
366 vars_[i]->WhenRange(range);
368 &BoundsAllDifferent::PropagateValue,
369 "PropagateValue", i);
374 void InitialPropagate()
override {
375 IncrementalPropagate();
376 for (
int i = 0; i < size(); ++i) {
377 if (
vars_[i]->Bound()) {
383 virtual void IncrementalPropagate() {
384 for (
int i = 0; i < size(); ++i) {
385 matching_.SetRange(i,
vars_[i]->Min(),
vars_[i]->Max());
388 if (matching_.Propagate()) {
389 for (
int i = 0; i < size(); ++i) {
390 vars_[i]->SetRange(matching_.Min(i), matching_.Max(i));
395 void PropagateValue(
int index) {
397 for (
int j = 0; j <
index; j++) {
398 if (
vars_[j]->Size() < 0xFFFFFF) {
399 vars_[j]->RemoveValue(to_remove);
401 solver()->AddConstraint(solver()->MakeNonEquality(
vars_[j], to_remove));
404 for (
int j =
index + 1; j < size(); j++) {
405 if (
vars_[j]->Size() < 0xFFFFFF) {
406 vars_[j]->RemoveValue(to_remove);
408 solver()->AddConstraint(solver()->MakeNonEquality(
vars_[j], to_remove));
413 std::string DebugString()
const override {
414 return DebugStringInternal(
"BoundsAllDifferent");
417 void Accept(ModelVisitor*
const visitor)
const override {
426 RangeBipartiteMatching matching_;
429 class SortConstraint :
public Constraint {
431 SortConstraint(Solver*
const solver,
432 const std::vector<IntVar*>& original_vars,
433 const std::vector<IntVar*>& sorted_vars)
434 : Constraint(solver),
435 ovars_(original_vars),
437 mins_(original_vars.size(), 0),
438 maxs_(original_vars.size(), 0),
439 matching_(solver, original_vars.size()) {}
441 ~SortConstraint()
override {}
443 void Post()
override {
445 solver()->MakeDelayedConstraintInitialPropagateCallback(
this);
446 for (
int i = 0; i < size(); ++i) {
447 ovars_[i]->WhenRange(demon);
448 svars_[i]->WhenRange(demon);
452 void InitialPropagate()
override {
453 for (
int i = 0; i < size(); ++i) {
456 ovars_[i]->Range(&vmin, &vmax);
461 std::sort(mins_.begin(), mins_.end());
462 std::sort(maxs_.begin(), maxs_.end());
463 for (
int i = 0; i < size(); ++i) {
464 svars_[i]->SetRange(mins_[i], maxs_[i]);
467 for (
int i = 0; i < size() - 1; ++i) {
468 svars_[i + 1]->SetMin(svars_[i]->Min());
470 for (
int i = size() - 1; i > 0; --i) {
471 svars_[i - 1]->SetMax(svars_[i]->Max());
474 for (
int i = 0; i < size(); ++i) {
477 FindIntersectionRange(i, &imin, &imax);
478 matching_.SetRange(i, imin, imax);
480 matching_.Propagate();
481 for (
int i = 0; i < size(); ++i) {
482 const int64 vmin = svars_[matching_.Min(i)]->Min();
483 const int64 vmax = svars_[matching_.Max(i)]->Max();
484 ovars_[i]->SetRange(vmin, vmax);
488 void Accept(ModelVisitor*
const visitor)
const override {
497 std::string DebugString()
const override {
503 int64 size()
const {
return ovars_.size(); }
505 void FindIntersectionRange(
int index,
int64*
const range_min,
506 int64*
const range_max)
const {
510 while (imin < size() && NotIntersect(
index, imin)) {
513 if (imin == size()) {
516 int64 imax = size() - 1;
517 while (imax > imin && NotIntersect(
index, imax)) {
524 bool NotIntersect(
int oindex,
int sindex)
const {
525 return ovars_[oindex]->Min() > svars_[sindex]->Max() ||
526 ovars_[oindex]->Max() < svars_[sindex]->Min();
529 const std::vector<IntVar*> ovars_;
530 const std::vector<IntVar*> svars_;
531 std::vector<int64> mins_;
532 std::vector<int64> maxs_;
533 RangeBipartiteMatching matching_;
538 class AllDifferentExcept :
public Constraint {
540 AllDifferentExcept(Solver*
const s, std::vector<IntVar*> vars,
542 : Constraint(s),
vars_(std::move(vars)), escape_value_(escape_value) {}
544 ~AllDifferentExcept()
override {}
546 void Post()
override {
547 for (
int i = 0; i <
vars_.size(); ++i) {
550 solver(),
this, &AllDifferentExcept::Propagate,
"Propagate", i);
555 void InitialPropagate()
override {
556 for (
int i = 0; i <
vars_.size(); ++i) {
557 if (
vars_[i]->Bound()) {
563 void Propagate(
int index) {
565 if (val != escape_value_) {
566 for (
int j = 0; j <
vars_.size(); ++j) {
568 vars_[j]->RemoveValue(val);
574 std::string DebugString()
const override {
575 return absl::StrFormat(
"AllDifferentExcept([%s], %d",
579 void Accept(ModelVisitor*
const visitor)
const override {
588 std::vector<IntVar*>
vars_;
589 const int64 escape_value_;
597 class NullIntersectArrayExcept :
public Constraint {
599 NullIntersectArrayExcept(Solver*
const s, std::vector<IntVar*> first_vars,
600 std::vector<IntVar*> second_vars,
int64 escape_value)
602 first_vars_(std::move(first_vars)),
603 second_vars_(std::move(second_vars)),
604 escape_value_(escape_value),
605 has_escape_value_(true) {}
607 NullIntersectArrayExcept(Solver*
const s, std::vector<IntVar*> first_vars,
608 std::vector<IntVar*> second_vars)
610 first_vars_(std::move(first_vars)),
611 second_vars_(std::move(second_vars)),
613 has_escape_value_(false) {}
615 ~NullIntersectArrayExcept()
override {}
617 void Post()
override {
618 for (
int i = 0; i < first_vars_.size(); ++i) {
619 IntVar*
const var = first_vars_[i];
621 solver(),
this, &NullIntersectArrayExcept::PropagateFirst,
622 "PropagateFirst", i);
625 for (
int i = 0; i < second_vars_.size(); ++i) {
626 IntVar*
const var = second_vars_[i];
628 solver(),
this, &NullIntersectArrayExcept::PropagateSecond,
629 "PropagateSecond", i);
634 void InitialPropagate()
override {
635 for (
int i = 0; i < first_vars_.size(); ++i) {
636 if (first_vars_[i]->Bound()) {
640 for (
int i = 0; i < second_vars_.size(); ++i) {
641 if (second_vars_[i]->Bound()) {
647 void PropagateFirst(
int index) {
648 const int64 val = first_vars_[
index]->Value();
649 if (!has_escape_value_ || val != escape_value_) {
650 for (
int j = 0; j < second_vars_.size(); ++j) {
651 second_vars_[j]->RemoveValue(val);
656 void PropagateSecond(
int index) {
657 const int64 val = second_vars_[
index]->Value();
658 if (!has_escape_value_ || val != escape_value_) {
659 for (
int j = 0; j < first_vars_.size(); ++j) {
660 first_vars_[j]->RemoveValue(val);
665 std::string DebugString()
const override {
666 return absl::StrFormat(
"NullIntersectArray([%s], [%s], escape = %d",
672 void Accept(ModelVisitor*
const visitor)
const override {
683 std::vector<IntVar*> first_vars_;
684 std::vector<IntVar*> second_vars_;
685 const int64 escape_value_;
686 const bool has_escape_value_;
695 bool stronger_propagation) {
696 const int size = vars.size();
697 for (
int i = 0; i < size; ++i) {
702 }
else if (size == 2) {
704 const_cast<IntVar* const
>(vars[1]));
706 if (stronger_propagation) {
707 return RevAlloc(
new BoundsAllDifferent(
this, vars));
709 return RevAlloc(
new ValueAllDifferent(
this, vars));
715 const std::vector<IntVar*>& sorted) {
716 CHECK_EQ(vars.size(), sorted.size());
717 return RevAlloc(
new SortConstraint(
this, vars, sorted));
721 int64 escape_value) {
722 int escape_candidates = 0;
723 for (
int i = 0; i < vars.size(); ++i) {
724 escape_candidates += (vars[i]->Contains(escape_value));
726 if (escape_candidates <= 1) {
729 return RevAlloc(
new AllDifferentExcept(
this, vars, escape_value));
734 const std::vector<IntVar*>& second_vars) {
735 return RevAlloc(
new NullIntersectArrayExcept(
this, first_vars, second_vars));
739 const std::vector<IntVar*>& first_vars,
740 const std::vector<IntVar*>& second_vars,
int64 escape_value) {
741 int first_escape_candidates = 0;
742 for (
int i = 0; i < first_vars.size(); ++i) {
743 first_escape_candidates += (first_vars[i]->Contains(escape_value));
745 int second_escape_candidates = 0;
746 for (
int i = 0; i < second_vars.size(); ++i) {
747 second_escape_candidates += (second_vars[i]->Contains(escape_value));
749 if (first_escape_candidates == 0 || second_escape_candidates == 0) {
751 new NullIntersectArrayExcept(
this, first_vars, second_vars));
753 return RevAlloc(
new NullIntersectArrayExcept(
this, first_vars, second_vars,