Skip to content

Instantly share code, notes, and snippets.

@hkawabata
Last active December 15, 2023 06:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hkawabata/12061820cfef20172e8a7549464995de to your computer and use it in GitHub Desktop.
Save hkawabata/12061820cfef20172e8a7549464995de to your computer and use it in GitHub Desktop.

ソートアルゴリズム

  • バブルソート
  • クイックソート
  • マージソート
  • ...
import numpy as np
class BubbleSort:
def sort(self, array):
"""
array : ソート対象の numpy 配列
"""
a = array.copy()
n = len(a)
i_max = n - 2
while i_max >= 0:
for i in range(i_max+1):
if a[i] > a[i+1]:
a[i], a[i+1] = a[i+1], a[i]
i_max -= 1
return a
BubbleSort().sort(np.array([5,2,10,7,4,3,8,6,1,9]))
# array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
import numpy as np
class QuickSort:
def sort(self, array):
"""
array : ソート対象の numpy 配列
"""
n = len(array)
if n < 2:
return array
pivot = array[-1]
result = np.full(n, pivot)
cnt_l = 0
cnt_r = 0
for i in range(n-1):
if array[i] < pivot:
cnt_l += 1
result[cnt_l-1] = array[i]
elif pivot < array[i]:
cnt_r += 1
result[-cnt_r] = array[i]
result[:cnt_l] = self.sort(result[:cnt_l])
if cnt_r > 0:
result[-cnt_r:] = self.sort(result[-cnt_r:])
return result
QuickSort().sort(np.array([5,2,10,7,4,3,8,6,1,9]))
# array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
import numpy as np
class HeapSort:
def sort(self, array):
a = array.copy()
for i in range(len(a))[::-1]:
HeapSort.down_heap(a, i)
i_max = len(a) - 1
while i_max > 0:
# ヒープ先頭と末尾を入れ替える
a[0], a[i_max] = a[i_max], a[0]
i_max -= 1
HeapSort.down_heap(a[:i_max+1], 0)
return a
@staticmethod
def down_heap(a, i_root):
"""左右の子ノード以下の部分木がヒープ化済みの場合にルートノード以下の木をヒープ化する"""
i_left = i_root*2 + 1
i_right = i_left + 1
if i_right < len(a):
# 左側・右側両方に子ノードを持つ場合
i_bigger_child = i_right if a[i_left] < a[i_right] else i_left
if a[i_root] < a[i_bigger_child]:
a[i_root], a[i_bigger_child] = a[i_bigger_child], a[i_root]
HeapSort.down_heap(a, i_bigger_child)
elif i_left < len(a):
# 左側だけに子ノードを持つ場合
if a[i_root] < a[i_left]:
a[i_root], a[i_left] = a[i_left], a[i_root]
HeapSort().sort(np.array([8, 9, 5, 4, 2, 3, 6, 7, 0, 1]))
# array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
import numpy as np
class MergeSort:
def sort(self, array):
n = len(array)
if n > 2:
a_left = self.sort(array[:n//2])
a_right = self.sort(array[n//2:])
return self.merge_sorted_arrays(a_left, a_right)
elif n == 2:
if array[0] <= array[1]:
return array
else:
return array[::-1]
elif n == 1:
return array
def merge_sorted_arrays(self, a1, a2):
"""ソート済みの2つの配列を1つにマージ"""
n1, n2 = len(a1), len(a2)
n = n1 + n2
res = np.empty(n, dtype=a1.dtype)
i, i1, i2 = 0, 0, 0
while i1 < n1 and i2 < n2:
if a1[i1] < a2[i2]:
res[i] = a1[i1]
i1 += 1
i += 1
else:
res[i] = a2[i2]
i2 += 1
i += 1
if i1 == n1:
res[i:] = a2[i2:]
else:
res[i:] = a1[i1:]
return np.array(res)
MergeSort().sort(np.array([8, 9, 5, 4, 2, 3, 6, 7, 0, 1]))
# array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
import numpy as np
class SelectionSort:
def sort(self, array):
a = array.copy()
n = len(a)
for i in range(n-1):
j_min = i
for j in range(i+1, n):
if a[j] < a[j_min]:
j_min = j
a[i], a[j_min] = a[j_min], a[i]
return a
import numpy as np
class InsertionSort:
def sort(self, array):
a = array.copy()
n = len(a)
for i in range(1, n):
for j in range(0, i):
if a[i] <= a[j]:
tmp = a[i]
a[j+1:i+1] = a[j:i]
a[j] = tmp
return a
import numpy as np
class ArrayGenerator:
@staticmethod
def random(n):
a = np.array(range(n))
np.random.shuffle(a)
return a
@staticmethod
def asc(n):
return np.array(range(n))
@staticmethod
def desc(n):
return np.array(range(n))[::-1]
@staticmethod
def asc_random(n, w=2):
a = ArrayGenerator.asc(n)
for i in range(n-w+1):
np.random.shuffle(a[i:i+w])
return a
@staticmethod
def desc_random(n, w=2):
a = ArrayGenerator.desc(n)
for i in range(n-w+1):
np.random.shuffle(a[i:i+w])
return a
import numpy as np
from matplotlib import pyplot as plt
def fit_x_line(x, y):
"""
最小二乗法を用いて
y = ax
に fitting し、パラメータ a を求める
"""
a = (y*x).sum() / (x**2).sum()
y_fit = a * x
return y_fit
def fit_x_square(x, y):
"""
最小二乗法を用いて
y = ax^2
に fitting し、パラメータ a を求める
"""
a = (y*x**2).sum() / (x**4).sum()
y_fit = a * x**2
return y_fit
def fit_x_log_x(x, y):
"""
最小二乗法を用いて
y = ax log(bx) = ax log(x) + cx (c = a log(b))
に fitting し、パラメータ a, b を求める
"""
g1 = x * np.log(x)
g2 = x
A = np.matrix([
[(g1*g1).sum(), (g1*g2).sum()],
[(g2*g1).sum(), (g2*g2).sum()]
])
v = np.matrix([(g1*y).sum(), (g2*y).sum()]).T
params = np.linalg.inv(A) * v
a, c = params[0,0], params[1,0]
b = np.exp(c/a)
y_fit = a * x * np.log(b*x)
return y_fit
def draw_computing_order(x, y, yerr):
y_fit_line = fit_x_line(x, y)
y_fit_square = fit_x_square(x, y)
y_fit_nlogn = fit_x_log_x(x, y)
plt.xlabel('Array Length $n$')
plt.ylabel('Computing Time [sec]')
plt.plot(x, y_fit_line, label=r'$O(n)$')
plt.plot(x, y_fit_square, label=r'$O(n^2)$')
plt.plot(x, y_fit_nlogn, label=r'$O(n \log n)$')
plt.errorbar(x, y, yerr=yerr,
capsize=3, fmt='o', markersize=3,
ecolor='black', markeredgecolor = "black", color='w',
label=r'measured value ($\pm \sigma$)')
plt.legend()
plt.grid()
plt.show()
import numpy as np
import time
def experiment_computing_time(sort_algorithm, Ns, T=10):
"""
sort_algorithm : numpy 配列をソートする sort メソッドを持つクラスのインスタンス
Ns : 試行する配列要素数の集合
T : 同じ要素数の配列に対して速度の計測を何度繰り返すか
"""
time_ave = []
time_std = []
for N in Ns:
times = []
for _ in range(T):
a = np.array(range(N))
np.random.shuffle(a)
start = time.time()
a_sorted = sort_algorithm.sort(a)
end = time.time()
times.append(end-start)
time_ave.append(np.mean(times))
time_std.append(np.std(times))
return np.array(time_ave), np.array(time_std)
import numpy as np
def test_sort_algorithm(sort_argorithm, N=7):
"""
全ての並び変えパターンを入力として、正しくソートできるかテスト
sort_algorithm : numpy 配列をソートする sort メソッドを持つクラスのインスタンス
"""
ans = list(range(1,N+1))
def get_all_patterns(array):
"""与えられた配列の全ての並び変えパターンを返す補助関数"""
if len(array) == 1:
return [array]
result = []
for i in range(len(array)):
head = array[i]
sub_arrays = get_all_patterns(array[:i] + array[i+1:])
for sub_array in sub_arrays:
result.append([head] + sub_array)
return result
all_patterns = get_all_patterns(ans)
cnt = 0
for a in all_patterns:
a_sorted = sort_argorithm.sort(np.array(a))
if list(a_sorted) != ans:
# 不正解ならエラー終了
raise Exception('Test failed! Sorted result is wrong:\n{} --> {}'.format(a, a_sorted))
elif (cnt+1)%(len(all_patterns)//10) == 0:
# 10件程度の実例を出力
print('OK: {} --> {}'.format(a, a_sorted))
cnt += 1
print('All {} tests passed.'.format(cnt))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment