Skip to content

Instantly share code, notes, and snippets.

@poloxue
Last active March 10, 2025 07:35
Show Gist options
  • Save poloxue/74fc77c6069a2293c6b776f0ea40a5bf to your computer and use it in GitHub Desktop.
Save poloxue/74fc77c6069a2293c6b776f0ea40a5bf to your computer and use it in GitHub Desktop.
import sys
import click
import matplotlib.pyplot as plt
import pandas as pd
import tushare as ts
from tqdm import tqdm
pro = ts.pro_api()
def extract_date_range(group):
return group["trade_date"].min(), group["trade_date"].max()
def extract_date_range_with_next_trade_date(group):
start_date = group["trade_date"].min()
if pd.notna(group["next_trade_date"].iloc[-1]):
end_date = group["next_trade_date"].iloc[-1]
else:
end_date = group["trade_date"].max()
return start_date, end_date
def download(ts_code, start_date, adjust, method):
mapping_data = pro.fut_mapping(ts_code=ts_code, start_date=start_date)
mapping_data.sort_values(by="trade_date", inplace=True, ignore_index=True)
drop_columns = ["pre_mapping_ts_code", "rollover"]
mapping_data["pre_mapping_ts_code"] = mapping_data["mapping_ts_code"].shift(1)
mapping_data["rollover"] = (
mapping_data["mapping_ts_code"] != mapping_data["pre_mapping_ts_code"]
) & mapping_data["pre_mapping_ts_code"].notna()
if method == "open/open":
drop_columns.append("next_trade_date")
mapping_data["next_trade_date"] = mapping_data["trade_date"].shift(-1)
extract_date_range_func = extract_date_range_with_next_trade_date
else:
extract_date_range_func = extract_date_range
date_ranges = mapping_data.groupby("mapping_ts_code").apply(extract_date_range_func)
all_ohlcvs = []
for ts_code, (start_date, end_date) in tqdm(
date_ranges.items(), total=len(date_ranges)
):
ohlcvs = pro.fut_daily(
ts_code=ts_code, start_date=start_date, end_date=end_date
)
if not ohlcvs.empty:
all_ohlcvs.append(ohlcvs)
ohlcv_data = pd.concat(all_ohlcvs, ignore_index=True)
data = mapping_data.merge(
ohlcv_data,
left_on=["mapping_ts_code", "trade_date"],
right_on=["ts_code", "trade_date"],
how="left",
suffixes=("", "_after"),
)
drop_columns.append("ts_code_after")
if method == "pre_close/pre_close":
data["pre_close_before"] = data["close"].shift(1)
rollover_data = data[data["rollover"]]
data["rollover_factor"] = (
rollover_data["pre_close"] / rollover_data["pre_close_before"]
)
data["pre_close"] = data["pre_close_before"]
drop_columns.append("pre_close_before")
elif method == "open/pre_close":
data["pre_close_before"] = data["close"].shift(1)
rollover_data = data[data["rollover"]]
data["rollover_factor"] = (
rollover_data["open"] / rollover_data["pre_close_before"]
)
data["pre_close"] = data["pre_close_before"]
drop_columns.append("pre_close_before")
elif method == "pre_settle/pre_settle":
data["pre_settle_before"] = data["settle"].shift(1)
rollover_data = data[data["rollover"]]
data["rollover_factor"] = (
rollover_data["pre_settle"] / rollover_data["pre_settle_before"]
)
data["pre_settle"] = data["pre_settle_before"]
drop_columns.append("pre_settle_before")
elif method == "open/open":
data = data.merge(
ohlcv_data[["ts_code", "trade_date", "open"]],
left_on=["pre_mapping_ts_code", "trade_date"],
right_on=["ts_code", "trade_date"],
suffixes=("", "_before"),
)
rollover_data = data[data["rollover"]]
drop_columns += ["ts_code_before", "open_before"]
data["rollover_factor"] = rollover_data["open"] / rollover_data["open_before"]
else:
raise ValueError(f"{method} 是不支持的复权计算方法")
drop_columns.append("rollover_factor")
data["trade_date"] = pd.to_datetime(data["trade_date"])
data.set_index("trade_date", inplace=True)
data.index.name = "date"
if adjust == "forward":
data["adj_factor"] = data["rollover_factor"].shift(-1)[::-1].fillna(1).cumprod()
elif adjust == "backward":
data["adj_factor"] = (1 / data["rollover_factor"]).fillna(1).cumprod()
data["adj_close"] = data["close"] * data["adj_factor"]
data["adj_open"] = data["open"] * data["adj_factor"]
data["adj_high"] = data["high"] * data["adj_factor"]
data["adj_low"] = data["low"] * data["adj_factor"]
data.drop(columns=drop_columns, inplace=True)
return data
@click.command()
@click.option(
"--ts-code",
required=True, # 设为必填参数
help="期货合约代码(例如:LH.DCE 表示生猪期货)",
)
@click.option(
"--start-date",
default=None,
help="起始日期(格式:YYYYMMDD,默认为20220101)",
)
@click.option(
"--adjust",
type=click.Choice(["forward", "backward"], case_sensitive=False),
default="backward",
help="复权方式(默认:forward):\n"
"forward: 前复权(以当前价格为基准调整历史数据)\n"
"backward: 后复权(保持历史价格不变调整未来数据)",
)
@click.option(
"--method",
type=click.Choice(
["pre_close/pre_close", "open/pre_close", "pre_settle/pre_settle", "open/open"],
case_sensitive=False,
),
default="pre_close/pre_close",
help="换月调整方法(默认:pre_close/pre_close):\n"
"pre_close/pre_close: 使用前收盘价复权\n"
"open/pre_close: 使用开盘价/前收盘价复权\n"
"pre_settle/pre_settle: 使用前结算价复权\n"
"open/open: 使用开盘价复权",
)
@click.option(
"--no-download",
is_flag=True, # 布尔标志
help="是否禁用复权数据下载(默认下载)",
)
@click.option(
"--plot",
is_flag=True, # 布尔标志
help="是否显示价格曲线图(默认不显示)",
)
def main(ts_code, start_date, adjust, method, no_download, plot):
if no_download and not plot:
print("提示: 绘图(--plot)与数据下载 (--no-download) 不可同时禁用\n")
ctx = click.get_current_context()
click.echo(ctx.get_help())
ctx.exit()
data = download(ts_code, start_date, adjust, method)
if start_date is None:
csv_filename = f"{ts_code}_{adjust}_{method.replace('/', '-')}.csv"
else:
csv_filename = f"{ts_code}_{adjust}_{start_date}_{method.replace('/', '-')}.csv"
if not no_download:
data.to_csv(csv_filename)
print(f"复权数据下载完成,请查看数据文件 {csv_filename}")
if plot:
data[["adj_close", "close"]].plot()
plt.show()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment