Skip to content

Instantly share code, notes, and snippets.

@odanado
Last active October 18, 2022 14:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save odanado/982884668b2c4b39d63ca820399ba24a to your computer and use it in GitHub Desktop.
Save odanado/982884668b2c4b39d63ca820399ba24a to your computer and use it in GitHub Desktop.
import jax
import jax.numpy as jnp
@jax.jit
def get_under_sample_index_one(y: jnp.ndarray):
"""
y は 0, 1 のみを取る
先頭から 0, 1 を選んでいくため y は事前にシャッフルされている必要がある
"""
labels = jnp.array([0, 1])
# 0, 1 の数を数えて少ない方を取得
cap = jnp.array([(y == 0).sum(), (y == 1).sum()]).min()
# 累積和で先頭からの 0,1 の数を数える
y_cumsum = jnp.cumsum(y == labels[:, None], axis=1)
# 0, 1 の数が cap 以下のインデックスを取得
index = (y == labels[:, None]) & (y_cumsum <= cap)
# 列方向に or を取る
index = index.T.sum(axis=1) > 0
return index
get_under_sample_index = jax.vmap(get_under_sample_index_one)
y = jnp.array([
[0, 1, 0, -1, -1],
[1, 0, 0, 1, 1],
])
index = get_under_sample_index(y)
jnp.where(index, y, -1)
# => DeviceArray([[ 0, 1, -1, -1, -1], [ 1, 0, 0, 1, -1]], dtype=int32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment