Skip to content

Instantly share code, notes, and snippets.

@danielsuo
Last active October 20, 2023 17:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danielsuo/81780b85af59c7f722ccfecf92f94031 to your computer and use it in GitHub Desktop.
Save danielsuo/81780b85af59c7f722ccfecf92f94031 to your computer and use it in GitHub Desktop.
deluca
"""Base interfaces and datatypes.
NOTES
- Agents are not, generally speaking, interchangeable and neither are
environments.
- Agents and environments likely need to know how to interpret each others'
inputs and outputs.
- Any notion of an API is a loose one (as evidenced by the number of `*args` and
`**kwargs` in `Protocol` signatures); this API serves a more minor ergonomic
role and provides some consistency.
- There may be subsets of environments (e.g., lungs) or agents that can have a
more specific API. This should probably be encouraged.
Questions
- Should we be able to further update agent/controller state when getting the
control of the agent based on its state?
- Should `Environment`s control what is observable or leave that to `Agent`s to
responsibly access only what they should be allowed to?
- Should `AgentControlFn`s take `AgentState` and `EnvironmentState` or be more
permissive?
"""
from typing import NamedTuple, Protocol
import chex
AgentState = chex.ArrayTree
AgentRegistry = {}
Control = chex.ArrayTree
EnvironmentState = chex.ArrayTree
EnvironmentRegistry = {}
class EmptyAgentState(AgentState):
"""The empty state of either `Agent`."""
class EmptyEnvironmentState(AgentState):
"""The empty state of either `Agent`."""
class AgentInitFn(Protocol):
"""A callable type for the `init` step of an `Agent`.
The `init` step takes `*args` and `**kwargs` to construct an arbitrary initial
`state` for the Agent. This may hold statistics of the past updates or any
other non-static information.
"""
def __call__(self, *args, **kwargs) -> AgentState:
"""The `init` function.
Args:
*args: Arbitrary positional arguments.
**kwargs: Arbitrary keyword arguments.
Returns:
The initial state of an Agent.
"""
class AgentControlFn(Protocol):
"""A callable type for the `control` step of an `Agent`.
The `control` step takes an `AgentState`, an `EnvironmentState`, and `*args` /
`**kwargs` and produces a possibly updated `AgentState` and a `Control`.
"""
def __call__(
self, state: AgentState, obs: EnvironmentState, *args, **kwargs
) -> tuple[AgentState, Control]:
"""The `control` function.
Args:
state: An `AgentState` object to update.
obs: An `EnvironmentState` object.
*args: Arbitrary positional arguments.
**kwargs: Arbitrary keyword arguments.
Returns:
An updated `AgentState` for an `Agent` and a `Control`.
"""
class AgentUpdateFn(Protocol):
"""A callable type for the `update` step of an `Agent`.
The `update` step takes an `AgentState` and `*args` / `**kwargs`, typically
from the `Environment` and produces a new `AgentState`.
"""
def __call__(self, state: AgentState, *args, **kwargs) -> AgentState:
"""The `update` function.
Args:
state: An `AgentState` object to update.
*args: Arbitrary positional arguments.
**kwargs: Arbitrary keyword arguments.
Returns:
An updated `AgentState` for an `Agent`.
"""
class Agent(NamedTuple):
"""A set of pure functions implementing an agent's behavior."""
init: AgentInitFn
control: AgentControlFn
update: AgentUpdateFn
def __init_subclass__(cls, *args, **kwargs) -> None:
"""Adds a new `Agent` subclass to the registry.
Args:
*args: Arbitrary positional arguments.
**kwargs: Arbitrary keyword arguments.
"""
name = cls.__name__
if name in AgentRegistry:
raise ValueError(f'Agent {name} already exists.')
AgentRegistry[name] = cls
class EnvironmentInitFn(Protocol):
"""A callable type for the `init` step of an `Environment`.
The `init` step takes `*args` and `**kwargs` to construct an arbitrary initial
`state` for the Environment. This may hold statistics of the past updates or
any
other non-static information.
"""
def __call__(self, *args, **kwargs) -> EnvironmentState:
"""The `init` function.
Args:
*args: Arbitrary positional arguments.
**kwargs: Arbitrary keyword arguments.
Returns:
The initial state of an Environment.
"""
class EnvironmentUpdateFn(Protocol):
"""A callable type for the `update` step of an `Environment`.
The `update` step takes an `EnvironmentState` and `*args` / `**kwargs`,
typically
from the `Environment` and produces a new `EnvironmentState`.
"""
def __call__(
self, state: EnvironmentState, *args, **kwargs
) -> EnvironmentState:
"""The `update` function.
Args:
state: An `EnvironmentState` object to update.
*args: Arbitrary positional arguments.
**kwargs: Arbitrary keyword arguments.
Returns:
An updated `EnvironmentState` for an `Environment`.
"""
class Environment(NamedTuple):
"""A set of pure functions implementing an agent's behavior."""
init: EnvironmentInitFn
update: EnvironmentUpdateFn
def __init_subclass__(cls, *args, **kwargs) -> None:
"""Adds a new `Environment` subclass to the registry.
Args:
*args: Arbitrary positional arguments.
**kwargs: Arbitrary keyword arguments.
"""
name = cls.__name__
if name in EnvironmentRegistry:
raise ValueError(f'Environment {name} already exists.')
EnvironmentRegistry[name] = cls
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment