Last active
August 29, 2015 14:01
-
-
Save fanannan/b27564453c5018fa1f43 to your computer and use it in GitHub Desktop.
A trial implementation of SCW (Soft Confidence Weighted Learning)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
; SCW Classifier | |
; | |
; This is a straightforward and an experimental implementation of SCW by fanannan _at_ forestwinds.com, based on "Exact Soft Confidence-Weighted Learning" by J Wang, P Zhao and S C H Hoi, 2012. | |
; Full covariance matrix is used for sigma so that this may run slowly with high dimensional data. | |
(ns wagtail.scw | |
(:require [clatrix.core :as cl]) | |
(:import [org.apache.commons.math3.special Erf])) | |
(defn probit [p] | |
(* (Math/sqrt 2.0) (Erf/erfInv (- (* p 2.0) 1.0)) )) | |
(defn calc-phi [eta] | |
(probit eta)) | |
(defn calc-psi [eta] | |
(let [p (calc-phi eta)] | |
(+ 1.0 (* p p 0.5)))) | |
(defn calc-zeta [eta] | |
(let [p (calc-phi eta)] | |
(+ 1.0 (* p p)))) | |
(defn estimate [mu, feature] | |
; WtXt | |
(cl/dot mu feature)) | |
(defn predict [mu, feature] | |
; sign(WtXt) <-> Yt | |
(let [r (estimate mu feature)] | |
(cond (> r 0.0) 1 | |
(< r 0.0) -1 | |
:else 0))) | |
(defn margin [mu, feature, label] | |
; Yt(WtXt) | |
(* label (estimate mu feature))) | |
(defn cofidence [sigma, feature] | |
(cl/get (cl/* (cl/t feature) sigma feature) 0 0)) | |
(defn calc-loss [mu, sigma, phi, feature, label] | |
(max 0.0 | |
(- (* phi (Math/sqrt (cofidence sigma feature))) | |
(margin mu feature label)))) | |
(defn calc-alpha-I [mu, sigma, phi, psi, zeta, c, feature, label] ;?_? | |
(let [v (cofidence sigma feature) | |
m (margin mu feature label) | |
j (* m m phi phi phi phi 0.25) | |
k (* v phi phi zeta) | |
t (/ (- (Math/sqrt (+ j k))(* m psi)) | |
(* v zeta))] | |
(min c (max 0.0 t)))) | |
(defn calc-beta [mu, sigma, alpha, phi, psi, zeta, feature, label] | |
(let [v (cofidence sigma feature) | |
m (margin mu feature label) | |
j (* -1 alpha v phi) | |
k (Math/sqrt (+ (* (* alpha v phi)(* alpha v phi)) (* 4 v))) | |
u (* (+ j k)(+ j k) 0.25)] | |
(/ (* alpha phi) | |
(+ (Math/sqrt u)(* v alpha phi))))) | |
(defn calc-next-mu [mu, sigma, alpha, feature, label] | |
(cl/+ mu (cl/* alpha label sigma feature))) | |
(defn calc-next-sigma [mu, sigma, beta, feature, label] | |
(cl/- sigma (cl/* beta sigma feature (cl/t feature) sigma))) | |
(defn calc-next-params | |
[mu, sigma, phi, psi, zeta, loss-fn, alpha-fn, c, feature, label] | |
(let [loss (loss-fn mu sigma phi feature label)] | |
(if-not (pos? loss) | |
[mu sigma] | |
(let [alpha (alpha-fn mu sigma phi psi zeta c feature label) | |
beta (calc-beta mu sigma alpha phi psi zeta feature label) | |
mu (calc-next-mu mu sigma alpha feature label) | |
sigma (calc-next-sigma mu sigma beta feature label)] | |
[mu sigma])))) | |
(defn train-init [c, eta, num-fields] | |
(let [mu (cl/zeros num-fields) | |
sigma (cl/eye num-fields) | |
phi (calc-phi eta) | |
psi (calc-psi eta) | |
zeta (calc-zeta eta)] | |
[mu sigma phi psi zeta])) | |
(defn train-core [[mu, sigma], phi, psi, zeta, loss-fn, alpha-fn, c, features, labels] | |
(reduce (fn[[mu sigma], [feature label]] | |
(calc-next-params mu sigma phi psi zeta loss-fn alpha-fn c feature label)) | |
[mu sigma] | |
(zipmap features labels))) | |
(defn looper [iterations func init] | |
(reduce (fn[r _](func r)) init (range iterations))) | |
(defn train [loss-fn, alpha-fn, c, eta, features, labels, iterations] | |
(assert (and (> c 0.0)(> 1.0 eta 0.5))) | |
(let [[mu0 sigma0 phi psi zeta] (train-init c eta (count (first features)))] | |
(looper iterations | |
#(train-core % phi psi zeta loss-fn alpha-fn c features labels) | |
[mu0, sigma0]))) | |
(defn test [mu features] | |
(map (fn[feature][feature (predict mu feature)]) features)) | |
;;;;;;; iris data (http://archive.ics.uci.edu/ml/datasets/Iris) | |
(def iris | |
[[1 5.1 3.5 1.4 0.2 :setosa] | |
[2 4.9 3.0 1.4 0.2 :setosa] | |
[3 4.7 3.2 1.3 0.2 :setosa] | |
[4 4.6 3.1 1.5 0.2 :setosa] | |
[5 5.0 3.6 1.4 0.2 :setosa] | |
[6 5.4 3.9 1.7 0.4 :setosa] | |
[7 4.6 3.4 1.4 0.3 :setosa] | |
[8 5.0 3.4 1.5 0.2 :setosa] | |
[9 4.4 2.9 1.4 0.2 :setosa] | |
[10 4.9 3.1 1.5 0.1 :setosa] | |
[11 5.4 3.7 1.5 0.2 :setosa] | |
[12 4.8 3.4 1.6 0.2 :setosa] | |
[13 4.8 3.0 1.4 0.1 :setosa] | |
[14 4.3 3.0 1.1 0.1 :setosa] | |
[15 5.8 4.0 1.2 0.2 :setosa] | |
[16 5.7 4.4 1.5 0.4 :setosa] | |
[17 5.4 3.9 1.3 0.4 :setosa] | |
[18 5.1 3.5 1.4 0.3 :setosa] | |
[19 5.7 3.8 1.7 0.3 :setosa] | |
[20 5.1 3.8 1.5 0.3 :setosa] | |
[21 5.4 3.4 1.7 0.2 :setosa] | |
[22 5.1 3.7 1.5 0.4 :setosa] | |
[23 4.6 3.6 1.0 0.2 :setosa] | |
[24 5.1 3.3 1.7 0.5 :setosa] | |
[25 4.8 3.4 1.9 0.2 :setosa] | |
[26 5.0 3.0 1.6 0.2 :setosa] | |
[27 5.0 3.4 1.6 0.4 :setosa] | |
[28 5.2 3.5 1.5 0.2 :setosa] | |
[29 5.2 3.4 1.4 0.2 :setosa] | |
[30 4.7 3.2 1.6 0.2 :setosa] | |
[31 4.8 3.1 1.6 0.2 :setosa] | |
[32 5.4 3.4 1.5 0.4 :setosa] | |
[33 5.2 4.1 1.5 0.1 :setosa] | |
[34 5.5 4.2 1.4 0.2 :setosa] | |
[35 4.9 3.1 1.5 0.2 :setosa] | |
[36 5.0 3.2 1.2 0.2 :setosa] | |
[37 5.5 3.5 1.3 0.2 :setosa] | |
[38 4.9 3.6 1.4 0.1 :setosa] | |
[39 4.4 3.0 1.3 0.2 :setosa] | |
[40 5.1 3.4 1.5 0.2 :setosa] | |
[41 5.0 3.5 1.3 0.3 :setosa] | |
[42 4.5 2.3 1.3 0.3 :setosa] | |
[43 4.4 3.2 1.3 0.2 :setosa] | |
[44 5.0 3.5 1.6 0.6 :setosa] | |
[45 5.1 3.8 1.9 0.4 :setosa] | |
[46 4.8 3.0 1.4 0.3 :setosa] | |
[47 5.1 3.8 1.6 0.2 :setosa] | |
[48 4.6 3.2 1.4 0.2 :setosa] | |
[49 5.3 3.7 1.5 0.2 :setosa] | |
[50 5.0 3.3 1.4 0.2 :setosa] | |
[51 7.0 3.2 4.7 1.4 :versicolor] | |
[52 6.4 3.2 4.5 1.5 :versicolor] | |
[53 6.9 3.1 4.9 1.5 :versicolor] | |
[54 5.5 2.3 4.0 1.3 :versicolor] | |
[55 6.5 2.8 4.6 1.5 :versicolor] | |
[56 5.7 2.8 4.5 1.3 :versicolor] | |
[57 6.3 3.3 4.7 1.6 :versicolor] | |
[58 4.9 2.4 3.3 1.0 :versicolor] | |
[59 6.6 2.9 4.6 1.3 :versicolor] | |
[60 5.2 2.7 3.9 1.4 :versicolor] | |
[61 5.0 2.0 3.5 1.0 :versicolor] | |
[62 5.9 3.0 4.2 1.5 :versicolor] | |
[63 6.0 2.2 4.0 1.0 :versicolor] | |
[64 6.1 2.9 4.7 1.4 :versicolor] | |
[65 5.6 2.9 3.6 1.3 :versicolor] | |
[66 6.7 3.1 4.4 1.4 :versicolor] | |
[67 5.6 3.0 4.5 1.5 :versicolor] | |
[68 5.8 2.7 4.1 1.0 :versicolor] | |
[69 6.2 2.2 4.5 1.5 :versicolor] | |
[70 5.6 2.5 3.9 1.1 :versicolor] | |
[71 5.9 3.2 4.8 1.8 :versicolor] | |
[72 6.1 2.8 4.0 1.3 :versicolor] | |
[73 6.3 2.5 4.9 1.5 :versicolor] | |
[74 6.1 2.8 4.7 1.2 :versicolor] | |
[75 6.4 2.9 4.3 1.3 :versicolor] | |
[76 6.6 3.0 4.4 1.4 :versicolor] | |
[77 6.8 2.8 4.8 1.4 :versicolor] | |
[78 6.7 3.0 5.0 1.7 :versicolor] | |
[79 6.0 2.9 4.5 1.5 :versicolor] | |
[80 5.7 2.6 3.5 1.0 :versicolor] | |
[81 5.5 2.4 3.8 1.1 :versicolor] | |
[82 5.5 2.4 3.7 1.0 :versicolor] | |
[83 5.8 2.7 3.9 1.2 :versicolor] | |
[84 6.0 2.7 5.1 1.6 :versicolor] | |
[85 5.4 3.0 4.5 1.5 :versicolor] | |
[86 6.0 3.4 4.5 1.6 :versicolor] | |
[87 6.7 3.1 4.7 1.5 :versicolor] | |
[88 6.3 2.3 4.4 1.3 :versicolor] | |
[89 5.6 3.0 4.1 1.3 :versicolor] | |
[90 5.5 2.5 4.0 1.3 :versicolor] | |
[91 5.5 2.6 4.4 1.2 :versicolor] | |
[92 6.1 3.0 4.6 1.4 :versicolor] | |
[93 5.8 2.6 4.0 1.2 :versicolor] | |
[94 5.0 2.3 3.3 1.0 :versicolor] | |
[95 5.6 2.7 4.2 1.3 :versicolor] | |
[96 5.7 3.0 4.2 1.2 :versicolor] | |
[97 5.7 2.9 4.2 1.3 :versicolor] | |
[98 6.2 2.9 4.3 1.3 :versicolor] | |
[99 5.1 2.5 3.0 1.1 :versicolor] | |
[100 5.7 2.8 4.1 1.3 :versicolor] | |
[101 6.3 3.3 6.0 2.5 :virginica] | |
[102 5.8 2.7 5.1 1.9 :virginica] | |
[103 7.1 3.0 5.9 2.1 :virginica] | |
[104 6.3 2.9 5.6 1.8 :virginica] | |
[105 6.5 3.0 5.8 2.2 :virginica] | |
[106 7.6 3.0 6.6 2.1 :virginica] | |
[107 4.9 2.5 4.5 1.7 :virginica] | |
[108 7.3 2.9 6.3 1.8 :virginica] | |
[109 6.7 2.5 5.8 1.8 :virginica] | |
[110 7.2 3.6 6.1 2.5 :virginica] | |
[111 6.5 3.2 5.1 2.0 :virginica] | |
[112 6.4 2.7 5.3 1.9 :virginica] | |
[113 6.8 3.0 5.5 2.1 :virginica] | |
[114 5.7 2.5 5.0 2.0 :virginica] | |
[115 5.8 2.8 5.1 2.4 :virginica] | |
[116 6.4 3.2 5.3 2.3 :virginica] | |
[117 6.5 3.0 5.5 1.8 :virginica] | |
[118 7.7 3.8 6.7 2.2 :virginica] | |
[119 7.7 2.6 6.9 2.3 :virginica] | |
[120 6.0 2.2 5.0 1.5 :virginica] | |
[121 6.9 3.2 5.7 2.3 :virginica] | |
[122 5.6 2.8 4.9 2.0 :virginica] | |
[123 7.7 2.8 6.7 2.0 :virginica] | |
[124 6.3 2.7 4.9 1.8 :virginica] | |
[125 6.7 3.3 5.7 2.1 :virginica] | |
[126 7.2 3.2 6.0 1.8 :virginica] | |
[127 6.2 2.8 4.8 1.8 :virginica] | |
[128 6.1 3.0 4.9 1.8 :virginica] | |
[129 6.4 2.8 5.6 2.1 :virginica] | |
[130 7.2 3.0 5.8 1.6 :virginica] | |
[131 7.4 2.8 6.1 1.9 :virginica] | |
[132 7.9 3.8 6.4 2.0 :virginica] | |
[133 6.4 2.8 5.6 2.2 :virginica] | |
[134 6.3 2.8 5.1 1.5 :virginica] | |
[135 6.1 2.6 5.6 1.4 :virginica] | |
[136 7.7 3.0 6.1 2.3 :virginica] | |
[137 6.3 3.4 5.6 2.4 :virginica] | |
[138 6.4 3.1 5.5 1.8 :virginica] | |
[139 6.0 3.0 4.8 1.8 :virginica] | |
[140 6.9 3.1 5.4 2.1 :virginica] | |
[141 6.7 3.1 5.6 2.4 :virginica] | |
[142 6.9 3.1 5.1 2.3 :virginica] | |
[143 5.8 2.7 5.1 1.9 :virginica] | |
[144 6.8 3.2 5.9 2.3 :virginica] | |
[145 6.7 3.3 5.7 2.5 :virginica] | |
[146 6.7 3.0 5.2 2.3 :virginica] | |
[147 6.3 2.5 5.0 1.9 :virginica] | |
[148 6.5 3.0 5.2 2.0 :virginica] | |
[149 6.2 3.4 5.4 2.3 :virginica] | |
[150 5.9 3.0 5.1 1.8 :virginica]]) | |
; very dirty stuff | |
(clojure.pprint/pprint | |
(let [c 2.5, | |
eta 0.9, | |
records (shuffle iris) | |
num-records (count records) | |
num-train (int (* num-records 0.75)) | |
features (map #(cl/matrix (rest (butlast %))) records), | |
labels (map #(if (= :virginica (last %)) 1 -1) records) | |
training-features (take num-train features) | |
test-features (drop num-train features) | |
training-labels (take num-train labels) | |
test-labels (drop num-train labels) | |
iterations 2 | |
[mu sigma] (train calc-loss calc-alpha-I c eta training-features training-labels iterations) | |
result (map (fn[x y](vector x y (= (last x) y)))(test mu training-features) training-labels) | |
corrects (count (filter last result)) | |
_ (println "training performance: " (float (/ corrects num-train))) | |
result (map (fn[x y](vector x y (= (last x) y)))(test mu test-features) test-labels) | |
corrects (count (filter last result)) | |
_ (println "test performance: " (float (/ corrects (- num-records num-train)))) | |
] | |
result)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment