Last active
March 31, 2024 06:15
-
-
Save iglesias/4cbe1ccfef73693daf2b666184fabfda to your computer and use it in GitHub Desktop.
Porting Deep Symbolic Optimization to a Python version > 3.7
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
import torch | |
from torch.nn.functional import one_hot | |
from dso.program import Program | |
class StateManager(ABC): | |
""" | |
An interface for handling the Tensor inputs to the Policy. | |
""" | |
def setup_manager(self, policy): | |
""" | |
Function called inside the policy to perform the needed initializations | |
:param policy the policy class | |
""" | |
self.policy = policy | |
self.max_length = policy.max_length | |
@abstractmethod | |
def get_tensor_input(self, obs): | |
""" | |
Convert an observation from a Task into a Tensor input for the | |
Policy, e.g. by performing one-hot encoding or embedding lookup. | |
Parameters | |
---------- | |
obs : np.ndarray (dtype=np.float32) | |
Observation coming from the Task. | |
Returns | |
-------- | |
input_ : torch.Tensor (dtype=torch.float32) | |
Tensor to be used as input to the Policy. | |
""" | |
return | |
def process_state(self, obs): | |
""" | |
Entry point for adding information to the state tuple. | |
If not overwritten, this functions does nothing | |
""" | |
return obs | |
def make_state_manager(config): | |
""" | |
Parameters | |
---------- | |
config : dict | |
Parameters for this StateManager. | |
Returns | |
------- | |
state_manager : StateManager | |
The StateManager to be used by the policy. | |
""" | |
manager_dict = { | |
"hierarchical": HierarchicalStateManager | |
} | |
if config is None: | |
config = {} | |
# Use HierarchicalStateManager by default | |
manager_type = config.pop("type", "hierarchical") | |
manager_class = manager_dict[manager_type] | |
state_manager = manager_class(**config) | |
return state_manager | |
class HierarchicalStateManager(StateManager): | |
""" | |
Class that uses the previous action, parent, sibling, and/or dangling as | |
observations. | |
""" | |
def __init__(self, observe_parent=True, observe_sibling=True, | |
observe_action=False, observe_dangling=False, embedding=False, | |
embedding_size=8): | |
""" | |
Parameters | |
---------- | |
observe_parent : bool | |
Observe the parent of the Token being selected? | |
observe_sibling : bool | |
Observe the sibling of the Token being selected? | |
observe_action : bool | |
Observe the previously selected Token? | |
observe_dangling : bool | |
Observe the number of dangling nodes? | |
embedding : bool | |
Use embeddings for categorical inputs? | |
embedding_size : int | |
Size of embeddings for each categorical input if embedding=True. | |
""" | |
self.observe_parent = observe_parent | |
self.observe_sibling = observe_sibling | |
self.observe_action = observe_action | |
self.observe_dangling = observe_dangling | |
self.library = Program.library | |
# Parameter assertions/warnings | |
assert self.observe_action + self.observe_parent + self.observe_sibling + self.observe_dangling > 0, \ | |
"Must include at least one observation." | |
self.embedding = embedding | |
self.embedding_size = embedding_size | |
def setup_manager(self, policy): | |
super().setup_manager(policy) | |
# Create embeddings if needed | |
if self.embedding: | |
if self.observe_action: | |
# it's funny that both torch.rand and Parameter accept requires_grad - I am leaving them explicit with the defaults | |
self.action_embeddings = torch.nn.Parameter(torch.rand(self.library.n_action_inputs, | |
self.embedding_size, | |
requires_grad=False), | |
requires_grad=True)*2 - 1 | |
if self.observe_parent: | |
self.parent_embeddings = torch.nn.Parameter(torch.rand(self.library.n_parent_inputs, | |
self.embedding_size, | |
requires_grad=False), | |
requires_grad=True)*2 - 1 | |
if self.observe_sibling: | |
self.sibling_embeddings = torch.nn.Parameter(torch.rand(self.library.n_sibling_inputs, | |
self.embedding_size, | |
requires_grad=False), | |
requires_grad=True)*2 - 1 | |
def get_tensor_input(self, obs): | |
observations = [] | |
unstacked_obs = torch.unbind(obs, axis=1) | |
action, parent, sibling, dangling = unstacked_obs[:4] | |
# Cast action, parent, sibling to int for lookup or one-hot | |
action = action.type(torch.IntTensor) | |
parent = parent.type(torch.IntTensor) | |
sibling = sibling.type(torch.IntTensor) | |
# Action, parent, and sibling inputs are either one-hot or embeddings | |
# FIXME reduce code duplication in the three ifs parameterizing bool, etc. | |
if self.observe_action: | |
# FIXME streamline if in append. | |
if self.embedding: | |
x = self.action_embeddings[action] | |
else: | |
x = one_hot(action, num_classes=self.library.n_action_inputs) | |
observations.append(x) | |
if self.observe_parent: | |
# FIXME streamline if in append. | |
if self.embedding: | |
x = self.parent_embeddings[parent] | |
else: | |
x = one_hot(parent, num_classes=self.library.n_parent_inputs) | |
observations.append(x) | |
if self.observe_sibling: | |
# FIXME streamline if in append. | |
if self.embedding: | |
x = self.sibling_embeddings[sibling] | |
else: | |
x = one_hot(sibling, num_classes=self.library.n_sibling_inputs) | |
observations.append(x) | |
# Dangling input is just the value of dangling | |
if self.observe_dangling: | |
x = torch.unsqueeze(dangling, -1) | |
observations.append(x) | |
input_ = torch.cat(observations, -1) | |
# possibly concatenates additional observations (e.g., bert embeddings) | |
if len(unstacked_obs) > 4: | |
input_ = torch.cat([input_, torch.stack(unstacked_obs[4:], axis=-1)], -1) | |
return input_ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment