22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
34 "Initial size of the array of the hash "
35 "table of caches for objects of type Var(x == 3)");
43 class EqualityExprCst :
public Constraint {
45 EqualityExprCst(Solver*
const s, IntExpr*
const e,
int64 v);
46 ~EqualityExprCst()
override {}
48 void InitialPropagate()
override;
49 IntVar* Var()
override {
50 return solver()->MakeIsEqualCstVar(
expr_->Var(), value_);
52 std::string DebugString()
const override;
54 void Accept(ModelVisitor*
const visitor)
const override {
67 EqualityExprCst::EqualityExprCst(Solver*
const s, IntExpr*
const e,
int64 v)
68 : Constraint(s),
expr_(e), value_(v) {}
70 void EqualityExprCst::Post() {
71 if (!
expr_->IsVar()) {
72 Demon* d = solver()->MakeConstraintInitialPropagateCallback(
this);
77 void EqualityExprCst::InitialPropagate() {
expr_->SetValue(value_); }
79 std::string EqualityExprCst::DebugString()
const {
80 return absl::StrFormat(
"(%s == %d)",
expr_->DebugString(), value_);
88 if (IsADifference(e, &left, &right)) {
89 return MakeEquality(left, MakeSum(right, v));
91 return MakeFalseConstraint();
92 }
else if (e->
Min() == e->
Max() && e->
Min() == v) {
93 return MakeTrueConstraint();
95 return RevAlloc(
new EqualityExprCst(
this, e, v));
103 if (IsADifference(e, &left, &right)) {
104 return MakeEquality(left, MakeSum(right, v));
106 return MakeFalseConstraint();
107 }
else if (e->
Min() == e->
Max() && e->
Min() == v) {
108 return MakeTrueConstraint();
110 return RevAlloc(
new EqualityExprCst(
this, e, v));
121 ~GreaterEqExprCst()
override {}
122 void Post()
override;
125 IntVar*
Var()
override {
129 void Accept(ModelVisitor*
const visitor)
const override {
130 visitor->BeginVisitConstraint(ModelVisitor::kGreaterOrEqual,
this);
131 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
133 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
134 visitor->EndVisitConstraint(ModelVisitor::kGreaterOrEqual,
this);
138 IntExpr*
const expr_;
143 GreaterEqExprCst::GreaterEqExprCst(Solver*
const s, IntExpr*
const e,
int64 v)
144 : Constraint(s),
expr_(e), value_(v), demon_(nullptr) {}
146 void GreaterEqExprCst::Post() {
148 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
156 void GreaterEqExprCst::InitialPropagate() {
158 if (demon_ !=
nullptr &&
expr_->
Min() >= value_) {
163 std::string GreaterEqExprCst::DebugString()
const {
172 }
else if (e->
Max() < v) {
175 return RevAlloc(
new GreaterEqExprCst(
this, e, v));
183 }
else if (e->
Max() < v) {
186 return RevAlloc(
new GreaterEqExprCst(
this, e, v));
194 }
else if (e->
Max() <= v) {
197 return RevAlloc(
new GreaterEqExprCst(
this, e, v + 1));
205 }
else if (e->
Max() <= v) {
208 return RevAlloc(
new GreaterEqExprCst(
this, e, v + 1));
219 ~LessEqExprCst()
override {}
220 void Post()
override;
221 void InitialPropagate()
override;
222 std::string DebugString()
const override;
223 IntVar* Var()
override {
224 return solver()->MakeIsLessOrEqualCstVar(
expr_->Var(), value_);
226 void Accept(ModelVisitor*
const visitor)
const override {
235 IntExpr*
const expr_;
240 LessEqExprCst::LessEqExprCst(Solver*
const s, IntExpr*
const e,
int64 v)
241 : Constraint(s),
expr_(e), value_(v), demon_(nullptr) {}
243 void LessEqExprCst::Post() {
245 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
253 void LessEqExprCst::InitialPropagate() {
255 if (demon_ !=
nullptr &&
expr_->
Max() <= value_) {
260 std::string LessEqExprCst::DebugString()
const {
269 }
else if (e->
Min() > v) {
272 return RevAlloc(
new LessEqExprCst(
this, e, v));
280 }
else if (e->
Min() > v) {
283 return RevAlloc(
new LessEqExprCst(
this, e, v));
291 }
else if (e->
Min() >= v) {
294 return RevAlloc(
new LessEqExprCst(
this, e, v - 1));
302 }
else if (e->
Min() >= v) {
305 return RevAlloc(
new LessEqExprCst(
this, e, v - 1));
316 ~DiffCst()
override {}
317 void Post()
override {}
318 void InitialPropagate()
override;
319 void BoundPropagate();
320 std::string DebugString()
const override;
321 IntVar* Var()
override {
322 return solver()->MakeIsDifferentCstVar(var_, value_);
324 void Accept(ModelVisitor*
const visitor)
const override {
333 bool HasLargeDomain(IntVar*
var);
340 DiffCst::DiffCst(Solver*
const s, IntVar*
const var,
int64 value)
341 : Constraint(s), var_(
var), value_(
value), demon_(nullptr) {}
343 void DiffCst::InitialPropagate() {
344 if (HasLargeDomain(var_)) {
353 void DiffCst::BoundPropagate() {
356 if (var_min > value_ || var_max < value_) {
358 }
else if (var_min == value_) {
360 }
else if (var_max == value_) {
362 }
else if (!HasLargeDomain(var_)) {
368 std::string DiffCst::DebugString()
const {
369 return absl::StrFormat(
"(%s != %d)", var_->
DebugString(), value_);
372 bool DiffCst::HasLargeDomain(IntVar*
var) {
381 if (IsADifference(e, &left, &right)) {
385 }
else if (e->
Bound() && e->
Min() == v) {
396 if (IsADifference(e, &left, &right)) {
400 }
else if (e->
Bound() && e->
Min() == v) {
413 void Post()
override {
414 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
415 var_->WhenDomain(demon_);
418 void InitialPropagate()
override {
419 bool inhibit = var_->Bound();
421 int64 l = inhibit ? u : 0;
425 if (var_->Size() <= 0xFFFFFF) {
426 var_->RemoveValue(
cst_);
430 var_->SetValue(
cst_);
435 demon_->inhibit(solver());
438 std::string DebugString()
const override {
439 return absl::StrFormat(
"IsEqualCstCt(%s, %d, %s)", var_->DebugString(),
443 void Accept(ModelVisitor*
const visitor)
const override {
463 if (IsADifference(
var, &left, &right)) {
501 if (boolvar->
Bound()) {
502 if (boolvar->
Min() == 0) {
510 model_cache_->InsertExprConstantExpression(
514 if (IsADifference(
var, &left, &right)) {
529 void Post()
override {
530 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
531 var_->WhenDomain(demon_);
535 void InitialPropagate()
override {
536 bool inhibit = var_->Bound();
538 int64 u = inhibit ? l : 1;
542 if (var_->Size() <= 0xFFFFFF) {
543 var_->RemoveValue(
cst_);
547 var_->SetValue(
cst_);
552 demon_->inhibit(solver());
556 std::string DebugString()
const override {
557 return absl::StrFormat(
"IsDiffCstCt(%s, %d, %s)", var_->DebugString(),
cst_,
561 void Accept(ModelVisitor*
const visitor)
const override {
581 if (IsADifference(
var, &left, &right)) {
584 return var->Var()->IsDifferent(
value);
597 if (
var->IsVar() && !
var->Var()->Contains(
value)) {
603 if (boolvar->
Bound()) {
604 if (boolvar->
Min() == 0) {
610 model_cache_->InsertExprConstantExpression(
614 if (IsADifference(
var, &left, &right)) {
629 void Post()
override {
630 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
631 expr_->WhenRange(demon_);
634 void InitialPropagate()
override {
635 bool inhibit =
false;
651 demon_->inhibit(solver());
654 std::string DebugString()
const override {
655 return absl::StrFormat(
"IsGreaterEqualCstCt(%s, %d, %s)",
660 void Accept(ModelVisitor*
const visitor)
const override {
671 IntExpr*
const expr_;
685 return var->Var()->IsGreaterOrEqual(
value);
700 if (boolvar->
Bound()) {
701 if (boolvar->
Min() == 0) {
709 model_cache_->InsertExprConstantExpression(
727 void Post()
override {
728 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
729 expr_->WhenRange(demon_);
733 void InitialPropagate()
override {
734 bool inhibit =
false;
750 demon_->inhibit(solver());
754 std::string DebugString()
const override {
755 return absl::StrFormat(
"IsLessEqualCstCt(%s, %d, %s)",
expr_->DebugString(),
759 void Accept(ModelVisitor*
const visitor)
const override {
770 IntExpr*
const expr_;
784 return var->Var()->IsLessOrEqual(
value);
799 if (boolvar->
Bound()) {
800 if (boolvar->
Min() == 0) {
808 model_cache_->InsertExprConstantExpression(
826 void Post()
override {
827 if (!
expr_->IsVar()) {
828 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
829 expr_->WhenRange(demon_);
833 void InitialPropagate()
override {
834 expr_->SetRange(min_, max_);
837 expr_->Range(&emin, &emax);
838 if (demon_ !=
nullptr && emin >= min_ && emax <= max_) {
839 demon_->inhibit(solver());
843 std::string DebugString()
const override {
844 return absl::StrFormat(
"BetweenCt(%s, %d, %d)",
expr_->DebugString(), min_,
848 void Accept(ModelVisitor*
const visitor)
const override {
858 IntExpr*
const expr_;
866 class NotBetweenCt :
public Constraint {
868 NotBetweenCt(Solver*
const s, IntExpr*
const v,
int64 l,
int64 u)
869 : Constraint(s),
expr_(v), min_(l), max_(u), demon_(nullptr) {}
871 void Post()
override {
872 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
873 expr_->WhenRange(demon_);
876 void InitialPropagate()
override {
879 expr_->Range(&emin, &emax);
881 expr_->SetMin(max_ + 1);
882 }
else if (emax <= max_) {
883 expr_->SetMax(min_ - 1);
886 if (!
expr_->IsVar() && (emax < min_ || emin > max_)) {
887 demon_->inhibit(solver());
891 std::string DebugString()
const override {
892 return absl::StrFormat(
"NotBetweenCt(%s, %d, %d)",
expr_->DebugString(),
896 void Accept(ModelVisitor*
const visitor)
const override {
906 IntExpr*
const expr_;
912 int64 ExtractExprProductCoeff(IntExpr** expr) {
915 while ((*expr)->solver()->IsProduct(*expr, expr, &coeff)) prod *= coeff;
929 expr->
Range(&emin, &emax);
937 int64 coeff = ExtractExprProductCoeff(&expr);
949 return RevAlloc(
new BetweenCt(
this, expr, l, u));
962 expr->
Range(&emin, &emax);
968 if (emax <= u)
return MakeLess(expr, l);
971 return RevAlloc(
new NotBetweenCt(
this, expr, l, u));
988 void Post()
override {
989 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
990 expr_->WhenRange(demon_);
991 boolvar_->WhenBound(demon_);
994 void InitialPropagate()
override {
995 bool inhibit =
false;
998 expr_->Range(&emin, &emax);
999 int64 u = 1 - (emin > max_ || emax < min_);
1000 int64 l = emax <= max_ && emin >= min_;
1001 boolvar_->SetRange(l, u);
1002 if (boolvar_->Bound()) {
1004 if (boolvar_->Min() == 0) {
1005 if (
expr_->IsVar()) {
1006 expr_->Var()->RemoveInterval(min_, max_);
1008 }
else if (emin > min_) {
1009 expr_->SetMin(max_ + 1);
1010 }
else if (emax < max_) {
1011 expr_->SetMax(min_ - 1);
1014 expr_->SetRange(min_, max_);
1017 if (inhibit &&
expr_->IsVar()) {
1018 demon_->inhibit(solver());
1023 std::string DebugString()
const override {
1024 return absl::StrFormat(
"IsBetweenCt(%s, %d, %d, %s)",
expr_->DebugString(),
1025 min_, max_, boolvar_->DebugString());
1028 void Accept(ModelVisitor*
const visitor)
const override {
1040 IntExpr*
const expr_;
1043 IntVar*
const boolvar_;
1059 expr->
Range(&emin, &emax);
1067 int64 coeff = ExtractExprProductCoeff(&expr);
1080 return RevAlloc(
new IsBetweenCt(
this, expr, l, u,
b));
1100 const std::vector<int64>& sorted_values)
1101 :
Constraint(s), var_(v), values_(sorted_values) {
1106 void Post()
override {}
1108 void InitialPropagate()
override { var_->SetValues(values_); }
1110 std::string DebugString()
const override {
1111 return absl::StrFormat(
"Member(%s, %s)", var_->DebugString(),
1112 absl::StrJoin(values_,
", "));
1115 void Accept(ModelVisitor*
const visitor)
const override {
1125 const std::vector<int64> values_;
1128 class NotMemberCt :
public Constraint {
1130 NotMemberCt(Solver*
const s, IntVar*
const v,
1131 const std::vector<int64>& sorted_values)
1132 : Constraint(s), var_(v), values_(sorted_values) {
1137 void Post()
override {}
1139 void InitialPropagate()
override { var_->RemoveValues(values_); }
1141 std::string DebugString()
const override {
1142 return absl::StrFormat(
"NotMember(%s, %s)", var_->DebugString(),
1143 absl::StrJoin(values_,
", "));
1146 void Accept(ModelVisitor*
const visitor)
const override {
1156 const std::vector<int64> values_;
1161 const std::vector<int64>& values) {
1162 const int64 coeff = ExtractExprProductCoeff(&expr);
1164 return std::find(values.begin(), values.end(), 0) == values.end()
1168 std::vector<int64> copied_values = values;
1173 for (
const int64 v : copied_values) {
1174 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1176 copied_values.resize(num_kept);
1182 expr->
Range(&emin, &emax);
1183 for (
const int64 v : copied_values) {
1184 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1186 copied_values.resize(num_kept);
1192 if (copied_values.size() == 1)
return MakeEquality(expr, copied_values[0]);
1194 if (copied_values.size() ==
1195 copied_values.back() - copied_values.front() + 1) {
1197 return MakeBetweenCt(expr, copied_values.front(), copied_values.back());
1202 if (emax - emin < 2 * copied_values.size()) {
1204 std::vector<bool> is_among_input_values(emax - emin + 1,
false);
1205 for (
const int64 v : copied_values) is_among_input_values[v - emin] =
true;
1208 copied_values.clear();
1209 for (
int64 v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1210 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1215 if (copied_values.size() == 1) {
1218 return RevAlloc(
new NotMemberCt(
this, expr->
Var(), copied_values));
1221 return RevAlloc(
new MemberCt(
this, expr->
Var(), copied_values));
1225 const std::vector<int>& values) {
1230 const std::vector<int64>& values) {
1231 const int64 coeff = ExtractExprProductCoeff(&expr);
1233 return std::find(values.begin(), values.end(), 0) == values.end()
1237 std::vector<int64> copied_values = values;
1242 for (
const int64 v : copied_values) {
1243 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1245 copied_values.resize(num_kept);
1251 expr->
Range(&emin, &emax);
1252 for (
const int64 v : copied_values) {
1253 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1255 copied_values.resize(num_kept);
1261 if (copied_values.size() == 1)
return MakeNonEquality(expr, copied_values[0]);
1263 if (copied_values.size() ==
1264 copied_values.back() - copied_values.front() + 1) {
1265 return MakeNotBetweenCt(expr, copied_values.front(), copied_values.back());
1270 if (emax - emin < 2 * copied_values.size()) {
1272 std::vector<bool> is_among_input_values(emax - emin + 1,
false);
1273 for (
const int64 v : copied_values) is_among_input_values[v - emin] =
true;
1276 copied_values.clear();
1277 for (
int64 v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1278 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1283 if (copied_values.size() == 1) {
1286 return RevAlloc(
new MemberCt(
this, expr->
Var(), copied_values));
1289 return RevAlloc(
new NotMemberCt(
this, expr->
Var(), copied_values));
1293 const std::vector<int>& values) {
1303 const std::vector<int64>& sorted_values,
IntVar*
const b)
1306 values_as_set_(sorted_values.begin(), sorted_values.end()),
1307 values_(sorted_values),
1311 domain_(var_->MakeDomainIterator(true)),
1321 void Post()
override {
1324 if (!var_->Bound()) {
1325 var_->WhenDomain(demon_);
1327 if (!boolvar_->Bound()) {
1329 solver(),
this, &IsMemberCt::TargetBound,
"TargetBound");
1330 boolvar_->WhenBound(bdemon);
1334 void InitialPropagate()
override {
1335 boolvar_->SetRange(0, 1);
1336 if (boolvar_->Bound()) {
1343 std::string DebugString()
const override {
1344 return absl::StrFormat(
"IsMemberCt(%s, %s, %s)", var_->DebugString(),
1345 absl::StrJoin(values_,
", "),
1346 boolvar_->DebugString());
1349 void Accept(ModelVisitor*
const visitor)
const override {
1361 if (boolvar_->Bound()) {
1364 for (
int offset = 0; offset < values_.size(); ++offset) {
1365 const int candidate = (support_ + offset) % values_.size();
1366 if (var_->Contains(values_[candidate])) {
1367 support_ = candidate;
1368 if (var_->Bound()) {
1369 demon_->inhibit(solver());
1370 boolvar_->SetValue(1);
1375 if (var_->Contains(neg_support_)) {
1379 for (
const int64 value : InitAndGetValues(domain_)) {
1381 neg_support_ =
value;
1387 demon_->inhibit(solver());
1388 boolvar_->SetValue(1);
1393 demon_->inhibit(solver());
1394 boolvar_->SetValue(0);
1398 void TargetBound() {
1399 DCHECK(boolvar_->Bound());
1400 if (boolvar_->Min() == 1LL) {
1401 demon_->inhibit(solver());
1402 var_->SetValues(values_);
1404 demon_->inhibit(solver());
1405 var_->RemoveValues(values_);
1410 absl::flat_hash_set<int64> values_as_set_;
1411 std::vector<int64> values_;
1412 IntVar*
const boolvar_;
1415 IntVarIterator*
const domain_;
1420 Constraint* BuildIsMemberCt(Solver*
const solver, IntExpr*
const expr,
1421 const std::vector<T>& values,
1422 IntVar*
const boolvar) {
1425 IntExpr* sub =
nullptr;
1427 if (solver->IsProduct(expr, &sub, &
coef) &&
coef != 0 &&
coef != 1) {
1428 std::vector<int64> new_values;
1429 new_values.reserve(values.size());
1435 return BuildIsMemberCt(solver, sub, new_values, boolvar);
1438 std::set<T> set_of_values(values.begin(), values.end());
1439 std::vector<int64> filtered_values;
1440 bool all_values =
false;
1441 if (expr->IsVar()) {
1442 IntVar*
const var = expr->
Var();
1443 for (
const T
value : set_of_values) {
1445 filtered_values.push_back(
value);
1448 all_values = (filtered_values.size() ==
var->
Size());
1452 expr->Range(&emin, &emax);
1453 for (
const T
value : set_of_values) {
1455 filtered_values.push_back(
value);
1458 all_values = (filtered_values.size() == emax - emin + 1);
1460 if (filtered_values.empty()) {
1461 return solver->MakeEquality(boolvar,
Zero());
1462 }
else if (all_values) {
1463 return solver->MakeEquality(boolvar, 1);
1464 }
else if (filtered_values.size() == 1) {
1465 return solver->MakeIsEqualCstCt(expr, filtered_values.back(), boolvar);
1466 }
else if (filtered_values.back() ==
1467 filtered_values.front() + filtered_values.size() - 1) {
1469 return solver->MakeIsBetweenCt(expr, filtered_values.front(),
1470 filtered_values.back(), boolvar);
1472 return solver->RevAlloc(
1473 new IsMemberCt(solver, expr->Var(), filtered_values, boolvar));
1479 const std::vector<int64>& values,
1481 return BuildIsMemberCt(
this, expr, values, boolvar);
1485 const std::vector<int>& values,
1487 return BuildIsMemberCt(
this, expr, values, boolvar);
1491 const std::vector<int64>& values) {
1498 const std::vector<int>& values) {
1505 class SortedDisjointForbiddenIntervalsConstraint :
public Constraint {
1507 SortedDisjointForbiddenIntervalsConstraint(
1510 :
Constraint(solver), var_(
var), intervals_(std::move(intervals)) {}
1512 ~SortedDisjointForbiddenIntervalsConstraint()
override {}
1514 void Post()
override {
1515 Demon*
const demon = solver()->MakeConstraintInitialPropagateCallback(
this);
1516 var_->WhenRange(demon);
1519 void InitialPropagate()
override {
1520 const int64 vmin = var_->Min();
1521 const int64 vmax = var_->Max();
1522 const auto first_interval_it = intervals_.FirstIntervalGreaterOrEqual(vmin);
1523 if (first_interval_it == intervals_.end()) {
1527 const auto last_interval_it = intervals_.LastIntervalLessOrEqual(vmax);
1528 if (last_interval_it == intervals_.end()) {
1534 if (vmin >= first_interval_it->start) {
1537 var_->SetMin(
CapAdd(first_interval_it->end, 1));
1539 if (vmax <= last_interval_it->end) {
1541 var_->SetMax(
CapSub(last_interval_it->start, 1));
1545 std::string DebugString()
const override {
1546 return absl::StrFormat(
"ForbiddenIntervalCt(%s, %s)", var_->DebugString(),
1547 intervals_.DebugString());
1550 void Accept(ModelVisitor*
const visitor)
const override {
1554 std::vector<int64> starts;
1555 std::vector<int64> ends;
1556 for (
auto&
interval : intervals_) {
1567 const SortedDisjointIntervalList intervals_;
1572 std::vector<int64> starts,
1573 std::vector<int64> ends) {
1574 return RevAlloc(
new SortedDisjointForbiddenIntervalsConstraint(
1575 this, expr->
Var(), {starts, ends}));
1579 std::vector<int> starts,
1580 std::vector<int> ends) {
1581 return RevAlloc(
new SortedDisjointForbiddenIntervalsConstraint(
1582 this, expr->
Var(), {starts, ends}));
1587 return RevAlloc(
new SortedDisjointForbiddenIntervalsConstraint(
1588 this, expr->
Var(), std::move(intervals)));