This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| authorName: default | |
| experimentName: example_mnist | |
| trialConcurrency: 4 | |
| maxExecDuration: 1h | |
| maxTrialNum: 10 | |
| #choice: local, remote, pai | |
| trainingServicePlatform: local | |
| #choice: true, false | |
| useAnnotation: true | |
| tuner: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| fn = {'stochastic': get_stochastic, 'rsi': fnRSI,'macd': fnMACD, 'bollinger': get_bollinger_diffs} | |
| self.get_derivative_diffs = fn.get(custom_args.environment) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def main(): | |
| env = TradingEnv(custom_args=args, env_id='custom_trading_env', obs_data_len=obs_data_len, step_len=step_len, sample_len=sample_len, | |
| df=df, fee=fee, initial_budget=1, n_action_intervals=n_action_intervals, deal_col_name='c', sell_at_end=True, | |
| feature_names=['o', 'h','l','c','v', | |
| 'num_trades', 'taker_base_vol']) | |
| agent = dqn_agent.Agent(action_size=2 * n_action_intervals + 1, obs_len=obs_data_len, num_features=env.reset().shape[-1], **hyperparams) | |
| agent.qnetwork_local.load_state_dict(torch.load(os.path.join(load_location, 'TradingGym_Rainbow_1000.pth'), map_location=device)) | |
| agent.qnetwork_local.to(device) | |
| for eps in range(n_episode=500): | |
| next_state, reward, done, _ = env.step(agent.act(state)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def step(self): | |
| ... | |
| derivative_diff = self.get_derivative_diffs(self.df_sample.iloc[self.step_st: self.step_st + self.obs_len]) | |
| self.fee_rate = np.clip( self.fee_rate * derivative_diff / self.previous_diff, self.min_fee_rate, self.max_fee_rate) | |
| ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def fnMACD(m_Df, m_NumFast=12, m_NumSlow=26, m_NumSignal=9): | |
| EMAFast = m_Df['c'].ewm( span = m_NumFast, min_periods = m_NumFast - 1).mean() | |
| EMASlow = m_Df['c'].ewm( span = m_NumSlow, min_periods = m_NumSlow - 1).mean() | |
| MACD = EMAFast - EMASlow | |
| MACDSignal= MACD.ewm( span = m_NumSignal, min_periods = m_NumSignal-1).mean() | |
| MACDDiff= MACD - MACDSignal | |
| return MACDDiff.mean() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_bollinger_diffs(df, n=20, k=2): | |
| ma_n = df['c'].rolling(n).mean() | |
| Bol_upper = df['c'].rolling(n).mean() + k* df['c'].rolling(n).std() | |
| Bol_lower = df['c'].rolling(n).mean() - k* df['c'].rolling(n).std() | |
| return (Bol_upper - Bol_lower).mean() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def fnRSI(m_Df, m_N=15): | |
| m_Df = m_Df.c | |
| U = np.where(m_Df.diff(1) > 0, m_Df.diff(1), 0) | |
| D = np.where(m_Df.diff(1) < 0, m_Df.diff(1) *(-1), 0) | |
| AU = pd.DataFrame(U).rolling( window=m_N, min_periods=m_N).mean() | |
| AD = pd.DataFrame(D).rolling( window=m_N, min_periods=m_N).mean() | |
| RSI = AU.div(AD+AU)[0].mean() | |
| return RSI |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def get_stochastic(df, n=15, m=5, t=3): | |
| # highest price during n days | |
| ndays_high = df.h.rolling(window=n, min_periods=1).max() | |
| # lowest price during n days | |
| ndays_low = df.l.rolling(window=n, min_periods=1).min() | |
| # Fast%K | |
| kdj_k = ((df.c - ndays_low) / (ndays_high - ndays_low)) | |
| # Fast%D (=Slow%K) | |
| kdj_d = kdj_k.ewm(span=m).mean() | |
| # Slow%D |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class TradingEnv: | |
| def _long(self,): # buy | |
| ... | |
| def _long_cover(self, current_price_mean, current_mkt_position, action): # sell possession | |
| ... | |
| def step(self, action): | |
| ... | |
| # process buy and sell action | |
| # update position of the agent | |
| # return next_state and reward |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Agent(): | |
| def __init__(self, risk_aversion, **args): | |
| ... | |
| def model(): | |
| ... | |
| def act(self, state, eps=0.): | |
| ... | |
| return model(state) | |
| def learn(self, experiences, is_weights, gamma): | |
| ... |