Last active
August 10, 2017 14:36
-
-
Save regonn/8d2f041ea56132695025004d4af24229 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'jiji/model/agents/agent' | |
require 'httpclient' | |
require 'json' | |
TENSORFLOW_API_URL = 'http://tensorflow:5000/api/estimator'.freeze | |
CURRENCY_UNIT = 10000 | |
class KerasAgent | |
include Jiji::Model::Agents::Agent | |
def self.description | |
"Kerasを使ったエージェント" | |
end | |
def self.property_infos | |
[Property.new('exec_mode', '動作モード("collect" or "trade")', "collect")] | |
end | |
def post_create | |
@mode = create_mode(@exec_mode) | |
@currencies = [Currency.new(:USDJPY, broker, @mode)] | |
end | |
def next_tick(tick) | |
timestamp = tick.timestamp | |
return if already_check?(timestamp) | |
@current_timestamp = timestamp | |
@currencies.each do |currency| | |
currency.next_tick(tick) | |
currency.do_trade | |
end | |
end | |
def already_check?(timestamp) | |
return true if is_market_holiday?(timestamp) | |
!@current_timestamp.nil? && @current_timestamp.mday == timestamp.mday && @current_timestamp.hour == timestamp.hour | |
end | |
def is_market_holiday?(timestamp) | |
timestamp.wday == 0 || (timestamp.wday == 1 && timestamp.hour <= 6) || (timestamp.wday == 6 && timestamp.hour > 6) | |
end | |
def create_mode(mode) | |
# まずは CollectModeだけ実装 | |
CollectMode.new | |
end | |
class CollectMode | |
def do_trade?(signal, sell_or_buy) | |
true | |
end | |
def after_position_closed(signal, position) | |
TradeAndSignals.create_from(signal, position).save | |
end | |
end | |
end | |
class TradeAndSignals | |
include Mongoid::Document | |
store_in collection: 'tensorflow_example_trade_and_signals' | |
field :macd_difference, type: Float # macd - macd_signal | |
field :rsi, type: Float | |
field :slope_10, type: Float # 10日移動平均線の傾き | |
field :slope_25, type: Float # 25日移動平均線の傾き | |
field :slope_50, type: Float # 50日移動平均線の傾き | |
field :ma_10_estrangement, type: Float # 10日移動平均からの乖離率 | |
field :ma_25_estrangement, type: Float | |
field :ma_50_estrangement, type: Float | |
field :profit_or_loss, type: Float | |
field :sell_or_buy, type: Symbol | |
field :entered_at, type: Time | |
field :exited_at, type: Time | |
def self.create_from(signal_data, position) | |
TradeAndSignals.new do |ts| | |
signal_data.each do |pair| | |
next if pair[0] == :ma5 || pair[0] == :ma10 | |
ts.send("#{pair[0]}=".to_sym, pair[1]) | |
end | |
ts.profit_or_loss = position.profit_or_loss | |
ts.sell_or_buy = position.sell_or_buy | |
ts.entered_at = position.entered_at | |
ts.exited_at = position.exited_at | |
end | |
end | |
end | |
class Currency | |
def initialize(currency_pair, broker, mode) | |
@currency_pair = currency_pair | |
@broker = broker | |
@mode = mode | |
end | |
def next_tick(tick) | |
prepare_signals(tick) unless @macd | |
@current_signals = calculate_signals(tick[@currency_pair]) | |
end | |
def do_trade | |
self.send(%i(buy sell).sample) | |
end | |
def buy | |
close_exist_positions | |
result = @broker.buy(@currency_pair, CURRENCY_UNIT) | |
@current_position = @broker.positions[result.trade_opened.internal_id] | |
@current_hold_signals = @current_signals | |
end | |
def sell | |
close_exist_positions | |
result = @broker.sell(@currency_pair, CURRENCY_UNIT) | |
@current_position = @broker.positions[result.trade_opened.internal_id] | |
@current_hold_signals = @current_signals | |
end | |
def close_exist_positions | |
return unless @current_position | |
@current_position.close | |
@mode.after_position_closed(@current_hold_signals, @current_position) | |
@current_position = nil | |
@current_hold_signals = nil | |
end | |
def calculate_signals(tick) | |
price = tick.bid | |
macd = @macd.next_data(price) | |
ma10 = @ma10.next_data(price) | |
ma25 = @ma25.next_data(price) | |
ma50 = @ma50.next_data(price) | |
{ | |
ma5: @ma5.next_data(price), | |
ma10: ma10, | |
macd_difference: macd ? macd[:macd] - macd[:signal] : nil, | |
rsi: @rsi.next_data(price), | |
slope_10: ma10 ? @ma10v.next_data(ma10) : nil, | |
slope_25: ma25 ? @ma25v.next_data(ma25) : nil, | |
slope_50: ma50 ? @ma50v.next_data(ma50) : nil, | |
ma_10_estrangement: ma10 ? calculate_estrangement(price, ma10) : nil, | |
ma_25_estrangement: ma25 ? calculate_estrangement(price, ma25) : nil, | |
ma_50_estrangement: ma50 ? calculate_estrangement(price, ma50) : nil | |
} | |
end | |
def prepare_signals(tick) | |
create_signals | |
retrieve_rates(tick.timestamp).each do |rate| | |
calculate_signals(rate.close) | |
end | |
end | |
def create_signals | |
@macd = Signals::MACD.new | |
@ma5 = Signals::MovingAverage.new(5) | |
@ma10 = Signals::MovingAverage.new(10) | |
@ma25 = Signals::MovingAverage.new(25) | |
@ma50 = Signals::MovingAverage.new(50) | |
@ma5v = Signals::Vector.new(5) | |
@ma10v = Signals::Vector.new(10) | |
@ma25v = Signals::Vector.new(25) | |
@ma50v = Signals::Vector.new(50) | |
@rsi = Signals::RSI.new(9) | |
end | |
def retrieve_rates(time) | |
@broker.retrieve_rates(@currency_pair, :one_day, time - 60 * 60 * 24 * 60, time) | |
end | |
def calculate_estrangement(price, ma) | |
((BigDecimal.new(price, 10) - ma) / ma * 100).to_f | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment