Skip to content

Instantly share code, notes, and snippets.

@RuolinZheng08
Created January 31, 2021 20:45
Show Gist options
  • Save RuolinZheng08/86518fbf0d7ef2cc7c7cb63cb49fd32b to your computer and use it in GitHub Desktop.
Save RuolinZheng08/86518fbf0d7ef2cc7c7cb63cb49fd32b to your computer and use it in GitHub Desktop.
[Python, Statistics] Permutation Test
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from itertools import combinations\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def two_sample_permutation_test(arr, start, size):\n",
" target = arr[start : start + size]\n",
" arr_counter = Counter(arr)\n",
" target_diff = sum(target) - sum((arr_counter - Counter(target)).elements())\n",
" print('target: ', target_diff)\n",
" count = 0\n",
" for curr in combinations(arr, size):\n",
" curr_counter = Counter(curr)\n",
" complement = list((arr_counter - curr_counter).elements())\n",
" diff = sum(curr) - sum(complement)\n",
" if diff >= target_diff:\n",
" print(curr, diff)\n",
" count += 1\n",
" return count"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"arr = [111, 56, 86, 92, 104, 118, 117, 111]\n",
"two_sample_permutation_test(arr, 4, 4)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(4.5, 3, 7, 6) 13.0\n",
"(4.5, 7, 6, 4.5) 16.0\n",
"(3, 7, 6, 4.5) 13.0\n"
]
},
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"two_sample_permutation_test([4.5, 0, 1, 2, 3, 7, 6, 4.5], 4, 4)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"target: 42\n",
"(25, 31, 46) 42\n",
"(25, 46, 31) 42\n",
"(31, 46, 31) 54\n"
]
},
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"two_sample_permutation_test([25, 31, 46, 10, 19, 31], 0, 3)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"def two_sample_permutation_test_two_sided(arr, start, size):\n",
" target = arr[start : start + size]\n",
" arr_counter = Counter(arr)\n",
" target_diff = abs(sum(target) - sum((arr_counter - Counter(target)).elements()))\n",
" print('target: ', target_diff)\n",
" count = 0\n",
" for curr in combinations(arr, size):\n",
" curr_counter = Counter(curr)\n",
" complement = list((arr_counter - curr_counter).elements())\n",
" diff = abs(sum(curr) - sum(complement))\n",
" if diff >= target_diff:\n",
" print(curr, diff)\n",
" count += 1\n",
" return count"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"target: 3.299999999999999\n",
"(3.4, 2.8, 1.9, 2.6) 3.299999999999999\n",
"(3.4, 2.8, 2.6, 2.4) 4.299999999999999\n",
"(3.4, 2.8, 2.6, 2.1) 3.6999999999999993\n",
"(3.4, 2.8, 2.4, 2.1) 3.299999999999999\n",
"(1.9, 2.6, 1.4, 1.5) 3.299999999999999\n",
"(1.9, 1.4, 2.4, 1.5) 3.6999999999999993\n",
"(1.9, 1.4, 2.1, 1.5) 4.299999999999999\n",
"(1.4, 2.4, 2.1, 1.5) 3.299999999999999\n"
]
},
{
"data": {
"text/plain": [
"8"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"two_sample_permutation_test_two_sided([3.4, 2.8, 1.9, 2.6, 1.4, 2.4, 2.1, 1.5], 4, 4)"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"target: 12\n",
"(8, 7, 6, 3) 12\n",
"(8, 7, 6, 5) 16\n",
"(8, 7, 6, 4) 14\n",
"(8, 7, 5, 4) 12\n",
"(6, 3, 2, 1) 12\n",
"(3, 5, 2, 1) 14\n",
"(3, 4, 2, 1) 16\n",
"(5, 4, 2, 1) 12\n"
]
},
{
"data": {
"text/plain": [
"8"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"two_sample_permutation_test_two_sided([8, 7, 6, 3, 5, 4, 2, 1], 4, 4)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def matched_pair_permutation_test(arr):\n",
" count = 0\n",
" neg_sum = sum([elm for elm in arr if elm < 0])\n",
" if neg_sum < 0: \n",
" # at least one entry negative, should count the all positive array\n",
" print('[], 0')\n",
" count += 1\n",
" arr = sorted([-abs(elm) for elm in arr], reverse=True)\n",
" \n",
" for i in range(1, len(arr)):\n",
" for nums in combinations(arr, i):\n",
" curr_sum = sum(nums)\n",
" if curr_sum >= neg_sum:\n",
" print(nums, curr_sum)\n",
" count += 1\n",
" return count"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[], 0\n",
"(-1,) -1\n",
"(-2,) -2\n",
"(-3,) -3\n",
"(-4,) -4\n",
"(-1, -2) -3\n",
"(-1, -3) -4\n"
]
},
{
"data": {
"text/plain": [
"7"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"arr = [-1, 6, 4, 6, 2, -3, 5]\n",
"matched_pair_permutation_test(arr)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[], 0\n",
"[-1.0] -1.0\n",
"[-2.0] -2.0\n",
"[-3.0] -3.0\n",
"[-4.0] -4.0\n",
"[-1.0, -2.0] -3.0\n"
]
},
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"matched_pair_permutation_test([-1.0, 6.5, 4.0, 6.5, 2.0, -3.0, 5.0])"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[], 0\n",
"[-0.25] -0.25\n",
"[-0.33] -0.33\n",
"[-0.25, -0.33] -0.5800000000000001\n"
]
},
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"matched_pair_permutation_test([1.85, -0.25, 0.88, 1.46, 1.05, 1.67, 1.74, -0.33])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[], 0\n",
"[-1] -1\n",
"[-2] -2\n",
"[-3] -3\n",
"[-1, -2] -3\n"
]
},
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"matched_pair_permutation_test([-1, -2, 3, 4, 5, 6, 7, 8])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[], 0\n",
"(-1,) -1\n",
"(-1,) -1\n",
"(-1,) -1\n"
]
},
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"matched_pair_permutation_test([1, 6, 1, 7, -1, 2, 8])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[], 0\n",
"(-2,) -2\n",
"(-2,) -2\n",
"(-2,) -2\n"
]
},
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"matched_pair_permutation_test([2, 5, 2, 6, -2, 4, 7])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment