Skip to content

Instantly share code, notes, and snippets.

@sseveran
Created October 17, 2014 04:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sseveran/3aa2d9afbbba8bb4883d to your computer and use it in GitHub Desktop.
Save sseveran/3aa2d9afbbba8bb4883d to your computer and use it in GitHub Desktop.
Cuda Black-Sholes
{-# LANGUAGE ScopedTypeVariables #-}
import Control.Applicative
import Data.Array.Accelerate as A
import Data.Array.Accelerate.CUDA
import qualified Data.Array.Accelerate.Interpreter as I
import Data.Array.Accelerate.Pretty
import Data.Array.Unboxed
import qualified Data.List as L
import System.Random
import Prelude as P
cumulativeNormalDistributionHull :: Exp Float -> Exp Float
cumulativeNormalDistributionHull d =
let k = 1 / (1.0 + 0.2316419 * abs (d))
a1 = 0.31938153
a2 = -0.356563782
a3 = 1.781477937
a4 = -1.821255978
a5 = 1.330274429
sqrt2pi = sqrt (2 * pi)
cnd = sqrt2pi * (exp (-0.5*d*d)) * (k * (a1 + k * (a2 + k * (a3 + k * (a4 + k * a5)))))
in d >* 0 ? (1 - cnd, cnd)
blackScholes :: Acc (A.Array DIM1 (Float,Float,Float)) -> Exp Float -> Exp Float -> Acc (A.Array DIM1 (Float,Float))
blackScholes xs riskFree volatility = A.map go xs
where
go x =
let (price,strike,years) = A.unlift x
in A.lift (vcall price strike years,vput price strike years)
cnd d = (1 - cumulativeNormalDistributionHull (1 - d))
d1 price strike year = (log (price/strike) + (riskFree + ((volatility*volatility)/2)) * year)/(volatility*sqrt(year))
d2 price strike year = (log (price/strike) + (riskFree - ((volatility*volatility)/2)) * year)/(volatility*sqrt(year))
vcall price strike year = price * (cnd (d1 price strike year)) - strike * exp (negate(riskFree) * year) * cnd (d2 price strike year)
vput price strike year = strike * exp (negate(riskFree)*year) * cnd (1 - (d2 price strike year)) - price * cnd (1 - (d1 price strike year))
main :: IO ()
main = do
volatility :: Float <- randomRIO (10,50)
riskFree :: Float <- randomRIO (0.0,0.07)
xs <- mapM (\ _ -> do
price :: Float <- randomRIO (11,100)
let strikes = [(price-10)..(price+10)]
let years = [1..10]
return $ (,,) <$> [price] <*> strikes <*> years
) [1..4500]
let flattened = L.concat xs
optionsData = fromList (Z :. (L.length flattened)) flattened
print $ "Calculating " P.++ (show $ L.length flattened) P.++ " options prices"
let x = run $ blackScholes (A.use optionsData) (A.constant riskFree) (A.constant volatility)
print $ L.length $ toList x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment