Skip to content

Instantly share code, notes, and snippets.

@Arseny-N
Created April 25, 2018 10:00
Show Gist options
  • Save Arseny-N/476c65a4cefa9b2efc81fd80dd2b8577 to your computer and use it in GitHub Desktop.
Save Arseny-N/476c65a4cefa9b2efc81fd80dd2b8577 to your computer and use it in GitHub Desktop.
Costum iginte engine example
from ignite.engines import Events, Engine
from enum import Enum
from ignite._utils import _to_hours_mins_secs
class CostumEvents(Enum):
EPOCH_STARTED = "epoch_started"
EPOCH_COMPLETED = "epoch_completed"
STARTED = "started"
COMPLETED = "completed"
ITERATION_STARTED = "iteration_started"
ITERATION_COMPLETED = "iteration_completed"
DISPLAY_VAL_LOSS = "display_val_loss"
DISPLAY_LOSS = "display_loss"
EXCEPTION_RAISED = "exception_raised"
class CostumEngine(Engine):
Events = CosptumEvents
def __init__(self, process_function, val_interval, display_interval):
self.val_interval = val_interval
self.display_interval = display_interval
super().__init__(process_function)
def _run_once_on_dataset(self):
try:
start_time = time.time()
for batch in self.state.dataloader:
self.state.batch = batch
self.state.iteration += 1
if self.state.iteration % self.val_interval == 0:
self._fire_event(self.__class__.Events.DISPLAY_VAL_LOSS)
if self.state.iteration % self.display_interval == 0:
self._fire_event(self.__class__.Events.DISPLAY_LOSS)
self._fire_event(Events.ITERATION_STARTED)
self.state.output = self._process_function(self, batch)
self._fire_event(Events.ITERATION_COMPLETED)
if self.should_terminate:
break
time_taken = time.time() - start_time
hours, mins, secs = _to_hours_mins_secs(time_taken)
return hours, mins, secs
except BaseException as e:
self._logger.error("Current run is terminating due to exception: %s", str(e))
self._handle_exception(e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment