Created
July 19, 2023 18:58
-
-
Save louis030195/85fa9f93bc0adb0d9174a42b532b5539 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from websockets import connect | |
import json | |
import numpy as np | |
from sklearn.decomposition import IncrementalPCA as PCA | |
import streamlit as st | |
import plotly.graph_objects as go | |
# Create initial figure | |
fig = go.Figure() | |
import pandas as pd | |
pca = PCA(n_components=3) | |
data = [] | |
batch_size = 10 | |
async def print_messages(): | |
global data | |
async with connect("ws://localhost:8080") as ws: | |
while True: | |
msg = await ws.recv() | |
# Extract theta | |
data.append(json.loads(msg)['data']['theta']) | |
if len(data) >= batch_size: | |
# Extract theta | |
X = np.array(data) | |
# Reshape | |
X = X.reshape(len(data), -1) | |
# Update PCA model | |
pca.partial_fit(X) | |
# Print components | |
print(pca.transform(X)[-1]) | |
# Reset data | |
data = [] | |
# Reset figure | |
fig = go.Figure() | |
# Create dataframe | |
df = pd.DataFrame(pca.transform(X), columns=['PC1', 'PC2', 'PC3']) | |
# Add new points to figure | |
fig.add_trace(go.Scatter3d( | |
x=df['PC1'], | |
y=df['PC2'], | |
z=df['PC3'], | |
mode='markers' | |
)) | |
# Streamlit | |
st.plotly_chart(fig, use_container_width=True) | |
import asyncio | |
asyncio.run(print_messages()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment