Skip to content

Instantly share code, notes, and snippets.

Last active April 11, 2021 23:29
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericspod/f4da372d22cc8da420ee74b8968303cd to your computer and use it in GitHub Desktop.
Save ericspod/f4da372d22cc8da420ee74b8968303cd to your computer and use it in GitHub Desktop.
Base Engine, Trainer, and Evaluator
import torch
import warnings
import threading
import numpy as np
from ignite.engine.engine import Engine, Events
def ensure_tuple(vals):
Returns a tuple containing just `vals` if it is not a list or tuple, or `vals` converted to a tuple otherwise.
if not isinstance(vals, (list, tuple)):
vals = (vals,)
return tuple(vals)
class BaseEngine(Engine):
Base training/evaluating engine inheriting from Ignite's Engine. This manages a single network's train/eval/infer
process. Setups with multiple networks should have multiple instances of subtypes of this class.
net : torch.nn.Module
The network to train or evaluate
loss : torch.nn.modules.loss._Loss
Loss object
opt : torch.optim.Optimizer, optional
Optimizer object for training or None
device_ids : int or tuple of int, optional
CUDA device ID numbers stating which devices to compute on, empty sequence for CPU computation
useCUDA : bool
True if CUDA is to be used (ie. `device_ids` has valid indentifiers in it) and is available
device : torch.device
The device this object uses to create tensors
nonBlocking : bool
Determines if tensors are created as blocking or not, default is False
lock : threading.RLock
Anytime `net` is used for training or inference this lock should be used to ensure exclusive access
net (torch.nn.Module): The network to train or evaluate
loss (torch.nn.modules.loss._Loss): Loss object
opt (torch.optim.Optimizer): Optimizer object for training or None
device_ids (tuple of int): CUDA device ID numbers stating which devices to compute on, empty for CPU computation
useCUDA (bool): True if CUDA is to be used (ie. `device_ids` has valid indentifiers in it) and is available
device (torch.device): The device this object uses to create tensors
nonBlocking (bool): Determines if tensors are created as blocking or not, default is False
lock (threading.RLock): Anytime `net` is used for training or inference, or anytime tensors it relies on are
accessed, this lock should be used to ensure exclusive access
def __init__(self, net, loss, opt=None, device_ids=[0], step_func=None):
Initialize the engine with network, loss, and optimizer provided.
net : torch.nn.Module
The network to train or evaluate
loss : torch.nn.modules.loss._Loss
Loss object
opt : torch.optim.Optimizer, optional
Optimizer object for training or None
device_ids : tuple of int, optional
CUDA device ID numbers stating which devices to compute on, empty sequence for CPU computation
step_func : callable, optional
Callable defining the training/evaluation iteration step behaviour, passed to the super-constructor call, if
None `self.step` is passed instead allowing engine behaviour to be determined through inheritance
self.device_ids = list(device_ids)
self.useCuda = len(self.device_ids) > 0 and torch.cuda.is_available
self.device = torch.device("cuda:%i"%(self.device_ids[0]) if self.useCuda else "cpu")
self.nonBlocking = False
self.lock = threading.RLock() =
self.loss =
self.opt = opt
step = self.step
if step_func is not None:
step = step_func
def to_tensor(self, arr):
Convert the array or sequence of arrays to tensor(s) if necessary, returning a tensor or a tuple thereof.
arr : iterable
A list, tuple, dictionary, etc. containing numpy arrays to convert to tensors
A list, tuple, dictionary, etc. of tensors situated on `self.device`
if isinstance(arr, np.ndarray):
return torch.from_numpy(arr).to(device=self.device, non_blocking=self.nonBlocking)
elif isinstance(arr, (list, tuple)):
return tuple(map(self.to_tensor, arr))
elif isinstance(arr, dict):
return {k: self.to_tensor(v) for k, v in arr.items()}
return arr
def to_numpy(self, tensor):
Convert the tensor or sequence of tensors to numpy array(s) if necessary, returning an array or a tuple thereof.
arr : iterable
A list, tuple, dictionary, etc. containing tensors to convert to numpy arrays
A list, tuple, dictionary, etc. of numpy arrays
if isinstance(tensor,np.ndarray):
return tensor
elif isinstance(tensor, (list, tuple)):
return tuple(map(self.to_numpy, tensor))
elif isinstance(tensor, dict):
return {k: self.to_numpy(v) for k, v in tensor.items()}
def set_requires_grad(self, grad=True):
Set `requires_grad` for every parameter of `` to `grad`.
grad : bool
Value to set each `requires_grad` to
for p in
p.requires_grad = grad
def net_forward(self, inputs):
Apply the values from `inputs` to the network and return the results. If multiple devices are being computed
on, `torch.nn.parallel.data_parallel` is used to broadcast the values to each per-device replica ``.
inputs : list or tuple of tensors
The input parameters for ``
tuple of tensors
The output(s) of the network contained in a tuple, if a single tensor is produced by the network this is
placed in a single-value tuple
# TODO: add event for this method?
if self.useCuda and len(self.device_ids) > 1:
result = torch.nn.parallel.data_parallel(, tuple(inputs), self.device_ids)
result =*inputs)
return ensure_tuple(result)
def loss_forward(self, predictions, ground):
Compute the loss value using `self.loss` with input expanded from `ground` and `predictions`.
predictions : list or tuple of tensors
The prediction values for `self.loss` which will be expanded as the first series of positional arguments
ground : list or tuple of tensors
The ground truth values for `self.loss` which will be expanded as the second series of positional arguments
tuple of tensors
The output(s) of the loss object contained in a tuple, if a single tensor is produced by the loss function
this is placed in a single-value tuple
# TODO: add event for this method?
args = tuple(predictions) + tuple(ground)
return self.loss(*args)
def infer(self, infer_src):
Apply inference to the batches taken from `infer_src`, returning a list of results from the network. The input
source is expected to be finite otherwise this method will not return. The `self.net_forward` method is called
with every item in a generated batch passed as the argument tuple.
infer_src : iterable
Iterable yielding tuples of numpy arrays containing the inputs to `self.net_forward`
list of tuples
Returns a list of results from applying each tuple from the source to the network
# TODO: add event for this method?
return list(self.infer_gen(infer_src))
def infer_gen(self,infer_src):
Apply inference to the batches taken from `infer_src`, yielding the results from the network. The input source
is expected to be finite otherwise this generator will not return and will rely on the consumer to stop the
iteration. `self.net_forward` is called with the whole of a generated batch passed as the argument tuple.
infer_src : iterable
Iterable yielding tuples of numpy arrays containing the inputs to `self.net_forward`
Yields the result from applying each tuple from the source to the network
# TODO: add event for this method?
for batch in infer_src:
with self.lock, torch.no_grad():
net_inputs = self.to_tensor(ensure_tuple(batch))
net_outputs = self.net_forward(net_inputs)
yield self.to_numpy(net_outputs)
def step(self, engine, batch):
Train/eval/infer step function, accepts the engine (which is `self`) and current batch as input. By default
this method only asserts that `engine` is `self`.
engine : BaseEngine
This is the same object as `self`
batch : tuple of np.ndarray
The batch tuple for the current iteration
The loss result should be returned in overrides
assert engine is self
class Trainer(BaseEngine):
The basic engine subtype for training a network. The given `self.step` method is for training a network accepting
inputs from a batch, the results from which are passed to a loss function whose output can be back-propagated.
During training the converted batch is stored in `state.net_inputs`, network outputs in `state.net_outputs`, and loss
function outputs in `state.loss_outputs`. These members of the state object can be accessed to inspect the training
net_input_indices : tuple of int
Indices of network input tensors in each batch
loss_pred_indices : tuple of int
Indices of the ground truth tensors in each batch
loss_ground_indices : tuple of int
Indices of the prediction tensors in each network output
def __init__(self, net, loss, opt=None, device_ids=[0], net_input_indices=[0],
loss_pred_indices=[0], loss_ground_indices=[-1], step_func=None):
Create the trainer object with the given network, loss function, optimizer, and parameters stating which
members of each batch tuple or network output are inputs for the network or loss function, and which are
ground truth values. Changing these values allows various configurations of training a network whose
output is passed to a loss function along with ground truth values, eg. a simple supervised environment.
net : torch.nn.Module
The network to train or evaluate
loss : torch.nn.modules.loss._Loss
Loss object
opt : torch.optim.Optimizer, optional
Optimizer object for training, if None Adam is used instead with default parameters
device_ids : int or tuple of int, optional
CUDA device ID numbers stating which devices to compute on, empty sequence for CPU computation
net_input_indices : tuple of int, optional
Indices of network input tensors in each batch
loss_pred_indices : tuple of int, optional
Indices of the ground truth tensors in each batch
loss_ground_indices : tuple of int, optional
Indices of the prediction tensors in each network output
step_func : callable, optional
Callable defining the training/evaluation iteration step behaviour, passed to the super-constructor call, if
None `self.step` is passed instead allowing engine behaviour to be determined through inheritance
if opt is not None:
opt = torch.optim.Adam(net.parameters())
super().__init__(net, loss, opt, device_ids, step_func)
self.net_input_indices = tuple(net_input_indices)
self.loss_pred_indices = tuple(loss_pred_indices)
self.loss_ground_indices = tuple(loss_ground_indices)
def step(self, engine, batch):
The default network training loop for any training process with paired input and ground truth datasets. This
trains the network for the number of substeps given in `self.state.num_substeps`, that is the network is trained
for that many steps with the same input data and ground truths derived from `batch`. The indexing member
`net_input_indices` is used to determine which arrays from `batch` are inputs to the network, `netPredIndices` to
determine which outputs from the network are to be passed to the loss function as prediction values, and
`loss_ground_indices` to determine which arrays from `batch` are ground truth.
engine : BaseEngine
This is the same object as `self`
batch : tuple of np.ndarray
The batch tuple for the current iteration
The result from the loss function
with self.lock:
self.state.net_inputs = self.to_tensor(ensure_tuple(batch))
inputs = [self.state.net_inputs[i] for i in self.net_input_indices]
ground = [self.state.net_inputs[i] for i in self.loss_ground_indices]
for substep in range(self.state.num_substeps):
self.state.net_outputs = self.net_forward(inputs)
pred = [self.state.net_outputs[i] for i in self.loss_pred_indices]
self.state.loss_outputs = self.loss_forward(pred, ground)
return self.state.loss_outputs.item()
def train(self, src, max_iterations=None, max_epochs=1, num_substeps=1):
Train the network with the given data source, maximum iteration, epoch, and substep counts.
src : iterable
Iterable data source yielding batches of data
max_iterations : int, optional
Number of iterations to train for each epoch, if not None iterations are performed until `src` is exhausted
for each epoch
max_epochs : int, optional
Number of epochs (sets of iterations) to train
maxSubsteps : int, optional
Number of substeps to train, default of 1 implies the commonplace behaviour of training only once per batch
The state object from the training run
def _set_state(_):
self.state.num_substeps = num_substeps
with self.add_event_handler(Events.STARTED,_set_state):
return, max_epochs,max_iterations)
def get_evaluator(self, loss=None, step_func=None):
Return an Evaluator object referencing this object's network and loss objects, and configured with the same
devices and index tuples.
step_func: callable, optional
Evaluation step function to pass to Evaluator object, if None the default `step` method is used
Evaluation object for this object's network
return Evaluator(, loss or self.loss, self.device_ids, self.net_input_indices,
self.loss_pred_indices, self.loss_ground_indices,step_func)
class Evaluator(BaseEngine):
Engine subclass for evaluating a network for validation or other analysis. The default `step` method implements a
simple evaluation step which does a forward pass on the network and loss function, and returns the loss value.
During evaluation the converted batch is stored in `state.net_inputs`, network outputs in `state.net_outputs`, and
loss function outputs in `state.loss_outputs`. These members of the state object can be accessed to inspect the
evaluation parameters.
net_input_indices : tuple of int
Indices of network input tensors in each batch
loss_pred_indices : tuple of int
Indices of the ground truth tensors in each batch
loss_ground_indices : tuple of int
Indices of the prediction tensors in each network output
def __init__(self, net, loss, device_ids=[0], net_input_indices=[0],
loss_pred_indices=[0], loss_ground_indices=[-1], step_func=None):
Create the evaluator object with the given network, loss function, and parameters stating which members of each
batch tuple or network output are inputs for the network or loss function. Changing these values allows various
configurations of evaluating a network whose output is passed to a loss function with ground truth values.
net : torch.nn.Module
The network to train or evaluate
loss : torch.nn.modules.loss._Loss
Loss object
device_ids : int or tuple of int, optional
CUDA device ID numbers stating which devices to compute on, empty sequence for CPU computation
net_input_indices : tuple of int, optional
Indices of network input tensors in each batch
loss_pred_indices : tuple of int, optional
Indices of the ground truth tensors in each batch
loss_ground_indices : tuple of int, optional
Indices of the prediction tensors in each network output
step_func : callable, optional
Callable defining the training/evaluation iteration step behaviour, passed to the super-constructor call, if
None `self.step` is passed instead allowing engine behaviour to be determined through inheritance
super().__init__(net, loss, None, device_ids, step_func)
self.net_input_indices = net_input_indices # indices of the network input tensors in each batch
self.loss_pred_indices = loss_pred_indices # indices of the ground truth tensors in each batch
self.loss_ground_indices = loss_ground_indices # indices of the prediction tensors in each network output
def step(self, engine, batch):
The default network evaluation loop for any process with paired input and ground truth datasets. This
evaluates the network and loss function, returning the loss result. The indexing member `net_input_indices` is
used to determine which arrays from `batch` are inputs to the network, `netPredIndices` to determine which
outputs from the network are to be passed to the loss function as prediction values, and`loss_ground_indices`
to determine which arrays from `batch` are ground truth.
engine : BaseEngine
This is the same object as `self`
batch : tuple of np.ndarray
The batch tuple for the current iteration
The result the loss function
with self.lock, torch.no_grad():
self.state.net_inputs = self.to_tensor(ensure_tuple(batch))
inputs = [self.state.net_inputs[i] for i in self.net_input_indices]
ground = [self.state.net_inputs[i] for i in self.loss_ground_indices]
self.state.net_outputs = self.net_forward(inputs)
pred = [self.state.net_outputs[i] for i in self.loss_pred_indices]
self.state.loss_outputs = self.loss_forward(pred, ground)
return self.state.loss_outputs.item()
def evaluate(self, src, max_iterations=None):
Evaluates the network for each batch in `src`, which must be finite or `max_iterations` must be a positive int.
For each batch, the returned list will contain a pair storing the network output and loss output tensors.
src : iterable
Input batch source
max_iterations : int, optional
Maximum number of evaluation iterations, if not None then iterations are performed until `src` is exhausted
list of tuples
List of each (network output, loss output) pairs for each batch from `src`
results = []
def _collect_results(_):
out = self.state.net_outputs
loss = self.state.loss_outputs
results.append(self.to_numpy((out, loss)))
with self.add_event_handler(Events.ITERATION_COMPLETED,_collect_results):, 1)
return results
def evaluate_gen(self, src, max_iterations=None):
Evaluates the network for each batch in `src`, which must be finite or `max_iterations` must be a positive int.
For each batch, this generator yields a pair storing the network output and loss output tensors.
src : iterable
Input batch source
max_iterations : int, optional
Maximum number of evaluation iterations, if not None then iterations are performed until `src` is exhausted
The (network output, loss output) pair for each batch from `src`
for batch in src:[batch], 1)
out = self.state.net_outputs
loss = self.state.loss_outputs
yield self.to_numpy((out, loss))
def evaluate_mean_loss(self, src, max_iterations=0):
Calculate the mean loss over all of the inputs in `src`.
for output, eloss in self.evaluate_gen(src, max_iterations):
return total_loss/total_size
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment