Skip to content

Instantly share code, notes, and snippets.

@ckhung
Last active January 31, 2019 15:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ckhung/5024b51595477b4dc5250f69897e21d9 to your computer and use it in GitHub Desktop.
Save ckhung/5024b51595477b4dc5250f69897e21d9 to your computer and use it in GitHub Desktop.
stock analysis: (sana.py) 拿股票三維資料學習 pandas 的 multiindex
# https://newtoypia.blogspot.com/2019/01/pandas-multiindex.html
# 「拿股票三維資料學習 pandas 的 multiindex」
import re, os, time
import pandas as pd
import numpy as np
from numpy import NaN
from warnings import warn
datapath = os.environ['HOME']+'/stock/day'
def read_csv(path, index=u'代號', textcols=[u'代號', u'名稱', 'sid', 'name'], **kwargs):
df = pd.read_csv(
path, index_col=False, comment='#', skipinitialspace=True,
dtype={x:str for x in textcols},
**kwargs
)
df.set_index(index, inplace=True)
df.sort_index(inplace=True)
# so that ".loc[begin:end]" works
colnames = list(set(df.columns).difference(set(textcols)))
df[colnames] = df[colnames].apply(pd.to_numeric, errors='coerce')
return df
def rid_nan(pdseries):
two_ends_fixed = pdseries.copy()
if np.isnan(two_ends_fixed[0]):
two_ends_fixed[0] = next((x for x in pdseries if not np.isnan(x)), NaN)
if np.isnan(two_ends_fixed[-1]):
two_ends_fixed[-1] = next((x for x in pdseries[::-1] if not np.isnan(x)), NaN)
return two_ends_fixed.interpolate()
class stock():
def __init__(self, sid, sdict):
self.sid = sid
self.name = sdict['名稱']
self.price = sdict['收盤價']
def need_days(self, days):
self.byday = self.byday.reindex(set(self.byday.index) | set(days))
self.byday.sort_index(inplace=True)
def days_since(since, datapath=datapath):
previous_dir = os.getcwd()
os.chdir(datapath)
csvfiles = os.listdir('.')
csvfiles = sorted(filter(
lambda x:
re.match(r'\d{6}\.csv', x) and
x>=since+'.csv',
csvfiles
))
os.chdir(previous_dir)
return [fn[:-4] for fn in csvfiles]
def slow_read_daily(table, since, datapath=datapath):
all_sids = list(table.keys())
all_days = days_since(since, datapath=datapath)
for day in all_days:
daily_data = read_csv('{}/{}.csv'.format(datapath, day))
if len(daily_data.keys()) < 3: continue # 股市休市日
daily_data['日期'] = day
for sid in daily_data.index:
if not sid in table: continue
if not hasattr(table[sid], 'byday'):
table[sid].byday = pd.DataFrame(columns=daily_data.columns)
table[sid].byday = table[sid].byday.append(daily_data.loc[sid])
for sid in all_sids:
table[sid].byday.set_index('日期', inplace=True)
table[sid].byday.drop(columns='名稱', inplace=True)
return table
def fast_read_daily(table, since, datapath=datapath):
all_sids = list(table.keys())
all_days = days_since(since, datapath=datapath)
all_data = {}
for day in all_days:
daily_data = read_csv('{}/{}.csv'.format(datapath, day))
if len(daily_data.keys()) < 3: continue # 股市休市日
all_data[day] = daily_data
all_data = pd.concat(all_data, names=['日期','代號'], sort=False)
all_data.drop(columns='名稱', inplace=True)
ok_sids = all_data.index.get_level_values(1)
for sid in all_sids:
if sid in ok_sids:
table[sid].byday = all_data.xs(sid, level='代號')
else:
table[sid].byday = pd.DataFrame().reindex_like(table[all_sids[0]].byday)
return table
############################################################
# technical analysis
import tulipy as ti
import cProfile
day0 = '181001'
days = days_since(day0)
lastday = days[-1]
twstocks = read_csv('{}/{}.csv'.format(datapath, lastday))
all_sids = twstocks.index.values
#all_sids = ['1101', '1463', '2330', '6189', '8163']
# 如果等不耐煩可以先改用少量資料測試。
twstocks = { sid: stock(sid, dict(twstocks.loc[sid])) for sid in all_sids }
t0 = time.time()
cProfile.runctx("fast_read_daily(twstocks, day0)", None, locals(), datapath+'/techan.prof', sort='cumtime')
print('spent {:.2f} seconds reading daily data'.format(time.time() - t0))
ta_missing = []
for sto in twstocks.values():
sto.need_days(days)
try:
close = rid_nan(sto.byday['收盤價']).values
if np.isnan(close[-1]):
warn('{} {} 沒有每日收盤價'.format(sto.sid, sto.name))
ta_missing.append(sto.sid)
continue
sma20 = ti.sma(close, period=20)
sma60 = ti.sma(close, period=60)
sto.byday = sto.byday.assign(
sma20=np.insert(sma20, 0, [NaN]*(20-1)),
sma60=np.insert(sma60, 0, [NaN]*(60-1)),
)
except:
print(sto.sid, sto.name)
raise
print('技術分析資料不足的個股:')
print(ta_missing)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment