Skip to content

Instantly share code, notes, and snippets.

@StuartFarmer
Last active July 20, 2019 22:32
Show Gist options
  • Save StuartFarmer/96308a8b315da6fc9b263b822b99e209 to your computer and use it in GitHub Desktop.
Save StuartFarmer/96308a8b315da6fc9b263b822b99e209 to your computer and use it in GitHub Desktop.
import numpy as np
import random
'''
How to use: Subclass Agent and override 'get_action_for_state' with a neural network or another piece of logic.
To make learning agents, store some weights as state and update them based on the reward that is returned.
Enviroment expects a 1D array, but you can modify this easily.
Environment logic has been reimplemented from this: https://drive.google.com/drive/folders/1qCvIeui5dJKMXnjUm9_wiPf65VVHdWwz
It was done so for more clarity and to prevent overfitting by 1. shuffling data each episode and 2. featuring seperate
validation data.
Enviroment only goes long, but keeps a set of inventory that it can sell of asynchronously. It's a lot simpler than it sounds.
'''
class Environment:
def __init__(self, observation_space,
starting_cash=10000,
commission=0.0,
validation_split=0.3,
space_per_episode=0.1,
max_steps_per_episode=500,
window=40):
# Splits data across line where 1 - validation_split is the % of training data and the rest is testing
split_idx = int(len(observation_space) * (1-validation_split))
self.training_data = observation_space[:split_idx]
self.testing_data = observation_space[split_idx:]
# Variables for reward
self.starting_cash = starting_cash
self.current_cash = self.starting_cash
self.commission = commission
self.current_reward = 0
# The 'slice' of each observation step in time.
assert len(self.training_data) > window + 1, \
'Not enough data provided for the current window size of '.format(window)
self.window = window
# The percentage of the training data that is observed for each episode. This creates varience between episodes
# and allows you to throw datasets at it without worrying about shuffling, etc
self.space_per_episode = space_per_episode
# A maximum bound to the number of steps. This is a catch-all in case you are using data with a few hundred points
# to create your model and then switch to one with a few million. You won't have to update anything.
self.max_steps_per_episode = max_steps_per_episode
# The smaller of the two values provided.
self.current_episode_length = min(int(len(self.training_data) * self.space_per_episode),
self.max_steps_per_episode)
assert self.current_episode_length > 0, 'Provide more data or make the self.space_per_episode variable higher.'
self.current_episode_starting_index = \
random.randint(0, len(self.training_data) - self.current_episode_length - self.window)
# Current index in time
self.current_episode_i = 0
self.inventory = 0
def current_state(self):
i = self.current_episode_length + self.current_episode_i
state = self.training_data[i: i + self.window] # Get a time slice
diffed_state = np.diff(state) # Take the diff
diffed_state = np.insert(diffed_state, 0, 0) # Add a zero at the beginning
return diffed_state, state
def reset(self):
self.current_cash = self.starting_cash
self.current_reward = 0
self.current_episode_length = min(int(len(self.training_data) * self.space_per_episode),
self.max_steps_per_episode)
self.current_episode_starting_index = \
random.randint(0, len(self.training_data) - self.current_episode_length - self.window)
# Current index in time
self.current_episode_i = 0
self.inventory = 0
def act(self, action):
# Returns observation / next state
diffed_state, state = self.current_state()
if action == 1:
if self.current_cash < state[-1]:
self.current_cash = 0
else:
self.current_cash -= state[-1]
self.current_cash -= self.commission
self.inventory += 1
if action == 2 and self.inventory > 0:
self.inventory -= 1
self.current_cash += state[-1]
self.current_cash -= self.commission
self.current_episode_i += 1
if self.current_episode_i >= self.current_episode_length or self.current_cash <= 0:
if self.current_cash == 0:
reward = -1
else:
reward = (self.starting_cash - self.current_cash) / self.current_cash * 100
self.reset()
return True, reward
else:
return False, 0.0
class Agent:
def __init__(self, environment):
self.environment = environment
self.last_reward = 0.0
def run_episode(self):
done = False
reward = 0.0
while not done:
diff, state = self.environment.current_state()
action = self.get_action_for_state(diff)
is_done, reward = self.environment.act(action)
done = is_done
self.last_reward = reward
return reward
def get_action_for_state(self, state):
return 0
class TestTrader(TestCase):
def test_init(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
Environment(observation_space=test_data, window=2, space_per_episode=0.5)
def test_post_init_setup(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
e = Environment(observation_space=test_data, window=2, space_per_episode=0.5)
self.assertEqual(e.training_data, [1, 2, 3, 4, 5, 6, 7])
self.assertEqual(e.testing_data, [8, 9, 0])
def test_action_one_adds_one_to_inventory(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
e = Environment(observation_space=test_data, starting_cash=1000000, window=2, space_per_episode=0.5)
e.act(1)
self.assertEqual(e.inventory, 1)
def test_action_two_sells_existing_inventory_if_it_exists(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
e = Environment(observation_space=test_data, starting_cash=1000000, window=2, space_per_episode=0.5)
e.act(1)
self.assertEqual(e.inventory, 1)
e.act(2)
self.assertEqual(e.inventory, 0)
def test_action_one_buys_and_deducts_from_current_cash(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
e = Environment(observation_space=test_data, starting_cash=1000000, window=2, space_per_episode=0.5)
_, state = e.current_state()
e.act(1)
self.assertEqual(e.current_cash, e.starting_cash - state[-1])
def test_action_one_buys_and_deducts_with_commission(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
e = Environment(observation_space=test_data, starting_cash=1000000, window=2, space_per_episode=0.5, commission=100)
_, state = e.current_state()
e.act(1)
self.assertEqual(e.current_cash, e.starting_cash - state[-1] - e.commission)
def test_action_two_without_inventory_does_nothing(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
e = Environment(observation_space=test_data, starting_cash=1000000, window=2, space_per_episode=0.5,
commission=100)
_, state = e.current_state()
e.act(2)
self.assertEqual(e.current_cash, e.starting_cash)
self.assertEqual(e.inventory, 0)
def test_action_increments_i(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
e = Environment(observation_space=test_data, starting_cash=1000000, window=2, space_per_episode=0.5,
commission=100)
self.assertEqual(e.current_episode_i, 0)
e.act(1)
self.assertEqual(e.current_episode_i, 1)
e.act(1)
self.assertEqual(e.current_episode_i, 2)
def test_action_one_does_nothing_if_no_cash_available_to_buy(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
e = Environment(observation_space=test_data, starting_cash=100, window=2, space_per_episode=0.5,
commission=100)
e.current_cash = 0
res = e.act(1)
self.assertEqual(e.inventory, 0)
self.assertEqual(res[0], True)
self.assertEqual(res[1], -1)
def test_actions_to_end_returns_reward(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
e = Environment(observation_space=test_data, starting_cash=100, window=2, space_per_episode=0.5,
commission=100)
res_1 = e.act(0)
res_2 = e.act(0)
res_3 = e.act(0)
self.assertEqual(res_1, (False, 0))
self.assertEqual(res_2, (False, 0))
self.assertEqual(res_3, (True, 0.0))
def test_running_out_of_cash_ends_episode(self):
test_data = [1000000, 1000000, 1000000, 1000000, 1000000, 1000000, 1000000, 1000000, 1000000, -2000000, -2000000, -2000000, -2000000, -2000000, -2000000, -2000000, -2000000, -2000000]
e = Environment(observation_space=test_data, starting_cash=1000000, window=2, space_per_episode=0.5,
commission=100)
e.act(1)
res_2 = e.act(2)
self.assertEqual(res_2, (True, -1000100.0))
def test_resetting_puts_all_variables_back_to_normal(self):
test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
e = Environment(observation_space=test_data, starting_cash=100, window=2, space_per_episode=0.5,
commission=0)
first_state = e.current_state()
e.act(1)
e.act(1)
self.assertEqual(e.inventory, 2)
self.assertNotEqual(first_state, e.current_state())
self.assertNotEqual(e.starting_cash, e.current_cash)
e.reset()
self.assertEqual(first_state, e.current_state())
self.assertEqual(e.current_episode_i, 0)
self.assertEqual(e.starting_cash, e.current_cash)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment