Created
February 10, 2024 06:45
-
-
Save matteobertozzi/9b79884947a6c0e092e4ae9bd7e51c5a to your computer and use it in GitHub Desktop.
Amazon States Language Demo Execution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Amazon States Language - https://states-language.net/ | |
import jsonpath_ng | |
from collections import deque, namedtuple | |
import asyncio | |
import time | |
import json | |
class State: | |
def __init__(self, specs) -> None: | |
self.Comment = specs.get('Comment') | |
self.Type = specs.get('Type') | |
self.Next = specs.get('Next') | |
self.End = self.Next is None | |
def has_next(self): | |
return self.End | |
async def execute(self): | |
raise NotImplementedError | |
class OperationState(State): | |
def __init__(self, specs) -> None: | |
super().__init__(specs) | |
self.InputPath = specs.get('InputPath') | |
self.OutputPath = specs.get('OutputPath') | |
self.ResultPath = specs.get('ResultPath') | |
self.Parameters = specs.get('Parameters') | |
self.ResultSelector = specs.get('ResultSelector') | |
class PassState(State): | |
def __init__(self, specs) -> None: | |
super().__init__(specs) | |
self.InputPath = specs.get('InputPath') | |
self.OutputPath = specs.get('OutputPath') | |
self.ResultPath = specs.get('ResultPath') | |
self.Parameters = specs.get('Parameters') | |
self.Result = specs.get('Result') | |
async def execute(self): | |
return self.Next, None | |
class TaskState(OperationState): | |
def __init__(self, specs) -> None: | |
super().__init__(specs) | |
self.TimeoutSeconds = specs.get('TimeoutSeconds') | |
self.TimeoutSecondsPath = specs.get('TimeoutSecondsPath') | |
self.HeartbeatSeconds = specs.get('HeartbeatSeconds') | |
self.HeartbeatSecondsPath = specs.get('HeartbeatSecondsPath') | |
async def execute(self): | |
task_id = await new_task() | |
return self.Next, [task_id] | |
class WaitState(OperationState): | |
def __init__(self, specs) -> None: | |
super().__init__(specs) | |
self.Seconds = specs.get('Seconds') | |
self.SecondsPath = specs.get('SecondsPath') | |
self.Timestamp = specs.get('Timestamp') | |
self.TimestampPath = specs.get('TimestampPath') | |
async def execute(self): | |
task_id = await new_task() | |
return self.Next, [task_id] | |
class ChoiceState(OperationState): | |
def __init__(self, specs) -> None: | |
super().__init__(specs) | |
self.Choices = specs.get('Choices') | |
async def execute(self): | |
return self.Choices[0].get('Next'), None | |
class ParallelState(OperationState): | |
def __init__(self, specs) -> None: | |
super().__init__(specs) | |
self.Branches = specs.get('Branches') | |
async def execute(self): | |
task_ids = [] | |
for branch in self.Branches: | |
print(' --->', branch) | |
task_ids.append(new_sm_task(branch)) | |
return self.Next, task_ids | |
class MapState(OperationState): | |
def __init__(self, specs) -> None: | |
super().__init__(specs) | |
self.ItemProcessor = specs.get('ItemProcessor') | |
self.MaxConcurrency = specs.get('MaxConcurrency') | |
self.ItemSelector = specs.get('ItemSelector') | |
self.ItemsPath = specs.get('ItemsPath') | |
self._processor_config = self.ItemProcessor.get('ProcessorConfig') | |
async def execute(self): | |
# for each item spawn a new task | |
return self.Next, [] | |
class SucceedState(State): | |
def __init__(self, specs) -> None: | |
super().__init__(specs) | |
async def execute(self): | |
return None, None | |
class FailState(State): | |
def __init__(self, specs) -> None: | |
super().__init__(specs) | |
self.Error = specs.get('Error') | |
self.Cause = specs.get('Cause') | |
async def execute(self): | |
return None, None | |
STATE_TYPES = { | |
'Task': TaskState, | |
# Flow | |
'Choice': ChoiceState, | |
'Parallel': ParallelState, | |
'Map': MapState, | |
'Pass': PassState, | |
'Wait': WaitState, | |
'Succeed': SucceedState, | |
'Fail': FailState, | |
} | |
def new_state(state): | |
state_cls = STATE_TYPES[state['Type']] | |
return state_cls(state) | |
class StateMachine: | |
def __init__(self, task_id, specs) -> None: | |
self.Comment = specs.get('Comment') | |
self.Version = specs.get('Version') | |
self.StartAt = specs.get('StartAt') | |
self.TimeoutSeconds = specs.get('TimeoutSeconds') | |
self.States = {name: new_state(state) for name, state in specs['States'].items()} | |
self.task_id = task_id | |
self._next_state = self.StartAt | |
self._current_state = None | |
self._pending_tasks = [] | |
self._last_pending_check = 0 | |
self._context = {} | |
async def step(self): | |
self._pending_tasks = self._wait_pending_tasks() | |
while not self._pending_tasks and self._next_state: | |
current_state = self._next_state | |
state = self.States[current_state] | |
print(' -> execute', current_state, state.Type) | |
self._next_state, self._pending_tasks = await state.execute() | |
return self._next_state is None | |
def _wait_pending_tasks(self): | |
if not self._pending_tasks: | |
return None | |
if time.time() - self._last_pending_check < 1: | |
return None | |
self._last_pending_check = time.time() | |
pending_tasks = [] | |
for task_id in self._pending_tasks: | |
task_state, task_result = get_task_result(task_id) | |
print(' ---> CHECK PENDING', task_id, task_state) | |
if task_state != 'DONE': | |
pending_tasks.append(task_id) | |
return pending_tasks | |
# ========================================================================================== | |
_task_state = {} | |
_task_next_id = 0 | |
def next_task_id(): | |
global _task_next_id | |
task_id = _task_next_id | |
_task_next_id += 1 | |
return task_id | |
def get_task_result(task_id): | |
return _task_state.get(task_id) | |
_state_machines = deque() | |
def new_sm_task(sm_specs): | |
global _task_state | |
task_id = next_task_id() | |
task_sm = StateMachine(task_id, sm_specs) | |
_state_machines.append(task_sm) | |
print('-> SM CREATED', task_id) | |
return task_id | |
def new_sm_task_from_file(filename): | |
with open(filename) as fd: | |
sm = json.load(fd) | |
return new_sm_task(sm) | |
async def busy_task(task_id): | |
await asyncio.sleep(3) | |
_task_state[task_id] = ('DONE', {'r': task_id}) | |
async def new_task(): | |
global _task_state | |
task_id = next_task_id() | |
_task_state[task_id] = ('RUNNING', None) | |
asyncio.create_task(busy_task(task_id)) | |
return task_id | |
def forge_input(context_data, input_data, params_data): | |
result = {} | |
for key, val in params_data.items(): | |
if isinstance(val, dict): | |
result[key] = forge_input(context_data, input_data, val) | |
elif key.endswith('.$'): | |
if val.startswith('$.'): | |
json_expr = jsonpath_ng.parse(val) | |
result[key[:-2]] = [m.value for m in json_expr.find(input_data)] | |
elif val.startswith('$$.'): | |
json_expr = jsonpath_ng.parse(val[1:]) | |
result[key[:-2]] = [m.value for m in json_expr.find(context_data)] | |
else: | |
result[key[:-2]] = val | |
else: | |
result[key] = val | |
return result | |
async def main(): | |
new_sm_task_from_file('hello.json') | |
while True: | |
if not _state_machines: | |
await asyncio.sleep(1) | |
continue | |
sm = _state_machines.popleft() | |
completed = await sm.step() | |
if completed: | |
print('-> SM COMPLETE', sm.task_id) | |
_task_state[sm.task_id] = ('DONE', None) | |
else: | |
_state_machines.append(sm) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment