Skip to content

Instantly share code, notes, and snippets.

@dnola
Last active February 4, 2020 07:04
Show Gist options
  • Save dnola/847a803f2b48223e5b7f88e75a0b499e to your computer and use it in GitHub Desktop.
Save dnola/847a803f2b48223e5b7f88e75a0b499e to your computer and use it in GitHub Desktop.
from ignite.engine import Engine, _prepare_batch
from ignite.engine import Events
from ignite.contrib.handlers import ProgressBar
from ignite.metrics import Accuracy, Loss, RunningAverage
def create_keras_supervised_trainer(model, optimizer, loss_fn, metrics={}, device=None , prepare_batch=None):
from ignite.engine import Engine, _prepare_batch
from ignite.engine import Events
from ignite.contrib.handlers import ProgressBar
from ignite.metrics import Accuracy, Loss, RunningAverage
def _update(engine, batch):
model.train()
optimizer.zero_grad()
if not prepare_batch:
x, y = _prepare_batch(batch, device=device)
else:
x, y = prepare_batch(batch, device=device)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item(), y_pred, y
def _metrics_transform(output):
return output[1], output[2]
engine = Engine(_update)
for name, metric in metrics.items():
metric._output_transform = _metrics_transform
metric.attach(engine, name)
return engine
def make_keras_like(trainer, evaluator, validation_loader):
training_history = {'accuracy':[],'loss':[]}
validation_history = {'accuracy':[],'loss':[]}
last_epoch = []
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'loss')
RunningAverage(Accuracy(output_transform=lambda x: (x[1], x[2]))).attach(trainer, 'accuracy')
prog_bar = ProgressBar()
prog_bar.attach(trainer, ['loss', 'accuracy'])
prog_bar_vd = ProgressBar()
prog_bar_vd.attach(evaluator)
from ignite.handlers import Timer
timer = Timer(average=True)
timer.attach(trainer,start=Events.EPOCH_STARTED,
resume=Events.EPOCH_STARTED,
pause=Events.EPOCH_COMPLETED,
step=Events.EPOCH_COMPLETED)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
metrics = trainer.state.metrics
accuracy = metrics['accuracy']*100
loss = metrics['nll']
last_epoch.append(0)
training_history['accuracy'].append(accuracy)
training_history['loss'].append(loss)
train_msg = "Train Epoch {}: acc: {:.2f}% loss: {:.2f}, train time: {:.2f}s".format(trainer.state.epoch, accuracy, loss, timer.value())
evaluator.run(validation_loader)
metrics = evaluator.state.metrics
accuracy = metrics['accuracy']*100
loss = metrics['nll']
validation_history['accuracy'].append(accuracy)
validation_history['loss'].append(loss)
val_msg = "Valid Epoch {}: acc: {:.2f}% loss: {:.2f}".format(trainer.state.epoch, accuracy, loss)
prog_bar_vd.log_message(train_msg+" --- "+val_msg)
from nvidia.dali import pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from typing import Sequence
def _pipelines_sizes(pipes):
for p in pipes:
p.build()
keys = list(p.epoch_size().keys())
if len(keys) > 0:
for k in keys:
yield p.epoch_size(k)
else:
yield len(p)
class TransformPipeline(pipeline.Pipeline):
"""
Pipeline for coco with data augrmentation
"""
def __init__(
self,
batch_size,
num_threads,
device_id,
size=0,
transform=None,
target_transform=None,
iter_setup=None,
reader=None,
samples=None,
randomize=True,
):
super().__init__(batch_size, num_threads, device_id, seed=-1)
self.reader = reader
# TODO: Take into account the cpu case
self.transform = transform
self.target_transform = target_transform
self._iter_setup = iter_setup
self.size = size
self._jpegs = ops.ExternalSource()
self._labels = ops.ExternalSource()
self.randomize = randomize
if randomize and samples:
samples = sample(samples, len(samples))
self.samples = samples
self.slice = slice(0, batch_size)
def define_graph(self):
if self.reader is None:
"""
Default case, jpegs and labels are feed by `ops.ExternalSource`
"""
self.jpegs = self._jpegs()
self.labels = self._labels()
else:
self.jpegs, self.labels = self.reader()
targets = self.labels
if self.transform:
self.jpegs = self.transform(self.jpegs)
if self.target_transform:
targets = self.target_transform(targets)
return self.jpegs, targets
def __len__(self):
keys = list(self.epoch_size().keys())
if len(keys) > 0:
size = sum(self.epoch_size(k) for k in keys)
else:
size = self.size
return size
def iter_setup(self):
if self._iter_setup:
sl = self.slice
samples = self.samples[sl]
diff = self.batch_size - len(samples)
if diff > 0:
"""a `Pipeline` expect a list of `batch_size` elements but at the end of
an epoch this size is not guaranted so we repeat the last element.
However the `Pipeline` return `Pipeline.size` elements even if `Pipeline.size`
is not a multiple of `batch_size`"""
s = self.samples[-1]
samples = chain.from_iterable([samples, repeat(s, diff)])
jpegs, labels = self._iter_setup(samples)
self.feed_input(self.jpegs, jpegs)
self.feed_input(self.labels, labels)
self.slice = slice(sl.stop, sl.stop + self.batch_size, sl.step)
def reset(self):
if self._iter_setup:
self.slice = slice(0, self.batch_size)
class DALILoader(DALIGenericIterator):
"""
Class to make a `DALIGenericIterator` because `ProgressBar` wants an object with a
`__len__` method. Also the `ProgressBar` is updated by step of 1 !
"""
def __init__(self, pipelines, output_map=("data", "label"), auto_reset=True, stop_at_epoch=True):
if not isinstance(pipelines, Sequence):
pipelines = [pipelines]
size = sum(_pipelines_sizes(pipelines))
super().__init__(pipelines, output_map, size, auto_reset, stop_at_epoch)
self.batch_size = pipelines[0].batch_size
def __len__(self):
return self._size // self.batch_size
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment