Skip to content

Instantly share code, notes, and snippets.

@ttumiel
ttumiel / symbolic.py
Created November 23, 2019 07:56
Symbolic Differentiation in Python
"Symbolic differentiator"
# Possible things to work on in interview:
# - Add substitution into expression for real numbers
# - If instance is the exact same symbol, use power instead of multiply. Use __eq__
# - Add a simplify method to simplify the mess that results from ugly expressions
# - Convert an expression from string
# - Add support for a^x
class Symbol():
@ttumiel
ttumiel / table_logger.py
Last active June 28, 2020 17:35
Pytorch Lightning TableLogger
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
import pandas as pd
from IPython.display import DisplayHandle
import numpy as np
import re
from pathlib import Path
import yaml
# Jax Version of Complex Observation Space
import pytest
from gym import spaces
import numpy as np
import jax.numpy as jnp
import jax.tree_util as tree
def batch_obs(o):
return tree.tree_map(lambda *x: jnp.concatenate(x), *o)
# Dummy wrappers for testing a complex observation space
### Dummy Complex Obs Wrapper ###
class DummyComplex(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = gym.spaces.Dict({"dummy": gym.spaces.Box(0, 255, (12,12)), "original": self.env.observation_space})
def observation(self, observation):
return {"dummy": 255 * np.random.random((12,12)), "original": observation}
## Base Implementation from `ppo_atari_envpool_async_jax_scan_impalanet_machado.py`
## https://github.com/vwxyzjn/cleanrl/blob/52e263887744e022c6b6d0c0de2591c85212c86a/cleanrl/ppo_atari_envpool_async_jax_scan_impalanet_machado.py
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy
# https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
import argparse
import os
import random
import time
from distutils.util import strtobool
# Numpy/torch Complex Observation Handling
import pytest
import tree
import gym
from gym import spaces
import numpy as np
import torch
def make_storage(obs_space: gym.Space, batch_dims: tuple, device: torch.device):