Skip to content

Instantly share code, notes, and snippets.

@ptmcg
Last active October 2, 2023 19:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ptmcg/10a873d04a7dbca0aa6b235cb5b7bb48 to your computer and use it in GitHub Desktop.
Save ptmcg/10a873d04a7dbca0aa6b235cb5b7bb48 to your computer and use it in GitHub Desktop.
State pattern in Python
import contextlib
class InvalidStateActionError(NotImplementedError):
pass
class InvalidStateTransitionError(InvalidStateActionError):
pass
class InvalidProcessStepError(Exception):
pass
class JobState:
def __init__(self, job):
self.job = job
self.job.state = self
@property
def cur_state(self):
return type(self).__name__.removesuffix("JobState").lower()
@cur_state.setter
def cur_state(self, new_state):
self.job.state = new_state(job)
@property
def completed(self):
return False
def raise_invalid_state_action(self, action, exception_class=InvalidStateActionError):
raise exception_class(f"{action!r} not permitted in {self.cur_state} state")
def raise_invalid_state_transition(self, action):
self.raise_invalid_state_action(action, InvalidStateTransitionError)
def setup(self, **kwargs):
self.raise_invalid_state_transition("setup")
def start(self):
self.raise_invalid_state_transition("start")
def pause(self):
self.raise_invalid_state_transition("pause")
def resume(self):
self.raise_invalid_state_transition("resume")
def process(self, *args, **kwargs):
self.raise_invalid_state_action("process")
class UninitializedJobState(JobState):
def setup(self, **kwargs):
print("setup complete")
try:
self.job.setup_impl(**kwargs)
except Exception:
pass
else:
self.cur_state = ReadyJobState
return self.job
class ReadyJobState(JobState):
def start(self):
print("starting")
try:
self.job.start_impl()
except Exception:
pass
else:
self.cur_state = ActiveJobState
return self.job
class RunningJobState(JobState):
pass
class CompletedJobState(JobState):
@property
def completed(self):
return True
class ActiveJobState(RunningJobState):
def pause(self):
print("pausing")
try:
self.job.pause_impl()
except Exception:
pass
else:
self.cur_state = PausedJobState
return self.job
def process(self, *args, **kwargs):
try:
process_step = args[0]
self.job.process_impl(process_step)
except Exception:
raise
return self.job
class PausedJobState(RunningJobState):
def resume(self):
print("resuming")
try:
self.job.resume_impl()
except Exception:
pass
else:
self.cur_state = ActiveJobState
return self.job
class Job:
def __init__(self):
self.state = UninitializedJobState(self)
self.config = {}
self.next_step = 0
@classmethod
def make_job(cls, **kwargs):
return cls().setup(**kwargs)
def __getattr__(self, action):
if action in ("setup start pause resume process completed cur_state".split()):
return getattr(self.state, action)
raise AttributeError(f"No such attribute {action}")
def validate_setup_config(self, config):
if "process" not in config:
return False, "must contain 'process' attribute"
return True, ""
def setup_impl(self, **kwargs):
valid, msg = self.validate_setup_config(kwargs)
if not valid:
raise Exception(f"setup config not valid, {msg}")
self.config.update(kwargs)
def start_impl(self):
pass
def pause_impl(self):
pass
def resume_impl(self):
pass
def set_completed(self):
self.state.cur_state = CompletedJobState
def process_impl(self, step):
print(f"processing {step!r}")
expected = self.config["process"][self.next_step]
if step == expected:
print(f"process {step!r} complete")
self.next_step += 1
if self.next_step == len(self.config["process"]):
self.set_completed()
else:
raise InvalidProcessStepError(f"Invalid step, expected {expected!r}, received {step!r}")
def __repr__(self):
completed = self.config.get("process", "")[:self.next_step]
return f"Job {id(job):x}: state={job.cur_state}: completed={completed!r}"
@contextlib.contextmanager
def ignore_exception():
try:
yield
except Exception as e:
print(f"raised exception {type(e).__name__}: {e}")
if __name__ == '__main__':
job = Job()
print(job)
with ignore_exception():
job.start()
job = Job.make_job(process=list("ABCD"))
print(job)
with ignore_exception():
job.process("B")
print(job)
job.start()
print(job)
with ignore_exception():
job.process("B")
job.process("A")
job.process("B")
job.pause()
with ignore_exception():
job.process("C")
job.resume()
job.process("C")
print(job)
job.process("D")
with ignore_exception():
job.process("E")
print(job)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment