Skip to content

Instantly share code, notes, and snippets.

@IperGiove
Created January 9, 2024 23:05
Show Gist options
  • Save IperGiove/325832e30df44639ec78a618b84daf3c to your computer and use it in GitHub Desktop.
Save IperGiove/325832e30df44639ec78a618b84daf3c to your computer and use it in GitHub Desktop.
baseline stable-baseline3 & backtesting.py
import pandas as pd
from backtesting import Backtest, Strategy
from gymnasium import spaces, Env
from stable_baselines3 import PPO
import numpy as np
class periodicStrategy(Strategy):
def init(self):
print(f"Start with equity={self.equity:.2f}")
def next(self, action:int|None=None):
print(f"Action={action} Equity={self.equity:.2f} Date={self.data.index[-1]}")
if action:
if action == 1:
self.buy()
elif action == 2:
self.position.close()
def observation(self):
closes = self.data.Close[-20:]
closes = (closes - closes.min()) / (closes.max() - closes.min())
return [closes]
class CustomEnv(Env):
"""Custom Environment that follows gym interface."""
def __init__(self, bt: Backtest):
# observation (1,20) = (close price, 20 back days)
self.observation_space = spaces.Box(low=-1, high=1, shape=(1, 20), dtype=np.float32)
# action -1 sell all shares, 1 buy all shares for 1 crypto
self.action_space = spaces.Discrete(3)
self.bt = bt
def reward_calculation(self):
if self.previous_equity < self.bt._step_strategy.equity:
return +1
return -1
def check_done(self):
if self.bt._step_time + 2 > len(self.bt._data):
self.render()
return True
return False
def step(self, action):
obs = self.bt._step_strategy.observation()
reward = self.reward_calculation()
done = self.check_done()
info = {}
self.bt.next(action=action)
# False is done (never finish because the market can not finish)
# done is the truncate (the market can be truncated)
return obs, reward, False, done, info
def reset_backtesting(self):
# backtesting, give first next because when initialize can return the whole dataset
self.bt.initialize()
self.bt.next()
while True:
obs = self.bt._step_strategy.observation()
if np.shape(obs) == (1,20):
break
self.bt.next()
def reset(self, seed=None):
self.previous_equity = 10
self.reset_backtesting()
return self.bt._step_strategy.observation(), {}
def render(self, mode='human'):
result = self.bt.next(done=True)
self.bt.plot(results=result, open_browser=False)
def close(self):
pass
def generate_sin_wave(periods: int = 1000, amplitude: float = 1.0) -> pd.DataFrame:
x = pd.date_range(start='2023-01-01', periods=periods, freq='D')
y = amplitude * pd.Series(data=np.sin(np.linspace(0, 10 * np.pi, periods)), index=x) + 2
# Create a DataFrame with the required columns
data = pd.DataFrame({'Open': y, 'High': y, 'Low': y, 'Close': y, 'Volume': y})
return data
data = generate_sin_wave()
print(data)
# Instantiate the env
bt = Backtest(data, periodicStrategy, cash=10)
env = CustomEnv(bt)
# env = VecNormalize(env)
# Define and Train the agent
model = PPO("MlpPolicy", env, verbose=0, tensorboard_log="./logs/")
model.learn(total_timesteps=1000000, log_interval=1)
# model.save("")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment