Skip to content

Instantly share code, notes, and snippets.

@shidenggui
Last active February 1, 2019 01:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shidenggui/933213ff4d1abfa142923ed766544112 to your computer and use it in GitHub Desktop.
Save shidenggui/933213ff4d1abfa142923ed766544112 to your computer and use it in GitHub Desktop.
"""
有个 list 内嵌了 set, 结构如下:
[
{1,2,3}
{2,3}
{3,5}
{10,15}
]
定义:如果一个 set 和另外的 set 有重复的元素,则表示这几个 set 是同一 group
对于输入:
[
{1,2,3}
{2,3}
{3,5}
{10,15}
]
输出应该是:
[
{1,2,3,5}
{10,15}
]
写一个合并 group 操作时遇到的这个问题,感觉和 leecode 上的 friend circle 有点像,但是他的是矩阵,我的是嵌套的结构。 求组大佬指点一二或者提醒一下类似的算法,我去借鉴一下解决思路,不用直接给我答案,谢谢~
"""
class DisjointSet:
def __init__(self, size):
self.arr = [-1] * size
def find(self, s1):
if self.arr[s1] < 0:
return s1
self.arr[s1] = self.find(self.arr[s1])
return self.arr[s1]
def union(self, s1, s2):
s1 = self.find(s1)
s2 = self.find(s2)
if s1 == s2:
return s1
if self.arr[s1] < self.arr[s2]:
self.arr[s2] = s1
return s1
if self.arr[s1] == self.arr[s2]:
self.arr[s2] -= 1
self.arr[s1] = s2
return s2
from collections import defaultdict
class Solution:
MAX_SIZE = 16
def merge(self, lists):
disjoint_set = DisjointSet(self.MAX_SIZE)
set_idx = [None] * len(lists)
for i, s in enumerate(lists):
s_iter = iter(s)
base = next(s_iter)
for elem in s_iter:
set_idx[i] = disjoint_set.union(base, elem)
set_idx = [disjoint_set.find(s) for s in set_idx]
groups = defaultdict(set)
for i, items in zip(set_idx, lists):
groups[i].update(items)
return list(groups.values())
def test():
test = [
{1,2,3},
{2,3},
{3,5},
{10,15},
]
result = Solution().merge(test)
print(result)
# output:
# [{1, 2, 3, 5}, {10, 15}]
test = [
{1,2,3},
{6,7},
{3,6},
]
result = Solution().merge(test)
print(result)
# output:
# [{1, 2, 3, 6, 7}]
test()
def test_performance():
import numpy as np
import time
max_size = 1000
def generate_test(n):
li = np.random.randint(max_size, size=(n, 1000)).tolist()
li = [set(_) for _ in li]
return li, sum(len(_) for _ in li)
print('union-find:')
solution = Solution()
solution.MAX_SIZE = max_size
for i in range(6):
rows = 100 * 2 ** i
test_data, number = generate_test(rows)
start_time = time.time()
solution.merge(test_data)
usage = int((time.time() - start_time) * 1000)
print('K rows: {} N: {} time: {} T(N)/O(N): {}'.format(rows, number, usage, int(usage / number * 1e6)))
test_performance()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment