Last active
August 30, 2021 13:06
-
-
Save lewapek/6e98f034d92e05e9182515e9a898f20b to your computer and use it in GitHub Desktop.
Arbitrage solution (Ammonite)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Tested on ammonite with Scala 2.13.1 | |
// import $ivy.`io.estatico::newtype:0.4.4` | |
import $ivy.`io.monix::monix:3.3.0` | |
import $ivy.`io.circe::circe-parser:0.13.0` | |
import cats.Show | |
import cats.instances.double._ | |
import cats.instances.int._ | |
import cats.syntax.either._ | |
import cats.syntax.option._ | |
import cats.syntax.show._ | |
import io.circe.parser.decode | |
// import io.estatico.newtype.macros.newtype | |
import monix.eval.Task | |
import monix.execution.Scheduler | |
import Model._ | |
import scala.annotation.tailrec | |
import scala.language.{implicitConversions, postfixOps} | |
import scala.util.Try | |
// How it works: | |
// Input data forms a graph (vertices = currencies, edges = exchange rates) | |
// Arbitrage is when rate(Vs, V1) * rate(V1, V2) * rate(V2, ?) * ... * rate(?, Vs) > 1 // s = source currency | |
// After applying a trick: | |
// log(rate(Vs, V1) * rate(V1, V2) * rate(V2, ?) * ... * rate(?, Vs)) > log(1) | |
// log(rate(Vs, V1)) + log(rate(V1, V2)) + log(rate(V2, ?)) + ... + log(rate(?, Vs)) > 0 | |
// -log(rate(Vs, V1)) + -log(rate(V1, V2)) + -log(rate(V2, ?)) + ... + -log(rate(?, Vs)) < 0 | |
// we can use Bellman-Ford algorithm to find negative cycle in a graph. The negative cycle will represent arbitrage | |
// | |
// Algorithm (cycle detection) time complexity is O(VE), where V - number of vertices, E - number of edges | |
// space complexity is O(V) // distances and predecessors | |
// | |
// I run Bellman-Ford algorithm for detecting negative cycles from each node (each currency) in parallel | |
// I could run it from only one node with assumption that I can reach all nodes from the source | |
// | |
// Alternatively I could run Bellman-Ford for random node and have a set of unreachable nodes | |
// later on I could rerun it using one of the unreachable nodes as source until all nodes are visited | |
// Thanks to that I should be sure that if there is an arbitrage, it will be found | |
// | |
// By running Bellman-Ford for all possible sources I may encounter more arbitrages (however probably not all because | |
// there can be exponentially many cycles). I then detect te same cycles and convert them to arbitrages | |
// | |
// Because I run cycle detection |V| times for each currency: | |
// time complexity is O(V^2 * E), having p available processors we can have O(V^2 * E / p) | |
// space complexity is O(V^2) // in case all sources are running in parallel at the same time | |
// | |
// before parallel run I also need O(V^2) memory to hold lookup data (Map[String, Map[String, Double]]]) | |
// | |
object Model { | |
sealed trait ExchangeError | |
final case class FailedToFetchData(msg: String) extends ExchangeError | |
final case class FailedToParseJson(msg: String) extends ExchangeError | |
final case class IncorrectJsonKey(key: String) extends ExchangeError | |
final case class IncorrectJsonValue(value: String) extends ExchangeError | |
final case class DuplicatedExchangeData(exchange: Exchange) extends ExchangeError | |
final case class CouldNotGatherResults(reason: String) extends ExchangeError | |
// normally I would make Currency a newtype using io.estatico.newtype.macros.newtype | |
// however I experienced some issues trying to run it on different ammonite / scala version | |
// I do not use value classes here because Currency is used as element of collections | |
// alternatively I could use tagged types here | |
/*@newtype*/ case class Currency(name: String) | |
final case class Exchange(from: Currency, to: Currency) | |
final case class ExchangeData(lookup: Map[Currency, Map[Currency, Double]], currencies: Set[Currency]) | |
final case class Edge(from: Currency, to: Currency, rate: Double) | |
final case class CycleEdges(edges: Set[Edge], predecessors: Map[Currency, Currency]) | |
// Cycle represents possible cycle - I decided to modify equals to return true in case when 2 identical cycles | |
// are rotated | |
// examole: a -> b -> c == b -> c -> a // same cycles but represented in a different way due to rotation | |
// I use AsSet case class internally to simplify equals and hashcode | |
final case class Cycle(vertices: Vector[Currency], size: Int) { | |
private lazy val asSet = Cycle.asSet(this) | |
override def hashCode(): Int = asSet.hashCode() | |
override def equals(obj: Any): Boolean = obj match { | |
case that: Cycle => | |
asSet == that.asSet && Cycle.areEqualByRotation(size, vertices, that.vertices) | |
case _ => false | |
} | |
} | |
object Cycle { | |
private case class AsSet(set: Set[Currency], size: Int) | |
private def asSet(cycle: Cycle): AsSet = AsSet(cycle.vertices.toSet, cycle.size) | |
@tailrec | |
private def areEqualByRotation(rotationsLeft: Int, a: Vector[Currency], b: Vector[Currency]): Boolean = | |
if (a == b) true | |
else if (rotationsLeft <= 0) false | |
else areEqualByRotation(rotationsLeft - 1, a, b.tail :+ b.head) | |
} | |
case class Arbitrage(exchanges: Vector[Currency], rates: Vector[Double], profit: Double) | |
object Arbitrage { | |
def from(cycle: Cycle, lookup: Map[Currency, Map[Currency, Double]]): Option[Arbitrage] = { | |
val exchanges = cycle.vertices | |
if (exchanges.isEmpty) None | |
else { | |
val head = exchanges.head | |
val (_, rates) = (exchanges.tail :+ head).foldLeft((head, Vector.empty[Double])) { | |
case ((from, rs), to) => | |
val rate = lookup(from)(to) | |
to -> (rs :+ rate) | |
} | |
Arbitrage(exchanges :+ head, rates, rates.product).some | |
} | |
} | |
} | |
case class Graph(edges: Set[Edge], verticesQuantity: Int) | |
object Graph { | |
def from(exchangeData: ExchangeData): Graph = { | |
val n = { | |
for { | |
outerKey <- exchangeData.lookup.keySet | |
key <- exchangeData.lookup(outerKey).keySet + outerKey | |
} yield key | |
}.size | |
val edges: Set[Edge] = { | |
for { | |
(from, toMap) <- exchangeData.lookup | |
(to, rate) <- toMap | |
} yield Edge(from, to, rate) | |
}.toSet | |
Graph(edges, n) | |
} | |
} | |
} | |
object ShowInstances { | |
implicit val currencyShow: Show[Currency] = Show.show(_.name) | |
implicit val exchangeDataShow: Show[ExchangeData] = Show.show { data => | |
val exchanges = | |
for { | |
(from, toMap) <- data.lookup | |
(to, rate) <- toMap | |
} yield s" ${from.show} -> ${to.show}: ${rate.show}" | |
s"""Currencies: ${data.currencies.map(_.show).mkString(", ")} | |
|Exchanges: | |
|${exchanges.mkString("\n")} | |
|""".stripMargin | |
} | |
implicit val arbitrageShow: Show[Arbitrage] = Show.show { arbitrage => | |
s""" ${arbitrage.exchanges.map(_.show).mkString(" -> ")} | |
| ${arbitrage.profit} = ${arbitrage.rates.mkString(" * ")}""".stripMargin | |
} | |
implicit val arbitrageVectorShow: Show[Vector[Arbitrage]] = Show.show { vector => | |
if (vector.isEmpty) "No arbitrage opportunity found" | |
else { | |
vector.zipWithIndex | |
.map { | |
case (arbitrage, index) => | |
s"""--- | |
|Arbitrage ${(index + 1).show}: | |
|${arbitrage.show}""".stripMargin | |
} | |
.mkString("\n") | |
} | |
} | |
} | |
object Runner { | |
type ExchangeErrorOr[T] = Either[ExchangeError, T] | |
def retrieveData(url: String): ExchangeErrorOr[ExchangeData] = | |
for { | |
response <- Try(requests.get(url).text()).toEither.left.map(e => FailedToFetchData(e.getMessage)) | |
decoded <- decode[Map[String, String]](response).left.map(e => FailedToParseJson(e.getMessage)) | |
exchangeMap <- transformIntoExchangeMap(decoded) | |
} yield exchangeDataFrom(exchangeMap) | |
def transformIntoExchangeMap(input: Map[String, String]): ExchangeErrorOr[Map[Exchange, Double]] = | |
input.foldLeft(Map.empty[Exchange, Double].asRight[ExchangeError]) { | |
case (ratesEither, (key, value)) => | |
val exchangeEither: ExchangeErrorOr[Exchange] = | |
key.split('_').map(_.trim) match { | |
case Array(from, to) => Exchange(Currency(from), Currency(to)).asRight | |
case _ => IncorrectJsonKey(key).asLeft | |
} | |
for { | |
exchange <- exchangeEither | |
rates <- ratesEither | |
_ <- if (rates.contains(exchange)) DuplicatedExchangeData(exchange).asLeft else exchangeEither | |
rate <- value.toDoubleOption.toRight(IncorrectJsonValue(value)) | |
} yield rates + (exchange -> rate) | |
} | |
def exchangeDataFrom(exchangeMap: Map[Exchange, Double]): ExchangeData = { | |
val initLookupData = Map.empty[Currency, Map[Currency, Double]].withDefaultValue(Map.empty) | |
val initCurrencies = Set.empty[Currency] | |
exchangeMap.foldLeft(ExchangeData(initLookupData, initCurrencies)) { | |
case (ExchangeData(data, currencies), (Exchange(from, to), rate)) => | |
val fromMap = data(from) + (to -> rate) | |
val updatedLookupData = data + (from -> fromMap) | |
val updatedCurrencies = currencies + from + to | |
ExchangeData(updatedLookupData, updatedCurrencies) | |
} | |
} | |
def negateLogarithms(exchangeData: ExchangeData): ExchangeData = { | |
val negate: Double => Double = -_ | |
val negateLog: Double => Double = negate.compose(Math.log) | |
val newLookup = | |
exchangeData.lookup.view.mapValues { | |
_.view.mapValues(negateLog).toMap | |
}.toMap | |
exchangeData.copy(lookup = newLookup) | |
} | |
def findArbitrages(exchangeData: ExchangeData): Task[Vector[Arbitrage]] = { | |
val bellmanFordPreparedData = negateLogarithms(exchangeData) | |
val graph = Graph.from(bellmanFordPreparedData) | |
if (graph.edges.isEmpty) Task.now(Vector.empty) | |
else { | |
val futureCycles = exchangeData.currencies.map { source => | |
Task { | |
val cycleEdges = findCycleEdges(source, graph) | |
for { | |
edge <- cycleEdges.edges | |
cycle <- Set.empty[Cycle] ++ | |
followCycle(edge, cycleEdges.predecessors) ++ | |
followCycle(edge.to, cycleEdges.predecessors) | |
} yield cycle | |
} | |
} | |
for { | |
cycles <- Task.sequence(futureCycles).map(_.flatten) | |
} yield { | |
cycles | |
.map(Arbitrage.from(_, exchangeData.lookup)) | |
.collect { case Some(arbitrage) => arbitrage } | |
.toVector | |
} | |
} | |
} | |
// Bellman-Ford algorithm | |
def findCycleEdges(source: Currency, graph: Graph): CycleEdges = { | |
val distances = Map.empty[Currency, Double].withDefaultValue(Double.PositiveInfinity) + (source -> 0.0) | |
val predecessors = Map.empty[Currency, Currency] | |
val (relaxedDistances, relaxedPredecessors) = | |
relaxNTimes( | |
n = graph.verticesQuantity - 1, | |
graph.edges, | |
distances, | |
predecessors | |
) | |
val cycleEdges = | |
for { | |
edge <- graph.edges if relaxedDistances(edge.from) + edge.rate < relaxedDistances(edge.to) | |
} yield edge | |
CycleEdges(cycleEdges, relaxedPredecessors) | |
} | |
@tailrec | |
def relaxNTimes(n: Int, | |
edges: Set[Edge], | |
initDistances: Map[Currency, Double], | |
initPredecessors: Map[Currency, Currency]): (Map[Currency, Double], Map[Currency, Currency]) = | |
if (n <= 0) (initDistances, initPredecessors) | |
else { | |
val (relaxedDistances, relaxedPredecessors) = | |
edges.foldLeft((initDistances, initPredecessors)) { | |
case ((distances, predecessors), edge) => | |
val fromPlusEdge = initDistances(edge.from) + edge.rate | |
if (fromPlusEdge < initDistances(edge.to)) { | |
val updatedDistances = distances + (edge.to -> fromPlusEdge) | |
val updatedPredecessors = predecessors + (edge.to -> edge.from) | |
(updatedDistances, updatedPredecessors) | |
} else (distances, predecessors) | |
} | |
relaxNTimes(n - 1, edges, relaxedDistances, relaxedPredecessors) | |
} | |
def followCycle(endVertex: Currency, predecessors: Map[Currency, Currency]): Option[Cycle] = | |
if (predecessors.contains(endVertex)) { | |
buildCycleRec(predecessors, 1, predecessors(endVertex), Map(endVertex -> 0), Vector(endVertex)) | |
} else None | |
def followCycle(endEdge: Edge, predecessors: Map[Currency, Currency]): Option[Cycle] = | |
if (endEdge.from == endEdge.to) Cycle(Vector(endEdge.from, endEdge.to), 2).some | |
else if (predecessors.contains(endEdge.from)) { | |
buildCycleRec(predecessors, | |
2, | |
predecessors(endEdge.from), | |
Map(endEdge.from -> 1, endEdge.to -> 0), | |
Vector(endEdge.from, endEdge.to)) | |
} else None | |
@tailrec | |
def buildCycleRec(predecessors: Map[Currency, Currency], | |
i: Int, | |
current: Currency, | |
visited: Map[Currency, Int], | |
result: Vector[Currency]): Option[Cycle] = | |
if (visited.contains(current)) { | |
val size = i - visited(current) | |
Cycle(current +: result.take(size - 1), size).some | |
} else if (predecessors.contains(current)) { | |
buildCycleRec(predecessors, i + 1, predecessors(current), visited + (current -> i), current +: result) | |
} else None | |
def run(): Unit = { | |
import ShowInstances._ | |
implicit val scheduler: Scheduler = Scheduler.global | |
val dataUrl = "https://fx.priceonomics.com/v1/rates/" | |
val arbitrages = for { | |
exchangeData <- retrieveData(dataUrl) | |
} yield (exchangeData, findArbitrages(exchangeData).map(_.sortBy(-_.profit))) | |
arbitrages match { | |
case Right((input, results)) => | |
println(input.show) | |
println(results.runSyncUnsafe().show) | |
case Left(error) => println(error) | |
} | |
} | |
} | |
Runner.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment