Skip to content

Instantly share code, notes, and snippets.

@levimcclenny
Last active January 12, 2021 20:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save levimcclenny/e87dd0979e339ea89a9885ec05fe7c10 to your computer and use it in GitHub Desktop.
Save levimcclenny/e87dd0979e339ea89a9885ec05fe7c10 to your computer and use it in GitHub Desktop.
multidimensional meshgrid, can be used for 3D meshgrids or beyond. Takes a list of linspaces as input, one for each desired dimension, then creates an array of all possible combinations of those points. Useful for plotting functions or, when reshaped, inputting all points on a grid into a neural network
# Higher order meshgrid that takes a list of np.linspace type object and greates an output array
# Created to generate meshes of input points for PINN training
# Adopted from https://stackoverflow.com/questions/1827489/numpy-meshgrid-in-3d
# and modified to work in N dimensions, as well as fixed for python3
def multimesh(arrs):
lens = list(map(len, arrs))
dim = len(arrs)
print(arrs)
print(list(lens))
print(dim)
sz = 1
for s in lens:
sz*=s
ans = []
for i, arr in enumerate(arrs):
slc = [1]*dim
slc[i] = lens[i]
arr2 = np.asarray(arr).reshape(slc)
for j, sz in enumerate(lens):
if j!=i:
arr2 = arr2.repeat(sz, axis=j)
ans.append(arr2)
return ans #returns like np.meshgrid
# if desired, this flattens and hstacks the output dimensions for feeding into a tf/keras type neural network
def flatten_and_stack(mesh):
dims = np.shape(mesh)
output = np.zeros((len(mesh), np.prod(dims[1:])))
for i, arr in enumerate(mesh):
output[i] = arr.flatten()
return output #returns in an [nxm] matrix
# used in the following way
x = np.array((1,2,3,4))
y = np.array((5,6,7,8))
z = np.array((9,10,11,12))
array_list = [x,y,z]
out = multimesh(array_list)
# out # np.meshgrid-like
# [array([[[1, 1, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1],
# [1, 1, 1, 1]],
# [[2, 2, 2, 2],
# [2, 2, 2, 2],
# [2, 2, 2, 2],
# [2, 2, 2, 2]],
# [[3, 3, 3, 3],
# [3, 3, 3, 3],
# [3, 3, 3, 3],
# [3, 3, 3, 3]],
# [[4, 4, 4, 4],
# [4, 4, 4, 4],
# [4, 4, 4, 4],
# [4, 4, 4, 4]]]), array([[[5, 5, 5, 5],
# [6, 6, 6, 6],
# [7, 7, 7, 7],
# [8, 8, 8, 8]],
# [[5, 5, 5, 5],
# [6, 6, 6, 6],
# [7, 7, 7, 7],
# [8, 8, 8, 8]],
# [[5, 5, 5, 5],
# [6, 6, 6, 6],
# [7, 7, 7, 7],
# [8, 8, 8, 8]],
# [[5, 5, 5, 5],
# [6, 6, 6, 6],
# [7, 7, 7, 7],
# [8, 8, 8, 8]]]), array([[[ 9, 10, 11, 12],
# [ 9, 10, 11, 12],
# [ 9, 10, 11, 12],
# [ 9, 10, 11, 12]],
# [[ 9, 10, 11, 12],
# [ 9, 10, 11, 12],
# [ 9, 10, 11, 12],
# [ 9, 10, 11, 12]],
# [[ 9, 10, 11, 12],
# [ 9, 10, 11, 12],
# [ 9, 10, 11, 12],
# [ 9, 10, 11, 12]],
# [[ 9, 10, 11, 12],
# [ 9, 10, 11, 12],
# [ 9, 10, 11, 12],
# [ 9, 10, 11, 12]]])]
nn_input = flatten_and_stack(out)
# nn_input
# array([[ 1., 5., 9.],
# [ 1., 5., 10.],
# [ 1., 5., 11.],
# [ 1., 5., 12.],
# [ 1., 6., 9.],
# [ 1., 6., 10.],
# [ 1., 6., 11.],
# [ 1., 6., 12.],
# [ 1., 7., 9.],
# [ 1., 7., 10.],
# [ 1., 7., 11.],
# [ 1., 7., 12.],
# [ 1., 8., 9.],
# [ 1., 8., 10.],
# [ 1., 8., 11.],
# [ 1., 8., 12.],
# [ 2., 5., 9.],
# [ 2., 5., 10.],
# [ 2., 5., 11.],
# [ 2., 5., 12.],
# [ 2., 6., 9.],
# [ 2., 6., 10.],
# [ 2., 6., 11.],
# [ 2., 6., 12.],
# [ 2., 7., 9.],
# [ 2., 7., 10.],
# [ 2., 7., 11.],
# [ 2., 7., 12.],
# [ 2., 8., 9.],
# [ 2., 8., 10.],
# [ 2., 8., 11.],
# [ 2., 8., 12.],
# [ 3., 5., 9.],
# [ 3., 5., 10.],
# [ 3., 5., 11.],
# [ 3., 5., 12.],
# [ 3., 6., 9.],
# [ 3., 6., 10.],
# [ 3., 6., 11.],
# [ 3., 6., 12.],
# [ 3., 7., 9.],
# [ 3., 7., 10.],
# [ 3., 7., 11.],
# [ 3., 7., 12.],
# [ 3., 8., 9.],
# [ 3., 8., 10.],
# [ 3., 8., 11.],
# [ 3., 8., 12.],
# [ 4., 5., 9.],
# [ 4., 5., 10.],
# [ 4., 5., 11.],
# [ 4., 5., 12.],
# [ 4., 6., 9.],
# [ 4., 6., 10.],
# [ 4., 6., 11.],
# [ 4., 6., 12.],
# [ 4., 7., 9.],
# [ 4., 7., 10.],
# [ 4., 7., 11.],
# [ 4., 7., 12.],
# [ 4., 8., 9.],
# [ 4., 8., 10.],
# [ 4., 8., 11.],
# [ 4., 8., 12.]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment