Skip to content

Instantly share code, notes, and snippets.

@sebjai
Last active April 16, 2024 11:11
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save sebjai/7ab81e65abab30186e70b2ddf1b27f1a to your computer and use it in GitHub Desktop.
Save sebjai/7ab81e65abab30186e70b2ddf1b27f1a to your computer and use it in GitHub Desktop.
Algo Trading Book Price Limiter (Chapter 7.2 of of Algorithmic and High-Frequency Trading by Cartea, Jaimungal, Penalva, published by Cambridge University Press)
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Optimal Execution Strategy with Price Limiter (Chap 7.2) Helper Function File\n",
"This file contains the PDE solver function and all plotting functionality in the main file."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Importing Packges and defining plot parameters\n",
"# Importing Packges\n",
"import time\n",
"import math\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from mpl_toolkits import mplot3d\n",
"from scipy import interpolate\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"np.random.seed(20) # Setting random seed\n",
"np.seterr(divide='ignore', invalid='ignore')\n",
"import matplotlib.pylab as pylab\n",
"params = {'legend.fontsize': 10,\n",
" 'figure.figsize': (8, 4),\n",
" 'axes.labelsize': 20,\n",
" 'axes.titlesize': 20,\n",
" 'xtick.labelsize': 15,\n",
" 'ytick.labelsize': 15}\n",
"pylab.rcParams.update(params)\n",
"font = {'family': 'serif',\n",
" 'style': 'italic',\n",
" 'weight': 1,\n",
" 'size': 16,\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following function solves the PDE for $h(t,S)$, which is used in the optimal execution strategy, using a Crank-Nicholson scheme. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"def solve_pde_hjb(Fmin, Fmax, NdF, alpha, phi, k, sigma, T, Ndt):\n",
" # Initialization\n",
" dt = T/Ndt # Time change increment\n",
" dF = (Fmax - Fmin)/(NdF - 1)\n",
" h = np.full([NdF, Ndt+1], np.nan) # Solution\n",
"\n",
" h[:, -1] = alpha # Boundary/Terminal Condition\n",
" ainv = 1/k\n",
" dtanddF = 0.25*dt/(dF**2) # coefficient that shows up in the FD method\n",
"\n",
" # the M matrix in the Crank-Nicholson scheme\n",
" M = np.zeros([NdF-2, NdF-2])\n",
" M[0, 0] = 1\n",
" M[NdF-3, NdF-3] = 1 + 2*dtanddF*sigma**2\n",
" M[NdF-3, NdF-4] = -dtanddF*sigma**2\n",
"\n",
" for i in range(2, NdF-2):\n",
" i_0 = i-1\n",
" M[i_0, i_0-1] = - dtanddF*sigma**2\n",
" M[i_0, i_0] = 1 + 2*dtanddF*sigma**2\n",
" M[i_0, i_0+1] = - dtanddF*sigma**2\n",
" Minv = np.linalg.inv(M)\n",
" \n",
" \n",
" for i in range(Ndt-1, -1, -1):\n",
"\n",
" H = h[:, i+1].reshape(h.shape[0], 1)\n",
" E = np.full([NdF, 1], np.nan)\n",
" E[1:NdF-1] = H[1:NdF-1] \\\n",
" + dtanddF*(sigma**2)*(H[2:NdF] - 2*H[1:NdF-1] + H[:NdF-2]) \\\n",
" - dt*ainv*np.power(H[1:NdF-1], 2) + dt*phi\n",
" \n",
" v = E[1: NdF-1]\n",
"\n",
" v[NdF-3] = v[NdF-3] + dtanddF * sigma**2 * alpha\n",
"\n",
" h[1:NdF-1, i] = (Minv@v).reshape(NdF-2,)\n",
"\n",
" h[NdF-1, i] = alpha\n",
" h[0, i] = 2*h[1, i] - h[2, i]\n",
"\n",
" t = np.arange(0, T+10**-9, dt)\n",
" F = np.arange(Fmin, Fmax+10**-9, dF)\n",
" return h, F, t"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following function plots the rate of aquisition as a functin of time and fundamental price base on the approximated HJB equation solution."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"def plot_stategy(h, t, Fgrid, F0, T, Fmin, Fmax):\n",
" \n",
" fig = plt.figure()\n",
" ax = plt.gca(projection='3d')\n",
" axes = fig.gca()\n",
" axes.set_xlim([0, T])\n",
" axes.set_ylim([F0, Fmax])\n",
" indx = Fgrid >= F0\n",
" Z_plot = h[indx, :]\n",
" x = t\n",
" y = Fgrid[indx]\n",
" X_plot, Y_plot = np.meshgrid(x, y)\n",
"\n",
" # Scale each axis for visualization\n",
" x_scale = 1\n",
" y_scale = 1\n",
" z_scale = 2/3\n",
" scale = np.diag([x_scale, y_scale, z_scale, 1.0])\n",
" scale = scale * (1.0 / scale.max())\n",
" scale[3, 3] = 1.0\n",
" \n",
" # Helper function for scaling each axis on plot\n",
" def short_proj():\n",
" return np.dot(Axes3D.get_proj(ax), scale)\n",
" ax.get_proj = short_proj\n",
"\n",
" surf = ax.plot_surface(X_plot, Y_plot, Z_plot, cmap='viridis', linewidth=0, antialiased=False)\n",
" fig.colorbar(surf, shrink=0.5, aspect=5)\n",
" ax.set_xlabel('t', fontsize=10)\n",
" ax.set_xticks([0, 0.5, 1])\n",
" ax.set_ylabel('S', fontsize=10)\n",
" ax.set_yticks([20, 20.05, 20.1])\n",
" ax.set_zlabel(r'$h(t,S)$', fontsize=10)\n",
" ax.set_zticks([0, 50, 100])\n",
" ax.tick_params(axis='both', which='major', labelsize=10)\n",
" ax.view_init(30, 225)\n",
" plt.tight_layout()\n",
" plt.subplots_adjust(left=0.001, bottom=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# Plot Path for Price VS Time\n",
"def PlotPricePathMap(T, t, F, Fmin, Fmax, S, idxfig, lw):\n",
" fig_1 = plt.figure()\n",
" axes = fig_1.gca()\n",
" axes.set_xlim((0, T))\n",
" axes.set_ylim([Fmin-0.05, Fmax+0.05])\n",
" plt.tick_params(direction='in', bottom=True, top=True, left=True, right=True)\n",
"\n",
" F[S==0] = np.nan\n",
" for i in range(len(idxfig)):\n",
" plt.plot(t, F[idxfig[i]], linewidth=lw, label=i+1)\n",
" plt.hlines(Fmax, 0, T, linestyles='dotted')\n",
" plt.legend(['1', '2', '3'], loc='lower right')\n",
" plt.ylabel('Fundamental Price ' + r'$(S_t)$', fontdict=font)\n",
" plt.xlabel('Time ' + r'$(T)$', fontdict=font)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# Plot Path for Inventory VS Time\n",
"def PlotInvPathMap(t, Q_AC, Q, idxfig, lw):\n",
" fig_2 = plt.figure()\n",
" axes = fig_2.gca()\n",
" axes.set_xlim(left=0)\n",
" axes.set_ylim(bottom=0)\n",
" plt.tick_params(direction='in', bottom=True, top=True, left=True, right=True)\n",
"\n",
" plt.plot(t, Q_AC[0, :], linestyle='dotted', linewidth=3, label='AC')\n",
" for i in range(len(idxfig)):\n",
" plt.plot(t, Q[idxfig[i], :], linewidth=lw, label=i+1)\n",
"\n",
" plt.legend()\n",
" plt.ylabel('Inventory ' +r'$(Q_t)$', fontdict=font)\n",
" plt.xlabel('Time ' + r'$(T)$', fontdict=font)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Plot Path for Trade Speed VS Time\n",
"def PlotTradeSpeedPathMap(T, t, nu_AC, nu, idxfig, lw):\n",
" fig_3 = plt.figure()\n",
" axes = fig_3.gca()\n",
" axes.set_xlim((0, T))\n",
" axes.set_ylim([0, max(nu_AC[0,:])*1.25])\n",
" plt.tick_params(direction='in', bottom=True, top=True, left=True, right=True)\n",
"\n",
" plt.plot(t, nu_AC[0, :], linestyle='dotted', linewidth=3, label='AC')\n",
" for i in range(len(idxfig)):\n",
" plt.plot(t, nu[idxfig[i]], linewidth=lw, label=i+1)\n",
"\n",
" plt.ylabel('Trading Speed ' +r'$(v_t)$', fontdict=font)\n",
" plt.xlabel('Time ' + r'$(T)$', fontdict=font)\n",
" plt.legend()\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Plot Path for Cost per Share VS Time\n",
"def PlotCostPathMap(t, X, Q, F0, Fmax, idxfig, lw):\n",
" fig_4 = plt.figure()\n",
" axes = fig_4.gca()\n",
" axes.set_xlim(left=0)\n",
" axes.set_ylim([F0*0.99, Fmax])\n",
" plt.tick_params(direction='in', bottom=True, top=True, left=True, right=True)\n",
" for i in range(len(idxfig)):\n",
" plt.plot(t, np.divide(X[idxfig[i]], Q[idxfig[i], :]), linewidth=lw, label=i+1)\n",
" plt.ylabel('Cost per Share', fontdict=font)\n",
" plt.xlabel('Time ' + r'$(T)$', fontdict=font)\n",
" plt.legend()\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Plot Heat Map for Inventory VS Time\n",
"def PlotInvHeatMap(t, Q, Q_AC, lower_cutoff, Nsims):\n",
" fig_5 = plt.figure()\n",
" plt.tick_params(direction='in', bottom=True, top=True, left=True, right=True)\n",
"\n",
" count_mat_i = np.full([np.arange(0, 1, 0.01).shape[0], len(t)], np.nan)\n",
" for i in range(len(t)):\n",
" count_mat_i[:, i] = np.histogram(Q[:, i], bins=np.arange(0, 1.001, 0.01))[0]\n",
" x_cord_i, y_cord_i = np.meshgrid(t, np.arange(0, 1, 0.01))\n",
" count_mat_i[count_mat_i <= lower_cutoff] = 0\n",
" z = count_mat_i/Nsims\n",
" cmap = plt.get_cmap('YlOrRd')\n",
" plt.contour(x_cord_i, y_cord_i, z, 100, cmap=cmap, levels=np.linspace(z.min(), z.max(), 1000))\n",
" plt.colorbar(ticks=np.arange(0, 1, 0.1))\n",
"\n",
" plt.plot(t, Q.mean(axis=0), linewidth=2, linestyle='-', color='blue', label = 'Optimal Inv')\n",
" plt.plot(t, Q_AC.mean(axis=0), linewidth=2, linestyle='--', color='black', label = 'AC Inv')\n",
" plt.xlim(0, 1)\n",
" plt.ylim(0, 1)\n",
"\n",
" plt.ylabel('Inventory ' +r'$(Q_t)$', fontdict=font)\n",
" plt.xlabel('Time ' + r'$(T)$', fontdict=font)\n",
" plt.legend()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Plot Heat Map for Trading Speed VS Time\n",
"def PlotTradeSpeedHeatMap(t, nu, nu_AC, lower_cutoff, Nsims):\n",
" fig_6 = plt.figure()\n",
" plt.tick_params(direction='in', bottom=True, top=True, left=True, right=True)\n",
"\n",
" count_mat_ts = np.full([np.arange(0, 5, 0.025).shape[0], len(t)], np.nan)\n",
"\n",
" for i in range(len(t)):\n",
" count_mat_ts[:, i] = np.histogram(nu[:, i], bins=np.arange(0, 5+0.0001, 0.025))[0]\n",
" x_cord_ts, y_cord_ts = np.meshgrid(t, np.arange(0, 5, 0.025))\n",
" count_mat_ts[count_mat_ts <= lower_cutoff] = 0\n",
" z_ts = count_mat_ts/Nsims\n",
" cmap = plt.get_cmap('YlOrRd')\n",
" plt.contour(x_cord_ts, y_cord_ts, z_ts, 100, cmap=cmap, levels=np.linspace(z_ts.min(), z_ts.max(), 1000))\n",
" plt.colorbar(ticks=np.arange(0, 1, 0.1))\n",
"\n",
" plt.plot(t, nu.mean(axis=0), '-b', linewidth=2, label = 'Optimal Average Trade Speed')\n",
" plt.plot(t, nu_AC[0, :], '--k', linewidth=2, label = 'AC Average Trade Speed')\n",
" plt.ylabel('Trading Speed ' +r'$(v_t)$', fontdict=font)\n",
" plt.xlabel('Time ' + r'$(T)$', fontdict=font)\n",
" plt.legend()\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Plot Histagram for Price\n",
"def PlotPriceHist(X, Q):\n",
" fig_6 = plt.figure()\n",
" axes = fig_6.gca()\n",
" data = np.divide(X[:, -1], Q[:, -1])\n",
" axes.set_xlim(xmin = np.percentile(data, 1), xmax = np.percentile(data, 99))\n",
" plt.tick_params(direction='in', bottom=True, top=True, left=True, right=True)\n",
"\n",
" upper = np.percentile(data, 99)\n",
" lower = np.percentile(data, 1)\n",
" plt.hist(data, np.arange(lower, upper, (upper-lower)/50))\n",
" plt.tick_params(axis=\"x\", width=(upper-lower)/5)\n",
"\n",
" plt.ylabel('Frequency', fontdict=font)\n",
" plt.xlabel('Price', fontdict=font) \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Plot Histagram for Price of AC Strategy\n",
"def PlotPriceACHist(X_AC, Q_AC):\n",
" fig_7 = plt.figure()\n",
" data = np.divide(X_AC[:, -1], Q_AC[:, -1])\n",
" axes = fig_7.gca()\n",
" axes.set_xlim(xmin = np.percentile(data, 1), xmax = np.percentile(data, 99))\n",
" plt.tick_params(direction='in', bottom=True, top=True, left=True, right=True)\n",
" \n",
" upper = np.percentile(data, 99)\n",
" lower = np.percentile(data, 1)\n",
" plt.hist(data, np.arange(lower, upper, (upper-lower)/50))\n",
" plt.tick_params(axis=\"x\", width=(upper-lower)/5)\n",
"\n",
" plt.ylabel('Frequency', fontdict=font)\n",
" plt.xlabel('Price_AC', fontdict=font)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Plot Histagram for Price Difference between optimal and AC Strategy\n",
"def PlotPriceDifHist(X, X_AC):\n",
" fig_9 = plt.figure()\n",
" data = np.subtract(X[:, -1], X_AC[:, -1])\n",
" axes = fig_9.gca()\n",
" axes.set_xlim(xmin = np.percentile(data, 1), xmax = np.percentile(data, 99))\n",
" \n",
" upper = np.percentile(data, 99)\n",
" lower = np.percentile(data, 1)\n",
" plt.hist(data, np.arange(lower, upper, (upper-lower)/50))\n",
" plt.tick_params(axis=\"x\", width=(upper-lower)/5)\n",
" \n",
" plt.ylabel('Frequency', fontdict=font)\n",
" plt.xlabel('Price_Diff', fontdict=font)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Plot Histagram for I (Name to be finalized)\n",
"def PlotIHist(I):\n",
" fig_11 = plt.figure()\n",
" axes = fig_11.gca()\n",
" axes.set_xlim(xmin = np.percentile(I, 1), xmax = np.percentile(I, 99))\n",
" plt.tick_params(direction='in', bottom=True, top=True, left=True, right=True)\n",
" \n",
" upper = np.percentile(I, 99)\n",
" lower = np.percentile(I, 1)\n",
" plt.hist(I, np.arange(lower, upper, (upper-lower)/50))\n",
" plt.tick_params(axis=\"x\", width=(upper-lower)/5)\n",
"\n",
" plt.ylabel('Frequency', fontdict=font)\n",
" plt.xlabel('Cost', fontdict=font)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment