-
-
Save poloxue/74fc77c6069a2293c6b776f0ea40a5bf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import click | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import tushare as ts | |
from tqdm import tqdm | |
pro = ts.pro_api() | |
def extract_date_range(group): | |
return group["trade_date"].min(), group["trade_date"].max() | |
def extract_date_range_with_next_trade_date(group): | |
start_date = group["trade_date"].min() | |
if pd.notna(group["next_trade_date"].iloc[-1]): | |
end_date = group["next_trade_date"].iloc[-1] | |
else: | |
end_date = group["trade_date"].max() | |
return start_date, end_date | |
def download(ts_code, start_date, adjust, method): | |
mapping_data = pro.fut_mapping(ts_code=ts_code, start_date=start_date) | |
mapping_data.sort_values(by="trade_date", inplace=True, ignore_index=True) | |
drop_columns = ["pre_mapping_ts_code", "rollover"] | |
mapping_data["pre_mapping_ts_code"] = mapping_data["mapping_ts_code"].shift(1) | |
mapping_data["rollover"] = ( | |
mapping_data["mapping_ts_code"] != mapping_data["pre_mapping_ts_code"] | |
) & mapping_data["pre_mapping_ts_code"].notna() | |
if method == "open/open": | |
drop_columns.append("next_trade_date") | |
mapping_data["next_trade_date"] = mapping_data["trade_date"].shift(-1) | |
extract_date_range_func = extract_date_range_with_next_trade_date | |
else: | |
extract_date_range_func = extract_date_range | |
date_ranges = mapping_data.groupby("mapping_ts_code").apply(extract_date_range_func) | |
all_ohlcvs = [] | |
for ts_code, (start_date, end_date) in tqdm( | |
date_ranges.items(), total=len(date_ranges) | |
): | |
ohlcvs = pro.fut_daily( | |
ts_code=ts_code, start_date=start_date, end_date=end_date | |
) | |
if not ohlcvs.empty: | |
all_ohlcvs.append(ohlcvs) | |
ohlcv_data = pd.concat(all_ohlcvs, ignore_index=True) | |
data = mapping_data.merge( | |
ohlcv_data, | |
left_on=["mapping_ts_code", "trade_date"], | |
right_on=["ts_code", "trade_date"], | |
how="left", | |
suffixes=("", "_after"), | |
) | |
drop_columns.append("ts_code_after") | |
if method == "pre_close/pre_close": | |
data["pre_close_before"] = data["close"].shift(1) | |
rollover_data = data[data["rollover"]] | |
data["rollover_factor"] = ( | |
rollover_data["pre_close"] / rollover_data["pre_close_before"] | |
) | |
data["pre_close"] = data["pre_close_before"] | |
drop_columns.append("pre_close_before") | |
elif method == "open/pre_close": | |
data["pre_close_before"] = data["close"].shift(1) | |
rollover_data = data[data["rollover"]] | |
data["rollover_factor"] = ( | |
rollover_data["open"] / rollover_data["pre_close_before"] | |
) | |
data["pre_close"] = data["pre_close_before"] | |
drop_columns.append("pre_close_before") | |
elif method == "pre_settle/pre_settle": | |
data["pre_settle_before"] = data["settle"].shift(1) | |
rollover_data = data[data["rollover"]] | |
data["rollover_factor"] = ( | |
rollover_data["pre_settle"] / rollover_data["pre_settle_before"] | |
) | |
data["pre_settle"] = data["pre_settle_before"] | |
drop_columns.append("pre_settle_before") | |
elif method == "open/open": | |
data = data.merge( | |
ohlcv_data[["ts_code", "trade_date", "open"]], | |
left_on=["pre_mapping_ts_code", "trade_date"], | |
right_on=["ts_code", "trade_date"], | |
suffixes=("", "_before"), | |
) | |
rollover_data = data[data["rollover"]] | |
drop_columns += ["ts_code_before", "open_before"] | |
data["rollover_factor"] = rollover_data["open"] / rollover_data["open_before"] | |
else: | |
raise ValueError(f"{method} 是不支持的复权计算方法") | |
drop_columns.append("rollover_factor") | |
data["trade_date"] = pd.to_datetime(data["trade_date"]) | |
data.set_index("trade_date", inplace=True) | |
data.index.name = "date" | |
if adjust == "forward": | |
data["adj_factor"] = data["rollover_factor"].shift(-1)[::-1].fillna(1).cumprod() | |
elif adjust == "backward": | |
data["adj_factor"] = (1 / data["rollover_factor"]).fillna(1).cumprod() | |
data["adj_close"] = data["close"] * data["adj_factor"] | |
data["adj_open"] = data["open"] * data["adj_factor"] | |
data["adj_high"] = data["high"] * data["adj_factor"] | |
data["adj_low"] = data["low"] * data["adj_factor"] | |
data.drop(columns=drop_columns, inplace=True) | |
return data | |
@click.command() | |
@click.option( | |
"--ts-code", | |
required=True, # 设为必填参数 | |
help="期货合约代码(例如:LH.DCE 表示生猪期货)", | |
) | |
@click.option( | |
"--start-date", | |
default=None, | |
help="起始日期(格式:YYYYMMDD,默认为20220101)", | |
) | |
@click.option( | |
"--adjust", | |
type=click.Choice(["forward", "backward"], case_sensitive=False), | |
default="backward", | |
help="复权方式(默认:forward):\n" | |
"forward: 前复权(以当前价格为基准调整历史数据)\n" | |
"backward: 后复权(保持历史价格不变调整未来数据)", | |
) | |
@click.option( | |
"--method", | |
type=click.Choice( | |
["pre_close/pre_close", "open/pre_close", "pre_settle/pre_settle", "open/open"], | |
case_sensitive=False, | |
), | |
default="pre_close/pre_close", | |
help="换月调整方法(默认:pre_close/pre_close):\n" | |
"pre_close/pre_close: 使用前收盘价复权\n" | |
"open/pre_close: 使用开盘价/前收盘价复权\n" | |
"pre_settle/pre_settle: 使用前结算价复权\n" | |
"open/open: 使用开盘价复权", | |
) | |
@click.option( | |
"--no-download", | |
is_flag=True, # 布尔标志 | |
help="是否禁用复权数据下载(默认下载)", | |
) | |
@click.option( | |
"--plot", | |
is_flag=True, # 布尔标志 | |
help="是否显示价格曲线图(默认不显示)", | |
) | |
def main(ts_code, start_date, adjust, method, no_download, plot): | |
if no_download and not plot: | |
print("提示: 绘图(--plot)与数据下载 (--no-download) 不可同时禁用\n") | |
ctx = click.get_current_context() | |
click.echo(ctx.get_help()) | |
ctx.exit() | |
data = download(ts_code, start_date, adjust, method) | |
if start_date is None: | |
csv_filename = f"{ts_code}_{adjust}_{method.replace('/', '-')}.csv" | |
else: | |
csv_filename = f"{ts_code}_{adjust}_{start_date}_{method.replace('/', '-')}.csv" | |
if not no_download: | |
data.to_csv(csv_filename) | |
print(f"复权数据下载完成,请查看数据文件 {csv_filename}") | |
if plot: | |
data[["adj_close", "close"]].plot() | |
plt.show() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment