-
-
Save Atlas7/e49a517615d944855773a5e873c9d612 to your computer and use it in GitHub Desktop.
numpy_fancy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# (Challenge 3) | |
import numpy as np | |
# sow a fix seed to make trial and error more predictable | |
np.random.seed(0) | |
# create a 10 x 3 NumPy array | |
a = np.random.rand(10,3) | |
# array([[ 0.5488135 , 0.71518937, 0.60276338], | |
# [ 0.54488318, 0.4236548 , 0.64589411], | |
# [ 0.43758721, 0.891773 , 0.96366276], | |
# [ 0.38344152, 0.79172504, 0.52889492], | |
# [ 0.56804456, 0.92559664, 0.07103606], | |
# [ 0.0871293 , 0.0202184 , 0.83261985], | |
# [ 0.77815675, 0.87001215, 0.97861834], | |
# [ 0.79915856, 0.46147936, 0.78052918], | |
# [ 0.11827443, 0.63992102, 0.14335329], | |
# [ 0.94466892, 0.52184832, 0.41466194]]) | |
# For each element, compute distance to 0.5 | |
# (small means closer. big mean further) | |
b = np.abs(a - 0.5) | |
# array([[ 0.0488135 , 0.21518937, 0.10276338], | |
# [ 0.04488318, 0.0763452 , 0.14589411], | |
# [ 0.06241279, 0.391773 , 0.46366276], | |
# [ 0.11655848, 0.29172504, 0.02889492], | |
# [ 0.06804456, 0.42559664, 0.42896394], | |
# [ 0.4128707 , 0.4797816 , 0.33261985], | |
# [ 0.27815675, 0.37001215, 0.47861834], | |
# [ 0.29915856, 0.03852064, 0.28052918], | |
# [ 0.38172557, 0.13992102, 0.35664671], | |
# [ 0.44466892, 0.02184832, 0.08533806]]) | |
# return the indices that would sort an array (closest to furtheset from 0.5) | |
ind = np.argsort(b) | |
# array([[0, 2, 1], | |
# [0, 1, 2], | |
# [0, 1, 2], | |
# [2, 0, 1], | |
# [0, 1, 2], | |
# [2, 0, 1], | |
# [0, 1, 2], | |
# [1, 2, 0], | |
# [1, 2, 0], | |
# [1, 2, 0]]) | |
# for each row, we are only interested in the column that is closest to 0.5 | |
ind2 = ind[:,0] | |
# array([0, 0, 0, 2, 0, 2, 0, 1, 1, 1]) | |
# Note that we could have use argmin() and skip the `ind` step. | |
# i.e. the following goes straight from `b` to `ind2`, skipping `ind` | |
# ind2 = np.argmin(b, axis=-1) | |
# -> array([0, 0, 0, 2, 0, 2, 0, 1, 1, 1]) | |
# See? same result | |
# do the fancy indexing: for each row, extract the element that is closest to 0.5 | |
a2 = a[np.arange(a.shape[0]),ind2] | |
# array([ 0.5488135 , 0.54488318, 0.43758721, 0.52889492, 0.56804456, | |
# 0.83261985, 0.77815675, 0.46147936, 0.63992102, 0.52184832]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment