Skip to content

Instantly share code, notes, and snippets.

import jax.numpy as jnp
from jax import random
class JaxCartPole:
"""
Based on OpenAI Gym Cartpole
https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
"""
def __init__(self):
self.gravity = 9.8
def fori_body(i, val):
env_state, action_key, all_obsv, all_reward, all_done = val
action = random.randint(action_key, (1,), 0, 2)[0]
action_key = random.split(action_key)[0]
env_state, obsv, reward, done, info = env.step(env_state, action)
all_obsv = all_obsv.at[i].set(obsv)
all_reward = all_reward.at[i].set(reward)
all_done = all_done.at[i].set(done)
val = (env_state, action_key, all_obsv, all_reward, all_done)
return val
@ngoodger
ngoodger / jax_skeleton_env.py
Last active November 13, 2021 08:23
jax_skeleton_env.py
import jax.numpy as jnp
from jax import random
class SkeletonEnv:
def __init__(self):
self.random_limit = 0.05
def _get_obsv(self, state):
return state
@ngoodger
ngoodger / torch_env.py
Created November 9, 2020 11:58
PyTorch Gym Environments
import gym
import math
import torch
from gym import spaces, logger
import numpy as np
class Pendulum(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 30
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import numpy as np
import requests
import time
import pyarrow
import torchvision
import torch
MAX_BATCH_SIZE = 1
TEST_CORRECT_OUTPUT = False
if TEST_CORRECT_OUTPUT:
MODEL = torchvision.models.resnet18(pretrained=True).eval()
import numpy as np
import time
import torchvision
import torch
DEVICE = "cuda"
MODEL = torchvision.models.resnet18(pretrained=True).eval().to(DEVICE)
BATCH_SIZE = 64
BATCHES = 100
x_cpu = torch.from_numpy(np.random.random((BATCH_SIZE, 3, 256, 256))
.astype(np.float32)).pin_memory()
from aiohttp import web
import argparse
import asyncio
from dataclasses import dataclass
from io import BytesIO
import numpy as np
import pyarrow
import prometheus_client as pc
import time
import torch
@ngoodger
ngoodger / split_csv.py
Created August 27, 2017 03:31
Split csv file into a user specified number of smaller csv files. Example: python3 split_csv.py measurement_by_imsi_cell_2017-07-22_v4.csv 5
import argparse
def main(input_csv, file_count):
with open(input_csv, "r") as f:
csv_text = f.readlines()
header_line = csv_text[0]
body_lines = csv_text[1:]
line_count = len(body_lines)
lines_per_file = int(line_count / output_file_count)
# First file should take extra lines