Skip to content

Instantly share code, notes, and snippets.

@Atlas7
Last active May 23, 2017 10:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Atlas7/e49a517615d944855773a5e873c9d612 to your computer and use it in GitHub Desktop.
Save Atlas7/e49a517615d944855773a5e873c9d612 to your computer and use it in GitHub Desktop.
numpy_fancy.py
# (Challenge 3)
import numpy as np
# sow a fix seed to make trial and error more predictable
np.random.seed(0)
# create a 10 x 3 NumPy array
a = np.random.rand(10,3)
# array([[ 0.5488135 , 0.71518937, 0.60276338],
# [ 0.54488318, 0.4236548 , 0.64589411],
# [ 0.43758721, 0.891773 , 0.96366276],
# [ 0.38344152, 0.79172504, 0.52889492],
# [ 0.56804456, 0.92559664, 0.07103606],
# [ 0.0871293 , 0.0202184 , 0.83261985],
# [ 0.77815675, 0.87001215, 0.97861834],
# [ 0.79915856, 0.46147936, 0.78052918],
# [ 0.11827443, 0.63992102, 0.14335329],
# [ 0.94466892, 0.52184832, 0.41466194]])
# For each element, compute distance to 0.5
# (small means closer. big mean further)
b = np.abs(a - 0.5)
# array([[ 0.0488135 , 0.21518937, 0.10276338],
# [ 0.04488318, 0.0763452 , 0.14589411],
# [ 0.06241279, 0.391773 , 0.46366276],
# [ 0.11655848, 0.29172504, 0.02889492],
# [ 0.06804456, 0.42559664, 0.42896394],
# [ 0.4128707 , 0.4797816 , 0.33261985],
# [ 0.27815675, 0.37001215, 0.47861834],
# [ 0.29915856, 0.03852064, 0.28052918],
# [ 0.38172557, 0.13992102, 0.35664671],
# [ 0.44466892, 0.02184832, 0.08533806]])
# return the indices that would sort an array (closest to furtheset from 0.5)
ind = np.argsort(b)
# array([[0, 2, 1],
# [0, 1, 2],
# [0, 1, 2],
# [2, 0, 1],
# [0, 1, 2],
# [2, 0, 1],
# [0, 1, 2],
# [1, 2, 0],
# [1, 2, 0],
# [1, 2, 0]])
# for each row, we are only interested in the column that is closest to 0.5
ind2 = ind[:,0]
# array([0, 0, 0, 2, 0, 2, 0, 1, 1, 1])
# Note that we could have use argmin() and skip the `ind` step.
# i.e. the following goes straight from `b` to `ind2`, skipping `ind`
# ind2 = np.argmin(b, axis=-1)
# -> array([0, 0, 0, 2, 0, 2, 0, 1, 1, 1])
# See? same result
# do the fancy indexing: for each row, extract the element that is closest to 0.5
a2 = a[np.arange(a.shape[0]),ind2]
# array([ 0.5488135 , 0.54488318, 0.43758721, 0.52889492, 0.56804456,
# 0.83261985, 0.77815675, 0.46147936, 0.63992102, 0.52184832])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment