Skip to content

Instantly share code, notes, and snippets.

@michaelriedl
Created September 12, 2023 02:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save michaelriedl/5984af8bb34872b53a430ce2e3e61179 to your computer and use it in GitHub Desktop.
Save michaelriedl/5984af8bb34872b53a430ce2e3e61179 to your computer and use it in GitHub Desktop.
Testing Lightning Fabric logging behavior
import os
import shutil
import torch
import lightning as L
from torch.utils.tensorboard import SummaryWriter
from tensorboard.backend.event_processing import event_accumulator
LOG_DIR = "./logs"
def main():
# Setup the logs folder
try:
shutil.rmtree(LOG_DIR)
except FileNotFoundError:
pass
os.mkdir(LOG_DIR)
# Create the Lightning Fabric
fabric = L.Fabric(accelerator="cpu", devices=4)
fabric.launch()
# Create unique data per device
data = torch.tensor(10 * fabric.global_rank, dtype=float)
# Create the loggers
no_reduce_summary_writer = SummaryWriter(os.path.join(LOG_DIR, "no_reduce"))
reduce_summary_writer = SummaryWriter(os.path.join(LOG_DIR, "reduce"))
reduce_rank_zero_summary_writer = SummaryWriter(
os.path.join(LOG_DIR, "reduce_rank_zero")
)
# Perform the logging
result = fabric.all_gather(data)
# Log without gather and reduction
no_reduce_summary_writer.add_scalar("data", data.mean(), 0)
# Log with gather and reduction
reduce_summary_writer.add_scalar("data", result.mean(), 0)
# Log with gather and reduction on rank zero
if fabric.global_rank == 0:
reduce_rank_zero_summary_writer.add_scalar(
"data", result.mean(), fabric.global_rank
)
# Wait for all processes
fabric.barrier()
# Output the results from the Tensorboard logs
if fabric.global_rank == 0:
for log_name in ["no_reduce", "reduce", "reduce_rank_zero"]:
# Create the event accumulator that will load the events files
event_acc = event_accumulator.EventAccumulator(
os.path.join(LOG_DIR, log_name),
)
event_acc.Reload()
# Print the number of sotred entries
print(f"Strategy: {log_name}")
print(f"Number of log entries: {len(event_acc.Scalars('data'))}")
results = [x.value for x in event_acc.Scalars("data")]
print(f"{results}\n")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment