-
Notifications
You must be signed in to change notification settings - Fork 20
/
grouper.py
executable file
·103 lines (83 loc) · 2.52 KB
/
grouper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
Disjoint set data structure <http://code.activestate.com/recipes/387776/>
Author: Michael Droettboom
"""
class Grouper(object):
"""
This class provides a lightweight way to group arbitrary objects
together into disjoint sets when a full-blown graph data structure
would be overkill.
Objects can be joined using .join(), tested for connectedness
using .joined(), and all disjoint sets can be retrieved using list(g)
The objects being joined must be hashable.
>>> g = Grouper()
>>> g.join('a', 'b')
>>> g.join('b', 'c')
>>> g.join('d', 'e')
>>> list(g)
[['a', 'b', 'c'], ['d', 'e']]
>>> g.joined('a', 'b')
True
>>> g.joined('a', 'c')
True
>>> 'f' in g
False
>>> g.joined('a', 'd')
False
"""
def __init__(self, init=[]):
mapping = self._mapping = {}
for x in init:
mapping[x] = [x]
def join(self, a, *args):
"""
Join given arguments into the same set. Accepts one or more arguments.
"""
mapping = self._mapping
set_a = mapping.setdefault(a, [a])
for arg in args:
set_b = mapping.get(arg)
if set_b is None:
set_a.append(arg)
mapping[arg] = set_a
elif set_b is not set_a:
if len(set_b) > len(set_a):
set_a, set_b = set_b, set_a
set_a.extend(set_b)
for elem in set_b:
mapping[elem] = set_a
def joined(self, a, b):
"""
Returns True if a and b are members of the same set.
"""
mapping = self._mapping
try:
return mapping[a] is mapping[b]
except KeyError:
return False
def __iter__(self):
"""
Returns an iterator returning each of the disjoint sets as a list.
"""
seen = set()
for elem, group in self._mapping.iteritems():
if elem not in seen:
yield group
seen.update(group)
def __getitem__(self, key):
"""
Returns the set that a certain key belongs.
"""
return tuple(self._mapping[key])
def __contains__(self, key):
return key in self._mapping
def __len__(self):
group = set()
for v in self._mapping.values():
group.update([tuple(v)])
return len(group)
if __name__ == '__main__':
import doctest
doctest.testmod()