-
-
Save kongmunist/78a39d793eec8c4c73e6112175cc2739 to your computer and use it in GitHub Desktop.
fetcher
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from datetime import datetime | |
from logging import INFO, Formatter, StreamHandler, getLogger | |
from urllib.error import HTTPError | |
import polars as pl | |
class Fetcher: | |
""" | |
Fetches the trading data from the Bybit public API and converts it to OHLCV format. | |
""" | |
trading_schema = { | |
"timestamp": pl.Float64, | |
"side": pl.Utf8, | |
"size": pl.Float64, | |
"price": pl.Float64, | |
} | |
ohlcv_schema = { | |
"truncated_dt": pl.Datetime, | |
"open": pl.Float64, | |
"high": pl.Float64, | |
"low": pl.Float64, | |
"close": pl.Float64, | |
"volume": pl.Float64, | |
"buy_volume": pl.Float64, | |
"sell_volume": pl.Float64, | |
} | |
def __init__(self, save_path: str) -> None: | |
self.base_url = "https://public.bybit.com/trading/" | |
self.__setup_save_path(save_path) | |
self.__setup_logger() | |
return None | |
def __setup_logger(self) -> None: | |
""" | |
Sets up the logger for the class. | |
""" | |
self.logger = getLogger(__name__) | |
handler = StreamHandler() | |
handler.setLevel(INFO) | |
handler.setFormatter(Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")) | |
self.logger.setLevel(INFO) | |
self.logger.addHandler(handler) | |
return None | |
def __setup_save_path(self, save_path: str, identify_date=True) -> None: | |
""" | |
Sets up and returns the save path based on the provided path and current date. | |
Parameters: | |
- save_path (str): The base save path. | |
- identify_date (bool): Whether to append the current date to the save path. Default is True. | |
Returns: | |
- str: The final save path. | |
""" | |
if identify_date: | |
save_path = os.path.join(save_path, datetime.now().strftime("%Y%m%d")) | |
os.makedirs(save_path, exist_ok=True) | |
self.save_path = save_path | |
return None | |
def __convert_to_ohlcv(self, trading: pl.DataFrame, interval_sec: int) -> pl.DataFrame: | |
""" | |
Converts the trading data to OHLCV format based on the given interval. | |
Parameters: | |
- trading (pl.DataFrame): The input trading data. | |
- interval_sec (int): The interval in seconds for the OHLCV data. | |
Returns: | |
- pl.DataFrame: The OHLCV formatted data. | |
""" | |
self.logger.info(f"----> Converting TRADING to OHLCV with interval {interval_sec} sec") | |
trading = trading.with_columns( | |
pl.from_epoch(pl.col("timestamp"), "s").dt.truncate(f"{interval_sec}s").alias("truncated_dt"), | |
pl.when(pl.col("side") == "Buy").then(pl.col("size")).otherwise(0).alias("buy_size"), | |
pl.when(pl.col("side") == "Sell").then(pl.col("size")).otherwise(0).alias("sell_size"), | |
) | |
ohlcv = trading.groupby("truncated_dt").agg( | |
pl.first("price").alias("open"), | |
pl.max("price").alias("high"), | |
pl.min("price").alias("low"), | |
pl.last("price").alias("close"), | |
pl.sum("size").alias("volume"), | |
pl.sum("buy_size").alias("buy_volume"), | |
pl.sum("sell_size").alias("sell_volume"), | |
) | |
ohlcv = ohlcv.sort("truncated_dt") | |
return ohlcv | |
def __fetch(self, symbol: str, from_dt: datetime, to_dt: datetime, interval_sec: int) -> pl.DataFrame | None: | |
""" | |
Fetches the trading data for a given symbol within the specified date range and converts it to OHLCV format. | |
Parameters: | |
- symbol (str): The trading symbol. | |
- from_dt (datetime): The start date for fetching data. | |
- to_dt (datetime): The end date for fetching data. | |
- interval_sec (int): The interval in seconds for the OHLCV data. | |
Returns: | |
- pl.DataFrame | None: The OHLCV formatted data or None if fetching failed. | |
""" | |
ohlcv_all = pl.DataFrame(schema=self.ohlcv_schema) | |
date_range: list[datetime] = pl.date_range(from_dt, to_dt, interval="1d").to_list() | |
for dt in date_range: | |
self.logger.info(f"[{symbol}] working on {dt}") | |
url = os.path.join(self.base_url, f'{symbol}', f'{symbol}{dt.strftime("%Y-%m-%d")}.csv.gz') | |
try: | |
trading = pl.read_csv( | |
url, | |
columns=list(self.trading_schema.keys()), | |
dtypes=list(self.trading_schema.values()), | |
) | |
except HTTPError: | |
self.logger.info(f"[{symbol}] {dt} is not available") | |
continue | |
except Exception as e: | |
self.logger.info(f"[{symbol}] {dt} raised {e}") | |
return None | |
converted_df: pl.DataFrame = self.__convert_to_ohlcv(trading, interval_sec) | |
ohlcv_all = pl.concat([ohlcv_all, converted_df]) | |
return ohlcv_all | |
def __save_to_csv(self, file_name: str, df: pl.DataFrame) -> None: | |
""" | |
Saves the provided DataFrame to a CSV file. | |
Parameters: | |
- file_name (str): The name of the file to save. | |
- df (pl.DataFrame): The DataFrame to be saved. | |
""" | |
self.logger.info(f"Saving {file_name} to {self.save_path}") | |
save_file_path = os.path.join(self.save_path, file_name) | |
df.write_csv(save_file_path) | |
return None | |
def download(self, symbols: list[str], interval_secs: list[int], from_dt_str: str, to_dt_str: str, is_debug=False) -> tuple[list[str], list[str]]: | |
""" | |
Executes the fetch operation for the provided symbols within the specified date range. | |
Parameters: | |
- symbols (list[str]): The list of trading symbols. | |
- interval_secs (list[int]): The interval in seconds for the OHLCV data. | |
- from_dt_str (str): The start date as a string. | |
- to_dt_str (str): The end date as a string. | |
- is_debug (bool): If True, the fetched data is not saved. Default is False. | |
Returns: | |
- list[str]: List of symbols for which fetching failed. | |
""" | |
dt_format = "%Y-%m-%d" | |
from_dt = datetime.strptime(from_dt_str, dt_format) | |
to_dt = datetime.strptime(to_dt_str, dt_format) | |
successes = [] | |
failures = [] | |
for symbol in symbols: | |
for interval_sec in interval_secs: | |
self.logger.info(f"Fetching {symbol} from {from_dt} to {to_dt}") | |
ohlcv: pl.DataFrame | None = self.__fetch(symbol, from_dt, to_dt, interval_sec) | |
if ohlcv is not None: | |
successes.append(symbol) | |
else: | |
failures.append(symbol) | |
continue | |
# save | |
if not is_debug: | |
save_file_name = f'{symbol}_{from_dt.strftime("%Y%m%d")}_{interval_sec}.csv' | |
self.__save_to_csv(save_file_name, ohlcv) | |
return successes, failures | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fetcher import Fetcher | |
if __name__ == '__main__': | |
fetcher = Fetcher(save_path='./temp') | |
_, failures = fetcher.download( | |
symbols=['BTCUSDT', 'ETHUSDT'], | |
interval_secs=[60, 1800], | |
from_dt_str='2021-01-01', | |
to_dt_str='2022-01-01', | |
) | |
# print symbols that failed to download and save. | |
print(failures) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment