-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
linked_hash_map.h
636 lines (541 loc) · 21.7 KB
/
linked_hash_map.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
// Copyright 2010-2024 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This is a simplistic insertion-ordered map. It behaves similarly to an STL
// map, but only implements a small subset of the map's methods. Internally, we
// just keep a map and a list going in parallel.
//
// This class provides no thread safety guarantees, beyond what you would
// normally see with std::list.
//
// Iterators point into the list and should be stable in the face of
// mutations, except for an iterator pointing to an element that was just
// deleted.
//
// This class supports heterogeneous lookups.
//
#ifndef OR_TOOLS_BASE_LINKED_HASH_MAP_H_
#define OR_TOOLS_BASE_LINKED_HASH_MAP_H_
#include <list>
#include <tuple>
#include <type_traits>
#include <utility>
#include "absl/container/flat_hash_set.h"
#include "absl/container/internal/common.h"
#include "ortools/base/logging.h"
namespace gtl {
// This holds a list of pair<Key, Value> items. This list is what gets
// traversed, and it's iterators from this list that we return from
// begin/end/find.
//
// We also keep a set<list::iterator> for find. Since std::list is a
// doubly-linked list, the iterators should remain stable.
template <typename Key, typename Value,
typename KeyHash = typename absl::flat_hash_set<Key>::hasher,
typename KeyEq =
typename absl::flat_hash_set<Key, KeyHash>::key_equal,
typename Alloc = std::allocator<std::pair<const Key, Value>>>
class linked_hash_map {
using KeyArgImpl = absl::container_internal::KeyArg<
absl::container_internal::IsTransparent<KeyEq>::value &&
absl::container_internal::IsTransparent<KeyHash>::value>;
// Alias used for heterogeneous lookup functions.
// `key_arg<K>` evaluates to `K` when the functors are transparent and to
// `key_type` otherwise. It permits template argument deduction on `K` for the
// transparent case.
template <class K>
using key_arg = typename KeyArgImpl::template type<K, Key>;
public:
using key_type = Key;
using mapped_type = Value;
using hasher = KeyHash;
using key_equal = KeyEq;
using value_type = std::pair<const key_type, mapped_type>;
using allocator_type = Alloc;
using difference_type = ptrdiff_t;
private:
using ListType = std::list<value_type, Alloc>;
template <class Fn>
class Wrapped {
template <typename K>
static const K& ToKey(const K& k) {
return k;
}
static const key_type& ToKey(typename ListType::const_iterator it) {
return it->first;
}
static const key_type& ToKey(typename ListType::iterator it) {
return it->first;
}
Fn fn_;
friend linked_hash_map;
public:
using is_transparent = void;
Wrapped() = default;
explicit Wrapped(Fn fn) : fn_(std::move(fn)) {}
template <class... Args>
auto operator()(Args&&... args) const
-> decltype(this->fn_(ToKey(args)...)) {
return fn_(ToKey(args)...);
}
};
using SetType =
absl::flat_hash_set<typename ListType::iterator, Wrapped<hasher>,
Wrapped<key_equal>, Alloc>;
class NodeHandle {
public:
using key_type = linked_hash_map::key_type;
using mapped_type = linked_hash_map::mapped_type;
using allocator_type = linked_hash_map::allocator_type;
constexpr NodeHandle() noexcept = default;
NodeHandle(NodeHandle&& nh) noexcept = default;
~NodeHandle() = default;
NodeHandle& operator=(NodeHandle&& node) noexcept = default;
bool empty() const noexcept { return list_.empty(); }
explicit operator bool() const noexcept { return !empty(); }
allocator_type get_allocator() const { return list_.get_allocator(); }
const key_type& key() const { return list_.front().first; }
mapped_type& mapped() { return list_.front().second; }
void swap(NodeHandle& nh) noexcept { list_.swap(nh.list_); }
private:
friend linked_hash_map;
explicit NodeHandle(ListType list) : list_(std::move(list)) {}
ListType list_;
};
template <class Iterator, class NodeType>
struct InsertReturnType {
Iterator position;
bool inserted;
NodeType node;
};
public:
using iterator = typename ListType::iterator;
using const_iterator = typename ListType::const_iterator;
using reverse_iterator = typename ListType::reverse_iterator;
using const_reverse_iterator = typename ListType::const_reverse_iterator;
using reference = typename ListType::reference;
using const_reference = typename ListType::const_reference;
using size_type = typename ListType::size_type;
using pointer = typename std::allocator_traits<allocator_type>::pointer;
using const_pointer =
typename std::allocator_traits<allocator_type>::const_pointer;
using node_type = NodeHandle;
using insert_return_type = InsertReturnType<iterator, node_type>;
linked_hash_map() = default;
explicit linked_hash_map(size_t bucket_count, const hasher& hash = hasher(),
const key_equal& eq = key_equal(),
const allocator_type& alloc = allocator_type())
: set_(bucket_count, Wrapped<hasher>(hash), Wrapped<key_equal>(eq),
alloc),
list_(alloc) {}
linked_hash_map(size_t bucket_count, const hasher& hash,
const allocator_type& alloc)
: linked_hash_map(bucket_count, hash, key_equal(), alloc) {}
linked_hash_map(size_t bucket_count, const allocator_type& alloc)
: linked_hash_map(bucket_count, hasher(), key_equal(), alloc) {}
explicit linked_hash_map(const allocator_type& alloc)
: linked_hash_map(0, hasher(), key_equal(), alloc) {}
template <class InputIt>
linked_hash_map(InputIt first, InputIt last, size_t bucket_count = 0,
const hasher& hash = hasher(),
const key_equal& eq = key_equal(),
const allocator_type& alloc = allocator_type())
: linked_hash_map(bucket_count, hash, eq, alloc) {
insert(first, last);
}
template <class InputIt>
linked_hash_map(InputIt first, InputIt last, size_t bucket_count,
const hasher& hash, const allocator_type& alloc)
: linked_hash_map(first, last, bucket_count, hash, key_equal(), alloc) {}
template <class InputIt>
linked_hash_map(InputIt first, InputIt last, size_t bucket_count,
const allocator_type& alloc)
: linked_hash_map(first, last, bucket_count, hasher(), key_equal(),
alloc) {}
template <class InputIt>
linked_hash_map(InputIt first, InputIt last, const allocator_type& alloc)
: linked_hash_map(first, last, /*bucket_count=*/0, hasher(), key_equal(),
alloc) {}
linked_hash_map(std::initializer_list<value_type> init,
size_t bucket_count = 0, const hasher& hash = hasher(),
const key_equal& eq = key_equal(),
const allocator_type& alloc = allocator_type())
: linked_hash_map(init.begin(), init.end(), bucket_count, hash, eq,
alloc) {}
linked_hash_map(std::initializer_list<value_type> init, size_t bucket_count,
const hasher& hash, const allocator_type& alloc)
: linked_hash_map(init, bucket_count, hash, key_equal(), alloc) {}
linked_hash_map(std::initializer_list<value_type> init, size_t bucket_count,
const allocator_type& alloc)
: linked_hash_map(init, bucket_count, hasher(), key_equal(), alloc) {}
linked_hash_map(std::initializer_list<value_type> init,
const allocator_type& alloc)
: linked_hash_map(init, /*bucket_count=*/0, hasher(), key_equal(),
alloc) {}
linked_hash_map(const linked_hash_map& other)
: linked_hash_map(other.bucket_count(), other.hash_function(),
other.key_eq(), other.get_allocator()) {
CopyFrom(other);
}
linked_hash_map(const linked_hash_map& other, const allocator_type& alloc)
: linked_hash_map(other.bucket_count(), other.hash_function(),
other.key_eq(), alloc) {
CopyFrom(other);
}
linked_hash_map(linked_hash_map&& other) noexcept
: set_(std::move(other.set_)), list_(std::move(other.list_)) {
// Since the list and set must agree for other to end up "valid",
// explicitly clear them.
other.set_.clear();
other.list_.clear();
}
linked_hash_map(linked_hash_map&& other, const allocator_type& alloc)
: linked_hash_map(0, other.hash_function(), other.key_eq(), alloc) {
if (get_allocator() == other.get_allocator()) {
*this = std::move(other);
} else {
CopyFrom(std::move(other));
}
}
linked_hash_map& operator=(const linked_hash_map& other) {
if (this == &other) return *this;
// Make a new set, with other's hash/eq/alloc.
set_ = SetType(other.bucket_count(), other.set_.hash_function(),
other.set_.key_eq(), other.get_allocator());
// Copy the list, with other's allocator.
list_ = ListType(other.get_allocator());
CopyFrom(other);
return *this;
}
linked_hash_map& operator=(linked_hash_map&& other) noexcept {
// underlying containers will handle progagate_on_container_move details
set_ = std::move(other.set_);
list_ = std::move(other.list_);
other.set_.clear();
other.list_.clear();
return *this;
}
linked_hash_map& operator=(std::initializer_list<value_type> values) {
clear();
insert(values.begin(), values.end());
return *this;
}
// Derive size_ from set_, as list::size might be O(N).
size_type size() const { return set_.size(); }
size_type max_size() const noexcept { return ~size_type{}; }
bool empty() const { return set_.empty(); }
// Iteration is list-like, in insertion order.
// These are all forwarded.
iterator begin() { return list_.begin(); }
iterator end() { return list_.end(); }
const_iterator begin() const { return list_.begin(); }
const_iterator end() const { return list_.end(); }
const_iterator cbegin() const { return list_.cbegin(); }
const_iterator cend() const { return list_.cend(); }
reverse_iterator rbegin() { return list_.rbegin(); }
reverse_iterator rend() { return list_.rend(); }
const_reverse_iterator rbegin() const { return list_.rbegin(); }
const_reverse_iterator rend() const { return list_.rend(); }
const_reverse_iterator crbegin() const { return list_.crbegin(); }
const_reverse_iterator crend() const { return list_.crend(); }
reference front() { return list_.front(); }
reference back() { return list_.back(); }
const_reference front() const { return list_.front(); }
const_reference back() const { return list_.back(); }
void pop_front() { erase(begin()); }
void pop_back() { erase(std::prev(end())); }
ABSL_ATTRIBUTE_REINITIALIZES void clear() {
set_.clear();
list_.clear();
}
void reserve(size_t n) { set_.reserve(n); }
size_t capacity() const { return set_.capacity(); }
size_t bucket_count() const { return set_.bucket_count(); }
float load_factor() const { return set_.load_factor(); }
hasher hash_function() const { return set_.hash_function().fn_; }
key_equal key_eq() const { return set_.key_eq().fn_; }
allocator_type get_allocator() const { return list_.get_allocator(); }
template <class K = key_type>
size_type erase(const key_arg<K>& key) {
auto found = set_.find(key);
if (found == set_.end()) return 0;
auto list_it = *found;
// Erase set entry first since it refers to the list element.
set_.erase(found);
list_.erase(list_it);
return 1;
}
iterator erase(const_iterator position) {
auto found = set_.find(position);
CHECK(*found == position) << "Inconsistent iterator for set and list, "
"or the iterator is invalid.";
set_.erase(found);
return list_.erase(position);
}
iterator erase(iterator position) {
return erase(static_cast<const_iterator>(position));
}
iterator erase(iterator first, iterator last) {
while (first != last) first = erase(first);
return first;
}
iterator erase(const_iterator first, const_iterator last) {
while (first != last) first = erase(first);
if (first == end()) return end();
return *set_.find(first);
}
template <class K = key_type>
iterator find(const key_arg<K>& key) {
auto found = set_.find(key);
if (found == set_.end()) return end();
return *found;
}
template <class K = key_type>
const_iterator find(const key_arg<K>& key) const {
auto found = set_.find(key);
if (found == set_.end()) return end();
return *found;
}
template <class K = key_type>
size_type count(const key_arg<K>& key) const {
return contains(key) ? 1 : 0;
}
template <class K = key_type>
bool contains(const key_arg<K>& key) const {
return set_.contains(key);
}
template <class K = key_type>
mapped_type& at(const key_arg<K>& key) {
auto it = find(key);
if (ABSL_PREDICT_FALSE(it == end())) {
LOG(FATAL) << "linked_hash_map::at failed bounds check";
}
return it->second;
}
template <class K = key_type>
const mapped_type& at(const key_arg<K>& key) const {
return const_cast<linked_hash_map*>(this)->at(key);
}
template <class K = key_type>
std::pair<iterator, iterator> equal_range(const key_arg<K>& key) {
auto iter = set_.find(key);
if (iter == set_.end()) return {end(), end()};
return {*iter, std::next(*iter)};
}
template <class K = key_type>
std::pair<const_iterator, const_iterator> equal_range(
const key_arg<K>& key) const {
auto iter = set_.find(key);
if (iter == set_.end()) return {end(), end()};
return {*iter, std::next(*iter)};
}
template <class K = key_type>
mapped_type& operator[](const key_arg<K>& key) {
return LazyEmplaceInternal(key).first->second;
}
template <class K = key_type, K* = nullptr>
mapped_type& operator[](key_arg<K>&& key) {
return LazyEmplaceInternal(std::forward<K>(key)).first->second;
}
std::pair<iterator, bool> insert(const value_type& v) {
return InsertInternal(v);
}
std::pair<iterator, bool> insert(value_type&& v) { // NOLINT(build/c++11)
return InsertInternal(std::move(v));
}
iterator insert(const_iterator, const value_type& v) {
return insert(v).first;
}
iterator insert(const_iterator, value_type&& v) {
return insert(std::move(v)).first;
}
void insert(std::initializer_list<value_type> ilist) {
insert(ilist.begin(), ilist.end());
}
template <class InputIt>
void insert(InputIt first, InputIt last) {
for (; first != last; ++first) insert(*first);
}
insert_return_type insert(node_type&& node) {
if (!node) return {end(), false, node_type()};
auto itr = find(node.key());
if (itr != end()) return {itr, false, std::move(node)};
list_.splice(list_.end(), node.list_);
set_.insert(--list_.end());
return {--list_.end(), true, node_type()};
}
iterator insert(const_iterator, node_type&& node) {
return insert(std::move(node)).first;
}
// The last two template parameters ensure that both arguments are rvalues
// (lvalue arguments are handled by the overloads below). This is necessary
// for supporting bitfield arguments.
//
// union { int n : 1; };
// linked_hash_map<int, int> m;
// m.insert_or_assign(n, n);
template <class K = key_type, class V = mapped_type, K* = nullptr,
V* = nullptr>
std::pair<iterator, bool> insert_or_assign(key_arg<K>&& k, V&& v) {
return InsertOrAssignInternal(std::forward<K>(k), std::forward<V>(v));
}
template <class K = key_type, class V = mapped_type, K* = nullptr>
std::pair<iterator, bool> insert_or_assign(key_arg<K>&& k, const V& v) {
return InsertOrAssignInternal(std::forward<K>(k), v);
}
template <class K = key_type, class V = mapped_type, V* = nullptr>
std::pair<iterator, bool> insert_or_assign(const key_arg<K>& k, V&& v) {
return InsertOrAssignInternal(k, std::forward<V>(v));
}
template <class K = key_type, class V = mapped_type>
std::pair<iterator, bool> insert_or_assign(const key_arg<K>& k, const V& v) {
return InsertOrAssignInternal(k, v);
}
template <class K = key_type, class V = mapped_type, K* = nullptr,
V* = nullptr>
iterator insert_or_assign(const_iterator, key_arg<K>&& k, V&& v) {
return insert_or_assign(std::forward<K>(k), std::forward<V>(v)).first;
}
template <class K = key_type, class V = mapped_type, K* = nullptr>
iterator insert_or_assign(const_iterator, key_arg<K>&& k, const V& v) {
return insert_or_assign(std::forward<K>(k), v).first;
}
template <class K = key_type, class V = mapped_type, V* = nullptr>
iterator insert_or_assign(const_iterator, const key_arg<K>& k, V&& v) {
return insert_or_assign(k, std::forward<V>(v)).first;
}
template <class K = key_type, class V = mapped_type>
iterator insert_or_assign(const_iterator, const key_arg<K>& k, const V& v) {
return insert_or_assign(k, v).first;
}
template <typename... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
ListType node_donor;
auto list_iter =
node_donor.emplace(node_donor.end(), std::forward<Args>(args)...);
auto ins = set_.insert(list_iter);
if (!ins.second) return {*ins.first, false};
list_.splice(list_.end(), node_donor, list_iter);
return {list_iter, true};
}
template <class K = key_type, class... Args, K* = nullptr>
iterator try_emplace(const_iterator, key_arg<K>&& k, Args&&... args) {
return try_emplace(std::forward<K>(k), std::forward<Args>(args)...).first;
}
template <typename... Args>
iterator emplace_hint(const_iterator, Args&&... args) {
return emplace(std::forward<Args>(args)...).first;
}
template <class K = key_type, typename... Args, K* = nullptr>
std::pair<iterator, bool> try_emplace(key_arg<K>&& key, Args&&... args) {
return LazyEmplaceInternal(std::forward<key_arg<K>>(key),
std::forward<Args>(args)...);
}
template <typename H, typename E>
void merge(linked_hash_map<Key, Value, H, E, Alloc>& src) {
auto itr = src.list_.begin();
while (itr != src.list_.end()) {
if (contains(itr->first)) {
++itr;
} else {
insert(src.extract(itr++));
}
}
}
template <typename H, typename E>
void merge(linked_hash_map<Key, Value, H, E, Alloc>&& src) {
merge(src);
}
node_type extract(const_iterator position) {
set_.erase(position->first);
ListType extracted_node_list;
extracted_node_list.splice(extracted_node_list.end(), list_, position);
return node_type(std::move(extracted_node_list));
}
template <class K = key_type,
std::enable_if_t<!std::is_same_v<K, iterator>, int> = 0>
node_type extract(const key_arg<K>& key) {
auto it = find(key);
return it == end() ? node_type() : extract(const_iterator{it});
}
template <class K = key_type, typename... Args>
std::pair<iterator, bool> try_emplace(const key_arg<K>& key, Args&&... args) {
return LazyEmplaceInternal(key, std::forward<Args>(args)...);
}
void swap(linked_hash_map& other) {
using std::swap;
swap(set_, other.set_);
swap(list_, other.list_);
}
friend bool operator==(const linked_hash_map& a, const linked_hash_map& b) {
if (a.size() != b.size()) return false;
const linked_hash_map* outer = &a;
const linked_hash_map* inner = &b;
if (outer->capacity() > inner->capacity()) std::swap(outer, inner);
for (const value_type& elem : *outer) {
auto it = inner->find(elem.first);
if (it == inner->end()) return false;
if (it->second != elem.second) return false;
}
return true;
}
friend bool operator!=(const linked_hash_map& a, const linked_hash_map& b) {
return !(a == b);
}
void rehash(size_t n) { set_.rehash(n); }
private:
template <typename Other>
void CopyFrom(Other&& other) {
for (auto& elem : other.list_) {
set_.insert(list_.insert(list_.end(), std::move(elem)));
}
DCHECK_EQ(set_.size(), list_.size()) << "Set and list are inconsistent.";
}
template <typename U>
std::pair<iterator, bool> InsertInternal(U&& pair) { // NOLINT(build/c++11)
auto iter = set_.find(pair.first);
if (iter != set_.end()) return {*iter, false};
auto list_iter = list_.insert(list_.end(), std::forward<U>(pair));
auto inserted = set_.insert(list_iter);
DCHECK(inserted.second);
return {list_iter, true};
}
template <class K, class V>
std::pair<iterator, bool> InsertOrAssignInternal(K&& k, V&& v) {
auto iter = set_.find(k);
if (iter != set_.end()) {
(*iter)->second = std::forward<V>(v);
return {*iter, false};
}
return LazyEmplaceInternal(std::forward<K>(k), std::forward<V>(v));
}
template <typename K, typename... Args>
std::pair<iterator, bool> LazyEmplaceInternal(K&& key, Args&&... args) {
bool constructed = false;
auto set_iter =
set_.lazy_emplace(key, [this, &constructed, &key, &args...](auto ctor) {
auto list_iter =
list_.emplace(list_.end(), std::piecewise_construct,
std::forward_as_tuple(std::forward<K>(key)),
std::forward_as_tuple(std::forward<Args>(args)...));
constructed = true;
ctor(list_iter);
});
return {*set_iter, constructed};
}
// The set component, used for speedy lookups.
SetType set_;
// The list component, used for maintaining insertion order.
ListType list_;
};
} // namespace gtl
#endif // OR_TOOLS_BASE_LINKED_HASH_MAP_H_