Skip to content

Instantly share code, notes, and snippets.

@sunshineatnoon
Last active March 3, 2019 04:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sunshineatnoon/0a3b531519b28388923b164faba0fca3 to your computer and use it in GitHub Desktop.
Save sunshineatnoon/0a3b531519b28388923b164faba0fca3 to your computer and use it in GitHub Desktop.
Pytorch Tools

Loss Visualization

A loss plotter for multiple losses visualized in Visdom, draw different losses within the same window.

class loss_plotter():
    def __init__(self, port = 8097, server = "http://localhost"):
        self.vis = visdom.Visdom(port=port, server=server)
        assert self.vis.check_connection(timeout_seconds=3),'No connection could be formed quickly'
        self.losses = {}
        self.win = None
        self.cnt = 0

    def plot(self, losses, names):
        self.cnt += 1
        x_axis = np.array(range(1, self.cnt+1))
        for name,loss in zip(names, losses):
            if(self.win is None):
                self.losses[name] = []
            self.losses[name].append(loss)
        Y = np.column_stack(np.array(v) for k,v in self.losses.items())
        X = np.column_stack(x_axis for k,v in self.losses.items())

        if(self.win is None):
            self.win = self.vis.line(
                Y = Y,
                X = X,
                opts = dict(markers=False, legend=names)
            )
        else:
            self.vis.line(
                Y = Y,
                X = X,
                opts = dict(markers=False, legend=names),
                win = self.win
            )

How to use

lper = loss_plotter()
lper.plot([self.c_losses.avg, self.d_losses.avg], ['constraint loss', 'data loss'])

A loss plotter for multiple losses visualized in Visdom, draw different losses in different windows.

class loss_plotter():
    def __init__(self, port = 8097, server = "http://localhost"):
        self.vis = visdom.Visdom(port=port, server=server)
        assert self.vis.check_connection(timeout_seconds=3),'No connection could be formed quickly'
        self.wins = {}
        self.losses = {}
        self.cnt = 0

    def plot(self, losses, names):
        self.cnt += 1
        X = np.array(range(1, self.cnt+1))
        for name,loss in zip(names, losses):
            if not (name in self.wins):
                self.losses[name] = []
                self.losses[name].append(loss)
                Y = np.array(self.losses[name])
                self.wins[name] = self.vis.line(
                    Y = Y,
                    X = X,
                    opts = dict(markers=False, legend=[name])
                )
            else:
                self.losses[name].append(loss)
                Y = np.array(self.losses[name])
                self.vis.line(
                    Y = Y,
                    X = X,
                    opts = dict(markers=False, legend=[name]),
                    win = self.wins[name]
                )

How to use

lper = loss_plotter()
lper.plot([self.c_losses.avg, self.d_losses.avg], ['constraint loss', 'data loss'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment