Skip to content

Instantly share code, notes, and snippets.

@fanannan
Last active August 29, 2015 14:01
Show Gist options
  • Save fanannan/b27564453c5018fa1f43 to your computer and use it in GitHub Desktop.
Save fanannan/b27564453c5018fa1f43 to your computer and use it in GitHub Desktop.
A trial implementation of SCW (Soft Confidence Weighted Learning)
; SCW Classifier
;
; This is a straightforward and an experimental implementation of SCW by fanannan _at_ forestwinds.com, based on "Exact Soft Confidence-Weighted Learning" by J Wang, P Zhao and S C H Hoi, 2012.
; Full covariance matrix is used for sigma so that this may run slowly with high dimensional data.
(ns wagtail.scw
(:require [clatrix.core :as cl])
(:import [org.apache.commons.math3.special Erf]))
(defn probit [p]
(* (Math/sqrt 2.0) (Erf/erfInv (- (* p 2.0) 1.0)) ))
(defn calc-phi [eta]
(probit eta))
(defn calc-psi [eta]
(let [p (calc-phi eta)]
(+ 1.0 (* p p 0.5))))
(defn calc-zeta [eta]
(let [p (calc-phi eta)]
(+ 1.0 (* p p))))
(defn estimate [mu, feature]
; WtXt
(cl/dot mu feature))
(defn predict [mu, feature]
; sign(WtXt) <-> Yt
(let [r (estimate mu feature)]
(cond (> r 0.0) 1
(< r 0.0) -1
:else 0)))
(defn margin [mu, feature, label]
; Yt(WtXt)
(* label (estimate mu feature)))
(defn cofidence [sigma, feature]
(cl/get (cl/* (cl/t feature) sigma feature) 0 0))
(defn calc-loss [mu, sigma, phi, feature, label]
(max 0.0
(- (* phi (Math/sqrt (cofidence sigma feature)))
(margin mu feature label))))
(defn calc-alpha-I [mu, sigma, phi, psi, zeta, c, feature, label] ;?_?
(let [v (cofidence sigma feature)
m (margin mu feature label)
j (* m m phi phi phi phi 0.25)
k (* v phi phi zeta)
t (/ (- (Math/sqrt (+ j k))(* m psi))
(* v zeta))]
(min c (max 0.0 t))))
(defn calc-beta [mu, sigma, alpha, phi, psi, zeta, feature, label]
(let [v (cofidence sigma feature)
m (margin mu feature label)
j (* -1 alpha v phi)
k (Math/sqrt (+ (* (* alpha v phi)(* alpha v phi)) (* 4 v)))
u (* (+ j k)(+ j k) 0.25)]
(/ (* alpha phi)
(+ (Math/sqrt u)(* v alpha phi)))))
(defn calc-next-mu [mu, sigma, alpha, feature, label]
(cl/+ mu (cl/* alpha label sigma feature)))
(defn calc-next-sigma [mu, sigma, beta, feature, label]
(cl/- sigma (cl/* beta sigma feature (cl/t feature) sigma)))
(defn calc-next-params
[mu, sigma, phi, psi, zeta, loss-fn, alpha-fn, c, feature, label]
(let [loss (loss-fn mu sigma phi feature label)]
(if-not (pos? loss)
[mu sigma]
(let [alpha (alpha-fn mu sigma phi psi zeta c feature label)
beta (calc-beta mu sigma alpha phi psi zeta feature label)
mu (calc-next-mu mu sigma alpha feature label)
sigma (calc-next-sigma mu sigma beta feature label)]
[mu sigma]))))
(defn train-init [c, eta, num-fields]
(let [mu (cl/zeros num-fields)
sigma (cl/eye num-fields)
phi (calc-phi eta)
psi (calc-psi eta)
zeta (calc-zeta eta)]
[mu sigma phi psi zeta]))
(defn train-core [[mu, sigma], phi, psi, zeta, loss-fn, alpha-fn, c, features, labels]
(reduce (fn[[mu sigma], [feature label]]
(calc-next-params mu sigma phi psi zeta loss-fn alpha-fn c feature label))
[mu sigma]
(zipmap features labels)))
(defn looper [iterations func init]
(reduce (fn[r _](func r)) init (range iterations)))
(defn train [loss-fn, alpha-fn, c, eta, features, labels, iterations]
(assert (and (> c 0.0)(> 1.0 eta 0.5)))
(let [[mu0 sigma0 phi psi zeta] (train-init c eta (count (first features)))]
(looper iterations
#(train-core % phi psi zeta loss-fn alpha-fn c features labels)
[mu0, sigma0])))
(defn test [mu features]
(map (fn[feature][feature (predict mu feature)]) features))
;;;;;;; iris data (http://archive.ics.uci.edu/ml/datasets/Iris)
(def iris
[[1 5.1 3.5 1.4 0.2 :setosa]
[2 4.9 3.0 1.4 0.2 :setosa]
[3 4.7 3.2 1.3 0.2 :setosa]
[4 4.6 3.1 1.5 0.2 :setosa]
[5 5.0 3.6 1.4 0.2 :setosa]
[6 5.4 3.9 1.7 0.4 :setosa]
[7 4.6 3.4 1.4 0.3 :setosa]
[8 5.0 3.4 1.5 0.2 :setosa]
[9 4.4 2.9 1.4 0.2 :setosa]
[10 4.9 3.1 1.5 0.1 :setosa]
[11 5.4 3.7 1.5 0.2 :setosa]
[12 4.8 3.4 1.6 0.2 :setosa]
[13 4.8 3.0 1.4 0.1 :setosa]
[14 4.3 3.0 1.1 0.1 :setosa]
[15 5.8 4.0 1.2 0.2 :setosa]
[16 5.7 4.4 1.5 0.4 :setosa]
[17 5.4 3.9 1.3 0.4 :setosa]
[18 5.1 3.5 1.4 0.3 :setosa]
[19 5.7 3.8 1.7 0.3 :setosa]
[20 5.1 3.8 1.5 0.3 :setosa]
[21 5.4 3.4 1.7 0.2 :setosa]
[22 5.1 3.7 1.5 0.4 :setosa]
[23 4.6 3.6 1.0 0.2 :setosa]
[24 5.1 3.3 1.7 0.5 :setosa]
[25 4.8 3.4 1.9 0.2 :setosa]
[26 5.0 3.0 1.6 0.2 :setosa]
[27 5.0 3.4 1.6 0.4 :setosa]
[28 5.2 3.5 1.5 0.2 :setosa]
[29 5.2 3.4 1.4 0.2 :setosa]
[30 4.7 3.2 1.6 0.2 :setosa]
[31 4.8 3.1 1.6 0.2 :setosa]
[32 5.4 3.4 1.5 0.4 :setosa]
[33 5.2 4.1 1.5 0.1 :setosa]
[34 5.5 4.2 1.4 0.2 :setosa]
[35 4.9 3.1 1.5 0.2 :setosa]
[36 5.0 3.2 1.2 0.2 :setosa]
[37 5.5 3.5 1.3 0.2 :setosa]
[38 4.9 3.6 1.4 0.1 :setosa]
[39 4.4 3.0 1.3 0.2 :setosa]
[40 5.1 3.4 1.5 0.2 :setosa]
[41 5.0 3.5 1.3 0.3 :setosa]
[42 4.5 2.3 1.3 0.3 :setosa]
[43 4.4 3.2 1.3 0.2 :setosa]
[44 5.0 3.5 1.6 0.6 :setosa]
[45 5.1 3.8 1.9 0.4 :setosa]
[46 4.8 3.0 1.4 0.3 :setosa]
[47 5.1 3.8 1.6 0.2 :setosa]
[48 4.6 3.2 1.4 0.2 :setosa]
[49 5.3 3.7 1.5 0.2 :setosa]
[50 5.0 3.3 1.4 0.2 :setosa]
[51 7.0 3.2 4.7 1.4 :versicolor]
[52 6.4 3.2 4.5 1.5 :versicolor]
[53 6.9 3.1 4.9 1.5 :versicolor]
[54 5.5 2.3 4.0 1.3 :versicolor]
[55 6.5 2.8 4.6 1.5 :versicolor]
[56 5.7 2.8 4.5 1.3 :versicolor]
[57 6.3 3.3 4.7 1.6 :versicolor]
[58 4.9 2.4 3.3 1.0 :versicolor]
[59 6.6 2.9 4.6 1.3 :versicolor]
[60 5.2 2.7 3.9 1.4 :versicolor]
[61 5.0 2.0 3.5 1.0 :versicolor]
[62 5.9 3.0 4.2 1.5 :versicolor]
[63 6.0 2.2 4.0 1.0 :versicolor]
[64 6.1 2.9 4.7 1.4 :versicolor]
[65 5.6 2.9 3.6 1.3 :versicolor]
[66 6.7 3.1 4.4 1.4 :versicolor]
[67 5.6 3.0 4.5 1.5 :versicolor]
[68 5.8 2.7 4.1 1.0 :versicolor]
[69 6.2 2.2 4.5 1.5 :versicolor]
[70 5.6 2.5 3.9 1.1 :versicolor]
[71 5.9 3.2 4.8 1.8 :versicolor]
[72 6.1 2.8 4.0 1.3 :versicolor]
[73 6.3 2.5 4.9 1.5 :versicolor]
[74 6.1 2.8 4.7 1.2 :versicolor]
[75 6.4 2.9 4.3 1.3 :versicolor]
[76 6.6 3.0 4.4 1.4 :versicolor]
[77 6.8 2.8 4.8 1.4 :versicolor]
[78 6.7 3.0 5.0 1.7 :versicolor]
[79 6.0 2.9 4.5 1.5 :versicolor]
[80 5.7 2.6 3.5 1.0 :versicolor]
[81 5.5 2.4 3.8 1.1 :versicolor]
[82 5.5 2.4 3.7 1.0 :versicolor]
[83 5.8 2.7 3.9 1.2 :versicolor]
[84 6.0 2.7 5.1 1.6 :versicolor]
[85 5.4 3.0 4.5 1.5 :versicolor]
[86 6.0 3.4 4.5 1.6 :versicolor]
[87 6.7 3.1 4.7 1.5 :versicolor]
[88 6.3 2.3 4.4 1.3 :versicolor]
[89 5.6 3.0 4.1 1.3 :versicolor]
[90 5.5 2.5 4.0 1.3 :versicolor]
[91 5.5 2.6 4.4 1.2 :versicolor]
[92 6.1 3.0 4.6 1.4 :versicolor]
[93 5.8 2.6 4.0 1.2 :versicolor]
[94 5.0 2.3 3.3 1.0 :versicolor]
[95 5.6 2.7 4.2 1.3 :versicolor]
[96 5.7 3.0 4.2 1.2 :versicolor]
[97 5.7 2.9 4.2 1.3 :versicolor]
[98 6.2 2.9 4.3 1.3 :versicolor]
[99 5.1 2.5 3.0 1.1 :versicolor]
[100 5.7 2.8 4.1 1.3 :versicolor]
[101 6.3 3.3 6.0 2.5 :virginica]
[102 5.8 2.7 5.1 1.9 :virginica]
[103 7.1 3.0 5.9 2.1 :virginica]
[104 6.3 2.9 5.6 1.8 :virginica]
[105 6.5 3.0 5.8 2.2 :virginica]
[106 7.6 3.0 6.6 2.1 :virginica]
[107 4.9 2.5 4.5 1.7 :virginica]
[108 7.3 2.9 6.3 1.8 :virginica]
[109 6.7 2.5 5.8 1.8 :virginica]
[110 7.2 3.6 6.1 2.5 :virginica]
[111 6.5 3.2 5.1 2.0 :virginica]
[112 6.4 2.7 5.3 1.9 :virginica]
[113 6.8 3.0 5.5 2.1 :virginica]
[114 5.7 2.5 5.0 2.0 :virginica]
[115 5.8 2.8 5.1 2.4 :virginica]
[116 6.4 3.2 5.3 2.3 :virginica]
[117 6.5 3.0 5.5 1.8 :virginica]
[118 7.7 3.8 6.7 2.2 :virginica]
[119 7.7 2.6 6.9 2.3 :virginica]
[120 6.0 2.2 5.0 1.5 :virginica]
[121 6.9 3.2 5.7 2.3 :virginica]
[122 5.6 2.8 4.9 2.0 :virginica]
[123 7.7 2.8 6.7 2.0 :virginica]
[124 6.3 2.7 4.9 1.8 :virginica]
[125 6.7 3.3 5.7 2.1 :virginica]
[126 7.2 3.2 6.0 1.8 :virginica]
[127 6.2 2.8 4.8 1.8 :virginica]
[128 6.1 3.0 4.9 1.8 :virginica]
[129 6.4 2.8 5.6 2.1 :virginica]
[130 7.2 3.0 5.8 1.6 :virginica]
[131 7.4 2.8 6.1 1.9 :virginica]
[132 7.9 3.8 6.4 2.0 :virginica]
[133 6.4 2.8 5.6 2.2 :virginica]
[134 6.3 2.8 5.1 1.5 :virginica]
[135 6.1 2.6 5.6 1.4 :virginica]
[136 7.7 3.0 6.1 2.3 :virginica]
[137 6.3 3.4 5.6 2.4 :virginica]
[138 6.4 3.1 5.5 1.8 :virginica]
[139 6.0 3.0 4.8 1.8 :virginica]
[140 6.9 3.1 5.4 2.1 :virginica]
[141 6.7 3.1 5.6 2.4 :virginica]
[142 6.9 3.1 5.1 2.3 :virginica]
[143 5.8 2.7 5.1 1.9 :virginica]
[144 6.8 3.2 5.9 2.3 :virginica]
[145 6.7 3.3 5.7 2.5 :virginica]
[146 6.7 3.0 5.2 2.3 :virginica]
[147 6.3 2.5 5.0 1.9 :virginica]
[148 6.5 3.0 5.2 2.0 :virginica]
[149 6.2 3.4 5.4 2.3 :virginica]
[150 5.9 3.0 5.1 1.8 :virginica]])
; very dirty stuff
(clojure.pprint/pprint
(let [c 2.5,
eta 0.9,
records (shuffle iris)
num-records (count records)
num-train (int (* num-records 0.75))
features (map #(cl/matrix (rest (butlast %))) records),
labels (map #(if (= :virginica (last %)) 1 -1) records)
training-features (take num-train features)
test-features (drop num-train features)
training-labels (take num-train labels)
test-labels (drop num-train labels)
iterations 2
[mu sigma] (train calc-loss calc-alpha-I c eta training-features training-labels iterations)
result (map (fn[x y](vector x y (= (last x) y)))(test mu training-features) training-labels)
corrects (count (filter last result))
_ (println "training performance: " (float (/ corrects num-train)))
result (map (fn[x y](vector x y (= (last x) y)))(test mu test-features) test-labels)
corrects (count (filter last result))
_ (println "test performance: " (float (/ corrects (- num-records num-train))))
]
result))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment