Created
June 9, 2022 20:07
-
-
Save jlazovskis/c480cc66d2cd2104eaf5dd200c563ccc to your computer and use it in GitHub Desktop.
Group data points into similar sized groups and use that as a grid for a heatmap
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################ | |
## | |
## Code | |
## | |
################################ | |
# Load packages | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from scipy.spatial import KDTree | |
# Given a kd-tree, returns a list of lists of indices of members of each leaf | |
def return_partition(tree, verbose=True, save_split=False): | |
global partition_size | |
global split_order | |
partition_size = [] | |
split_order = [] | |
leaf_size(tree.tree, '') | |
if verbose: | |
print('Partitioned into {0} bins, of {1} different sizes ({2} to {3})'.format( | |
len(partition_size), | |
len(np.unique(np.array(partition_size))), | |
min(partition_size), | |
max(partition_size) | |
)) | |
partition_size_sums = [sum(partition_size[:k]) for k in range(len(partition_size)+1)] | |
split = [tree.indices[partition_size_sums[k]:partition_size_sums[k+1]] for k in range(len(partition_size))] | |
if save_split: | |
return (split, split_order) | |
else: | |
return split | |
# Auxiliary function for checking current subtree is a leaf or not | |
def leaf_size(tree, order_string): | |
if hasattr(tree, "greater"): | |
leaf_size(tree.less, order_string+'l') | |
leaf_size(tree.greater, order_string+'g') | |
else: | |
partition_size.append(tree.children) | |
split_order.append(order_string) | |
# Given a matplotlib axis, draws the lines that split the partition | |
def draw_partition(tree, axis, lastxmin=0, lastxmax=1, lastymin=0, lastymax=1): | |
if hasattr(tree, "greater"): | |
if tree.split_dim == 0: | |
axis.vlines(tree.split, lastymin, lastymax, colors=(0,0,0), linewidth=1) | |
draw_partition(tree.greater, axis, tree.split, lastxmax, lastymin, lastymax) | |
draw_partition(tree.less, axis, lastxmin, tree.split, lastymin, lastymax) | |
else: | |
axis.hlines(tree.split, lastxmin, lastxmax, colors=(0,0,0), linewidth=1) | |
draw_partition(tree.greater, axis, lastxmin, lastxmax, tree.split, lastymax) | |
draw_partition(tree.less, axis, lastxmin, lastxmax, lastymin, tree.split) | |
# Given a matplotlib axis, draws the bins of the partition | |
def draw_boxes(x, y, partition, axis, colors=[], draw_points=False): | |
for i,p in enumerate(partition): | |
points_x = [x[k] for k in p] | |
points_y = [y[k] for k in p] | |
minx = np.min(points_x) | |
miny = np.min(points_y) | |
width = np.max(points_x)-minx | |
height = np.max(points_y)-miny | |
rect = patches.Rectangle((minx,miny), width, height, color=('C'+str(i%10) if colors==[] else colors[i]), alpha=1, linewidth=0) | |
axis.add_patch(rect) | |
if draw_points: | |
axis.scatter(points_x, points_y, color=('C'+str(i%10) if colors==[] else colors[i]), alpha=.5, linewidth=0) | |
################################ | |
## | |
## Example | |
## | |
################################ | |
# Initiate | |
fig = plt.figure() | |
x_original = .. | |
y_original = .. | |
x = x_original | |
y = y_original | |
# Add noise | |
noise_x_param = .. | |
noise_x = [(w-.5)*noise_x_param for w in np.random.rand(len(x))] | |
x = np.array([x[k]+noise_x[k] for k in range(len(x))]) | |
noise_y_param = .. | |
noise_y = [(w-.5)*noise_y_param for w in np.random.rand(len(y))] | |
y = np.array([y[k]+noise_y[k] for k in range(len(y))]) | |
# Normalize | |
minx, maxx, miny, maxy = (np.min(x), np.max(x), np.min(y), np.max(y)) | |
x = np.array([(xp-minx)/(maxx-minx) for xp in x]) | |
y = np.array([(yp-miny)/(maxy-miny) for yp in y]) | |
# Create tree and partition | |
tree = KDTree(np.transpose(np.vstack((x,y))), leafsize=50) | |
partition = return_partition(tree, verbose=True, save_split=False) | |
# Draw and show | |
ax = plt.gca() | |
draw_boxes(x_original, y_original, partition, ax, draw_points=False) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment