Last active
February 4, 2020 07:04
-
-
Save dnola/847a803f2b48223e5b7f88e75a0b499e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ignite.engine import Engine, _prepare_batch | |
from ignite.engine import Events | |
from ignite.contrib.handlers import ProgressBar | |
from ignite.metrics import Accuracy, Loss, RunningAverage | |
def create_keras_supervised_trainer(model, optimizer, loss_fn, metrics={}, device=None , prepare_batch=None): | |
from ignite.engine import Engine, _prepare_batch | |
from ignite.engine import Events | |
from ignite.contrib.handlers import ProgressBar | |
from ignite.metrics import Accuracy, Loss, RunningAverage | |
def _update(engine, batch): | |
model.train() | |
optimizer.zero_grad() | |
if not prepare_batch: | |
x, y = _prepare_batch(batch, device=device) | |
else: | |
x, y = prepare_batch(batch, device=device) | |
y_pred = model(x) | |
loss = loss_fn(y_pred, y) | |
loss.backward() | |
optimizer.step() | |
return loss.item(), y_pred, y | |
def _metrics_transform(output): | |
return output[1], output[2] | |
engine = Engine(_update) | |
for name, metric in metrics.items(): | |
metric._output_transform = _metrics_transform | |
metric.attach(engine, name) | |
return engine | |
def make_keras_like(trainer, evaluator, validation_loader): | |
training_history = {'accuracy':[],'loss':[]} | |
validation_history = {'accuracy':[],'loss':[]} | |
last_epoch = [] | |
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'loss') | |
RunningAverage(Accuracy(output_transform=lambda x: (x[1], x[2]))).attach(trainer, 'accuracy') | |
prog_bar = ProgressBar() | |
prog_bar.attach(trainer, ['loss', 'accuracy']) | |
prog_bar_vd = ProgressBar() | |
prog_bar_vd.attach(evaluator) | |
from ignite.handlers import Timer | |
timer = Timer(average=True) | |
timer.attach(trainer,start=Events.EPOCH_STARTED, | |
resume=Events.EPOCH_STARTED, | |
pause=Events.EPOCH_COMPLETED, | |
step=Events.EPOCH_COMPLETED) | |
@trainer.on(Events.EPOCH_COMPLETED) | |
def log_validation_results(trainer): | |
metrics = trainer.state.metrics | |
accuracy = metrics['accuracy']*100 | |
loss = metrics['nll'] | |
last_epoch.append(0) | |
training_history['accuracy'].append(accuracy) | |
training_history['loss'].append(loss) | |
train_msg = "Train Epoch {}: acc: {:.2f}% loss: {:.2f}, train time: {:.2f}s".format(trainer.state.epoch, accuracy, loss, timer.value()) | |
evaluator.run(validation_loader) | |
metrics = evaluator.state.metrics | |
accuracy = metrics['accuracy']*100 | |
loss = metrics['nll'] | |
validation_history['accuracy'].append(accuracy) | |
validation_history['loss'].append(loss) | |
val_msg = "Valid Epoch {}: acc: {:.2f}% loss: {:.2f}".format(trainer.state.epoch, accuracy, loss) | |
prog_bar_vd.log_message(train_msg+" --- "+val_msg) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nvidia.dali import pipeline | |
import nvidia.dali.ops as ops | |
import nvidia.dali.types as types | |
from nvidia.dali.plugin.pytorch import DALIGenericIterator | |
from typing import Sequence | |
def _pipelines_sizes(pipes): | |
for p in pipes: | |
p.build() | |
keys = list(p.epoch_size().keys()) | |
if len(keys) > 0: | |
for k in keys: | |
yield p.epoch_size(k) | |
else: | |
yield len(p) | |
class TransformPipeline(pipeline.Pipeline): | |
""" | |
Pipeline for coco with data augrmentation | |
""" | |
def __init__( | |
self, | |
batch_size, | |
num_threads, | |
device_id, | |
size=0, | |
transform=None, | |
target_transform=None, | |
iter_setup=None, | |
reader=None, | |
samples=None, | |
randomize=True, | |
): | |
super().__init__(batch_size, num_threads, device_id, seed=-1) | |
self.reader = reader | |
# TODO: Take into account the cpu case | |
self.transform = transform | |
self.target_transform = target_transform | |
self._iter_setup = iter_setup | |
self.size = size | |
self._jpegs = ops.ExternalSource() | |
self._labels = ops.ExternalSource() | |
self.randomize = randomize | |
if randomize and samples: | |
samples = sample(samples, len(samples)) | |
self.samples = samples | |
self.slice = slice(0, batch_size) | |
def define_graph(self): | |
if self.reader is None: | |
""" | |
Default case, jpegs and labels are feed by `ops.ExternalSource` | |
""" | |
self.jpegs = self._jpegs() | |
self.labels = self._labels() | |
else: | |
self.jpegs, self.labels = self.reader() | |
targets = self.labels | |
if self.transform: | |
self.jpegs = self.transform(self.jpegs) | |
if self.target_transform: | |
targets = self.target_transform(targets) | |
return self.jpegs, targets | |
def __len__(self): | |
keys = list(self.epoch_size().keys()) | |
if len(keys) > 0: | |
size = sum(self.epoch_size(k) for k in keys) | |
else: | |
size = self.size | |
return size | |
def iter_setup(self): | |
if self._iter_setup: | |
sl = self.slice | |
samples = self.samples[sl] | |
diff = self.batch_size - len(samples) | |
if diff > 0: | |
"""a `Pipeline` expect a list of `batch_size` elements but at the end of | |
an epoch this size is not guaranted so we repeat the last element. | |
However the `Pipeline` return `Pipeline.size` elements even if `Pipeline.size` | |
is not a multiple of `batch_size`""" | |
s = self.samples[-1] | |
samples = chain.from_iterable([samples, repeat(s, diff)]) | |
jpegs, labels = self._iter_setup(samples) | |
self.feed_input(self.jpegs, jpegs) | |
self.feed_input(self.labels, labels) | |
self.slice = slice(sl.stop, sl.stop + self.batch_size, sl.step) | |
def reset(self): | |
if self._iter_setup: | |
self.slice = slice(0, self.batch_size) | |
class DALILoader(DALIGenericIterator): | |
""" | |
Class to make a `DALIGenericIterator` because `ProgressBar` wants an object with a | |
`__len__` method. Also the `ProgressBar` is updated by step of 1 ! | |
""" | |
def __init__(self, pipelines, output_map=("data", "label"), auto_reset=True, stop_at_epoch=True): | |
if not isinstance(pipelines, Sequence): | |
pipelines = [pipelines] | |
size = sum(_pipelines_sizes(pipelines)) | |
super().__init__(pipelines, output_map, size, auto_reset, stop_at_epoch) | |
self.batch_size = pipelines[0].batch_size | |
def __len__(self): | |
return self._size // self.batch_size |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment