Skip to content

Instantly share code, notes, and snippets.

@aligusnet
Created January 21, 2017 11:37
Show Gist options
  • Save aligusnet/21fff32ea3e6bb2ee9eac259474a995c to your computer and use it in GitHub Desktop.
Save aligusnet/21fff32ea3e6bb2ee9eac259474a995c to your computer and use it in GitHub Desktop.
Process outputs for Multiclass Classification.
import qualified Data.Vector.Storable as V
import qualified Numeric.LinearAlgebra as LA
-- | Process outputs for Multiclass Classification.
-- Takes number of labels and output vector y.
-- Returns matrix of binary outputs (One-vs-All Classification).
-- It is supposed that labels are integerets start at 0.
processOutputMulti :: Int -> Vector -> Matrix
processOutputMulti numLabels y = LA.fromColumns $ map f [0 .. numLabels-1]
where f sample = V.map (\a -> if round a == sample then 1 else 0) y
processOutputMulti numLabels y = LA.assoc (V.length y, numLabels) 0 assocList
where assocList = zipWith (\index label -> ((index, round label), 1)) [0..] (V.toList y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment