Last active
February 1, 2019 01:19
-
-
Save shidenggui/933213ff4d1abfa142923ed766544112 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
有个 list 内嵌了 set, 结构如下: | |
[ | |
{1,2,3} | |
{2,3} | |
{3,5} | |
{10,15} | |
] | |
定义:如果一个 set 和另外的 set 有重复的元素,则表示这几个 set 是同一 group | |
对于输入: | |
[ | |
{1,2,3} | |
{2,3} | |
{3,5} | |
{10,15} | |
] | |
输出应该是: | |
[ | |
{1,2,3,5} | |
{10,15} | |
] | |
写一个合并 group 操作时遇到的这个问题,感觉和 leecode 上的 friend circle 有点像,但是他的是矩阵,我的是嵌套的结构。 求组大佬指点一二或者提醒一下类似的算法,我去借鉴一下解决思路,不用直接给我答案,谢谢~ | |
""" | |
class DisjointSet: | |
def __init__(self, size): | |
self.arr = [-1] * size | |
def find(self, s1): | |
if self.arr[s1] < 0: | |
return s1 | |
self.arr[s1] = self.find(self.arr[s1]) | |
return self.arr[s1] | |
def union(self, s1, s2): | |
s1 = self.find(s1) | |
s2 = self.find(s2) | |
if s1 == s2: | |
return s1 | |
if self.arr[s1] < self.arr[s2]: | |
self.arr[s2] = s1 | |
return s1 | |
if self.arr[s1] == self.arr[s2]: | |
self.arr[s2] -= 1 | |
self.arr[s1] = s2 | |
return s2 | |
from collections import defaultdict | |
class Solution: | |
MAX_SIZE = 16 | |
def merge(self, lists): | |
disjoint_set = DisjointSet(self.MAX_SIZE) | |
set_idx = [None] * len(lists) | |
for i, s in enumerate(lists): | |
s_iter = iter(s) | |
base = next(s_iter) | |
for elem in s_iter: | |
set_idx[i] = disjoint_set.union(base, elem) | |
set_idx = [disjoint_set.find(s) for s in set_idx] | |
groups = defaultdict(set) | |
for i, items in zip(set_idx, lists): | |
groups[i].update(items) | |
return list(groups.values()) | |
def test(): | |
test = [ | |
{1,2,3}, | |
{2,3}, | |
{3,5}, | |
{10,15}, | |
] | |
result = Solution().merge(test) | |
print(result) | |
# output: | |
# [{1, 2, 3, 5}, {10, 15}] | |
test = [ | |
{1,2,3}, | |
{6,7}, | |
{3,6}, | |
] | |
result = Solution().merge(test) | |
print(result) | |
# output: | |
# [{1, 2, 3, 6, 7}] | |
test() | |
def test_performance(): | |
import numpy as np | |
import time | |
max_size = 1000 | |
def generate_test(n): | |
li = np.random.randint(max_size, size=(n, 1000)).tolist() | |
li = [set(_) for _ in li] | |
return li, sum(len(_) for _ in li) | |
print('union-find:') | |
solution = Solution() | |
solution.MAX_SIZE = max_size | |
for i in range(6): | |
rows = 100 * 2 ** i | |
test_data, number = generate_test(rows) | |
start_time = time.time() | |
solution.merge(test_data) | |
usage = int((time.time() - start_time) * 1000) | |
print('K rows: {} N: {} time: {} T(N)/O(N): {}'.format(rows, number, usage, int(usage / number * 1e6))) | |
test_performance() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment