Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Python simple persistent database with dataclass, JSON and Redis.
"""A simple trade execution report database by using Python dataclass, Redis and JSON.
- A very simple database for arbitrage trade execution reports.
- Dataclasses are serialised and deserialised to JSON that is stored in the Redis.
- There is type validation for members and class-types.
- Class-types like Decimal are converted back to their original format upon deserialisation.
- Optional members are supported and the member presence is validated.
- Past trades can be iterated in either order by creation.
- A simple CSV exported is provided.
Ps. I checked couple of existing dataclass validator packages on PyPi and did not find anything simple and suited
for the purpose, especially when nested items are not needed. (They cannot be needed, data goes to a spreadsheet
in the end).
"""
import datetime
import json
import socket
import time
import typing
from decimal import Decimal
from io import TextIOBase
from typing import Optional, List
from dataclasses import dataclass, asdict
from redis import StrictRedis
from dataclass_csv import DataclassWriter
@dataclass
class ArbTradeExecution:
"""Define arbitrage trade related data for later analysis."""
# Internal ticker symbols for the pair we are trading
asset_1: str # Always crypto
asset_2: str # Always fiat
# Exchange names this trade was performed on
exchange_1: str
exchange_2: str
version: str # Strategy version that procuced this trade
order_id: str # Trade id
session_name: str # Trading session name
server_name: str # Which server runs the strategy
# Are we in testing mode
unit_test: bool
#
# Trade ordering related
#
session_started_at: float # How long the bot has been running
started_at: float # Wall clock when the order was recorded
clock: float # Framework clock signal when the order was started
# Which way was this trade, e.g. buy_2_sell_1
kind: str
# How much was the original target execution quantity in asset_1
quantity_1: Decimal
quantity_2: Decimal
# Market state when the decision of a trade was made
market_1_bid_price: Decimal
market_2_bid_price: Decimal
market_1_ask_price: Decimal
market_2_ask_price: Decimal
profitability_buy_1_sell_2: float
profitability_buy_2_sell_1: float
expected_profitability: float
# Exchange 1 execution information.
# Note this is always present, as if there is no exchange 1 trade there is no
# trade at all in our data model.
exchange_1_id: str # Exchange 1 internal id for the order
executed_quantity_1: Decimal
executed_price_1: Decimal
# Was this trade closed successful. Timestamps of closing.
closed_at: Optional[float] = None
succeed_at: Optional[float] = None # Set if both trades are done
failed_at: Optional[float] = None # Set if only the first trade is done
# How fast we were in seconds
order_1_latency: Optional[float] = None
order_2_latency: Optional[float] = None
# Raw exchange execution response results
order_1_data: Optional[str] = None
order_2_data: Optional[str] = None
# How good counter fill we got on the counter exchange
executed_price_2: Optional[Decimal] = None
executed_quantity_2: Optional[Decimal] = None
exchange_2_id: Optional[str] = None # Exchange 2 internal id for the order
# Partial execution dust handling
# For failed trades
dust_gathered: Optional[Decimal] = None # For poorly executioned trades
dust_cleared: Optional[Decimal] = None # For completed trades
# Fees
exchange_1_fees_paid: Optional[Decimal] = None
exchange_2_fees_paid: Optional[Decimal] = None
# How good was our performance
realised_arbitration: Optional[Decimal] = None # In fiat
realised_profitability: Optional[float] = None
realised_profitability_with_fees: Optional[float] = None
# FOK order management
exchange_2_accepted_slippage: Optional[int] = None # In BPS
exchange_2_fok_attempts: Optional[int] = None # How many round of attempts we did to get a fill
# Asset balances at the end of the trade.
# This allows following the balance development and easily
# pop up the last trade to read exchange balances.
exchange_1_balance_1: Optional[Decimal] = None
exchange_1_balance_2: Optional[Decimal] = None
exchange_2_balance_1: Optional[Decimal] = None
exchange_2_balance_2: Optional[Decimal] = None
# Internal json encoder that handles the Decimal instances
# https://stackoverflow.com/a/3885198/315168
class _Encoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Decimal):
return str(o)
return super(ArbTradeExecution._Encoder, self).default(o)
@property
def crypto(self) -> str:
"""Get crypto asset name of this trading paid."""
return self.asset_1
@property
def fiat(self) -> str:
"""Get fiat asset name of this trading paid."""
return self.asset_2
@property
def quantity_crypto(self) -> Decimal:
return self.quantity_1
@property
def quantity_fiat(self) -> Decimal:
return self.quantity_2
def validate(self):
"""Validate instance contents against the Python type hints."""
for name in self.__dataclass_fields__:
optional, expected_type = self.get_member_info(name)
value = self.__dict__.get(name)
if value is None and optional:
# Optional member was not filled in
continue
actual_type = type(value)
if expected_type != actual_type:
raise ValueError(f"Field {name} expected type {expected_type}, got {actual_type}")
def to_json(self) -> str:
return json.dumps(asdict(self), cls=ArbTradeExecution._Encoder)
@classmethod
def get_member_info(cls, name) -> typing.Tuple[bool, Optional[type]]:
"""Introspect what type we expect for a dataclass member.
:return: (optional, type)
"""
member = cls.__dataclass_fields__.get(name)
if member is None:
# Unknown member
return False, None
expected_type = member.type
# https://stackoverflow.com/a/66117226/315168
optional = False
expected_type_internal = expected_type
if typing.get_origin(expected_type) == typing.Union:
# Handle Optional typing
# Internally Python expresses Optional[float] as Typing.Union[float, NoneType]
args = typing.get_args(expected_type)
if len(args) == 2 and args[1] == type(None):
optional = True
expected_type_internal = args[0]
return optional, expected_type_internal
@classmethod
def from_json(cls, text: str) -> "ArbTradeExecution":
"""Load JSON blob back to Python data and convert all memberes to objects."""
data = json.loads(text)
prepared_data = {}
for name in data:
if name in data:
# Convert back to decimals
optional, member_type = cls.get_member_info(name)
if member_type is None:
raise ValueError(f"JSON had an unknown member: {name}")
value = data[name]
if optional and value is None:
prepared_data[name] = None
else:
prepared_data[name] = member_type(value)
report = ArbTradeExecution(**prepared_data)
report.validate()
return report
class ExecutionReportManager:
"""Persistently store trade data in Redis.
Uses Redis storage, where data is partitioned by server name and order id using HSETs.
We use hash sets (HSET), one per server. This allows us move and manipulate per-server
data more easily.
"""
#: Hash map for the reports
HKEY_PREFIX = "execution_report"
#: Time index for the reports
ZKEY_PREFIX = "execution_report_sorted"
def __init__(self, redis: StrictRedis):
self.redis = redis
def create_execution_report(self, report: ArbTradeExecution) -> str:
"""Store an execution report in the database.
Simple append only data structure - overwrites especially unwanted.
"""
server_name = report.server_name
assert server_name
hkey = f"{self.HKEY_PREFIX}:{server_name}"
zkey = f"{self.ZKEY_PREFIX}:{server_name}"
key = report.order_id
data = report.to_json()
if self.redis.hexists(hkey, key):
# Check the existince of the report
# to make problematic code bark out loud
raise RuntimeError(f"Execution report {key} has already been written to the databased")
# Store the actual report content
# TODO: We really do not care about atomicity guarantees here,
# as we assume only one writer
self.redis.hset(hkey, key, data)
# Manage sorted index of keys
# https://redis.io/commands/ZADD
self.redis.zadd(zkey, {key: report.clock})
return key
def load_execution_report(self, server_name: str, look_up_key: str) -> ArbTradeExecution:
hkey = f"{self.HKEY_PREFIX}:{server_name}"
text = self.redis.hget(hkey, look_up_key)
return ArbTradeExecution.from_json(text)
def load_execution_reports_by_time(self, server_name, start, end, desc=False) -> typing.Iterable[ArbTradeExecution]:
"""Allow iterating through all execution reports in the database in order they were created.
You can itereate either oldest first or newest first.
The trades are sorted by clock.
"""
zkey = f"{self.ZKEY_PREFIX}:{server_name}"
# Load a span of reports based by the time index
keys = self.redis.zrange(zkey, start, end, desc=desc)
for key in keys:
# Load individual reports
yield self.load_execution_report(server_name, key)
def get_first_trade(self, server_name: str) -> typing.Optional[ArbTradeExecution]:
"""Convenience method to the peek the first executed trade."""
for report in self.load_execution_reports_by_time(server_name, start=0, end=1, desc=False):
return report
return None
def get_last_trade(self, server_name: str) -> typing.Optional[ArbTradeExecution]:
"""Convenience method to the peek the last executed trade."""
for report in self.load_execution_reports_by_time(server_name, start=0, end=1, desc=True):
return report
return None
def export_all_data(self) -> List[ArbTradeExecution]:
"""Export all servers and all execution reports."""
entries = []
# All servers
for hkey in self.redis.keys(pattern=f"{self.HKEY_PREFIX}:*"):
# All trades for a server
for order_id in self.redis.hgetall(hkey):
text = self.redis.hget(hkey, order_id)
entries.append(ArbTradeExecution.from_json(text))
return entries
@staticmethod
def get_server_name():
return socket.gethostname()
def create_trade_session_name(exchange_1: str, exchange_2: str, asset_1: str, asset_2: str, started_at: float):
"""Add human readable name for trading sessions."""
human_date = datetime.datetime.fromtimestamp(int(started_at)).strftime('%Y-%m-%d %H:%M:%S')
return f"Arb session {exchange_1}-{exchange_2} {asset_1}-{asset_2} started at {human_date}"
def export_csv(stream: TextIOBase, trades: List[ArbTradeExecution]):
"""Create a CSV export of all trades."""
# Sort trades by start time
trades = sorted(trades, key=lambda t: t.clock)
# https://pypi.org/project/dataclass-csv/
w = DataclassWriter(stream, trades, ArbTradeExecution)
w.write()
def generate_export(manager: ExecutionReportManager, stream: TextIOBase) -> typing.Tuple[Optional[str], Optional[str]]:
"""Export CSV data to given Python file-like stream.
Useful e.g. to generate Telegram document upload.
:return: Tuple (fname, caption)
"""
# Generate in-memory CSV
trades = manager.export_all_data()
if not trades:
# Nothing to export
return None, None
export_csv(stream, trades)
first_date = datetime.datetime.fromtimestamp(int(trades[0].clock)).strftime('%Y-%m-%d')
last_date = datetime.datetime.fromtimestamp(int(trades[-1].clock)).strftime('%Y-%m-%d')
caption = f"Export of {len(trades)} trades in {first_date} - {last_date}"
# Send file to Telegram chat via multipart/form-data
human_date = datetime.datetime.fromtimestamp(int(time.time())).strftime('%Y-%m-%d')
fname = f"trade-export-{human_date}.csv"
return fname, caption
#!/usr/bin/env python
import asyncio
import json
import time
import uuid
from decimal import Decimal
from io import StringIO
from os.path import join, realpath
from nose.plugins.attrib import attr
import logging
import unittest
@attr('stable')
class ExecutionReportUnitTest(unittest.TestCase):
def setUp(self):
self.tredis = db.get_client()
self.redis = self.tredis.redis
self.redis.flushdb()
def get_server_name(self):
return "unit-test"
def create_sample_data(self, order_id, full=False, clock=0.0) -> dict:
server_name = self.get_server_name()
exchange_1 = "Test1"
exchange_2 = "Test2"
asset_1 = "BTC"
asset_2 = "GBP"
session_started_at = time.time()
session_name = create_trade_session_name(exchange_1, exchange_2, asset_1, asset_2, session_started_at)
order_data = {
"order_id": order_id,
"version": "0.0",
"session_name": session_name,
"server_name": server_name,
"session_started_at": session_started_at,
"started_at": time.time(),
"clock": clock,
"kind": "buy_2_sell_1",
"quantity_1": Decimal(1),
"quantity_2": Decimal(1),
"market_1_bid_price": Decimal(1),
"market_2_bid_price": Decimal(1),
"market_1_ask_price": Decimal(1),
"market_2_ask_price": Decimal(1),
"expected_profitability": 1.0,
"unit_test": True,
"profitability_buy_1_sell_2": 1.0,
"profitability_buy_2_sell_1": 1.0,
"executed_quantity_1": Decimal(1),
"executed_price_1": Decimal(1),
"exchange_1_id": "aaa",
"asset_1": asset_1,
"asset_2": asset_2,
"exchange_1": exchange_1,
"exchange_2": exchange_2,
}
if full:
order_data.update({
"exchange_2_id": "bbb",
"closed_at": time.time(),
"succeed_at": time.time(),
"order_1_latency": 100.0,
"order_2_latency": 100.0,
"executed_price_2": Decimal(1),
"executed_quantity_2": Decimal(1),
"dust_gathered": Decimal(0),
"exchange_1_fees_paid": Decimal("1.00"),
"exchange_2_fees_paid": Decimal("2.00"),
"realised_profitability": 0.03,
"realised_profitability_with_fees": 0.0175,
"exchange_2_accepted_slippage": 3,
"exchange_2_fok_attempts": 3,
})
return order_data
def test_export_import_roundtrip(self):
"""Serialise and deserialise JSON data"""
data = self.create_sample_data(str(uuid.uuid4()))
report = ArbTradeExecution(**data)
report.validate()
json_export = report.to_json()
json.loads(json_export) # JSON parses correctly
report_deserialised = ArbTradeExecution.from_json(json_export)
# Test a single member
assert report_deserialised.market_1_bid_price == report.market_1_bid_price
def test_import_bad_member(self):
"""Import bad member in JSON"""
data = self.create_sample_data(str(uuid.uuid4()))
report = ArbTradeExecution(**data)
json_export = report.to_json()
bad_data = json.loads(json_export)
bad_data["foo"] = "bar"
bad_json = json.dumps(bad_data)
with self.assertRaises(ValueError):
ArbTradeExecution.from_json(bad_json)
def test_store_bad_decimal(self):
"""Storing float is not allowed when Decimal is required"""
data = self.create_sample_data(str(uuid.uuid4()))
data["market_1_bid_price"] = 1.0
report = ArbTradeExecution(**data)
with self.assertRaises(ValueError):
report.validate()
def test_report_partial_trade(self):
"""See we get a stream of incoming web socket messages."""
order_id = str(uuid.uuid4())
order_data = self.create_sample_data(order_id)
report = ArbTradeExecution(**order_data)
report.validate()
server_name = order_data["server_name"]
order_id = order_data["order_id"]
# Do a round trip in database
reporter = ExecutionReportManager(self.redis)
persistent_id = reporter.create_execution_report(report)
execution_report_deserialised = reporter.load_execution_report(server_name, persistent_id)
assert execution_report_deserialised.order_id == order_id
def test_report_full_trade(self):
"""See we get a stream of incoming web socket messages."""
order_id = str(uuid.uuid4())
order_data = self.create_sample_data(order_id, full=True)
report = ArbTradeExecution(**order_data)
report.validate()
server_name = order_data["server_name"]
order_id = order_data["order_id"]
# Do a round trip in database
reporter = ExecutionReportManager(self.redis)
persistent_id = reporter.create_execution_report(report)
execution_report_deserialised = reporter.load_execution_report(server_name, persistent_id)
assert execution_report_deserialised.exchange_2_accepted_slippage == 3
def test_load_by_time(self):
"""See our time based indexing works."""
server_name = self.get_server_name()
trades = [
ArbTradeExecution(**self.create_sample_data('x', clock=1)),
ArbTradeExecution(**self.create_sample_data('y', clock=2)),
]
reporter = ExecutionReportManager(self.redis)
for t in trades:
reporter.create_execution_report(t)
# Load in ascending order
trades = reporter.load_execution_reports_by_time(server_name, 0, 9999, desc=False)
trades = list(trades)
assert len(trades) == 2
assert trades[0].clock == 1
assert trades[1].clock == 2
# Load in descending order
trades = reporter.load_execution_reports_by_time(server_name, 0, 9999, desc=True)
trades = list(trades)
assert len(trades) == 2
assert trades[0].clock == 2
assert trades[1].clock == 1
def test_last_trade(self):
"""We can load the last trade ordered by clock."""
server_name = self.get_server_name()
reporter = ExecutionReportManager(self.redis)
last_trade = reporter.get_last_trade(server_name)
assert last_trade is None
trades = [
ArbTradeExecution(**self.create_sample_data('x', clock=1)),
ArbTradeExecution(**self.create_sample_data('z', clock=3)),
ArbTradeExecution(**self.create_sample_data('y', clock=2)),
]
for t in trades:
reporter.create_execution_report(t)
last_trade = reporter.get_last_trade(server_name)
assert last_trade.order_id == "z"
def test_export_csv(self):
"""Exporting full and partial orders to CSV works."""
trade_list = [
ArbTradeExecution(**self.create_sample_data(str(uuid.uuid4()), full=False)),
ArbTradeExecution(**self.create_sample_data(str(uuid.uuid4()), full=True)),
]
stream = StringIO()
export_csv(stream, trade_list)
data = stream.getvalue()
assert trade_list[0].order_id in data
assert trade_list[1].order_id in data
def test_generate_export(self):
"""Check we generate export data correctly.."""
reporter = ExecutionReportManager(self.redis)
trades = [
ArbTradeExecution(**self.create_sample_data('x', clock=1)),
ArbTradeExecution(**self.create_sample_data('z', clock=3)),
ArbTradeExecution(**self.create_sample_data('y', clock=2)),
]
for t in trades:
reporter.create_execution_report(t)
stream = StringIO()
fname, caption = generate_export(reporter, stream)
assert fname
assert caption
stream.seek(0)
assert len(stream.read()) > 500
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment