Created
December 5, 2022 18:56
-
-
Save alecmerdler/cf57d805ee29740daf4777480cc40667 to your computer and use it in GitHub Desktop.
Training Lightning App
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! python -m pip install streamlit | |
import lightning as L, torch, torch.nn as nn | |
import requests, time, threading, os | |
from lightning.app.utilities.state import AppState | |
from lightning.app.frontend import StreamlitFrontend | |
class PyTorchComponent(L.LightningWork): | |
def run(self): | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model = nn.Sequential(nn.Linear(1000, 1000), nn.ReLU(), nn.Linear(1000, 1000)) | |
model.to(device) | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.1) | |
for step in range(1_000_000): | |
model.zero_grad() | |
x = torch.rand((1000)).to(device) | |
target = torch.rand((1000)).to(device) | |
output = model(x) | |
loss = nn.functional.mse_loss(output, target) | |
print(f'step: {step}. loss {loss}') | |
loss.backward() | |
optimizer.step() | |
class MetricCollector(PyTorchComponent): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs, parallel=True) | |
self.pod = None | |
def run(self): | |
self.pod = os.environ.get('LIGHTNING_POD_NAME') | |
super().run() | |
class DCGM_Visualizer(L.LightningFlow): | |
def __init__(self, child): | |
super().__init__() | |
self.graphics_engine_activity = [] | |
self.started_metrics = False | |
self.child = child | |
def metric_collection(self): | |
# see https://arxiv.org/pdf/2209.06018.pdf | |
prom_url = 'http://grid-prometheus.grid-system.svc.cluster.local:9090' + '/api/v1/query' | |
pod_selector = "|".join( | |
(w.pod for w in self.works() if hasattr(w, "pod"))) | |
query = { | |
'query': f"DCGM_FI_PROF_GR_ENGINE_ACTIVE{{exported_pod=~\"{pod_selector}\"}}" | |
} | |
while not all((w.has_succeeded for w in self.works())): | |
time.sleep(1) | |
print("making request", f"{prom_url=}, {query=}") | |
response = requests.get(prom_url, params=query) | |
print("raw:", response.json()) | |
self.graphics_engine_activity = response.json()['data']['result'] | |
print(self.graphics_engine_activity) | |
def run(self): | |
print("run start") | |
self.child.run() | |
print("did super") | |
if not self.started_metrics and hasattr(self.child, "pod") and self.child.pod: | |
print("starting metrics collection") | |
thread = threading.Thread(target=self.metric_collection) | |
thread.start() | |
self.started_metrics = True | |
if all(w.has_succeeded for w in self.works()): | |
thread.join() | |
self._exit("work has succeeded") | |
def configure_layout(self): | |
return StreamlitFrontend(render_fn=dashboard_layout) | |
def dashboard_layout(state: AppState): | |
import streamlit as st | |
st.write("GPU utlization") | |
st.line_chart(state.graphics_engine_activity) | |
component = MetricCollector(cloud_compute=L.CloudCompute('gpu')) | |
app = L.LightningApp(DCGM_Visualizer(child=component)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment