Skip to content

Instantly share code, notes, and snippets.

@arsalanaf
Forked from yashpatel5400/mountaincar.py
Last active June 15, 2023 04:10
Show Gist options
  • Star 41 You must be signed in to star a gist
  • Fork 28 You must be signed in to fork a gist
  • Save arsalanaf/d10e0c9e2422dba94c91e478831acb12 to your computer and use it in GitHub Desktop.
Save arsalanaf/d10e0c9e2422dba94c91e478831acb12 to your computer and use it in GitHub Desktop.
from btgym import BTgymEnv
import IPython.display as Display
import PIL.Image as Image
from gym import spaces
import gym
import numpy as np
import random
'''
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import Adam
'''
from keras.models import Sequential, load_model
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.recurrent import LSTM
from keras.optimizers import RMSprop, Adam
from collections import deque
class DQN:
def __init__(self, env):
self.env = env
self.memory = deque(maxlen=20000)
self.gamma = 0.85
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.learning_rate = 0.005
self.tau = .125
self.model = self.create_model()
self.target_model = self.create_model()
def create_model(self):
model = Sequential()
# state_shape = list(self.env.observation_space.shape.items())[0][1]
#Reshaping for LSTM
#state_shape=np.array(state_shape)
#state_shape= np.reshape(state_shape, (30,4,1))
'''
model.add(Dense(24, input_dim=state_shape[1], activation="relu"))
model.add(Dense(48, activation="relu"))
model.add(Dense(24, activation="relu"))
model.add(Dense(self.env.action_space.n))
model.compile(loss="mean_squared_error",
optimizer=Adam(lr=self.learning_rate))
'''
model.add(LSTM(64,
input_shape=(4,1),
#return_sequences=True,
stateful=False
))
model.add(Dropout(0.5))
#model.add(LSTM(64,
#input_shape=(1,4),
#return_sequences=False,
# stateful=False
# ))
model.add(Dropout(0.5))
model.add(Dense(self.env.action_space.n, init='lecun_uniform'))
model.add(Activation('linear')) #linear output so we can have range of real-valued outputs
rms = RMSprop()
adam = Adam()
model.compile(loss='mse', optimizer=adam)
return model
def act(self, state):
self.epsilon *= self.epsilon_decay
self.epsilon = max(self.epsilon_min, self.epsilon)
if np.random.random() < self.epsilon:
return self.env.action_space.sample()
return np.argmax(self.model.predict(state)[0])
def target_train(self):
weights = self.model.get_weights()
target_weights = self.target_model.get_weights()
for i in range(len(target_weights)):
target_weights[i] = weights[i] * self.tau + target_weights[i] * (1 - self.tau)
self.target_model.set_weights(target_weights)
def save_model(self, fn):
self.model.save(fn)
def show_rendered_image(self, rgb_array):
"""
Convert numpy array to RGB image using PILLOW and
show it inline using IPykernel.
"""
Display.display(Image.fromarray(rgb_array))
def render_all_modes(self, env):
"""
Retrieve and show environment renderings
for all supported modes.
"""
for mode in self.env.metadata['render.modes']:
print('[{}] mode:'.format(mode))
self.show_rendered_image(self.env.render(mode))
def main():
env = BTgymEnv(filename='./data/DAT_ASCII_EURUSD_M1_2016.csv',
state_shape={'raw_state': spaces.Box(low=-100, high=100,shape=(30,4))},
skip_frame=5,
start_cash=100000,
broker_commission=0.02,
fixed_stake=100,
drawdown_call=90,
render_ylabel='Price Lines',
render_size_episode=(12,8),
render_size_human=(8, 3.5),
render_size_state=(10, 3.5),
render_dpi=75,
verbose=0,)
gamma = 0.9
epsilon = .95
trials = 100
trial_len = 1000
# updateTargetNetwork = 1000
dqn_agent = DQN(env=env)
steps = []
for trial in range(trials):
#dqn_agent.model= load_model("./model.model")
cur_state = np.array(list(env.reset().items())[0][1])
cur_state= np.reshape(cur_state, (30,4,1))
for step in range(trial_len):
action = dqn_agent.act(cur_state)
new_state, reward, done, _ = env.step(action)
reward = reward*10 if not done else -10
new_state =list(new_state.items())[0][1]
new_state= np.reshape(new_state, (30,4,1))
dqn_agent.target_train() # iterates target model
cur_state = new_state
if done:
break
print("Completed trial #{} ".format(trial))
dqn_agent.render_all_modes(env)
dqn_agent.save_model("model.model".format(trial))
if __name__ == "__main__":
main()
@lukucz
Copy link

lukucz commented Mar 23, 2019

trying to run your code, but heading the same error:
AttributeError: 'Lines_LineSeries_LineIterator_DataAccessor_Strateg' object has no attribute 'get_raw_state_state'

replace code:
state_shape={'raw_state': spaces.Box(low=-100, high=100,shape=(30,4))},
to
state_shape={'raw': spaces.Box(low=-100, high=100,shape=(30,4))},

@itsvaibhav01
Copy link

Can you suggest some way to import and buy/sell more than one company stocks in btgym?

@saremeskandary
Copy link

Hi, I have this error "BTgymDataset class is DEPRECATED, use btgym.datafeed.derivative.BTgymDataset2 instead."

@mkraman2
Copy link

mkraman2 commented Jun 4, 2020

@sareme I encounter the same error. Have you found a resolution? This error is encountered even when running the simplest example provided by the developer, link here, so I am led to believe it is a dependency problem.

@asuliman17
Copy link

I am having the same problem as @mkraman2 and others on this thread. If someone in the community would be kind enough to post a solution, that would be highly appreciated.

@mkraman2
Copy link

mkraman2 commented Mar 2, 2021

Hi Adham, it has been a while since I've worked on this, but iirc, the problem was related to attempting to run on Windows, and resolved when running on a recent distro of Ubuntu. The developer mentioned elsewhere that BTGym was developed on and tested exclusively in Linux.

@asuliman17
Copy link

I'm currently running the example files on my mac, so the windows issue shouldn't be a problem. I'm getting the following error when trying to run setting_up_enviornment_vasic.ipynb. Any idea why this could be happening @Kismuz?

[2021-03-02 18:27:55.075162] DEBUG: SimpleDataSet_0: Start time adjusted to <00:00> Process BTgymDataFeedServer-7: Traceback (most recent call last): File "/opt/anaconda3/envs/btgym_env/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap self.run() File "/Users/adhamsuliman/Documents/personal_projects/bot_stock_trader/bt_gym/btgym/btgym/dataserver.py", line 176, in run sample = self.get_data(sample_config=service_input['kwargs']) File "/Users/adhamsuliman/Documents/personal_projects/bot_stock_trader/bt_gym/btgym/btgym/dataserver.py", line 88, in get_data sample = self.dataset.sample(**sample_config) File "/Users/adhamsuliman/Documents/personal_projects/bot_stock_trader/bt_gym/btgym/btgym/datafeed/base.py", line 539, in sample return self._sample(**kwargs) File "/Users/adhamsuliman/Documents/personal_projects/bot_stock_trader/bt_gym/btgym/btgym/datafeed/base.py", line 617, in _sample **kwargs File "/Users/adhamsuliman/Documents/personal_projects/bot_stock_trader/bt_gym/btgym/btgym/datafeed/base.py", line 882, in _sample_interval first_row = self.data.index.get_loc(adj_timedate, method='nearest') File "/opt/anaconda3/envs/btgym_env/lib/python3.6/site-packages/pandas/core/indexes/datetimes.py", line 622, in get_loc raise KeyError(key) KeyError: datetime.date(2016, 1, 11)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment