Created
May 8, 2023 05:38
-
-
Save wassname/6829b2dc4f5b5bf9a58f3ab32bf8ac27 to your computer and use it in GitHub Desktop.
lightning_utils.py to read from the csv logger
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
from pytorch_lightning.loggers import CSVLogger, WandbLogger | |
from pathlib import Path | |
import pandas as pd | |
def read_metrics_csv(metrics_file_path): | |
df_hist = pd.read_csv(metrics_file_path) | |
df_hist["epoch"] = df_hist["epoch"].ffill() | |
df_histe = df_hist.set_index("epoch").groupby("epoch").mean() | |
return df_histe | |
def read_hist(trainer: pl.Trainer): | |
ts = [t for t in trainer.loggers if isinstance(t, CSVLogger)] | |
try: | |
metrics_file_path = Path(ts[0].experiment.metrics_file_path) | |
df_histe = read_metrics_csv(metrics_file_path) | |
return df_histe | |
except Exception as e: | |
print(e) | |
import logging | |
class PLSilent(): | |
def __init__(self): | |
self.l = logging.getLogger("lightning") | |
def __enter__(self): | |
self.level = self.l.level | |
self.l.setLevel(logging.ERROR) | |
return self | |
def __exit__(self, type, value, traceback): | |
self.l.setLevel(self.level) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment