Created
August 8, 2016 09:24
-
-
Save ChristianSch/6c696240e96f9ee9455cf4a855853d10 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE OverloadedStrings #-} | |
module Main where | |
{- | |
This is the main entry point for the multi label correlation coefficient | |
microservice which listens to RPC calls for applications which need to work | |
with computed data about the label correlation coefficients, i.e. data viz | |
in dashboards and such. The data is fetched from a MySQL database. | |
-} | |
import Data.Word | |
import Data.List.Split | |
import System.Environment as SE | |
import Control.Exception | |
import System.IO.Error | |
import Database.MySQL.Simple | |
import Database.MySQL.Simple.QueryResults | |
import Database.MySQL.Simple.Result | |
import Control.Monad.Reader | |
import Data.MessagePack | |
import Network.MessagePack.Server | |
import Statistics.Matrix | |
import Statistics.Correlation | |
data Label = Label { jointLabels :: String } deriving Show | |
instance QueryResults Label where | |
convertResults [fa] [va] = Label { jointLabels = a } | |
where a = convert fa va | |
convertResults fs vs = convertError fs vs 2 | |
{-| | |
Transforms a label as returned from the database to an array of binary | |
indicators, one for each label. | |
-} | |
toArray :: Label -> [Double] | |
toArray l = Prelude.map (\x -> read x :: Double) | |
$ splitOn "," (jointLabels l) | |
{-| | |
This function fetches the labels from the database. | |
-} | |
fetchLabels :: Connection -> IO [Label] | |
fetchLabels conn = query_ conn | |
"select joint_labels from annotated_search_labels" | |
{-| | |
This function filters rows from a matrix which can be interpreted as a | |
constant function. | |
The background is that the pearson correlation (as defined by | |
`pearsonMatByRow`) where each row is a random variable is undefined for | |
constant functions and therefore returns NaN. This ought to be prevented. | |
-} | |
filterLabels :: [[Double]] -> [[Double]] | |
filterLabels = filter (\x -> 0 < sum x && sum x < fromIntegral (length x)) | |
{-| | |
Computes the pearson correlation of a matrix where each column represents | |
a random variable. | |
For n columns a matrix of `(n,n)` is returned. | |
-} | |
pearsonMatByCol :: Matrix -> Matrix | |
pearsonMatByCol = transpose . pearsonMatByRow . transpose | |
{-| | |
Implements correlation coefficients between columns of a matrix. Each row | |
represents an observation, whereas a column represents a variable. | |
-} | |
correlationCoefficients :: Connection -> IO [[Double]] | |
correlationCoefficients conn = do | |
labels <- query_ conn "select joint_labels from annotated_search_labels" | |
let arr = Prelude.map toArray labels; | |
let mat = fromRowLists arr; | |
return (toRowLists $ pearsonMatByCol mat) | |
{-| | |
Convenience wrapper to generate a ConnectInfo object from variables | |
provided by the environment | |
-} | |
getConnectInfo :: String -> Word16 -> String -> String -> String -> ConnectInfo | |
getConnectInfo host port user pass db = ConnectInfo { connectHost = host, | |
connectPort = port, | |
connectUser = user, | |
connectPassword = pass, | |
connectDatabase = db, | |
connectOptions = [], | |
connectPath = "", | |
connectSSL = Nothing } | |
add :: Int -> Int -> Server Int | |
add x y = return $ x + y | |
main :: IO () | |
main = do | |
host <- SE.getEnv "MYSQL_HOST" | |
-- supplied as String, needed as Int | |
port <- read <$> SE.getEnv "MYSQL_PORT" | |
user <- SE.getEnv "MYSQL_USER" | |
pass <- SE.getEnv "MYSQL_PASS" | |
db <- SE.getEnv "MYSQL_DB" | |
--res <- fetchLabels conn | |
conn <- connect $ getConnectInfo host port user pass db | |
let fn = correlationCoefficients conn | |
--let labels = filterLabels $ Prelude.map toArray res | |
--print $ show (fromRowLists labels) | |
--print . pearsonMatByCol $ fromRowLists labels | |
putStrLn "listening on 127.0.0.1:18800" | |
serve 18800 [ method "correlationCoefficients" | |
(correlationCoefficients conn) ] | |
-- serve 18800 [ method "add" add ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment