Created
October 12, 2019 12:25
-
-
Save THERE2/4518239e7c099e95b3a78432a01eeab9 to your computer and use it in GitHub Desktop.
PyTorhcのBERTでFX予測
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. パッケージのインポート、定数定義 | |
import random | |
import math | |
import numpy as np | |
import json | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.utils.data as data | |
import torch.nn.functional as F | |
from dataloader import get_DataLoaders_and_TEXT | |
from pytorch_transformers import BertForSequenceClassification | |
torch.manual_seed(42) | |
np.random.seed(42) | |
random.seed(42) | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
max_length=256 | |
batch_size=32 | |
pre_trained_weights = 'bert-base-uncased' | |
# 2. data loaderの取得 | |
train_dl, val_dl, test_dl, TEXT = get_DataLoaders_and_TEXT( | |
max_length=max_length, | |
batch_size=batch_size | |
) | |
dataloaders_dict = {"train":train_dl, "val": val_dl} | |
# 3. Bertモデルの読み込み | |
net = BertForSequenceClassification.from_pretrained(pre_trained_weights, num_labels=2) | |
net.to(device) | |
# Bertの1〜11段目は更新せず、12段目とSequenceClassificationのLayerのみトレーニングする。 | |
# 一旦全部のパラメータのrequires_gradをFalseで更新 | |
for name, param in net.named_parameters(): | |
param.requires_grad = False | |
# Bert encoderの最終レイヤのrequires_gradをTrueで更新 | |
for name, param in net.bert.encoder.layer[-1].named_parameters(): | |
param.requires_grad = True | |
# 最後のclassificationレイヤのrequires_gradをTrueで更新 | |
for name, param in net.classifier.named_parameters(): | |
param.requires_grad = True | |
# 4. Optimizerの設定 | |
optimizer = optim.Adam([ | |
{'params': net.bert.encoder.layer[-1].parameters(), 'lr': 5e-5}, | |
{'params': net.classifier.parameters(), 'lr': 5e-5}], betas=(0.9, 0.999)) | |
def train_model(net, dataloaders_dict, optimizer, num_epochs): | |
net.to(device) | |
torch.backends.cudnn.benchmark = True | |
for epoch in range(num_epochs): | |
for phase in ['train', 'val']: | |
if phase == 'train': | |
net.train() | |
else: | |
net.eval() | |
epoch_loss = 0.0 | |
epoch_corrects = 0 | |
batch_processed_num = 0 | |
# データローダーからミニバッチを取り出す | |
for batch in (dataloaders_dict[phase]): | |
inputs = batch.title[0].to(device) | |
labels = batch.label.to(device) | |
# optimizerの初期化 | |
optimizer.zero_grad() | |
with torch.set_grad_enabled(phase=='train'): | |
# 5. BERTモデルでの予測とlossの計算、backpropの実行 | |
outputs = net(inputs, token_type_ids=None, attention_mask=None, labels=labels) | |
# loss and accuracy | |
loss, logits = outputs[:2] | |
_, preds = torch.max(logits, 1) | |
if phase =='train': | |
loss.backward() | |
optimizer.step() | |
curr_loss = loss.item() * inputs.size(0) | |
epoch_loss += curr_loss | |
curr_corrects = (torch.sum(preds==labels.data)).to('cpu').numpy() / inputs.size(0) | |
epoch_corrects += torch.sum(preds==labels.data) | |
batch_processed_num += 1 | |
if batch_processed_num % 10 == 0 and batch_processed_num != 0: | |
print('Processed : ', batch_processed_num * batch_size, ' Loss : ', curr_loss, ' Accuracy : ', curr_corrects) | |
# loss and corrects per epoch | |
epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset) | |
epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset) | |
print('Epoch {}/{} | {:^5} | Loss:{:.4f} Acc:{:.4f}'.format(epoch+1, num_epochs, phase, epoch_loss, epoch_acc)) | |
return net | |
# trainingの実施 | |
num_epochs = 3 | |
net_trained = train_model(net, dataloaders_dict, optimizer, num_epochs=num_epochs) | |
# 6. testデータでの検証 | |
net_trained.eval() | |
net_trained.to(device) | |
epoch_corrects = 0 | |
for batch in (test_dl): | |
inputs = batch.title[0].to(device) | |
labels = batch.label.to(device) | |
with torch.set_grad_enabled(False): | |
# input to BertForSequenceClassifier | |
outputs = net_trained(inputs, token_type_ids=None, attention_mask=None, labels=labels) | |
# loss and accuracy | |
loss, logits = outputs[:2] | |
_, preds = torch.max(logits, 1) | |
epoch_corrects += torch.sum(preds == labels.data) | |
epoch_acc = epoch_corrects.double() / len(test_dl.dataset) | |
print('Correct rate {} records : {:.4f}'.format(len(test_dl.dataset), epoch_acc)) | |
# 7. torchモデルを保存しておく | |
torch.save(net_trained.state_dict(), 'weights/bert_net_trainded.model') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import torchtext | |
import pickle | |
import string | |
import re | |
from torchtext.vocab import Vectors | |
from pytorch_transformers import BertTokenizer | |
pre_trained_weights = 'bert-base-uncased' | |
tokenizer_bert = BertTokenizer.from_pretrained(pre_trained_weights) | |
def tokenizer_with_preprocessing(text, tokenizer=tokenizer_bert.tokenize): | |
#改行の削除 | |
text = re.sub('\r', '', text) | |
text = re.sub('\n', '', text) | |
#数字文字の一律0化 | |
text = re.sub(r'[0-9]', '0', text) | |
#カンマ、ピリオド以外の記号をスペースに置換 | |
for p in string.punctuation: | |
if (p == '.') or (p == ","): | |
continue | |
else: | |
text = text.replace(p, " ") | |
#ピリオド等の前後にはスペースを入れておく | |
text = text.replace("."," . ") | |
text = text.replace(","," , ") | |
#トークンに分割して返す | |
return tokenizer(text.lower()) | |
def get_DataLoaders_and_TEXT(max_length, batch_size): | |
#テキストの前処理 | |
TEXT = torchtext.data.Field(sequential=True, | |
tokenize=tokenizer_with_preprocessing, | |
use_vocab=True, | |
include_lengths=True, | |
batch_first=True, | |
fix_length=max_length, | |
init_token='[CLS]', | |
eos_token='[SEP]', | |
pad_token='[PAD]', | |
unk_token='[UNK]', | |
) | |
LABEL = torchtext.data.Field(sequential=False, use_vocab=False) | |
#data setの取得 | |
train_ds, val_ds, test_ds = torchtext.data.TabularDataset.splits( | |
path='./data/', | |
train='dataset_for_torchtext_train.tsv', | |
validation='dataset_for_torchtext_val.tsv', | |
test='dataset_for_torchtext_test.tsv', | |
format='tsv', | |
skip_header=True, | |
fields=[('title', TEXT), ('label', LABEL)] | |
) | |
# ボキャブラリーの作成 | |
# エラー回避のため一旦仮で作成し、bertのvocabで上書き | |
TEXT.build_vocab(train_ds, min_freq=1) | |
TEXT.vocab.stoi = tokenizer_bert.vocab | |
# Data loaderの作成 | |
train_dl = torchtext.data.Iterator(train_ds, batch_size=batch_size, train=True) | |
val_dl = torchtext.data.Iterator(val_ds, batch_size=batch_size, train=False, sort=False) | |
test_dl = torchtext.data.Iterator(test_ds, batch_size=batch_size, train=False, sort=False) | |
return train_dl, val_dl, test_dl, TEXT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from oandapyV20 import API | |
from oandapyV20.exceptions import V20Error | |
from oandapyV20.endpoints.pricing import PricingStream | |
import oandapyV20.endpoints.orders as orders | |
import oandapyV20.endpoints.instruments as instruments | |
import json | |
import datetime | |
import pandas as pd | |
# accountID, access_tokenは各自のコードで書き換えてください。 | |
accountID = my_accountID | |
access_token = my_access_token | |
api = API(access_token=access_token, environment="practice") | |
# Oandaからcandleデータを取得する。 | |
def getCandleDataFromOanda(instrument, api, date_from, date_to, granularity): | |
params = { | |
"from": date_from.isoformat(), | |
"to": date_to.isoformat(), | |
"granularity": granularity, | |
} | |
r = instruments.InstrumentsCandles(instrument=instrument, params=params) | |
return api.request(r) | |
# Oandaからのresponse(JSON形式)をpython list形式に変換する。 | |
def oandaJsonToPythonList(JSONRes): | |
data = [] | |
for res in JSONRes['candles']: | |
data.append( [ | |
datetime.datetime.fromisoformat(res['time'][:19]), | |
res['volume'], | |
res['mid']['o'], | |
res['mid']['h'], | |
res['mid']['l'], | |
res['mid']['c'], | |
]) | |
return data | |
all_data = [] | |
date_from = datetime.datetime(2017, 12, 1) | |
date_to = datetime.datetime(2018, 6, 1) | |
ret = getCandleDataFromOanda("USD_JPY", api, date_from, date_to, "H1") | |
all_data = oandaJsonToPythonList(ret) | |
# pandas DataFrameへ変換 | |
df = pd.DataFrame(all_data) | |
df.columns = ['Datetime', 'Volume', 'Open', 'High', 'Low', 'Close'] | |
df = df.set_index('Datetime') | |
# pickleファイルへの出力 | |
df.to_pickle('./pickle/USD_JPY_201712_201805.pkl') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import glob | |
import pandas as pd | |
# blogは対象外にするので、newsから始まるjsonファイルのみを取り出す。 | |
# 必要な項目のみlistに格納していく。 | |
json_news_files = glob.glob('./data/*/news_*.json') | |
data = [] | |
for json_file in json_news_files: | |
json_data = json.load(open(json_file, 'r')) | |
data.append([ | |
json_data['uuid'], | |
pd.to_datetime(json_data['published'], utc=True), #datetime型に変換してutc時間に設定。 | |
json_data['language'], | |
json_data['thread']['country'], | |
json_data['thread']['site'], | |
json_data['title'], | |
json_data['text'], | |
]) | |
# pandasのデータフレームに変換して、カラム名を設定。uuidをindexとする。 | |
df = pd.DataFrame(data) | |
df.columns = ['uuid', 'published', 'language', 'country', 'site', 'title', 'text'] | |
df = df.set_index('uuid') | |
df.to_pickle('./pickle/all_news_df.pkl') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import torchtext | |
import pandas as pd | |
import numpy as np | |
# read text data | |
news_df = pd.read_pickle('./pickle/all_news_df.pkl') | |
# read candle data | |
candle_df = pd.read_pickle('./pickle/USD_JPY_201712_201805.pkl') | |
################## labelの設定 ################### | |
# labelとして6時間後の価格が上がっているかどうかとするため、6時間後のclose値とのdiffを取る。 | |
# 上がっていればプラス、下がって入ればマイナス値となる。 | |
candle_df['diff_6hours'] = - candle_df['Close'].astype('float64').diff(-6) | |
candle_df['label'] = 0 | |
# labelに6時間後の価格が上がっているか下がっているかをセット | |
candle_df.loc[candle_df['diff_6hours']>0, 'label'] = 1 | |
# newsのtimestampから次の時間足のラベルを取得する。 | |
# 例) 2017-12-04 19:35:51のタイムスタンプのニュースであれば、2017-12-04 20:00:00の時間足のclose値に対して6時間後の時間足の価格が上がっているかどうかがラベルとなる。 | |
def get_label_from_candle(x): | |
tmp_idx = candle_df.loc[candle_df.index > x.tz_localize(None)].index.min() | |
return candle_df.loc[tmp_idx, 'label'] | |
# 各ニュースへのラベル設定。件数が多いので数分かかる。 | |
news_df['label'] = news_df['published'].map(get_label_from_candle) | |
# BERTでトレーニングするにはボリュームが有りすぎるので、ロイターニュースの30%に絞る | |
news_df = news_df[news_df.site == 'reuters.com'] | |
news_df = news_df.sample(frac=0.3) | |
# 学習用、バリデーション用、テスト用の配分 | |
train_size = 0.6 | |
validation_size = 0.2 | |
test_size = 1 - train_size - validation_size | |
total_num = news_df.shape[0] | |
train_df = news_df.iloc[:int(total_num*(1-validation_size-test_size))][['title', 'label']] | |
val_df = news_df.iloc[int(total_num*train_size):int(total_num*(train_size+validation_size))][['title', 'label']] | |
test_df = news_df.iloc[int(total_num*(train_size+validation_size)):][['title', 'label']] | |
# torchtextのdatasetとして取り込むのに、csvファイル形式に保存。 | |
# ※他にいい方法があると思うが、参考にしている本がCSVファイル形式での記載だったため。 | |
train_df.to_csv('data/dataset_for_torchtext_train.tsv', index=False, sep='\t') | |
val_df.to_csv('data/dataset_for_torchtext_val.tsv', index=False, sep='\t') | |
test_df.to_csv('data/dataset_for_torchtext_test.tsv', index=False, sep='\t') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment