21 #include "absl/strings/str_format.h"
32 class RangeEquality :
public Constraint {
34 RangeEquality(Solver*
const s, IntExpr*
const l, IntExpr*
const r)
35 : Constraint(s), left_(l), right_(r) {}
37 ~RangeEquality()
override {}
39 void Post()
override {
40 Demon*
const d = solver()->MakeConstraintInitialPropagateCallback(
this);
45 void InitialPropagate()
override {
46 left_->SetRange(right_->Min(), right_->Max());
47 right_->SetRange(left_->Min(), left_->Max());
50 std::string DebugString()
const override {
51 return left_->DebugString() +
" == " + right_->DebugString();
54 IntVar* Var()
override {
return solver()->MakeIsEqualVar(left_, right_); }
56 void Accept(ModelVisitor*
const visitor)
const override {
66 IntExpr*
const right_;
72 class RangeLessOrEqual :
public Constraint {
74 RangeLessOrEqual(Solver*
const s, IntExpr*
const l, IntExpr*
const r);
75 ~RangeLessOrEqual()
override {}
77 void InitialPropagate()
override;
78 std::string DebugString()
const override;
79 IntVar* Var()
override {
80 return solver()->MakeIsLessOrEqualVar(left_, right_);
82 void Accept(ModelVisitor*
const visitor)
const override {
92 IntExpr*
const right_;
96 RangeLessOrEqual::RangeLessOrEqual(Solver*
const s, IntExpr*
const l,
98 : Constraint(s), left_(l), right_(r), demon_(nullptr) {}
100 void RangeLessOrEqual::Post() {
101 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
102 left_->WhenRange(demon_);
103 right_->WhenRange(demon_);
106 void RangeLessOrEqual::InitialPropagate() {
107 left_->SetMax(right_->Max());
108 right_->SetMin(left_->Min());
109 if (left_->Max() <= right_->Min()) {
110 demon_->inhibit(solver());
114 std::string RangeLessOrEqual::DebugString()
const {
115 return left_->DebugString() +
" <= " + right_->DebugString();
121 class RangeLess :
public Constraint {
123 RangeLess(Solver*
const s, IntExpr*
const l, IntExpr*
const r);
124 ~RangeLess()
override {}
125 void Post()
override;
126 void InitialPropagate()
override;
127 std::string DebugString()
const override;
128 IntVar* Var()
override {
return solver()->MakeIsLessVar(left_, right_); }
129 void Accept(ModelVisitor*
const visitor)
const override {
130 visitor->BeginVisitConstraint(ModelVisitor::kLess,
this);
131 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
132 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
134 visitor->EndVisitConstraint(ModelVisitor::kLess,
this);
138 IntExpr*
const left_;
139 IntExpr*
const right_;
143 RangeLess::RangeLess(Solver*
const s, IntExpr*
const l, IntExpr*
const r)
144 : Constraint(s), left_(l), right_(r), demon_(nullptr) {}
146 void RangeLess::Post() {
147 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
148 left_->WhenRange(demon_);
149 right_->WhenRange(demon_);
152 void RangeLess::InitialPropagate() {
153 left_->SetMax(right_->Max() - 1);
154 right_->SetMin(left_->Min() + 1);
155 if (left_->Max() < right_->Min()) {
156 demon_->inhibit(solver());
160 std::string RangeLess::DebugString()
const {
161 return left_->DebugString() +
" < " + right_->DebugString();
167 class DiffVar :
public Constraint {
169 DiffVar(Solver*
const s, IntVar*
const l, IntVar*
const r);
170 ~DiffVar()
override {}
171 void Post()
override;
172 void InitialPropagate()
override;
173 std::string DebugString()
const override;
174 IntVar* Var()
override {
return solver()->MakeIsDifferentVar(left_, right_); }
178 void Accept(ModelVisitor*
const visitor)
const override {
179 visitor->BeginVisitConstraint(ModelVisitor::kNonEqual,
this);
180 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
181 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
183 visitor->EndVisitConstraint(ModelVisitor::kNonEqual,
this);
188 IntVar*
const right_;
191 DiffVar::DiffVar(Solver*
const s, IntVar*
const l, IntVar*
const r)
192 : Constraint(s), left_(l), right_(r) {}
194 void DiffVar::Post() {
195 Demon*
const left_demon =
197 Demon*
const right_demon =
199 left_->WhenBound(left_demon);
200 right_->WhenBound(right_demon);
204 void DiffVar::LeftBound() {
205 if (right_->Size() < 0xFFFFFF) {
206 right_->RemoveValue(left_->Min());
208 solver()->AddConstraint(solver()->MakeNonEquality(right_, left_->Min()));
212 void DiffVar::RightBound() {
213 if (left_->Size() < 0xFFFFFF) {
214 left_->RemoveValue(right_->Min());
216 solver()->AddConstraint(solver()->MakeNonEquality(left_, right_->Min()));
220 void DiffVar::InitialPropagate() {
221 if (left_->Bound()) {
224 if (right_->Bound()) {
229 std::string DiffVar::DebugString()
const {
230 return left_->DebugString() +
" != " + right_->DebugString();
240 class IsEqualCt :
public CastConstraint {
242 IsEqualCt(Solver*
const s, IntExpr*
const l, IntExpr*
const r,
244 : CastConstraint(s,
b), left_(l), right_(r), range_demon_(nullptr) {}
246 ~IsEqualCt()
override {}
248 void Post()
override {
249 range_demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
250 left_->WhenRange(range_demon_);
251 right_->WhenRange(range_demon_);
253 solver(),
this, &IsEqualCt::PropagateTarget,
"PropagateTarget");
257 void InitialPropagate()
override {
262 if (left_->Min() > right_->Max() || left_->Max() < right_->Min()) {
264 range_demon_->inhibit(solver());
265 }
else if (left_->Bound()) {
266 if (right_->Bound()) {
267 target_var_->SetValue(left_->Min() == right_->Min());
268 }
else if (right_->IsVar() && !right_->Var()->Contains(left_->Min())) {
269 range_demon_->inhibit(solver());
272 }
else if (right_->Bound() && left_->IsVar() &&
273 !left_->Var()->Contains(right_->Min())) {
274 range_demon_->inhibit(solver());
279 void PropagateTarget() {
281 if (left_->Bound()) {
282 range_demon_->inhibit(solver());
283 if (right_->IsVar()) {
284 right_->Var()->RemoveValue(left_->Min());
286 solver()->AddConstraint(
287 solver()->MakeNonEquality(right_, left_->Min()));
289 }
else if (right_->Bound()) {
290 range_demon_->inhibit(solver());
291 if (left_->IsVar()) {
292 left_->Var()->RemoveValue(right_->Min());
294 solver()->AddConstraint(
295 solver()->MakeNonEquality(left_, right_->Min()));
299 left_->SetRange(right_->Min(), right_->Max());
300 right_->SetRange(left_->Min(), left_->Max());
304 std::string DebugString()
const override {
305 return absl::StrFormat(
"IsEqualCt(%s, %s, %s)", left_->DebugString(),
306 right_->DebugString(),
target_var_->DebugString());
309 void Accept(ModelVisitor*
const visitor)
const override {
310 visitor->BeginVisitConstraint(ModelVisitor::kIsEqual,
this);
311 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
312 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
314 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
316 visitor->EndVisitConstraint(ModelVisitor::kIsEqual,
this);
320 IntExpr*
const left_;
321 IntExpr*
const right_;
327 class IsDifferentCt :
public CastConstraint {
329 IsDifferentCt(Solver*
const s, IntExpr*
const l, IntExpr*
const r,
331 : CastConstraint(s,
b), left_(l), right_(r), range_demon_(nullptr) {}
333 ~IsDifferentCt()
override {}
335 void Post()
override {
336 range_demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
337 left_->WhenRange(range_demon_);
338 right_->WhenRange(range_demon_);
340 solver(),
this, &IsDifferentCt::PropagateTarget,
"PropagateTarget");
344 void InitialPropagate()
override {
349 if (left_->Min() > right_->Max() || left_->Max() < right_->Min()) {
351 range_demon_->inhibit(solver());
352 }
else if (left_->Bound()) {
353 if (right_->Bound()) {
354 target_var_->SetValue(left_->Min() != right_->Min());
355 }
else if (right_->IsVar() && !right_->Var()->Contains(left_->Min())) {
356 range_demon_->inhibit(solver());
359 }
else if (right_->Bound() && left_->IsVar() &&
360 !left_->Var()->Contains(right_->Min())) {
361 range_demon_->inhibit(solver());
366 void PropagateTarget() {
368 left_->SetRange(right_->Min(), right_->Max());
369 right_->SetRange(left_->Min(), left_->Max());
371 if (left_->Bound()) {
372 range_demon_->inhibit(solver());
373 solver()->AddConstraint(
374 solver()->MakeNonEquality(right_, left_->Min()));
375 }
else if (right_->Bound()) {
376 range_demon_->inhibit(solver());
377 solver()->AddConstraint(
378 solver()->MakeNonEquality(left_, right_->Min()));
383 std::string DebugString()
const override {
384 return absl::StrFormat(
"IsDifferentCt(%s, %s, %s)", left_->DebugString(),
385 right_->DebugString(),
target_var_->DebugString());
388 void Accept(ModelVisitor*
const visitor)
const override {
389 visitor->BeginVisitConstraint(ModelVisitor::kIsDifferent,
this);
390 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
391 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
393 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
395 visitor->EndVisitConstraint(ModelVisitor::kIsDifferent,
this);
399 IntExpr*
const left_;
400 IntExpr*
const right_;
404 class IsLessOrEqualCt :
public CastConstraint {
406 IsLessOrEqualCt(Solver*
const s, IntExpr*
const l, IntExpr*
const r,
408 : CastConstraint(s,
b), left_(l), right_(r), demon_(nullptr) {}
410 ~IsLessOrEqualCt()
override {}
412 void Post()
override {
413 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
414 left_->WhenRange(demon_);
415 right_->WhenRange(demon_);
419 void InitialPropagate()
override {
422 right_->SetMax(left_->Max() - 1);
423 left_->SetMin(right_->Min() + 1);
425 right_->SetMin(left_->Min());
426 left_->SetMax(right_->Max());
428 }
else if (right_->Min() >= left_->Max()) {
429 demon_->inhibit(solver());
431 }
else if (right_->Max() < left_->Min()) {
432 demon_->inhibit(solver());
437 std::string DebugString()
const override {
438 return absl::StrFormat(
"IsLessOrEqualCt(%s, %s, %s)", left_->DebugString(),
439 right_->DebugString(),
target_var_->DebugString());
442 void Accept(ModelVisitor*
const visitor)
const override {
443 visitor->BeginVisitConstraint(ModelVisitor::kIsLessOrEqual,
this);
444 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
445 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
447 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
449 visitor->EndVisitConstraint(ModelVisitor::kIsLessOrEqual,
this);
453 IntExpr*
const left_;
454 IntExpr*
const right_;
458 class IsLessCt :
public CastConstraint {
460 IsLessCt(Solver*
const s, IntExpr*
const l, IntExpr*
const r, IntVar*
const b)
461 : CastConstraint(s,
b), left_(l), right_(r), demon_(nullptr) {}
463 ~IsLessCt()
override {}
465 void Post()
override {
466 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
467 left_->WhenRange(demon_);
468 right_->WhenRange(demon_);
472 void InitialPropagate()
override {
475 right_->SetMax(left_->Max());
476 left_->SetMin(right_->Min());
478 right_->SetMin(left_->Min() + 1);
479 left_->SetMax(right_->Max() - 1);
481 }
else if (right_->Min() > left_->Max()) {
482 demon_->inhibit(solver());
484 }
else if (right_->Max() <= left_->Min()) {
485 demon_->inhibit(solver());
490 std::string DebugString()
const override {
491 return absl::StrFormat(
"IsLessCt(%s, %s, %s)", left_->DebugString(),
492 right_->DebugString(),
target_var_->DebugString());
495 void Accept(ModelVisitor*
const visitor)
const override {
496 visitor->BeginVisitConstraint(ModelVisitor::kIsLess,
this);
497 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
498 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
500 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
502 visitor->EndVisitConstraint(ModelVisitor::kIsLess,
this);
506 IntExpr*
const left_;
507 IntExpr*
const right_;
513 CHECK(l !=
nullptr) <<
"left expression nullptr, maybe a bad cast";
514 CHECK(r !=
nullptr) <<
"left expression nullptr, maybe a bad cast";
518 return MakeEquality(r, l->
Min());
519 }
else if (r->
Bound()) {
520 return MakeEquality(l, r->
Min());
522 return RevAlloc(
new RangeEquality(
this, l, r));
527 CHECK(l !=
nullptr) <<
"left expression nullptr, maybe a bad cast";
528 CHECK(r !=
nullptr) <<
"left expression nullptr, maybe a bad cast";
532 return MakeTrueConstraint();
533 }
else if (l->
Bound()) {
534 return MakeGreaterOrEqual(r, l->
Min());
535 }
else if (r->
Bound()) {
536 return MakeLessOrEqual(l, r->
Min());
538 return RevAlloc(
new RangeLessOrEqual(
this, l, r));
543 return MakeLessOrEqual(r, l);
547 CHECK(l !=
nullptr) <<
"left expression nullptr, maybe a bad cast";
548 CHECK(r !=
nullptr) <<
"left expression nullptr, maybe a bad cast";
552 return MakeGreater(r, l->
Min());
553 }
else if (r->
Bound()) {
554 return MakeLess(l, r->
Min());
556 return RevAlloc(
new RangeLess(
this, l, r));
561 return MakeLess(r, l);
565 CHECK(l !=
nullptr) <<
"left expression nullptr, maybe a bad cast";
566 CHECK(r !=
nullptr) <<
"left expression nullptr, maybe a bad cast";
570 return MakeNonEquality(r, l->
Min());
571 }
else if (r->
Bound()) {
572 return MakeNonEquality(l, r->
Min());
574 return RevAlloc(
new DiffVar(
this, l->
Var(), r->
Var()));
581 return MakeIsEqualCstVar(v2, v1->
Min());
582 }
else if (v2->
Bound()) {
583 return MakeIsEqualCstVar(v1, v2->
Min());
585 IntExpr* cache = model_cache_->FindExprExprExpression(
586 v1, v2, ModelCache::EXPR_EXPR_IS_EQUAL);
587 if (cache ==
nullptr) {
588 cache = model_cache_->FindExprExprExpression(
589 v2, v1, ModelCache::EXPR_EXPR_IS_EQUAL);
591 if (cache !=
nullptr) {
594 IntVar* boolvar =
nullptr;
595 IntExpr* reverse_cache = model_cache_->FindExprExprExpression(
596 v1, v2, ModelCache::EXPR_EXPR_IS_NOT_EQUAL);
597 if (reverse_cache ==
nullptr) {
598 reverse_cache = model_cache_->FindExprExprExpression(
599 v2, v1, ModelCache::EXPR_EXPR_IS_NOT_EQUAL);
601 if (reverse_cache !=
nullptr) {
602 boolvar = MakeDifference(1, reverse_cache)->
Var();
604 std::string name1 = v1->
name();
608 std::string name2 = v2->
name();
613 MakeBoolVar(absl::StrFormat(
"IsEqualVar(%s, %s)", name1, name2));
614 AddConstraint(MakeIsEqualCt(v1, v2, boolvar));
615 model_cache_->InsertExprExprExpression(boolvar, v1, v2,
616 ModelCache::EXPR_EXPR_IS_EQUAL);
627 return MakeIsEqualCstCt(v2, v1->
Min(),
b);
628 }
else if (v2->
Bound()) {
629 return MakeIsEqualCstCt(v1, v2->
Min(),
b);
633 return MakeNonEquality(v1, v2);
635 return MakeEquality(v1, v2);
638 return RevAlloc(
new IsEqualCt(
this, v1, v2,
b));
645 return MakeIsDifferentCstVar(v2, v1->
Min());
646 }
else if (v2->
Bound()) {
647 return MakeIsDifferentCstVar(v1, v2->
Min());
649 IntExpr* cache = model_cache_->FindExprExprExpression(
650 v1, v2, ModelCache::EXPR_EXPR_IS_NOT_EQUAL);
651 if (cache ==
nullptr) {
652 cache = model_cache_->FindExprExprExpression(
653 v2, v1, ModelCache::EXPR_EXPR_IS_NOT_EQUAL);
655 if (cache !=
nullptr) {
658 IntVar* boolvar =
nullptr;
659 IntExpr* reverse_cache = model_cache_->FindExprExprExpression(
660 v1, v2, ModelCache::EXPR_EXPR_IS_EQUAL);
661 if (reverse_cache ==
nullptr) {
662 reverse_cache = model_cache_->FindExprExprExpression(
663 v2, v1, ModelCache::EXPR_EXPR_IS_EQUAL);
665 if (reverse_cache !=
nullptr) {
666 boolvar = MakeDifference(1, reverse_cache)->
Var();
668 std::string name1 = v1->
name();
672 std::string name2 = v2->
name();
677 MakeBoolVar(absl::StrFormat(
"IsDifferentVar(%s, %s)", name1, name2));
678 AddConstraint(MakeIsDifferentCt(v1, v2, boolvar));
680 model_cache_->InsertExprExprExpression(boolvar, v1, v2,
681 ModelCache::EXPR_EXPR_IS_NOT_EQUAL);
691 return MakeIsDifferentCstCt(v2, v1->
Min(),
b);
692 }
else if (v2->
Bound()) {
693 return MakeIsDifferentCstCt(v1, v2->
Min(),
b);
695 return RevAlloc(
new IsDifferentCt(
this, v1, v2,
b));
703 return MakeIsGreaterOrEqualCstVar(right, left->
Min());
704 }
else if (right->
Bound()) {
705 return MakeIsLessOrEqualCstVar(left, right->
Min());
707 IntExpr*
const cache = model_cache_->FindExprExprExpression(
708 left, right, ModelCache::EXPR_EXPR_IS_LESS_OR_EQUAL);
709 if (cache !=
nullptr) {
712 std::string name1 = left->
name();
716 std::string name2 = right->
name();
721 MakeBoolVar(absl::StrFormat(
"IsLessOrEqual(%s, %s)", name1, name2));
723 AddConstraint(RevAlloc(
new IsLessOrEqualCt(
this, left, right, boolvar)));
724 model_cache_->InsertExprExprExpression(
725 boolvar, left, right, ModelCache::EXPR_EXPR_IS_LESS_OR_EQUAL);
735 return MakeIsGreaterOrEqualCstCt(right, left->
Min(),
b);
736 }
else if (right->
Bound()) {
737 return MakeIsLessOrEqualCstCt(left, right->
Min(),
b);
739 return RevAlloc(
new IsLessOrEqualCt(
this, left, right,
b));
746 return MakeIsGreaterCstVar(right, left->
Min());
747 }
else if (right->
Bound()) {
748 return MakeIsLessCstVar(left, right->
Min());
750 IntExpr*
const cache = model_cache_->FindExprExprExpression(
751 left, right, ModelCache::EXPR_EXPR_IS_LESS);
752 if (cache !=
nullptr) {
755 std::string name1 = left->
name();
759 std::string name2 = right->
name();
764 MakeBoolVar(absl::StrFormat(
"IsLessOrEqual(%s, %s)", name1, name2));
766 AddConstraint(RevAlloc(
new IsLessCt(
this, left, right, boolvar)));
767 model_cache_->InsertExprExprExpression(boolvar, left, right,
768 ModelCache::EXPR_EXPR_IS_LESS);
778 return MakeIsGreaterCstCt(right, left->
Min(),
b);
779 }
else if (right->
Bound()) {
780 return MakeIsLessCstCt(left, right->
Min(),
b);
782 return RevAlloc(
new IsLessCt(
this, left, right,
b));
787 return MakeIsLessOrEqualVar(right, left);
793 return MakeIsLessOrEqualCt(right, left,
b);
797 return MakeIsLessVar(right, left);
802 return MakeIsLessCt(right, left,
b);