Skip to content

Instantly share code, notes, and snippets.

@masafumimori
Created March 5, 2023 11:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masafumimori/8cbfd606499bea0b78274cbb8649782a to your computer and use it in GitHub Desktop.
Save masafumimori/8cbfd606499bea0b78274cbb8649782a to your computer and use it in GitHub Desktop.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import abc
class BaseStrategy:
# bybit perp fee
MAKER_FEE = 0.0001 # 0.01%
TAKER_FEE = 0.0006 # 0.06%
LONG_SHORT_COLUMNS = ["long_entry_signal", "long_exit_signal", "long_entry_at", "long_exit_at",
"short_entry_signal", "short_exit_signal", "short_entry_at", "short_exit_at"]
def __init__(self, initial_capital, force_stop):
self.initial_capital = initial_capital
self.force_stop = force_stop
@abc.abstractmethod
def prepare(self):
pass
def __check_columns(self, df):
if not all(col in df.columns for col in self.LONG_SHORT_COLUMNS):
raise ValueError("Required columns are missing in the DataFrame")
def simulate(self, data=None, fee=0.0006):
# Validate input data
assert not data.empty, "Data must be provided"
self.__check_columns(data)
# Initialize variables
long_pos, short_pos = 0, 0
balance = self.initial_capital
# Create result dataframe
df = pd.DataFrame(index=data.index, columns=['balance', 'pl'])
df['balance'] = np.nan
df['pl'] = 0
# Loop through data and simulate trades
for index, row in data.iterrows():
# Get entry and exit signals for long and short positions
long_entry = row["long_entry_signal"]
long_exit = row["long_exit_signal"]
short_entry = row["short_entry_signal"]
short_exit = row["short_exit_signal"]
# Check if there is no existing position
if long_pos == 0 and short_pos == 0:
# Check if there is a long position entry signal
if long_entry:
long_pos = balance / row["long_entry_at"] * (1-fee)
# Check if there is a short position entry signal
elif short_entry:
short_pos = balance / row["short_entry_at"] * (1-fee)
# Check if there is an existing long position and an exit signal
elif long_pos > 0 and long_exit:
sold_price = long_pos * row["long_exit_at"] * (1-fee)
df.at[index, 'pl'] = sold_price - balance
long_pos = 0
balance = sold_price
# Check if there is an existing short position and an exit signal
elif short_pos > 0 and short_exit:
bought_price = short_pos * row["short_exit_at"] * (1+fee)
df.at[index, 'pl'] = balance - bought_price
balance += (balance - bought_price)
short_pos = 0
# Update the balance in the result dataframe
df.at[index, 'balance'] = balance
# Check if the maximum loss has been exceeded
if balance < self.initial_capital * self.force_stop:
print("Simulation stopped due to over {:.2f}% of loss from initial capital at {}".format((1 - balance / self.initial_capital)*100, index))
df.ffill(inplace=True) # Fill the rest rows with the last balance
break
print("Maximum balance {:.2f} at {}".format(df["balance"].max(), df["balance"].idxmax()))
print("Realised P&L {:.2f}%".format((balance - self.initial_capital) / self.initial_capital * 100))
return df
def check_signal_dist(self, data, window=100):
self.__check_columns(data)
print("Long/Short Signal Distribution")
fig, axs = plt.subplots(nrows=2, ncols=2)
l_entry = data["long_entry_signal"].rolling(window).mean()
s_entry = data["short_entry_signal"].rolling(window).mean()
l_exit = data["long_exit_signal"].rolling(window).mean()
s_exit = data["short_exit_signal"].rolling(window).mean()
# Plot long entry signals on the first subplot
axs[0, 0].plot(l_entry, label='Long Entry', color='green')
axs[0, 0].legend()
# Plot long exit signals on the second subplot
axs[0, 1].plot(l_exit, label='Long Exit', color='red')
axs[0, 1].legend()
# Plot short entry signals on the third subplot
axs[1, 0].plot(s_entry, label='Short Entry', color='blue')
axs[1, 0].legend()
# Plot short exit signals on the fourth subplot
axs[1, 1].plot(s_exit, label='Short Exit', color='orange')
axs[1, 1].legend()
fig.suptitle('Long/Short Entry/Exit Signals')
plt.show()
def plot(self, data, show_close=True):
fig, ax = plt.subplots()
ax.plot(data.index, data["balance"], color='green', label='Total balance')
if show_close:
ax.plot(data.index, data["close"], color='blue', label='Close price', alpha=0.5)
ax.legend()
plt.xticks(rotation=90)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment