Last active
January 31, 2019 15:03
-
-
Save ckhung/5024b51595477b4dc5250f69897e21d9 to your computer and use it in GitHub Desktop.
stock analysis: (sana.py) 拿股票三維資料學習 pandas 的 multiindex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://newtoypia.blogspot.com/2019/01/pandas-multiindex.html | |
# 「拿股票三維資料學習 pandas 的 multiindex」 | |
import re, os, time | |
import pandas as pd | |
import numpy as np | |
from numpy import NaN | |
from warnings import warn | |
datapath = os.environ['HOME']+'/stock/day' | |
def read_csv(path, index=u'代號', textcols=[u'代號', u'名稱', 'sid', 'name'], **kwargs): | |
df = pd.read_csv( | |
path, index_col=False, comment='#', skipinitialspace=True, | |
dtype={x:str for x in textcols}, | |
**kwargs | |
) | |
df.set_index(index, inplace=True) | |
df.sort_index(inplace=True) | |
# so that ".loc[begin:end]" works | |
colnames = list(set(df.columns).difference(set(textcols))) | |
df[colnames] = df[colnames].apply(pd.to_numeric, errors='coerce') | |
return df | |
def rid_nan(pdseries): | |
two_ends_fixed = pdseries.copy() | |
if np.isnan(two_ends_fixed[0]): | |
two_ends_fixed[0] = next((x for x in pdseries if not np.isnan(x)), NaN) | |
if np.isnan(two_ends_fixed[-1]): | |
two_ends_fixed[-1] = next((x for x in pdseries[::-1] if not np.isnan(x)), NaN) | |
return two_ends_fixed.interpolate() | |
class stock(): | |
def __init__(self, sid, sdict): | |
self.sid = sid | |
self.name = sdict['名稱'] | |
self.price = sdict['收盤價'] | |
def need_days(self, days): | |
self.byday = self.byday.reindex(set(self.byday.index) | set(days)) | |
self.byday.sort_index(inplace=True) | |
def days_since(since, datapath=datapath): | |
previous_dir = os.getcwd() | |
os.chdir(datapath) | |
csvfiles = os.listdir('.') | |
csvfiles = sorted(filter( | |
lambda x: | |
re.match(r'\d{6}\.csv', x) and | |
x>=since+'.csv', | |
csvfiles | |
)) | |
os.chdir(previous_dir) | |
return [fn[:-4] for fn in csvfiles] | |
def slow_read_daily(table, since, datapath=datapath): | |
all_sids = list(table.keys()) | |
all_days = days_since(since, datapath=datapath) | |
for day in all_days: | |
daily_data = read_csv('{}/{}.csv'.format(datapath, day)) | |
if len(daily_data.keys()) < 3: continue # 股市休市日 | |
daily_data['日期'] = day | |
for sid in daily_data.index: | |
if not sid in table: continue | |
if not hasattr(table[sid], 'byday'): | |
table[sid].byday = pd.DataFrame(columns=daily_data.columns) | |
table[sid].byday = table[sid].byday.append(daily_data.loc[sid]) | |
for sid in all_sids: | |
table[sid].byday.set_index('日期', inplace=True) | |
table[sid].byday.drop(columns='名稱', inplace=True) | |
return table | |
def fast_read_daily(table, since, datapath=datapath): | |
all_sids = list(table.keys()) | |
all_days = days_since(since, datapath=datapath) | |
all_data = {} | |
for day in all_days: | |
daily_data = read_csv('{}/{}.csv'.format(datapath, day)) | |
if len(daily_data.keys()) < 3: continue # 股市休市日 | |
all_data[day] = daily_data | |
all_data = pd.concat(all_data, names=['日期','代號'], sort=False) | |
all_data.drop(columns='名稱', inplace=True) | |
ok_sids = all_data.index.get_level_values(1) | |
for sid in all_sids: | |
if sid in ok_sids: | |
table[sid].byday = all_data.xs(sid, level='代號') | |
else: | |
table[sid].byday = pd.DataFrame().reindex_like(table[all_sids[0]].byday) | |
return table | |
############################################################ | |
# technical analysis | |
import tulipy as ti | |
import cProfile | |
day0 = '181001' | |
days = days_since(day0) | |
lastday = days[-1] | |
twstocks = read_csv('{}/{}.csv'.format(datapath, lastday)) | |
all_sids = twstocks.index.values | |
#all_sids = ['1101', '1463', '2330', '6189', '8163'] | |
# 如果等不耐煩可以先改用少量資料測試。 | |
twstocks = { sid: stock(sid, dict(twstocks.loc[sid])) for sid in all_sids } | |
t0 = time.time() | |
cProfile.runctx("fast_read_daily(twstocks, day0)", None, locals(), datapath+'/techan.prof', sort='cumtime') | |
print('spent {:.2f} seconds reading daily data'.format(time.time() - t0)) | |
ta_missing = [] | |
for sto in twstocks.values(): | |
sto.need_days(days) | |
try: | |
close = rid_nan(sto.byday['收盤價']).values | |
if np.isnan(close[-1]): | |
warn('{} {} 沒有每日收盤價'.format(sto.sid, sto.name)) | |
ta_missing.append(sto.sid) | |
continue | |
sma20 = ti.sma(close, period=20) | |
sma60 = ti.sma(close, period=60) | |
sto.byday = sto.byday.assign( | |
sma20=np.insert(sma20, 0, [NaN]*(20-1)), | |
sma60=np.insert(sma60, 0, [NaN]*(60-1)), | |
) | |
except: | |
print(sto.sid, sto.name) | |
raise | |
print('技術分析資料不足的個股:') | |
print(ta_missing) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment