Last active
June 28, 2023 04:02
-
-
Save TronicLT/13569164a6206c7b0b53781976909928 to your computer and use it in GitHub Desktop.
Pytorch train
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
import types | |
import torch | |
import torch.nn as nn | |
from torch.utils import data | |
from saga.data.datasets import ArrayDataset | |
from saga.utils.callbacks import Callback, History, ProgressBar, Callbacks | |
from saga.utils.general_utils import check_attribute | |
from saga.utils.metrics import check_metric | |
from saga.utils.torch_utils import check_optimiser, check_loss, moving_average | |
__all__ = [ | |
'ModelTrainer' | |
] | |
class ModelTrainer(object): | |
"""Pytorch model trainer based on Keras like interface | |
Parameters | |
---------- | |
model : `torch.nn.Module` | |
Model to train | |
device : str or `torch.device` | |
Device used t train the model, `cpu` or `cuda` | |
""" | |
def __init__(self, model, device=None): | |
if not isinstance(model, nn.Module): | |
raise ValueError('model argument must inherit from torch.nn.Module') | |
self.model = model | |
self.history_ = None | |
self.callbacks_ = list() | |
self.device = device if device is not None else 'cuda' if torch.cuda.is_available() else 'cpu' | |
@property | |
def loss(self): | |
return None if not hasattr(self, 'loss_') else self.loss_ | |
@property | |
def device(self): | |
return self.device_ | |
@device.setter | |
def device(self, device): | |
""" Set device to load the model | |
Parameters | |
---------- | |
device : str or `torch.device` | |
Device to load the model on | |
Returns | |
------- | |
None | |
""" | |
if not isinstance(device, torch.device) and 'cpu' not in str(device).lower() \ | |
and 'cuda' not in str(device).lower(): | |
raise TypeError( | |
'Device should be an instance of `torch.device` or `cpu` or `cuda`, {0} given'.format(device) | |
) | |
if isinstance(device, str): | |
device = torch.device(device) | |
self.device_ = device | |
self.model.to(self.device_) | |
@property | |
def history(self): | |
return self.history_ | |
@loss.setter | |
def loss(self, loss): | |
""" Set model loss | |
Parameters | |
---------- | |
loss: str or callable | |
Model loss | |
Returns | |
------- | |
None | |
""" | |
self.loss_ = check_loss(loss) | |
@property | |
def optimiser(self): | |
return None if not hasattr(self, 'optimiser_') else self.optimiser_ | |
@property | |
def metrics(self): | |
return list() if not hasattr(self, 'metrics_') else self.metrics_ | |
def set_loss(self, loss): | |
""" Set model loss | |
Parameters | |
---------- | |
loss: str or callable | |
Model loss | |
Returns | |
------- | |
None | |
""" | |
self.loss = loss | |
def set_metrics(self, metrics): | |
metrics = metrics or list() | |
self.metrics_ = [self.loss] | |
self.metrics_names_ = ['loss'] | |
for metric in metrics: | |
self.metrics_.append(check_metric(metric)) | |
self.metrics_names_.append(self.metrics_[-1].name) | |
def set_optimizer(self, optimizer, **kwargs): | |
if 'parameters' in kwargs: | |
parameters = kwargs['parameters'] | |
else: | |
parameters = self.model.parameters() | |
self.optimiser_ = check_optimiser(optimizer, parameters, **kwargs) | |
def set_callbacks(self, callbacks): | |
""" Set callbacks | |
Parameters | |
---------- | |
callbacks : iterable | |
List of 'utils.callbacks.Callback` | |
Returns | |
------- | |
self | |
""" | |
callbacks = callbacks or list() | |
if not isinstance(callbacks, (list, tuple)): | |
callbacks = [callbacks] | |
else: | |
callbacks = list(callbacks) | |
for callback in callbacks: | |
if callback is not None and not isinstance(callback, Callback): | |
raise TypeError('{0} not an instance of Callback'.format(callback)) | |
self.history_ = History() | |
self.callbacks_ = [self.history_] + callbacks | |
return self | |
def compile(self, loss, optimizer='adam', metrics=None, loss_kwargs=None): | |
self.set_optimizer(optimizer) | |
self.set_loss(loss) | |
self.set_metrics(metrics) | |
self.loss_kwargs_ = {'reduction': 'elementwise_mean'} if loss_kwargs is None else loss_kwargs | |
@staticmethod | |
def __get_data_loader(x, y=None, batch_size=32, shuffle=False, num_workers=0): | |
return data.DataLoader(ArrayDataset(x, y), | |
batch_size=batch_size, | |
shuffle=shuffle, | |
num_workers=num_workers) | |
@staticmethod | |
def __validation_loader(x, y=None, batch_size=32, num_workers=0): | |
return data.DataLoader(ArrayDataset(x, y), | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers) | |
def fit(self, | |
x, | |
y=None, | |
validation_data=None, | |
n_epoch=10, | |
batch_size=32, | |
callbacks=None, | |
shuffle=False, | |
n_workers=0, | |
verbose=1): | |
""" Fit model | |
Parameters | |
---------- | |
x : array-like, shape=(n_samples, ...) | |
Predictor variable | |
y : array-like, shape=(n_samples, ...) | |
dependent variables | |
validation_data : iterable, shape=(2,) | |
Validation data (predictor, dependent variable) | |
n_epoch : int | |
Number of epochs to train the model | |
batch_size : int | |
Batch size to use during training | |
callbacks : iterable | |
An iterable of `utils.callbacks.Callback` | |
shuffle : bool | |
Set to ``True`` to have the data reshuffled at every epoch (default: False). | |
n_workers : int, (optional) | |
How many subprocesses to use for data loading. | |
0 means that the data will be loaded in the main process. | |
verbose : int | |
verbosity level | |
Returns | |
------- | |
`utils.callbacks.History` | |
""" | |
# -------------------------------------------------- | |
check_attribute(self, ['optimiser_', 'loss_'], 'Call `compile` function first') | |
generator = self.__get_data_loader(x, y, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers) | |
return self.fit_generator(generator, | |
validation_data=validation_data, | |
n_epoch=n_epoch, | |
callbacks=callbacks, | |
verbose=verbose) | |
def fit_generator(self, | |
generator, | |
n_batches=None, | |
validation_data=None, | |
n_epoch=10, | |
callbacks=None, | |
verbose=1): | |
""" Fit model on generator | |
Parameters | |
---------- | |
generator : `torch.utils.data.DataLoader` or types.GeneratorType | |
Generator loader yielding tuples (predictors, dependent) variables | |
validation_data : iterable, shape=(2,) or `torch.utils.data.DataLoader` or generator | |
Validation data (predictor, dependent variable) or generator yielding (predictor, dependent variable) | |
n_batches : int | |
Number of batches per epoch | |
n_epoch : int | |
Number of epochs to train the model | |
callbacks : iterable | |
An iterable of `utils.callbacks.Callback` | |
verbose : int | |
verbosity level | |
Returns | |
------- | |
`saga.utils.callbacks.History` | |
""" | |
# -------------------------------------------------- | |
check_attribute(self, ['optimiser_', 'loss_'], 'Call `compile` function first') | |
# -------------------------------------------------- | |
callbacks = callbacks or list() | |
self.set_callbacks(callbacks) | |
# -------------------------------------------------- | |
if n_batches is None: | |
try: | |
n_batches = len(generator) | |
except ValueError as e: | |
raise ValueError('n_batches cannot be inferred from generetor. `n_batches=None`. ' | |
'Please specify `n_batches` or use the `torch.data.DataLoader` class.') | |
# -------------------------------------------------- | |
if isinstance(validation_data, (list, tuple)): | |
x_val, y_val = validation_data | |
validation_generator = self.__validation_loader(x_val, y_val, batch_size=64) | |
validate = True | |
elif isinstance(validation_data, data.DataLoader) or isinstance(validation_data, types.GeneratorType): | |
validation_generator = validation_data | |
validate = True | |
else: | |
validation_generator = None | |
validate = False | |
# --------------------------------------------------- | |
with ProgressBar() as p_bar: | |
if verbose > 0: | |
self.callbacks_.append(p_bar) | |
callback_container = Callbacks(self.callbacks_) | |
callback_container.set_model(self.model) | |
callback_container.on_train_begin({'n_batches': n_batches, 'n_epoch': n_epoch}) | |
# -------------------------------------------------------------- | |
batch_logs, epoch_logs = dict(), dict() | |
for idx_epoch in range(1, n_epoch + 1): | |
self.model.train(True) | |
callback_container.on_epoch_begin(idx_epoch, epoch_logs) | |
# -------------------------------------------------------------------- | |
for idx_batch, (x_batch, y_batch) in enumerate(generator, start=1): | |
callback_container.on_batch_begin(idx_batch, batch_logs) | |
# ---------------------------------------------------------------- | |
x_batch = x_batch.to(self.device) | |
y_batch = y_batch if y_batch is None else y_batch.to(self.device) | |
# ---------------------------------------------------------------- | |
self.optimiser_.zero_grad() | |
y_pred = self.model.forward(x_batch) | |
loss = self.loss_(y_pred, y_batch, **self.loss_kwargs_) | |
loss.backward() | |
self.optimiser_.step() | |
# -------------------------------------------------------------------- | |
outs = self.evaluate_tensor(x_batch, y_batch) | |
outs = outs if isinstance(outs, (tuple, list)) else [outs] | |
for out, name in zip(outs[1:], self.metrics_names_[1:]): | |
batch_logs[name + '_metric'] = moving_average(batch_logs.get(name + '_metric', 0.), out) | |
# ---------------------------------------------------------------- | |
batch_logs['loss'] = loss.item() | |
batch_logs['size'] = len(x_batch) | |
callback_container.on_batch_end(idx_batch, batch_logs) | |
# -------------------------------------------------------------------- | |
# TODO: Add metric aggregation callback | |
for key in batch_logs: | |
if 'loss' == key or key.endswith('_metric'): | |
epoch_logs[key] = batch_logs[key] | |
# -------------------------------------------------------------------- | |
if validate and validation_generator is not None: | |
val_outs = self.evaluate_generator(validation_generator) | |
val_outs = val_outs if isinstance(val_outs, (tuple, list)) else [val_outs] | |
for out, name in zip(val_outs, self.metrics_names_): | |
epoch_logs['val_' + name + '_metric'] = out | |
# --------------------------------------------------------------------- | |
callback_container.on_epoch_end(idx_epoch, epoch_logs) | |
# -------------------------------------------------------------- | |
callback_container.on_train_end() | |
self.model.train(mode=False) | |
return self.history_ | |
def predict(self, x, batch_size=32, as_tensor=False): | |
""" Run model inference on `x` | |
Parameters | |
---------- | |
x : array-like, shape=(n_samples, ...) | |
Predictor variable | |
batch_size : int | |
The number of samples to use per prediction call | |
as_tensor : bool | |
If `True` the result is a `torch.tensor` otherwise an `array-like` object is returned | |
Returns | |
------- | |
array-like or `torch.tensor` | |
""" | |
generator = self.__validation_loader(x, batch_size=batch_size, num_workers=0) | |
return self.predict_generator(generator=generator, as_tensor=as_tensor) | |
@torch.no_grad() | |
def predict_generator(self, generator, as_tensor=False): | |
""" Run model inference on generator | |
Parameters | |
---------- | |
generator : `torch.utils.data.DataLoader` or types.GeneratorType | |
Generator loader yielding predictors or tuples (predictors, dependent). | |
as_tensor : bool | |
If `True` the result is a `torch.tensor` otherwise an `array-like` object is returned | |
Returns | |
------- | |
array-like or `torch.tensor` | |
""" | |
self.model.eval() | |
x_res = list() | |
for batch in generator: | |
if isinstance(batch, (list, tuple)): | |
x_res.append(self.model.forward(batch[0].to(self.device))) | |
else: | |
x_res.append(self.model.forward(batch.to(self.device))) | |
x_res = torch.cat(x_res, 0) | |
return x_res if as_tensor else x_res.cpu().numpy() | |
@torch.no_grad() | |
def predict_tensor(self, x): | |
""" Run model inference on tensor | |
Parameters | |
---------- | |
x : 'torch.tensor` shape=(n_samples, ...) | |
Predictor variable | |
Returns | |
------- | |
`torch.tensor` | |
""" | |
self.model.eval() | |
return self.model.forward(x.to(self.device)) | |
def evaluate(self, x, y, batch_size=32): | |
""" Evaluate the model on data `x` and `y` | |
Parameters | |
---------- | |
x : array-like, shape=(n_samples, ...) | |
Predictor variable | |
y : array-like, shape=(n_samples, ...) | |
Dependent variables | |
batch_size : int | |
The number of samples to use per prediction call | |
Returns | |
------- | |
array-like | |
""" | |
generator = self.__validation_loader(x, y, batch_size, num_workers=0) | |
return self.evaluate_generator(generator=generator) | |
@torch.no_grad() | |
def evaluate_generator(self, generator): | |
""" Evaluate model on generator | |
Parameters | |
---------- | |
generator : `torch.utils.data.DataLoader` or types.GeneratorType | |
Generator loader yielding tuples (predictors, dependent) variables | |
Returns | |
------- | |
array-like | |
""" | |
self.model.eval() | |
x_res, y_res = list(), list() | |
for x_bach, y_batch in generator: | |
x_res.append(self.model.forward(x_bach.to(self.device))) | |
y_res.append(y_batch.to(self.device)) | |
x_res = torch.cat(x_res, 0) | |
y_res = torch.cat(y_res, 0) | |
res = [met(x_res, y_res).item() for met in self.metrics_] | |
return res[0] if len(res) == 1 else res | |
@torch.no_grad() | |
def evaluate_tensor(self, x, y): | |
""" Evaluate model using on tensor | |
Parameters | |
---------- | |
x : 'torch.tensor` shape=(n_samples, ...) | |
Predictor variable | |
y : `torch.tensor`, shape=(n_samples, ...) | |
Dependent variables | |
Returns | |
------- | |
`torch.tensor` | |
""" | |
self.model.eval() | |
y_pred = self.model.forward(x.to(self.device)) | |
res = [met(y_pred, y.to(self.device)).item() for met in self.metrics_] | |
return res[0] if len(res) == 1 else res | |
def __example(): | |
from sklearn.datasets import load_iris | |
from torch.nn.functional import relu, log_softmax, cross_entropy | |
X, y = load_iris(True) | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.fc1 = nn.Linear(4, 16) | |
self.fc2 = nn.Linear(16, 3) | |
def forward(self, x): | |
x = relu(self.fc1(x), inplace=True) | |
x = self.fc2(x) | |
return log_softmax(x, dim=1) | |
model = Net() | |
trainer = ModelTrainer(model) | |
trainer.compile(cross_entropy, metrics=['acc']) | |
history = trainer.fit(X, y, validation_data=(X, y), shuffle=True, batch_size=10, n_epoch=20, verbose=1) | |
acc = trainer.evaluate(X, y, 200) | |
y_pred = trainer.predict(X) | |
print(history.bach_history) | |
if __name__ == '__main__': | |
__example() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment