Skip to content

Instantly share code, notes, and snippets.

@usmcamp0811
Created September 29, 2019 19:09
Show Gist options
  • Save usmcamp0811/0e367e88bacd5bef109f59c9ff0888ad to your computer and use it in GitHub Desktop.
Save usmcamp0811/0e367e88bacd5bef109f59c9ff0888ad to your computer and use it in GitHub Desktop.
Something to visualize the guassians of a gmm model
import holoviews as hv
from sklearn.preprocessing import MinMaxScaler
from holoviews import opts
hv.extension('bokeh')
test = DiscretizerGMM(data=X["X"].values, merge_right_of_elbow=False, n_components=20, name="X")
gmm_df = test.gmm_df(sort="weights")
scale = MinMaxScaler()
gmm_df["weights"] = scale.fit_transform(np.reshape(gmm_df["weights"].values, (-1,1)))
def gmm_mean_plot(gmm_df, n_std=1):
mean_lines = []
plots = []
for m in gmm_df.sort_values(by="means").iterrows():
current_line = []
mean = m[1]["means"]
std = np.sqrt(m[1]["covariances"]) * n_std
weight = m[1]["weights"] * 1000
mean0 = mean - std
mean1 = mean + std
line_segment = pd.DataFrame(dict(center=mean, y=0, mean=mean0, std=std, weight=weight, covariances=m[1]["covariances"]), index=[0])
current_line.append(line_segment)
line_segment = pd.DataFrame(dict(center=mean, y=0, mean=mean1, std=std, weight=weight, covariances=m[1]["covariances"]), index=[1])
current_line.append(line_segment)
current_line = pd.concat(current_line)
p = current_line[["y", "mean"]].hvplot(x="mean", line_width=weight, label=str(int(round(mean, 2))), alpha=0.5)
mean_lines.append(current_line)
plots.append(p)
gmm_df["y"] = 0
plots.append(gmm_df.hvplot.scatter(x="means", y="y", color="red", size=10))
plots = eval(" * ".join([f"plots[{i}]" for i in range(len(plots))]))
plots.opts(show_legend=False, height=500, legend_position="bottom")
return plots
gmm_mean_plot(gmm_df)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment