Skip to content

Instantly share code, notes, and snippets.

@iglesias
Last active March 31, 2024 06:15
Show Gist options
  • Save iglesias/4cbe1ccfef73693daf2b666184fabfda to your computer and use it in GitHub Desktop.
Save iglesias/4cbe1ccfef73693daf2b666184fabfda to your computer and use it in GitHub Desktop.
Porting Deep Symbolic Optimization to a Python version > 3.7
from abc import ABC, abstractmethod
import torch
from torch.nn.functional import one_hot
from dso.program import Program
class StateManager(ABC):
"""
An interface for handling the Tensor inputs to the Policy.
"""
def setup_manager(self, policy):
"""
Function called inside the policy to perform the needed initializations
:param policy the policy class
"""
self.policy = policy
self.max_length = policy.max_length
@abstractmethod
def get_tensor_input(self, obs):
"""
Convert an observation from a Task into a Tensor input for the
Policy, e.g. by performing one-hot encoding or embedding lookup.
Parameters
----------
obs : np.ndarray (dtype=np.float32)
Observation coming from the Task.
Returns
--------
input_ : torch.Tensor (dtype=torch.float32)
Tensor to be used as input to the Policy.
"""
return
def process_state(self, obs):
"""
Entry point for adding information to the state tuple.
If not overwritten, this functions does nothing
"""
return obs
def make_state_manager(config):
"""
Parameters
----------
config : dict
Parameters for this StateManager.
Returns
-------
state_manager : StateManager
The StateManager to be used by the policy.
"""
manager_dict = {
"hierarchical": HierarchicalStateManager
}
if config is None:
config = {}
# Use HierarchicalStateManager by default
manager_type = config.pop("type", "hierarchical")
manager_class = manager_dict[manager_type]
state_manager = manager_class(**config)
return state_manager
class HierarchicalStateManager(StateManager):
"""
Class that uses the previous action, parent, sibling, and/or dangling as
observations.
"""
def __init__(self, observe_parent=True, observe_sibling=True,
observe_action=False, observe_dangling=False, embedding=False,
embedding_size=8):
"""
Parameters
----------
observe_parent : bool
Observe the parent of the Token being selected?
observe_sibling : bool
Observe the sibling of the Token being selected?
observe_action : bool
Observe the previously selected Token?
observe_dangling : bool
Observe the number of dangling nodes?
embedding : bool
Use embeddings for categorical inputs?
embedding_size : int
Size of embeddings for each categorical input if embedding=True.
"""
self.observe_parent = observe_parent
self.observe_sibling = observe_sibling
self.observe_action = observe_action
self.observe_dangling = observe_dangling
self.library = Program.library
# Parameter assertions/warnings
assert self.observe_action + self.observe_parent + self.observe_sibling + self.observe_dangling > 0, \
"Must include at least one observation."
self.embedding = embedding
self.embedding_size = embedding_size
def setup_manager(self, policy):
super().setup_manager(policy)
# Create embeddings if needed
if self.embedding:
if self.observe_action:
# it's funny that both torch.rand and Parameter accept requires_grad - I am leaving them explicit with the defaults
self.action_embeddings = torch.nn.Parameter(torch.rand(self.library.n_action_inputs,
self.embedding_size,
requires_grad=False),
requires_grad=True)*2 - 1
if self.observe_parent:
self.parent_embeddings = torch.nn.Parameter(torch.rand(self.library.n_parent_inputs,
self.embedding_size,
requires_grad=False),
requires_grad=True)*2 - 1
if self.observe_sibling:
self.sibling_embeddings = torch.nn.Parameter(torch.rand(self.library.n_sibling_inputs,
self.embedding_size,
requires_grad=False),
requires_grad=True)*2 - 1
def get_tensor_input(self, obs):
observations = []
unstacked_obs = torch.unbind(obs, axis=1)
action, parent, sibling, dangling = unstacked_obs[:4]
# Cast action, parent, sibling to int for lookup or one-hot
action = action.type(torch.IntTensor)
parent = parent.type(torch.IntTensor)
sibling = sibling.type(torch.IntTensor)
# Action, parent, and sibling inputs are either one-hot or embeddings
# FIXME reduce code duplication in the three ifs parameterizing bool, etc.
if self.observe_action:
# FIXME streamline if in append.
if self.embedding:
x = self.action_embeddings[action]
else:
x = one_hot(action, num_classes=self.library.n_action_inputs)
observations.append(x)
if self.observe_parent:
# FIXME streamline if in append.
if self.embedding:
x = self.parent_embeddings[parent]
else:
x = one_hot(parent, num_classes=self.library.n_parent_inputs)
observations.append(x)
if self.observe_sibling:
# FIXME streamline if in append.
if self.embedding:
x = self.sibling_embeddings[sibling]
else:
x = one_hot(sibling, num_classes=self.library.n_sibling_inputs)
observations.append(x)
# Dangling input is just the value of dangling
if self.observe_dangling:
x = torch.unsqueeze(dangling, -1)
observations.append(x)
input_ = torch.cat(observations, -1)
# possibly concatenates additional observations (e.g., bert embeddings)
if len(unstacked_obs) > 4:
input_ = torch.cat([input_, torch.stack(unstacked_obs[4:], axis=-1)], -1)
return input_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment