Skip to content

Instantly share code, notes, and snippets.

@alecmerdler
Created December 5, 2022 18:56
Show Gist options
  • Save alecmerdler/cf57d805ee29740daf4777480cc40667 to your computer and use it in GitHub Desktop.
Save alecmerdler/cf57d805ee29740daf4777480cc40667 to your computer and use it in GitHub Desktop.
Training Lightning App
#! python -m pip install streamlit
import lightning as L, torch, torch.nn as nn
import requests, time, threading, os
from lightning.app.utilities.state import AppState
from lightning.app.frontend import StreamlitFrontend
class PyTorchComponent(L.LightningWork):
def run(self):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = nn.Sequential(nn.Linear(1000, 1000), nn.ReLU(), nn.Linear(1000, 1000))
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for step in range(1_000_000):
model.zero_grad()
x = torch.rand((1000)).to(device)
target = torch.rand((1000)).to(device)
output = model(x)
loss = nn.functional.mse_loss(output, target)
print(f'step: {step}. loss {loss}')
loss.backward()
optimizer.step()
class MetricCollector(PyTorchComponent):
def __init__(self, **kwargs):
super().__init__(**kwargs, parallel=True)
self.pod = None
def run(self):
self.pod = os.environ.get('LIGHTNING_POD_NAME')
super().run()
class DCGM_Visualizer(L.LightningFlow):
def __init__(self, child):
super().__init__()
self.graphics_engine_activity = []
self.started_metrics = False
self.child = child
def metric_collection(self):
# see https://arxiv.org/pdf/2209.06018.pdf
prom_url = 'http://grid-prometheus.grid-system.svc.cluster.local:9090' + '/api/v1/query'
pod_selector = "|".join(
(w.pod for w in self.works() if hasattr(w, "pod")))
query = {
'query': f"DCGM_FI_PROF_GR_ENGINE_ACTIVE{{exported_pod=~\"{pod_selector}\"}}"
}
while not all((w.has_succeeded for w in self.works())):
time.sleep(1)
print("making request", f"{prom_url=}, {query=}")
response = requests.get(prom_url, params=query)
print("raw:", response.json())
self.graphics_engine_activity = response.json()['data']['result']
print(self.graphics_engine_activity)
def run(self):
print("run start")
self.child.run()
print("did super")
if not self.started_metrics and hasattr(self.child, "pod") and self.child.pod:
print("starting metrics collection")
thread = threading.Thread(target=self.metric_collection)
thread.start()
self.started_metrics = True
if all(w.has_succeeded for w in self.works()):
thread.join()
self._exit("work has succeeded")
def configure_layout(self):
return StreamlitFrontend(render_fn=dashboard_layout)
def dashboard_layout(state: AppState):
import streamlit as st
st.write("GPU utlization")
st.line_chart(state.graphics_engine_activity)
component = MetricCollector(cloud_compute=L.CloudCompute('gpu'))
app = L.LightningApp(DCGM_Visualizer(child=component))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment