18 #include "absl/strings/str_format.h"
19 #include "absl/strings/str_join.h"
25 uint64 FprintOfInt32(
int i) {
33 element_.assign(num_elements, -1);
34 index_of_.assign(num_elements, -1);
35 for (
int i = 0; i < num_elements; ++i) {
39 part_of_.assign(num_elements, 0);
41 for (
int i = 0; i < num_elements; ++i) fprint ^= FprintOfInt32(i);
42 part_.push_back(Part(0, num_elements,
48 const std::vector<int>& initial_part_of_element) {
49 if (initial_part_of_element.empty())
return;
50 part_of_ = initial_part_of_element;
51 const int n = part_of_.size();
52 const int num_parts = 1 + *std::max_element(part_of_.begin(), part_of_.end());
53 DCHECK_EQ(0, *std::min_element(part_of_.begin(), part_of_.end()));
54 part_.resize(num_parts);
57 for (
int i = 0; i < n; ++i) part_[part_of_[i]].fprint ^= FprintOfInt32(i);
62 for (
int p = 0; p < num_parts; ++p) {
63 part_[p].end_index = 0;
64 part_[p].parent_part = p;
66 for (
const int p : part_of_) ++part_[p].end_index;
67 int sum_part_sizes = 0;
68 for (
int p = 0; p < num_parts; ++p) {
69 part_[p].start_index = sum_part_sizes;
70 sum_part_sizes += part_[p].end_index;
76 for (Part& part : part_) part.end_index = part.start_index;
77 element_.assign(n, -1);
78 index_of_.assign(n, -1);
79 for (
int element = 0; element < n; ++element) {
80 Part*
const part = &part_[part_of_[element]];
81 element_[part->end_index] = element;
82 index_of_[element] = part->end_index;
91 for (
int p = 1; p <
NumParts(); ++p) {
92 DCHECK_EQ(part_[p - 1].end_index, part_[p].start_index);
99 tmp_counter_of_part_.resize(
NumParts(), 0);
101 tmp_affected_parts_.clear();
102 for (
const int element : distinguished_subset) {
105 const int part = part_of_[element];
106 const int num_distinguished_elements_in_part = ++tmp_counter_of_part_[part];
108 if (num_distinguished_elements_in_part == 1) {
110 tmp_affected_parts_.push_back(part);
113 const int old_index = index_of_[element];
114 const int new_index =
115 part_[part].end_index - num_distinguished_elements_in_part;
117 <<
"Duplicate element given to Refine(): " << element;
119 index_of_[element] = new_index;
120 index_of_[element_[new_index]] = old_index;
121 std::swap(element_[old_index], element_[new_index]);
127 std::sort(tmp_affected_parts_.begin(), tmp_affected_parts_.end());
131 for (
const int part : tmp_affected_parts_) {
132 const int start_index = part_[part].start_index;
133 const int end_index = part_[part].end_index;
134 const int split_index = end_index - tmp_counter_of_part_[part];
135 tmp_counter_of_part_[part] = 0;
140 if (split_index == start_index)
continue;
144 for (
int i = split_index; i < end_index; ++i) {
145 new_fprint ^= FprintOfInt32(element_[i]);
151 part_[part].end_index = split_index;
152 part_[part].fprint ^= new_fprint;
153 part_.push_back(Part( split_index, end_index,
156 part_of_[element] = new_part;
164 while (
NumParts() > original_num_parts) {
165 const int part_index =
NumParts() - 1;
166 const Part& part = part_[part_index];
167 const int parent_part_index = part.parent_part;
168 DCHECK_LT(parent_part_index, part_index) <<
"UndoRefineUntilNumPartsEqual()"
170 "'original_num_parts' too low";
174 part_of_[element] = parent_part_index;
176 Part*
const parent_part = &part_[parent_part_index];
177 DCHECK_EQ(part.start_index, parent_part->end_index);
178 parent_part->end_index = part.end_index;
179 parent_part->fprint ^= part.fprint;
186 return absl::StrFormat(
"Unsupported sorting: %d", sorting);
188 std::vector<std::vector<int>> parts;
189 for (
int i = 0; i <
NumParts(); ++i) {
191 parts.emplace_back(iterable_part.
begin(), iterable_part.
end());
192 std::sort(parts.back().begin(), parts.back().end());
195 std::sort(parts.begin(), parts.end());
198 for (
const std::vector<int>& part : parts) {
199 if (!out.empty()) out +=
" | ";
200 out += absl::StrJoin(part,
" ");
207 part_size_.assign(num_nodes, 1);
208 parent_.assign(num_nodes, -1);
209 for (
int i = 0; i < num_nodes; ++i) parent_[i] = i;
210 tmp_part_bit_.assign(num_nodes,
false);
220 if (root1 == root2)
return -1;
221 int s1 = part_size_[root1];
222 int s2 = part_size_[root2];
224 if (s1 < s2 || (s1 == s2 && root1 > root2)) {
225 std::swap(root1, root2);
231 part_size_[root1] += part_size_[root2];
232 SetParentAlongPathToRoot(node1, root1);
233 SetParentAlongPathToRoot(node2, root1);
240 const int root =
GetRoot(node);
241 SetParentAlongPathToRoot(node, root);
246 int num_nodes_kept = 0;
247 for (
const int node : *nodes) {
251 (*nodes)[num_nodes_kept++] = node;
254 nodes->resize(num_nodes_kept);
258 for (
const int node : *nodes) tmp_part_bit_[
GetRoot(node)] =
false;
262 std::vector<int>* node_equivalence_classes) {
263 node_equivalence_classes->assign(
NumNodes(), -1);
265 for (
int node = 0; node <
NumNodes(); ++node) {
267 if ((*node_equivalence_classes)[root] < 0) {
268 (*node_equivalence_classes)[root] = num_roots;
271 (*node_equivalence_classes)[node] = (*node_equivalence_classes)[root];
277 std::vector<std::vector<int>> sorted_parts(
NumNodes());
278 for (
int i = 0; i <
NumNodes(); ++i) {
281 for (std::vector<int>& part : sorted_parts)
282 std::sort(part.begin(), part.end());
283 std::sort(sorted_parts.begin(), sorted_parts.end());
287 for (
const std::vector<int>& part : sorted_parts) {
288 if (!out.empty()) out +=
" | ";
289 out += absl::StrJoin(part,
" ");