Skip to content

Instantly share code, notes, and snippets.

@cdfmlr

cdfmlr/sort.c Secret

Created August 20, 2021 08:20
Show Gist options
  • Save cdfmlr/07b9b4880de4f3a457f1aa90eacb55ad to your computer and use it in GitHub Desktop.
Save cdfmlr/07b9b4880de4f3a457f1aa90eacb55ad to your computer and use it in GitHub Desktop.
Sort algorithms in C
#include <stdio.h>
#include <stdlib.h>
#pragma region PrintArray
// 写这个 region 和 endregion 是为了让 VS Code 把这部分代码折叠起来,
// VS code 自动不能折叠这东西
// helper: 打印数组的
#define print_array(array, offset, length, formatter) \
{ \
printf("["); \
for (int i = offset; i < length; ++i) { \
printf(formatter, array[i]); \
if (i < length - 1) \
printf(", "); \
} \
printf("]"); \
}
#pragma endregion PrintArray
#pragma region SortAlgorithms
// sort 是排序算法的接口:对数组 A 的前 n 个元素进行原址排序
typedef void (*sort)(int A[], int n);
// insert_sort 直接插入排序
//
// 遍历,往前找到合适的位置,逐个元素后移腾出空间,插入进去。
//
// 复杂度:
// - 时间 O(n^2)
// - 空间 O(1)
void
insert_sort(int A[], int n)
{
for (int i = 0; i < n; ++i) {
int curr = A[i];
int j = i - 1;
while (j >= 0 && curr < A[j]) {
A[j + 1] = A[j];
--j;
}
A[j + 1] = curr;
}
}
// binary_insert_sort 二分插入排序
//
// 就是直接插入里往前找合适位置那里用个二分查找
//
// 复杂度
// - 时间:较好 O(n*log(n)),较坏 O(n^2),平均 O(n^2)
// - 空间 O(1)
void
binary_insert_sort(int A[], int n)
{
for (int i = 0; i < n; ++i) {
int curr = A[i];
// 二分查找
int l = 0, r = i - 1;
while (r >= 0 && l < i && l <= r) {
int m = (l + r) / 2;
if (curr < A[m]) {
r = m - 1;
} else {
l = m + 1;
}
}
// 后移
for (int j = i - 1; j >= l; --j) {
A[j + 1] = A[j];
}
// 插入
A[l] = curr;
}
}
#pragma region ShellSort
// Shell 步长序列: n/2^i,最坏情况下时间复杂度 O(n^2)
#define for_shell_gap(gap, n, f) \
for (gap = n >> 1; gap > 0; gap >>= 1) \
f;
// Papernov-Stasevich 步长序列:2^k-1,最坏情况下时间复杂度 O(n^(3/2))
// 总之,用这些奇怪的步长序列,最后一步务必为 1.
#pragma region PapernovStasevichGaps
// 这个要在宏里调用,一定要 static 不然编译出错
// 如果不用优化,加 always_inline 才能保证 inline
// https://gcc.gnu.org/onlinedocs/gcc/Inline.html
static inline int
ps_start(int) __attribute__((always_inline));
// papernov_stasevich_start find the max n s.t. (2^k+1) < n
// return 2^k+1
static inline int
ps_start(int n)
{
int k = 1, nn = n - 1;
while (nn > 0 && nn != 1) {
nn >>= 1;
k <<= 1;
}
return k + 1;
}
static inline int
ps_next(int) __attribute__((always_inline));
// 2^k+1, ..., 65, 33, 17, 9, 5, 3, 1
static inline int
ps_next(int gap)
{
switch (gap) {
case 1:
return -1; // to stop
case 3:
return 1;
default:
return ((gap - 1) >> 1) + 1;
}
}
#define for_papernov_stasevich_gap(gap, n, f) \
for (gap = ps_start(n); gap > 0; gap = ps_next(gap)) \
f;
#pragma endregion PapernovStasevichGaps
// 遍历 shell 排序的 gaps (步长|增量)
//
// usage: foreach_gaps(gap, n, {/* do something for each gap in gaps*/})
//
// 也可以用 for_papernov_stasevich_gap
#define foreach_gaps for_papernov_stasevich_gap
// 希尔排序
//
// 递减增量(gap)的排序算法
// (非稳定)
// 空间复杂度是 O(1),时间复杂度依赖于步长序列
void
shell_sort(int A[], int n)
{
int gap;
foreach_gaps(gap, n, {
// 最外层循环里面就是个插入排序
for (int i = gap; i < n; i++) {
int curr = A[i];
int j = i - gap;
for (; j >= 0 && A[j] > curr; j -= gap) {
A[j + gap] = A[j];
}
A[j + gap] = curr;
}
});
}
#pragma endregion ShellSort
// 原址交换数组 A 中下标 i 与 j 的值
static inline void
swap(int A[], int i, int j)
{
// int tmp = A[i];
// A[i] = A[j];
// A[j] = tmp;
// 还是用 xor swap 方便
if (i != j) { // ⚠️ 这个一定要有,有坑
// A[i] = A[i] ^ A[j];
// A[j] = A[i] ^ A[j];
// A[i] = A[i] ^ A[j];
// 左 iji,右 jij
A[i] ^= A[j];
A[j] ^= A[i];
A[i] ^= A[j];
}
// 同一个内存地址用 xor swap 就会爆炸(变成 0 ):
// x = x ^ x -> x = 0
// x = x ^ x -> x = 0
// x = x ^ x -> x = 0
}
// bubble_sort 冒泡排序:交换排序
//
// 从后向前形成序列(i=n-1...1):
// 每次从 0 检查至 i-1,后一个比前一个大的就交换一下:冒泡
//
// 时间复杂度:最坏 O(n^2),最好 O(n)
// 空间复杂度:O(1)
void
bubble_sort(int A[], int n)
{
for (int i = n - 1; i > 0; i--) {
int swap_flag = 0;
for (int j = 0; j < i; j++) {
if (A[j] > A[j + 1]) {
swap(A, j, j + 1);
swap_flag = 1;
}
}
if (!swap_flag) { // 这一轮都没交换,已经有序了,提前结束
return;
}
}
}
#pragma region QuickSort
int
partition(int A[], int l, int r);
int
partition_place(int A[], int l, int r);
int
partition_swap(int A[], int l, int r);
void
quick_sort(int A[], int l, int r);
// quick_sort 对 A[l...r] (闭区间) 快速排序
//
// 选个轴,序列中比轴小的放轴左边,比轴大的在其右边
// 然后把轴左右分别做两个子序列,递归。
//
// 快排:序列越无序越快
//
// 时间复杂度:最好 O(n*log(n)),最坏 O(n^2),平均 O(n*log(n))
// 空间复杂度 O(log(n))
void
quick_sort(int A[], int l, int r)
{
if (l < r) {
int p = partition(A, l, r);
quick_sort(A, l, p - 1);
quick_sort(A, p + 1, r);
}
}
// quick_sort_all 对整个长度为 n 的序列 A 执行快速排序
void
quick_sort_all(int A[], int n)
{
return quick_sort(A, 0, n - 1);
}
// partition 做快排的交换工作
//
// 以第一个元素 A[l] 为轴
// 序列中比轴小的放轴左边,比轴大的在其右边
//
// 这个有两种实现,一种是 partition_place,一种是 partition_swap
// 无论是从好理解还是从方便记,我都喜欢后者。
//
// 返回轴的索引
int
partition(int A[], int l, int r)
{
// return partition_place(A, l, r);
return partition_swap(A, l, r);
}
int
partition_place(int A[], int l, int r)
{
int pv = A[l];
while (l < r) {
while (r > l && A[r] > pv)
--r;
if (l < r)
A[l++] = A[r];
while (l < r && A[l] < pv)
++l;
if (l < r)
A[r--] = A[l];
}
A[l] = pv;
return l;
}
int
partition_swap(int A[], int l, int r)
{
int p = l;
for (int i = l; i < r; i++) {
if (A[i] <= A[r]) {
swap(A, p, i);
p++;
}
}
swap(A, p, r);
return p;
}
#pragma endregion QuickSort
// select_sort 选择排序
//
// 从下标 0 到 n,每个位置 i 选择 A[i...n] (闭区间)里最小的一个放上去。
//
// 时间复杂度 O(n^2)
// 空间复杂度 O(1)
void
select_sort(int A[], int n)
{
for (int i = 0; i < n; i++) {
int smallest = i;
for (int j = i + 1; j < n; j++) {
if (A[j] < A[smallest]) {
smallest = j;
}
}
swap(A, i, smallest);
}
}
#pragma region HeapSort
void
sift(int A[], int l, int r);
void
max_heapify(int A[], int root, int heap_size);
// heap_sort 堆排序
//
// 把序列搞成个大根堆(根结点比页大的那种)
// 然后依次出根节点,重新调整堆
//
// 时间复杂度 O(n * log n)
// 空间复杂度 O(1)
void
heap_sort(int A[], int n)
{
// 这个 sift 的我看不懂
// // 建立堆
// int heap_size = n;
// for (int i = n / 2 - 1; i >= 0; i--) {
// sift(A, i, n - 1);
// }
// // 调整堆,依次出根节点
// for (int i = n - 1; i >= 0; i--) {
// swap(A, 0, i);
// sift(A, 0, i - 1);
// }
// 下面用 CLRS 里面的 maxHeapify,这个容易理解
// 建立堆
for (int i = n / 2 - 1; i >= 0; i--) {
max_heapify(A, i, n - 1);
}
// 调整堆,依次出根节点
for (int i = n - 1; i >= 0; i--) {
swap(A, 0, i);
max_heapify(A, 0, i - 1);
}
}
// sift 调整,建堆
void
sift(int A[], int l, int r)
{
int i = l, j = 2 * i + 1;
int root = A[l];
while (j <= r) {
if (j + 1 <= r && A[j] < A[j + 1])
++j;
if (A[j] > root) {
A[i] = A[j];
i = j;
j = 2 * i + 1;
} else {
break;
}
}
A[i] = root;
}
// max_heapify 建立大根堆的调整函数
void
max_heapify(int A[], int i, int n)
{
// l, r are the children of root i
int l = (i << 1) + 1;
int r = l + 1;
// printf(" (%d=%d %d=%d %d=%d) ", i, A[i], l, A[l], r, A[r]); // debug
// max = max_idx(A[root], A[l], A[r])
int max = i;
if (l <= n && A[l] > A[max])
max = l;
if (r <= n && A[r] > A[max])
max = r;
// 更改根,继续向后调整
if (max != i) {
swap(A, i, max);
max_heapify(A, max, n);
}
}
#pragma endregion HeapSort
#pragma region MergeSort
// merge 归并:把数组的 A[l: m+1] 和 A[m: r+1] 两个已排序部分按升序合并
// 先备份数组,顺序从左右两半中选出较小者,放入原数组,完成归并
// 需要辅助数组,空间复杂度 O(n)
void
merge(int A[], int l, int m, int r)
{
const int INF = 1 << 30;
// 先备份左右两半切片:
// b = A[l: m+1] + [INF]
int n1 = (m - l + 1) + 1;
int* b = malloc(sizeof(int) * n1);
for (int i = 0, j = l; j <= m; i++, j++) {
b[i] = A[j];
}
b[n1 - 1] = INF;
// c = A[m: r+1] + [INF]
int n2 = (r - m) + 1;
int* c = malloc(sizeof(int) * n2);
for (int i = 0, j = m + 1; j <= r; i++, j++) {
c[i] = A[j];
}
c[n2 - 1] = INF;
// 从左右切片里逐个取小的出来,凑成排序数组:
int i = 0, j = 0;
for (int k = l; k <= r; k++) {
if (b[i] <= c[j]) {
A[k] = b[i++];
} else { // c[i] < b[i]
A[k] = c[j++];
}
}
}
// merge_sort 归并排序
//
// 递归完成左右两半的排序,然后归并
//
// 时间复杂度 O(n * log(n))
// 空间复杂度 O(n)
void
merge_sort(int A[], int l, int r)
{
if (l >= r)
return;
int m = (l + r) / 2;
merge_sort(A, l, m);
merge_sort(A, m + 1, r);
merge(A, l, m, r);
}
void
merge_sort_all(int A[], int n)
{
merge_sort(A, 0, n - 1);
}
#pragma endregion MergeSort
#pragma endregion SortAlgorithms
#pragma region Tests
// test a sort algorithm
//
// return 0: passed
// -1: failed, errors will be loged.
int
test(sort algo)
{
const int n = 11;
const int sorted[] = { 1, 2, 3, 4, 5, 6, 6, 6, 7, 8, 9 };
int s[] = { 2, 1, 3, 6, 7, 9, 6, 8, 5, 4, 6 };
// run sort
algo(s, n);
// check
for (int i = 0; i < n; i++) {
if (s[i] != sorted[i]) {
printf("FAILED: %p\n", algo);
printf(" excepted:");
print_array(sorted, 0, n, "%d");
printf("\n got:");
print_array(s, 0, n, "%d");
printf("\n----\n");
return -1;
}
}
printf("PASS: %p\n", algo);
return 0;
}
void
test_sort_algos()
{
int algo_num = 8;
sort algo[] = { insert_sort, binary_insert_sort, shell_sort,
bubble_sort, quick_sort_all, select_sort,
heap_sort, merge_sort_all };
int _resultsum = 0;
for (int i = 0; i < algo_num; ++i) {
_resultsum += test(algo[i]);
}
if (_resultsum != 0) {
printf("SOME FAILED!\n");
} else {
printf("ALL PASSED!\n");
}
}
int
main()
{
test_sort_algos();
return 0;
}
#pragma endregion Tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment