Created
April 25, 2018 10:00
-
-
Save Arseny-N/476c65a4cefa9b2efc81fd80dd2b8577 to your computer and use it in GitHub Desktop.
Costum iginte engine example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ignite.engines import Events, Engine | |
from enum import Enum | |
from ignite._utils import _to_hours_mins_secs | |
class CostumEvents(Enum): | |
EPOCH_STARTED = "epoch_started" | |
EPOCH_COMPLETED = "epoch_completed" | |
STARTED = "started" | |
COMPLETED = "completed" | |
ITERATION_STARTED = "iteration_started" | |
ITERATION_COMPLETED = "iteration_completed" | |
DISPLAY_VAL_LOSS = "display_val_loss" | |
DISPLAY_LOSS = "display_loss" | |
EXCEPTION_RAISED = "exception_raised" | |
class CostumEngine(Engine): | |
Events = CosptumEvents | |
def __init__(self, process_function, val_interval, display_interval): | |
self.val_interval = val_interval | |
self.display_interval = display_interval | |
super().__init__(process_function) | |
def _run_once_on_dataset(self): | |
try: | |
start_time = time.time() | |
for batch in self.state.dataloader: | |
self.state.batch = batch | |
self.state.iteration += 1 | |
if self.state.iteration % self.val_interval == 0: | |
self._fire_event(self.__class__.Events.DISPLAY_VAL_LOSS) | |
if self.state.iteration % self.display_interval == 0: | |
self._fire_event(self.__class__.Events.DISPLAY_LOSS) | |
self._fire_event(Events.ITERATION_STARTED) | |
self.state.output = self._process_function(self, batch) | |
self._fire_event(Events.ITERATION_COMPLETED) | |
if self.should_terminate: | |
break | |
time_taken = time.time() - start_time | |
hours, mins, secs = _to_hours_mins_secs(time_taken) | |
return hours, mins, secs | |
except BaseException as e: | |
self._logger.error("Current run is terminating due to exception: %s", str(e)) | |
self._handle_exception(e) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment