Last active
October 20, 2023 17:55
-
-
Save danielsuo/81780b85af59c7f722ccfecf92f94031 to your computer and use it in GitHub Desktop.
deluca
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Base interfaces and datatypes. | |
NOTES | |
- Agents are not, generally speaking, interchangeable and neither are | |
environments. | |
- Agents and environments likely need to know how to interpret each others' | |
inputs and outputs. | |
- Any notion of an API is a loose one (as evidenced by the number of `*args` and | |
`**kwargs` in `Protocol` signatures); this API serves a more minor ergonomic | |
role and provides some consistency. | |
- There may be subsets of environments (e.g., lungs) or agents that can have a | |
more specific API. This should probably be encouraged. | |
Questions | |
- Should we be able to further update agent/controller state when getting the | |
control of the agent based on its state? | |
- Should `Environment`s control what is observable or leave that to `Agent`s to | |
responsibly access only what they should be allowed to? | |
- Should `AgentControlFn`s take `AgentState` and `EnvironmentState` or be more | |
permissive? | |
""" | |
from typing import NamedTuple, Protocol | |
import chex | |
AgentState = chex.ArrayTree | |
AgentRegistry = {} | |
Control = chex.ArrayTree | |
EnvironmentState = chex.ArrayTree | |
EnvironmentRegistry = {} | |
class EmptyAgentState(AgentState): | |
"""The empty state of either `Agent`.""" | |
class EmptyEnvironmentState(AgentState): | |
"""The empty state of either `Agent`.""" | |
class AgentInitFn(Protocol): | |
"""A callable type for the `init` step of an `Agent`. | |
The `init` step takes `*args` and `**kwargs` to construct an arbitrary initial | |
`state` for the Agent. This may hold statistics of the past updates or any | |
other non-static information. | |
""" | |
def __call__(self, *args, **kwargs) -> AgentState: | |
"""The `init` function. | |
Args: | |
*args: Arbitrary positional arguments. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
The initial state of an Agent. | |
""" | |
class AgentControlFn(Protocol): | |
"""A callable type for the `control` step of an `Agent`. | |
The `control` step takes an `AgentState`, an `EnvironmentState`, and `*args` / | |
`**kwargs` and produces a possibly updated `AgentState` and a `Control`. | |
""" | |
def __call__( | |
self, state: AgentState, obs: EnvironmentState, *args, **kwargs | |
) -> tuple[AgentState, Control]: | |
"""The `control` function. | |
Args: | |
state: An `AgentState` object to update. | |
obs: An `EnvironmentState` object. | |
*args: Arbitrary positional arguments. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
An updated `AgentState` for an `Agent` and a `Control`. | |
""" | |
class AgentUpdateFn(Protocol): | |
"""A callable type for the `update` step of an `Agent`. | |
The `update` step takes an `AgentState` and `*args` / `**kwargs`, typically | |
from the `Environment` and produces a new `AgentState`. | |
""" | |
def __call__(self, state: AgentState, *args, **kwargs) -> AgentState: | |
"""The `update` function. | |
Args: | |
state: An `AgentState` object to update. | |
*args: Arbitrary positional arguments. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
An updated `AgentState` for an `Agent`. | |
""" | |
class Agent(NamedTuple): | |
"""A set of pure functions implementing an agent's behavior.""" | |
init: AgentInitFn | |
control: AgentControlFn | |
update: AgentUpdateFn | |
def __init_subclass__(cls, *args, **kwargs) -> None: | |
"""Adds a new `Agent` subclass to the registry. | |
Args: | |
*args: Arbitrary positional arguments. | |
**kwargs: Arbitrary keyword arguments. | |
""" | |
name = cls.__name__ | |
if name in AgentRegistry: | |
raise ValueError(f'Agent {name} already exists.') | |
AgentRegistry[name] = cls | |
class EnvironmentInitFn(Protocol): | |
"""A callable type for the `init` step of an `Environment`. | |
The `init` step takes `*args` and `**kwargs` to construct an arbitrary initial | |
`state` for the Environment. This may hold statistics of the past updates or | |
any | |
other non-static information. | |
""" | |
def __call__(self, *args, **kwargs) -> EnvironmentState: | |
"""The `init` function. | |
Args: | |
*args: Arbitrary positional arguments. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
The initial state of an Environment. | |
""" | |
class EnvironmentUpdateFn(Protocol): | |
"""A callable type for the `update` step of an `Environment`. | |
The `update` step takes an `EnvironmentState` and `*args` / `**kwargs`, | |
typically | |
from the `Environment` and produces a new `EnvironmentState`. | |
""" | |
def __call__( | |
self, state: EnvironmentState, *args, **kwargs | |
) -> EnvironmentState: | |
"""The `update` function. | |
Args: | |
state: An `EnvironmentState` object to update. | |
*args: Arbitrary positional arguments. | |
**kwargs: Arbitrary keyword arguments. | |
Returns: | |
An updated `EnvironmentState` for an `Environment`. | |
""" | |
class Environment(NamedTuple): | |
"""A set of pure functions implementing an agent's behavior.""" | |
init: EnvironmentInitFn | |
update: EnvironmentUpdateFn | |
def __init_subclass__(cls, *args, **kwargs) -> None: | |
"""Adds a new `Environment` subclass to the registry. | |
Args: | |
*args: Arbitrary positional arguments. | |
**kwargs: Arbitrary keyword arguments. | |
""" | |
name = cls.__name__ | |
if name in EnvironmentRegistry: | |
raise ValueError(f'Environment {name} already exists.') | |
EnvironmentRegistry[name] = cls |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment