-
-
Save masafumimori/8cbfd606499bea0b78274cbb8649782a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import abc | |
class BaseStrategy: | |
# bybit perp fee | |
MAKER_FEE = 0.0001 # 0.01% | |
TAKER_FEE = 0.0006 # 0.06% | |
LONG_SHORT_COLUMNS = ["long_entry_signal", "long_exit_signal", "long_entry_at", "long_exit_at", | |
"short_entry_signal", "short_exit_signal", "short_entry_at", "short_exit_at"] | |
def __init__(self, initial_capital, force_stop): | |
self.initial_capital = initial_capital | |
self.force_stop = force_stop | |
@abc.abstractmethod | |
def prepare(self): | |
pass | |
def __check_columns(self, df): | |
if not all(col in df.columns for col in self.LONG_SHORT_COLUMNS): | |
raise ValueError("Required columns are missing in the DataFrame") | |
def simulate(self, data=None, fee=0.0006): | |
# Validate input data | |
assert not data.empty, "Data must be provided" | |
self.__check_columns(data) | |
# Initialize variables | |
long_pos, short_pos = 0, 0 | |
balance = self.initial_capital | |
# Create result dataframe | |
df = pd.DataFrame(index=data.index, columns=['balance', 'pl']) | |
df['balance'] = np.nan | |
df['pl'] = 0 | |
# Loop through data and simulate trades | |
for index, row in data.iterrows(): | |
# Get entry and exit signals for long and short positions | |
long_entry = row["long_entry_signal"] | |
long_exit = row["long_exit_signal"] | |
short_entry = row["short_entry_signal"] | |
short_exit = row["short_exit_signal"] | |
# Check if there is no existing position | |
if long_pos == 0 and short_pos == 0: | |
# Check if there is a long position entry signal | |
if long_entry: | |
long_pos = balance / row["long_entry_at"] * (1-fee) | |
# Check if there is a short position entry signal | |
elif short_entry: | |
short_pos = balance / row["short_entry_at"] * (1-fee) | |
# Check if there is an existing long position and an exit signal | |
elif long_pos > 0 and long_exit: | |
sold_price = long_pos * row["long_exit_at"] * (1-fee) | |
df.at[index, 'pl'] = sold_price - balance | |
long_pos = 0 | |
balance = sold_price | |
# Check if there is an existing short position and an exit signal | |
elif short_pos > 0 and short_exit: | |
bought_price = short_pos * row["short_exit_at"] * (1+fee) | |
df.at[index, 'pl'] = balance - bought_price | |
balance += (balance - bought_price) | |
short_pos = 0 | |
# Update the balance in the result dataframe | |
df.at[index, 'balance'] = balance | |
# Check if the maximum loss has been exceeded | |
if balance < self.initial_capital * self.force_stop: | |
print("Simulation stopped due to over {:.2f}% of loss from initial capital at {}".format((1 - balance / self.initial_capital)*100, index)) | |
df.ffill(inplace=True) # Fill the rest rows with the last balance | |
break | |
print("Maximum balance {:.2f} at {}".format(df["balance"].max(), df["balance"].idxmax())) | |
print("Realised P&L {:.2f}%".format((balance - self.initial_capital) / self.initial_capital * 100)) | |
return df | |
def check_signal_dist(self, data, window=100): | |
self.__check_columns(data) | |
print("Long/Short Signal Distribution") | |
fig, axs = plt.subplots(nrows=2, ncols=2) | |
l_entry = data["long_entry_signal"].rolling(window).mean() | |
s_entry = data["short_entry_signal"].rolling(window).mean() | |
l_exit = data["long_exit_signal"].rolling(window).mean() | |
s_exit = data["short_exit_signal"].rolling(window).mean() | |
# Plot long entry signals on the first subplot | |
axs[0, 0].plot(l_entry, label='Long Entry', color='green') | |
axs[0, 0].legend() | |
# Plot long exit signals on the second subplot | |
axs[0, 1].plot(l_exit, label='Long Exit', color='red') | |
axs[0, 1].legend() | |
# Plot short entry signals on the third subplot | |
axs[1, 0].plot(s_entry, label='Short Entry', color='blue') | |
axs[1, 0].legend() | |
# Plot short exit signals on the fourth subplot | |
axs[1, 1].plot(s_exit, label='Short Exit', color='orange') | |
axs[1, 1].legend() | |
fig.suptitle('Long/Short Entry/Exit Signals') | |
plt.show() | |
def plot(self, data, show_close=True): | |
fig, ax = plt.subplots() | |
ax.plot(data.index, data["balance"], color='green', label='Total balance') | |
if show_close: | |
ax.plot(data.index, data["close"], color='blue', label='Close price', alpha=0.5) | |
ax.legend() | |
plt.xticks(rotation=90) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment