- バブルソート
- クイックソート
- マージソート
- ...
-
-
Save hkawabata/12061820cfef20172e8a7549464995de to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class BubbleSort: | |
def sort(self, array): | |
""" | |
array : ソート対象の numpy 配列 | |
""" | |
a = array.copy() | |
n = len(a) | |
i_max = n - 2 | |
while i_max >= 0: | |
for i in range(i_max+1): | |
if a[i] > a[i+1]: | |
a[i], a[i+1] = a[i+1], a[i] | |
i_max -= 1 | |
return a | |
BubbleSort().sort(np.array([5,2,10,7,4,3,8,6,1,9])) | |
# array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class QuickSort: | |
def sort(self, array): | |
""" | |
array : ソート対象の numpy 配列 | |
""" | |
n = len(array) | |
if n < 2: | |
return array | |
pivot = array[-1] | |
result = np.full(n, pivot) | |
cnt_l = 0 | |
cnt_r = 0 | |
for i in range(n-1): | |
if array[i] < pivot: | |
cnt_l += 1 | |
result[cnt_l-1] = array[i] | |
elif pivot < array[i]: | |
cnt_r += 1 | |
result[-cnt_r] = array[i] | |
result[:cnt_l] = self.sort(result[:cnt_l]) | |
if cnt_r > 0: | |
result[-cnt_r:] = self.sort(result[-cnt_r:]) | |
return result | |
QuickSort().sort(np.array([5,2,10,7,4,3,8,6,1,9])) | |
# array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class HeapSort: | |
def sort(self, array): | |
a = array.copy() | |
for i in range(len(a))[::-1]: | |
HeapSort.down_heap(a, i) | |
i_max = len(a) - 1 | |
while i_max > 0: | |
# ヒープ先頭と末尾を入れ替える | |
a[0], a[i_max] = a[i_max], a[0] | |
i_max -= 1 | |
HeapSort.down_heap(a[:i_max+1], 0) | |
return a | |
@staticmethod | |
def down_heap(a, i_root): | |
"""左右の子ノード以下の部分木がヒープ化済みの場合にルートノード以下の木をヒープ化する""" | |
i_left = i_root*2 + 1 | |
i_right = i_left + 1 | |
if i_right < len(a): | |
# 左側・右側両方に子ノードを持つ場合 | |
i_bigger_child = i_right if a[i_left] < a[i_right] else i_left | |
if a[i_root] < a[i_bigger_child]: | |
a[i_root], a[i_bigger_child] = a[i_bigger_child], a[i_root] | |
HeapSort.down_heap(a, i_bigger_child) | |
elif i_left < len(a): | |
# 左側だけに子ノードを持つ場合 | |
if a[i_root] < a[i_left]: | |
a[i_root], a[i_left] = a[i_left], a[i_root] | |
HeapSort().sort(np.array([8, 9, 5, 4, 2, 3, 6, 7, 0, 1])) | |
# array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class MergeSort: | |
def sort(self, array): | |
n = len(array) | |
if n > 2: | |
a_left = self.sort(array[:n//2]) | |
a_right = self.sort(array[n//2:]) | |
return self.merge_sorted_arrays(a_left, a_right) | |
elif n == 2: | |
if array[0] <= array[1]: | |
return array | |
else: | |
return array[::-1] | |
elif n == 1: | |
return array | |
def merge_sorted_arrays(self, a1, a2): | |
"""ソート済みの2つの配列を1つにマージ""" | |
n1, n2 = len(a1), len(a2) | |
n = n1 + n2 | |
res = np.empty(n, dtype=a1.dtype) | |
i, i1, i2 = 0, 0, 0 | |
while i1 < n1 and i2 < n2: | |
if a1[i1] < a2[i2]: | |
res[i] = a1[i1] | |
i1 += 1 | |
i += 1 | |
else: | |
res[i] = a2[i2] | |
i2 += 1 | |
i += 1 | |
if i1 == n1: | |
res[i:] = a2[i2:] | |
else: | |
res[i:] = a1[i1:] | |
return np.array(res) | |
MergeSort().sort(np.array([8, 9, 5, 4, 2, 3, 6, 7, 0, 1])) | |
# array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class SelectionSort: | |
def sort(self, array): | |
a = array.copy() | |
n = len(a) | |
for i in range(n-1): | |
j_min = i | |
for j in range(i+1, n): | |
if a[j] < a[j_min]: | |
j_min = j | |
a[i], a[j_min] = a[j_min], a[i] | |
return a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class InsertionSort: | |
def sort(self, array): | |
a = array.copy() | |
n = len(a) | |
for i in range(1, n): | |
for j in range(0, i): | |
if a[i] <= a[j]: | |
tmp = a[i] | |
a[j+1:i+1] = a[j:i] | |
a[j] = tmp | |
return a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class ArrayGenerator: | |
@staticmethod | |
def random(n): | |
a = np.array(range(n)) | |
np.random.shuffle(a) | |
return a | |
@staticmethod | |
def asc(n): | |
return np.array(range(n)) | |
@staticmethod | |
def desc(n): | |
return np.array(range(n))[::-1] | |
@staticmethod | |
def asc_random(n, w=2): | |
a = ArrayGenerator.asc(n) | |
for i in range(n-w+1): | |
np.random.shuffle(a[i:i+w]) | |
return a | |
@staticmethod | |
def desc_random(n, w=2): | |
a = ArrayGenerator.desc(n) | |
for i in range(n-w+1): | |
np.random.shuffle(a[i:i+w]) | |
return a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from matplotlib import pyplot as plt | |
def fit_x_line(x, y): | |
""" | |
最小二乗法を用いて | |
y = ax | |
に fitting し、パラメータ a を求める | |
""" | |
a = (y*x).sum() / (x**2).sum() | |
y_fit = a * x | |
return y_fit | |
def fit_x_square(x, y): | |
""" | |
最小二乗法を用いて | |
y = ax^2 | |
に fitting し、パラメータ a を求める | |
""" | |
a = (y*x**2).sum() / (x**4).sum() | |
y_fit = a * x**2 | |
return y_fit | |
def fit_x_log_x(x, y): | |
""" | |
最小二乗法を用いて | |
y = ax log(bx) = ax log(x) + cx (c = a log(b)) | |
に fitting し、パラメータ a, b を求める | |
""" | |
g1 = x * np.log(x) | |
g2 = x | |
A = np.matrix([ | |
[(g1*g1).sum(), (g1*g2).sum()], | |
[(g2*g1).sum(), (g2*g2).sum()] | |
]) | |
v = np.matrix([(g1*y).sum(), (g2*y).sum()]).T | |
params = np.linalg.inv(A) * v | |
a, c = params[0,0], params[1,0] | |
b = np.exp(c/a) | |
y_fit = a * x * np.log(b*x) | |
return y_fit | |
def draw_computing_order(x, y, yerr): | |
y_fit_line = fit_x_line(x, y) | |
y_fit_square = fit_x_square(x, y) | |
y_fit_nlogn = fit_x_log_x(x, y) | |
plt.xlabel('Array Length $n$') | |
plt.ylabel('Computing Time [sec]') | |
plt.plot(x, y_fit_line, label=r'$O(n)$') | |
plt.plot(x, y_fit_square, label=r'$O(n^2)$') | |
plt.plot(x, y_fit_nlogn, label=r'$O(n \log n)$') | |
plt.errorbar(x, y, yerr=yerr, | |
capsize=3, fmt='o', markersize=3, | |
ecolor='black', markeredgecolor = "black", color='w', | |
label=r'measured value ($\pm \sigma$)') | |
plt.legend() | |
plt.grid() | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import time | |
def experiment_computing_time(sort_algorithm, Ns, T=10): | |
""" | |
sort_algorithm : numpy 配列をソートする sort メソッドを持つクラスのインスタンス | |
Ns : 試行する配列要素数の集合 | |
T : 同じ要素数の配列に対して速度の計測を何度繰り返すか | |
""" | |
time_ave = [] | |
time_std = [] | |
for N in Ns: | |
times = [] | |
for _ in range(T): | |
a = np.array(range(N)) | |
np.random.shuffle(a) | |
start = time.time() | |
a_sorted = sort_algorithm.sort(a) | |
end = time.time() | |
times.append(end-start) | |
time_ave.append(np.mean(times)) | |
time_std.append(np.std(times)) | |
return np.array(time_ave), np.array(time_std) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def test_sort_algorithm(sort_argorithm, N=7): | |
""" | |
全ての並び変えパターンを入力として、正しくソートできるかテスト | |
sort_algorithm : numpy 配列をソートする sort メソッドを持つクラスのインスタンス | |
""" | |
ans = list(range(1,N+1)) | |
def get_all_patterns(array): | |
"""与えられた配列の全ての並び変えパターンを返す補助関数""" | |
if len(array) == 1: | |
return [array] | |
result = [] | |
for i in range(len(array)): | |
head = array[i] | |
sub_arrays = get_all_patterns(array[:i] + array[i+1:]) | |
for sub_array in sub_arrays: | |
result.append([head] + sub_array) | |
return result | |
all_patterns = get_all_patterns(ans) | |
cnt = 0 | |
for a in all_patterns: | |
a_sorted = sort_argorithm.sort(np.array(a)) | |
if list(a_sorted) != ans: | |
# 不正解ならエラー終了 | |
raise Exception('Test failed! Sorted result is wrong:\n{} --> {}'.format(a, a_sorted)) | |
elif (cnt+1)%(len(all_patterns)//10) == 0: | |
# 10件程度の実例を出力 | |
print('OK: {} --> {}'.format(a, a_sorted)) | |
cnt += 1 | |
print('All {} tests passed.'.format(cnt)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment