Skip to content

Instantly share code, notes, and snippets.

@matteobertozzi
Created February 10, 2024 06:45
Show Gist options
  • Save matteobertozzi/9b79884947a6c0e092e4ae9bd7e51c5a to your computer and use it in GitHub Desktop.
Save matteobertozzi/9b79884947a6c0e092e4ae9bd7e51c5a to your computer and use it in GitHub Desktop.
Amazon States Language Demo Execution
#!/usr/bin/env python3
# Amazon States Language - https://states-language.net/
import jsonpath_ng
from collections import deque, namedtuple
import asyncio
import time
import json
class State:
def __init__(self, specs) -> None:
self.Comment = specs.get('Comment')
self.Type = specs.get('Type')
self.Next = specs.get('Next')
self.End = self.Next is None
def has_next(self):
return self.End
async def execute(self):
raise NotImplementedError
class OperationState(State):
def __init__(self, specs) -> None:
super().__init__(specs)
self.InputPath = specs.get('InputPath')
self.OutputPath = specs.get('OutputPath')
self.ResultPath = specs.get('ResultPath')
self.Parameters = specs.get('Parameters')
self.ResultSelector = specs.get('ResultSelector')
class PassState(State):
def __init__(self, specs) -> None:
super().__init__(specs)
self.InputPath = specs.get('InputPath')
self.OutputPath = specs.get('OutputPath')
self.ResultPath = specs.get('ResultPath')
self.Parameters = specs.get('Parameters')
self.Result = specs.get('Result')
async def execute(self):
return self.Next, None
class TaskState(OperationState):
def __init__(self, specs) -> None:
super().__init__(specs)
self.TimeoutSeconds = specs.get('TimeoutSeconds')
self.TimeoutSecondsPath = specs.get('TimeoutSecondsPath')
self.HeartbeatSeconds = specs.get('HeartbeatSeconds')
self.HeartbeatSecondsPath = specs.get('HeartbeatSecondsPath')
async def execute(self):
task_id = await new_task()
return self.Next, [task_id]
class WaitState(OperationState):
def __init__(self, specs) -> None:
super().__init__(specs)
self.Seconds = specs.get('Seconds')
self.SecondsPath = specs.get('SecondsPath')
self.Timestamp = specs.get('Timestamp')
self.TimestampPath = specs.get('TimestampPath')
async def execute(self):
task_id = await new_task()
return self.Next, [task_id]
class ChoiceState(OperationState):
def __init__(self, specs) -> None:
super().__init__(specs)
self.Choices = specs.get('Choices')
async def execute(self):
return self.Choices[0].get('Next'), None
class ParallelState(OperationState):
def __init__(self, specs) -> None:
super().__init__(specs)
self.Branches = specs.get('Branches')
async def execute(self):
task_ids = []
for branch in self.Branches:
print(' --->', branch)
task_ids.append(new_sm_task(branch))
return self.Next, task_ids
class MapState(OperationState):
def __init__(self, specs) -> None:
super().__init__(specs)
self.ItemProcessor = specs.get('ItemProcessor')
self.MaxConcurrency = specs.get('MaxConcurrency')
self.ItemSelector = specs.get('ItemSelector')
self.ItemsPath = specs.get('ItemsPath')
self._processor_config = self.ItemProcessor.get('ProcessorConfig')
async def execute(self):
# for each item spawn a new task
return self.Next, []
class SucceedState(State):
def __init__(self, specs) -> None:
super().__init__(specs)
async def execute(self):
return None, None
class FailState(State):
def __init__(self, specs) -> None:
super().__init__(specs)
self.Error = specs.get('Error')
self.Cause = specs.get('Cause')
async def execute(self):
return None, None
STATE_TYPES = {
'Task': TaskState,
# Flow
'Choice': ChoiceState,
'Parallel': ParallelState,
'Map': MapState,
'Pass': PassState,
'Wait': WaitState,
'Succeed': SucceedState,
'Fail': FailState,
}
def new_state(state):
state_cls = STATE_TYPES[state['Type']]
return state_cls(state)
class StateMachine:
def __init__(self, task_id, specs) -> None:
self.Comment = specs.get('Comment')
self.Version = specs.get('Version')
self.StartAt = specs.get('StartAt')
self.TimeoutSeconds = specs.get('TimeoutSeconds')
self.States = {name: new_state(state) for name, state in specs['States'].items()}
self.task_id = task_id
self._next_state = self.StartAt
self._current_state = None
self._pending_tasks = []
self._last_pending_check = 0
self._context = {}
async def step(self):
self._pending_tasks = self._wait_pending_tasks()
while not self._pending_tasks and self._next_state:
current_state = self._next_state
state = self.States[current_state]
print(' -> execute', current_state, state.Type)
self._next_state, self._pending_tasks = await state.execute()
return self._next_state is None
def _wait_pending_tasks(self):
if not self._pending_tasks:
return None
if time.time() - self._last_pending_check < 1:
return None
self._last_pending_check = time.time()
pending_tasks = []
for task_id in self._pending_tasks:
task_state, task_result = get_task_result(task_id)
print(' ---> CHECK PENDING', task_id, task_state)
if task_state != 'DONE':
pending_tasks.append(task_id)
return pending_tasks
# ==========================================================================================
_task_state = {}
_task_next_id = 0
def next_task_id():
global _task_next_id
task_id = _task_next_id
_task_next_id += 1
return task_id
def get_task_result(task_id):
return _task_state.get(task_id)
_state_machines = deque()
def new_sm_task(sm_specs):
global _task_state
task_id = next_task_id()
task_sm = StateMachine(task_id, sm_specs)
_state_machines.append(task_sm)
print('-> SM CREATED', task_id)
return task_id
def new_sm_task_from_file(filename):
with open(filename) as fd:
sm = json.load(fd)
return new_sm_task(sm)
async def busy_task(task_id):
await asyncio.sleep(3)
_task_state[task_id] = ('DONE', {'r': task_id})
async def new_task():
global _task_state
task_id = next_task_id()
_task_state[task_id] = ('RUNNING', None)
asyncio.create_task(busy_task(task_id))
return task_id
def forge_input(context_data, input_data, params_data):
result = {}
for key, val in params_data.items():
if isinstance(val, dict):
result[key] = forge_input(context_data, input_data, val)
elif key.endswith('.$'):
if val.startswith('$.'):
json_expr = jsonpath_ng.parse(val)
result[key[:-2]] = [m.value for m in json_expr.find(input_data)]
elif val.startswith('$$.'):
json_expr = jsonpath_ng.parse(val[1:])
result[key[:-2]] = [m.value for m in json_expr.find(context_data)]
else:
result[key[:-2]] = val
else:
result[key] = val
return result
async def main():
new_sm_task_from_file('hello.json')
while True:
if not _state_machines:
await asyncio.sleep(1)
continue
sm = _state_machines.popleft()
completed = await sm.step()
if completed:
print('-> SM COMPLETE', sm.task_id)
_task_state[sm.task_id] = ('DONE', None)
else:
_state_machines.append(sm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment