22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_format.h"
34 ABSL_FLAG(
int, cp_impact_divider, 10,
"Divider for continuous update.");
40 const int kDefaultNumberOfSplits = 100;
41 const int kDefaultHeuristicPeriod = 100;
42 const int kDefaultHeuristicNumFailuresLimit = 30;
43 const bool kDefaultUseLastConflict =
true;
49 initialization_splits(kDefaultNumberOfSplits),
50 run_all_heuristics(true),
51 heuristic_period(kDefaultHeuristicPeriod),
52 heuristic_num_failures_limit(kDefaultHeuristicNumFailuresLimit),
53 persistent_impact(true),
56 use_last_conflict(kDefaultUseLastConflict),
57 decision_builder(nullptr) {}
66 DomainWatcher(
const std::vector<IntVar*>& vars,
int cache_size)
68 cached_log_.Init(cache_size);
71 double LogSearchSpaceSize() {
74 result += cached_log_.Log2(vars_[
index]->Size());
79 double Log2(
int64 size)
const {
return cached_log_.Log2(size); }
82 std::vector<IntVar*>
vars_;
83 CachedLog cached_log_;
89 class FindVar :
public DecisionVisitor {
91 enum Operation { NONE, ASSIGN, SPLIT_LOW, SPLIT_HIGH };
93 FindVar() : var_(nullptr), value_(0), operation_(NONE) {}
95 ~FindVar()
override {}
97 void VisitSetVariableValue(IntVar*
const var,
int64 value)
override {
103 void VisitSplitVariableDomain(IntVar*
const var,
int64 value,
104 bool start_with_lower_half)
override {
107 operation_ = start_with_lower_half ? SPLIT_LOW : SPLIT_HIGH;
110 void VisitScheduleOrPostpone(IntervalVar*
const var,
int64 est)
override {
114 virtual void VisitTryRankFirst(SequenceVar*
const sequence,
int index) {
118 virtual void VisitTryRankLast(SequenceVar*
const sequence,
int index) {
122 void VisitUnknownDecision()
override { operation_ = NONE; }
125 IntVar*
const var()
const {
136 Operation operation()
const {
return operation_; }
138 std::string DebugString()
const override {
139 return "FindVar decision visitor";
145 Operation operation_;
152 class InitVarImpacts :
public DecisionBuilder {
157 update_impact_callback_(nullptr),
161 update_impact_closure_([this]() { UpdateImpacts(); }),
162 updater_(update_impact_closure_) {
163 CHECK(update_impact_closure_ !=
nullptr);
166 ~InitVarImpacts()
override {}
168 void UpdateImpacts() {
170 update_impact_callback_(var_index_, var_->Min());
173 void Init(IntVar*
const var, IntVarIterator*
const iterator,
int var_index) {
176 var_index_ = var_index;
181 Decision* Next(Solver*
const solver)
override {
182 CHECK(var_ !=
nullptr);
185 active_values_.clear();
187 active_values_.push_back(
value);
191 if (value_index_ == active_values_.size()) {
194 updater_.var_ = var_;
195 updater_.value_ = active_values_[value_index_];
200 void set_update_impact_callback(std::function<
void(
int,
int64)>
callback) {
201 update_impact_callback_ = std::move(
callback);
206 class AssignCallFail :
public Decision {
208 explicit AssignCallFail(
const std::function<
void()>& update_impact_closure)
211 update_impact_closure_(update_impact_closure) {
212 CHECK(update_impact_closure_ !=
nullptr);
214 ~AssignCallFail()
override {}
215 void Apply(Solver*
const solver)
override {
216 CHECK(var_ !=
nullptr);
217 var_->SetValue(value_);
219 update_impact_closure_();
222 void Refute(Solver*
const solver)
override {}
228 const std::function<void()>& update_impact_closure_;
233 std::function<void(
int,
int64)> update_impact_callback_;
237 std::vector<int64> active_values_;
239 std::function<void()> update_impact_closure_;
240 AssignCallFail updater_;
246 class InitVarImpactsWithSplits :
public DecisionBuilder {
249 class AssignIntervalCallFail :
public Decision {
251 explicit AssignIntervalCallFail(
252 const std::function<
void()>& update_impact_closure)
256 update_impact_closure_(update_impact_closure) {
257 CHECK(update_impact_closure_ !=
nullptr);
259 ~AssignIntervalCallFail()
override {}
260 void Apply(Solver*
const solver)
override {
261 CHECK(var_ !=
nullptr);
264 update_impact_closure_();
267 void Refute(Solver*
const solver)
override {}
275 const std::function<void()>& update_impact_closure_;
281 explicit InitVarImpactsWithSplits(
int split_size)
283 update_impact_callback_(nullptr),
288 split_size_(split_size),
290 update_impact_closure_([this]() { UpdateImpacts(); }),
291 updater_(update_impact_closure_) {
292 CHECK(update_impact_closure_ !=
nullptr);
295 ~InitVarImpactsWithSplits()
override {}
297 void UpdateImpacts() {
299 update_impact_callback_(var_index_,
value);
303 void Init(IntVar*
const var, IntVarIterator*
const iterator,
int var_index) {
306 var_index_ = var_index;
312 const int64 length = max_value_ - min_value_ + 1;
313 return (min_value_ + length *
index / split_size_);
316 Decision* Next(Solver*
const solver)
override {
318 min_value_ = var_->Min();
319 max_value_ = var_->Max();
322 if (split_index_ == split_size_) {
325 updater_.var_ = var_;
326 updater_.value_min_ = IntervalStart(split_index_);
328 if (split_index_ == split_size_) {
329 updater_.value_max_ = max_value_;
331 updater_.value_max_ = IntervalStart(split_index_) - 1;
336 void set_update_impact_callback(std::function<
void(
int,
int64)>
callback) {
337 update_impact_callback_ = std::move(
callback);
342 std::function<void(
int,
int64)> update_impact_callback_;
348 const int split_size_;
350 std::function<void()> update_impact_closure_;
351 AssignIntervalCallFail updater_;
359 class ImpactRecorder :
public SearchMonitor {
367 ImpactRecorder(Solver*
const solver, DomainWatcher*
const domain_watcher,
368 const std::vector<IntVar*>& vars,
370 : SearchMonitor(solver),
371 domain_watcher_(domain_watcher),
374 current_log_space_(0.0),
376 original_min_(size_, 0LL),
377 domain_iterators_(new IntVarIterator*[size_]),
378 display_level_(display_level),
382 for (
int i = 0; i < size_; ++i) {
383 domain_iterators_[i] =
vars_[i]->MakeDomainIterator(
true);
384 var_map_[
vars_[i]] = i;
388 void ApplyDecision(Decision*
const d)
override {
392 d->Accept(&find_var_);
393 if (find_var_.operation() == FindVar::ASSIGN &&
395 current_var_ = var_map_[find_var_.var()];
396 current_value_ = find_var_.value();
397 current_log_space_ = domain_watcher_->LogSearchSpaceSize();
404 void AfterDecision(Decision*
const d,
bool apply)
override {
406 if (current_log_space_ > 0.0) {
407 const double log_space = domain_watcher_->LogSearchSpaceSize();
409 const double impact =
kPerfectImpact - log_space / current_log_space_;
410 UpdateImpact(current_var_, current_value_, impact);
414 current_log_space_ = log_space;
419 void BeginFail()
override {
427 void ResetAllImpacts() {
428 for (
int i = 0; i < size_; ++i) {
429 original_min_[i] =
vars_[i]->Min();
433 impacts_[i].resize(vars_[i]->Max() - vars_[i]->Min() + 1,
437 for (
int i = 0; i < size_; ++i) {
438 for (
int j = 0; j < impacts_[i].size(); ++j) {
444 void UpdateImpact(
int var_index,
int64 value,
double impact) {
445 const int64 value_index =
value - original_min_[var_index];
446 const double current_impact = impacts_[var_index][value_index];
447 const double new_impact =
448 (current_impact * (absl::GetFlag(FLAGS_cp_impact_divider) - 1) +
450 absl::GetFlag(FLAGS_cp_impact_divider);
451 impacts_[var_index][value_index] = new_impact;
455 const double log_space = domain_watcher_->LogSearchSpaceSize();
456 const double impact =
kPerfectImpact - log_space / current_log_space_;
457 const int64 value_index =
value - original_min_[var_index];
459 DCHECK_LT(value_index, impacts_[var_index].size());
460 impacts_[var_index][value_index] = impact;
464 void FirstRun(
int64 splits) {
465 Solver*
const s = solver();
466 current_log_space_ = domain_watcher_->LogSearchSpaceSize();
468 LOG(
INFO) <<
" - initial log2(SearchSpace) = " << current_log_space_;
470 const int64 init_time = s->wall_time();
472 int64 removed_counter = 0;
473 FirstRunVariableContainers* container =
474 s->RevAlloc(
new FirstRunVariableContainers(
this, splits));
476 for (
int var_index = 0; var_index < size_; ++var_index) {
477 IntVar*
const var =
vars_[var_index];
481 IntVarIterator*
const iterator = domain_iterators_[var_index];
482 DecisionBuilder* init_decision_builder =
nullptr;
483 const bool no_split =
var->Size() < splits;
486 container->without_split()->set_update_impact_callback(
487 container->update_impact_callback());
488 container->without_split()->Init(
var, iterator, var_index);
489 init_decision_builder = container->without_split();
493 container->with_splits()->set_update_impact_callback(
494 container->update_impact_callback());
495 container->with_splits()->Init(
var, iterator, var_index);
496 init_decision_builder = container->with_splits();
501 s->Solve(init_decision_builder);
506 if (init_count_ !=
var->Size() && no_split) {
507 container->ClearRemovedValues();
508 for (
const int64 value : InitAndGetValues(iterator)) {
509 const int64 value_index =
value - original_min_[var_index];
511 container->PushBackRemovedValue(
value);
514 CHECK(container->HasRemovedValues()) <<
var->DebugString();
515 removed_counter += container->NumRemovedValues();
516 const double old_log = domain_watcher_->Log2(
var->Size());
517 var->RemoveValues(container->removed_values());
518 current_log_space_ += domain_watcher_->Log2(
var->Size()) - old_log;
522 if (removed_counter) {
523 LOG(
INFO) <<
" - init done, time = " << s->wall_time() - init_time
524 <<
" ms, " << removed_counter
525 <<
" values removed, log2(SearchSpace) = "
526 << current_log_space_;
528 LOG(
INFO) <<
" - init done, time = " << s->wall_time() - init_time
532 s->SaveAndSetValue(&init_done_,
true);
538 void ScanVarImpacts(
int var_index,
int64*
const best_impact_value,
539 double*
const var_impacts,
542 CHECK(best_impact_value !=
nullptr);
543 CHECK(var_impacts !=
nullptr);
546 double sum_var_impact = 0.0;
547 int64 min_impact_value = -1;
548 int64 max_impact_value = -1;
549 for (
const int64 value : InitAndGetValues(domain_iterators_[var_index])) {
550 const int64 value_index =
value - original_min_[var_index];
552 DCHECK_LT(value_index, impacts_[var_index].size());
553 const double current_impact = impacts_[var_index][value_index];
554 sum_var_impact += current_impact;
555 if (current_impact > max_impact) {
556 max_impact = current_impact;
557 max_impact_value =
value;
559 if (current_impact < min_impact) {
560 min_impact = current_impact;
561 min_impact_value =
value;
565 switch (var_select) {
567 *var_impacts = sum_var_impact /
vars_[var_index]->Size();
571 *var_impacts = max_impact;
575 *var_impacts = sum_var_impact;
580 switch (value_select) {
582 *best_impact_value = min_impact_value;
586 *best_impact_value = max_impact_value;
592 std::string DebugString()
const override {
return "ImpactRecorder"; }
597 class FirstRunVariableContainers :
public BaseObject {
599 FirstRunVariableContainers(ImpactRecorder* impact_recorder,
int64 splits)
600 : update_impact_callback_(
601 [impact_recorder](int var_index,
int64 value) {
602 impact_recorder->InitImpact(var_index,
value);
606 with_splits_(splits) {}
607 std::function<void(
int,
int64)> update_impact_callback()
const {
608 return update_impact_callback_;
610 void PushBackRemovedValue(
int64 value) { removed_values_.push_back(
value); }
611 bool HasRemovedValues()
const {
return !removed_values_.empty(); }
612 void ClearRemovedValues() { removed_values_.clear(); }
613 size_t NumRemovedValues()
const {
return removed_values_.size(); }
614 const std::vector<int64>& removed_values()
const {
return removed_values_; }
615 InitVarImpacts* without_split() {
return &without_splits_; }
616 InitVarImpactsWithSplits* with_splits() {
return &with_splits_; }
618 std::string DebugString()
const override {
619 return "FirstRunVariableContainers";
623 const std::function<void(
int,
int64)> update_impact_callback_;
624 std::vector<int64> removed_values_;
625 InitVarImpacts without_splits_;
626 InitVarImpactsWithSplits with_splits_;
629 DomainWatcher*
const domain_watcher_;
630 std::vector<IntVar*>
vars_;
632 double current_log_space_;
635 std::vector<std::vector<double> > impacts_;
636 std::vector<int64> original_min_;
637 std::unique_ptr<IntVarIterator*[]> domain_iterators_;
641 int64 current_value_;
643 absl::flat_hash_map<const IntVar*, int> var_map_;
658 ChoiceInfo() : value_(0), var_(nullptr), left_(false) {}
661 : value_(
value), var_(
var), left_(left) {}
663 std::string DebugString()
const {
664 return absl::StrFormat(
"%s %s %d", var_->name(), (left_ ?
"==" :
"!="),
668 IntVar*
var()
const {
return var_; }
670 bool left()
const {
return left_; }
674 void set_left(
bool left) { left_ = left; }
684 class RunHeuristicsAsDives :
public Decision {
686 RunHeuristicsAsDives(Solver*
const solver,
const std::vector<IntVar*>& vars,
688 bool run_all_heuristics,
int random_seed,
689 int heuristic_period,
int heuristic_num_failures_limit)
690 : heuristic_limit_(nullptr),
691 display_level_(level),
692 run_all_heuristics_(run_all_heuristics),
693 random_(random_seed),
694 heuristic_period_(heuristic_period),
695 heuristic_branch_count_(0),
697 Init(solver, vars, heuristic_num_failures_limit);
702 void Apply(Solver*
const solver)
override {
703 if (!RunAllHeuristics(solver)) {
708 void Refute(Solver*
const solver)
override {}
711 if (heuristic_period_ <= 0) {
714 ++heuristic_branch_count_;
715 return heuristic_branch_count_ % heuristic_period_ == 0;
718 bool RunOneHeuristic(Solver*
const solver,
int index) {
719 HeuristicWrapper*
const wrapper = heuristics_[
index];
723 solver->SolveAndCommit(wrapper->phase, heuristic_limit_);
725 LOG(
INFO) <<
" --- solution found by heuristic " << wrapper->name
731 bool RunAllHeuristics(Solver*
const solver) {
732 if (run_all_heuristics_) {
734 for (
int run = 0; run < heuristics_[
index]->runs; ++run) {
735 if (RunOneHeuristic(solver,
index)) {
743 const int index = absl::Uniform<int>(random_, 0, heuristics_.size());
744 return RunOneHeuristic(solver,
index);
748 int Rand32(
int size) {
750 return absl::Uniform<int>(random_, 0, size);
753 void Init(Solver*
const solver,
const std::vector<IntVar*>& vars,
754 int heuristic_num_failures_limit) {
755 const int kRunOnce = 1;
756 const int kRunMore = 2;
757 const int kRunALot = 3;
759 heuristics_.push_back(
new HeuristicWrapper(
763 heuristics_.push_back(
new HeuristicWrapper(
767 heuristics_.push_back(
770 "AssignCenterValueToMinDomainSize", kRunOnce));
772 heuristics_.push_back(
new HeuristicWrapper(
774 "AssignRandomValueToFirstUnbound", kRunALot));
776 heuristics_.push_back(
new HeuristicWrapper(
778 "AssignMinValueToRandomVariable", kRunMore));
780 heuristics_.push_back(
new HeuristicWrapper(
782 "AssignMaxValueToRandomVariable", kRunMore));
784 heuristics_.push_back(
new HeuristicWrapper(
786 "AssignRandomValueToRandomVariable", kRunMore));
788 heuristic_limit_ = solver->MakeFailuresLimit(heuristic_num_failures_limit);
791 int heuristic_runs()
const {
return heuristic_runs_; }
796 struct HeuristicWrapper {
797 HeuristicWrapper(Solver*
const solver,
const std::vector<IntVar*>& vars,
800 const std::string& heuristic_name,
int heuristic_runs)
801 :
phase(solver->MakePhase(vars, var_strategy, value_strategy)),
802 name(heuristic_name),
803 runs(heuristic_runs) {}
815 std::vector<HeuristicWrapper*> heuristics_;
816 SearchMonitor* heuristic_limit_;
818 bool run_all_heuristics_;
819 std::mt19937 random_;
820 const int heuristic_period_;
821 int heuristic_branch_count_;
828 class DefaultIntegerSearch :
public DecisionBuilder {
832 DefaultIntegerSearch(Solver*
const solver,
const std::vector<IntVar*>& vars,
837 impact_recorder_(solver, &domain_watcher_, vars,
839 heuristics_(solver,
vars_, parameters_.display_level,
840 parameters_.run_all_heuristics, parameters_.random_seed,
841 parameters_.heuristic_period,
842 parameters_.heuristic_num_failures_limit),
844 last_int_var_(nullptr),
846 last_operation_(FindVar::NONE),
847 last_conflict_count_(0),
850 ~DefaultIntegerSearch()
override {}
852 Decision* Next(Solver*
const solver)
override {
855 if (heuristics_.ShouldRun()) {
859 Decision*
const decision = parameters_.decision_builder !=
nullptr
860 ? parameters_.decision_builder->Next(solver)
861 : ImpactNext(solver);
864 if (decision ==
nullptr) {
872 decision->Accept(&find_var_);
873 IntVar*
const decision_var =
874 find_var_.operation() != FindVar::NONE ? find_var_.var() :
nullptr;
884 if (parameters_.use_last_conflict && last_int_var_ !=
nullptr &&
885 !last_int_var_->Bound() &&
886 (decision_var ==
nullptr || decision_var != last_int_var_)) {
887 switch (last_operation_) {
888 case FindVar::ASSIGN: {
889 if (last_int_var_->Contains(last_int_value_)) {
890 Decision*
const assign =
891 solver->MakeAssignVariableValue(last_int_var_, last_int_value_);
893 last_conflict_count_++;
898 case FindVar::SPLIT_LOW: {
899 if (last_int_var_->Max() > last_int_value_ &&
900 last_int_var_->Min() <= last_int_value_) {
901 Decision*
const split = solver->MakeVariableLessOrEqualValue(
902 last_int_var_, last_int_value_);
904 last_conflict_count_++;
909 case FindVar::SPLIT_HIGH: {
910 if (last_int_var_->Min() < last_int_value_ &&
911 last_int_var_->Max() >= last_int_value_) {
912 Decision*
const split = solver->MakeVariableGreaterOrEqualValue(
913 last_int_var_, last_int_value_);
915 last_conflict_count_++;
926 if (parameters_.use_last_conflict) {
928 decision->Accept(&find_var_);
929 if (find_var_.operation() != FindVar::NONE) {
930 last_int_var_ = find_var_.var();
931 last_int_value_ = find_var_.value();
932 last_operation_ = find_var_.operation();
939 void ClearLastDecision() {
940 last_int_var_ =
nullptr;
942 last_operation_ = FindVar::NONE;
945 void AppendMonitors(Solver*
const solver,
946 std::vector<SearchMonitor*>*
const extras)
override {
947 CHECK(solver !=
nullptr);
948 CHECK(extras !=
nullptr);
949 if (parameters_.decision_builder ==
nullptr) {
950 extras->push_back(&impact_recorder_);
954 void Accept(ModelVisitor*
const visitor)
const override {
961 std::string DebugString()
const override {
962 std::string out =
"DefaultIntegerSearch(";
964 if (parameters_.decision_builder ==
nullptr) {
965 out.append(
"Impact Based Search, ");
967 out.append(parameters_.decision_builder->DebugString());
975 std::string StatString()
const {
976 const int runs = heuristics_.heuristic_runs();
979 if (!result.empty()) {
983 result.append(
"1 heuristic run");
985 absl::StrAppendFormat(&result,
"%d heuristic runs",
runs);
988 if (last_conflict_count_ > 0) {
989 if (!result.empty()) {
992 if (last_conflict_count_ == 1) {
993 result.append(
"1 last conflict hint");
995 absl::StrAppendFormat(&result,
"%d last conflict hints",
996 last_conflict_count_);
1003 void CheckInit(Solver*
const solver) {
1007 if (parameters_.decision_builder ==
nullptr) {
1009 for (
int i = 0; i <
vars_.size(); ++i) {
1010 if (vars_[i]->Max() - vars_[i]->Min() > 0xFFFFFF) {
1012 LOG(
INFO) <<
"Domains are too large, switching to simple "
1016 reinterpret_cast<void**
>(¶meters_.decision_builder));
1017 parameters_.decision_builder =
1020 solver->SaveAndSetValue(&init_done_,
true);
1027 LOG(
INFO) <<
"Search space is too small, switching to simple "
1031 reinterpret_cast<void**
>(¶meters_.decision_builder));
1032 parameters_.decision_builder = solver->MakePhase(
1034 solver->SaveAndSetValue(&init_done_,
true);
1039 LOG(
INFO) <<
"Init impact based search phase on " <<
vars_.size()
1040 <<
" variables, initialization splits = "
1041 << parameters_.initialization_splits
1042 <<
", heuristic_period = " << parameters_.heuristic_period
1043 <<
", run_all_heuristics = "
1044 << parameters_.run_all_heuristics;
1047 impact_recorder_.FirstRun(parameters_.initialization_splits);
1049 if (parameters_.persistent_impact) {
1052 solver->SaveAndSetValue(&init_done_,
true);
1060 Decision* ImpactNext(Solver*
const solver) {
1061 IntVar*
var =
nullptr;
1064 for (
int i = 0; i <
vars_.size(); ++i) {
1065 if (!vars_[i]->Bound()) {
1066 int64 current_value = 0;
1067 double current_var_impact = 0.0;
1068 impact_recorder_.ScanVarImpacts(i, ¤t_value, ¤t_var_impact,
1069 parameters_.var_selection_schema,
1070 parameters_.value_selection_schema);
1071 if (current_var_impact > best_var_impact) {
1073 value = current_value;
1074 best_var_impact = current_var_impact;
1078 if (
var ==
nullptr) {
1081 return solver->MakeAssignVariableValue(
var,
value);
1087 std::vector<IntVar*>
vars_;
1088 DefaultPhaseParameters parameters_;
1089 DomainWatcher domain_watcher_;
1090 ImpactRecorder impact_recorder_;
1091 RunHeuristicsAsDives heuristics_;
1093 IntVar* last_int_var_;
1094 int64 last_int_value_;
1095 FindVar::Operation last_operation_;
1096 int last_conflict_count_;
1106 DefaultIntegerSearch*
const dis =
dynamic_cast<DefaultIntegerSearch*
>(db);
1107 return dis !=
nullptr ? dis->StatString() :
"";
1116 const std::vector<IntVar*>& vars,