Skip to content

Instantly share code, notes, and snippets.

@ivankeller
Created May 22, 2019 11:59
Show Gist options
  • Save ivankeller/4b26b1b5092465da17a43504982f59ab to your computer and use it in GitHub Desktop.
Save ivankeller/4b26b1b5092465da17a43504982f59ab to your computer and use it in GitHub Desktop.
sub sample two-class sets into subsets given a ratio of items
import pytest
import numpy as np
def r_cut(x0: float, x1: float, r: float) -> (float, float):
"""Return y0, y1 such that y1/y0 = r, y0<x0, y1<x1 and y0+y1 is maximum.
Parameters
----------
x0, x1 > 0
r >= 0 and <= 1
Returns
-------
(float, float)
"""
R = x1 / (x0 + x1) # actual ratio
if r == 0:
return x0, 0
if r == 1:
return 0, x1
if r <= R:
y0 = x0
y1 = y0 * r/(1-r)
else:
y1 = x1
y0 = y1 * (1-r)/r
return y0, y1
def r_sample(set0, set1, r):
"""Return a subset of two sets with a given ratio of items.
Given two sets of items corresponding to 2 classes,
return the length of a subset of each one such that #class1 / (#class0 + #class1) == r,
maximizing the number of available items in total.
Parameters
----------
set0, set1 : sequence of objects to sample
r : float
Returns
-------
(int, int)
"""
x0, x1 = len(set0), len(set1)
return tuple([int(value) for value in r_cut(x0, x1, r)])
@pytest.mark.parametrize("test_input, expected", [
((10, 5, 0), (10, 0)), # special case r == 0
((10, 5, 1), (0, 5)), # special case r == 1
((10, 5, 0.5), (5, 5)), # x0 > x1 and r > R
((5, 10, 0.5), (5, 5)), # x0 < x1 and r < R
((9, 10, 0.1), (9, 1)), # x0 < x1 and r < R
((10, 9, 0.1), (10, 10/9)), # x0 > x1 and r < R
((5, 10, 0.9), (10/9, 10)) # x0 < x1 and r > R
])
def test_r_cut(test_input, expected):
np.testing.assert_allclose(r_cut(*test_input), expected)
@pytest.mark.parametrize("test_input, expected", [
((10*[1], 5*[1], 0), (10, 0)),
((10*[1], 5*[1], 1), (0, 5)),
((10*[1], 5*[1], 0.5), (5, 5)),
((5*[1], 10*[1], 0.5), (5, 5)),
((9*[1], 10*[1], 0.1), (9, 1)),
((10*[1], 9*[1], 0.1), (10, 1)),
((5*[1], 10*[1], 0.9), (1, 10))
])
def test_r_sample(test_input, expected):
assert r_sample(*test_input) == expected
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment