Skip to content

Instantly share code, notes, and snippets.

@louis030195
Created July 19, 2023 18:58
Show Gist options
  • Save louis030195/85fa9f93bc0adb0d9174a42b532b5539 to your computer and use it in GitHub Desktop.
Save louis030195/85fa9f93bc0adb0d9174a42b532b5539 to your computer and use it in GitHub Desktop.
from websockets import connect
import json
import numpy as np
from sklearn.decomposition import IncrementalPCA as PCA
import streamlit as st
import plotly.graph_objects as go
# Create initial figure
fig = go.Figure()
import pandas as pd
pca = PCA(n_components=3)
data = []
batch_size = 10
async def print_messages():
global data
async with connect("ws://localhost:8080") as ws:
while True:
msg = await ws.recv()
# Extract theta
data.append(json.loads(msg)['data']['theta'])
if len(data) >= batch_size:
# Extract theta
X = np.array(data)
# Reshape
X = X.reshape(len(data), -1)
# Update PCA model
pca.partial_fit(X)
# Print components
print(pca.transform(X)[-1])
# Reset data
data = []
# Reset figure
fig = go.Figure()
# Create dataframe
df = pd.DataFrame(pca.transform(X), columns=['PC1', 'PC2', 'PC3'])
# Add new points to figure
fig.add_trace(go.Scatter3d(
x=df['PC1'],
y=df['PC2'],
z=df['PC3'],
mode='markers'
))
# Streamlit
st.plotly_chart(fig, use_container_width=True)
import asyncio
asyncio.run(print_messages())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment