Skip to content

Instantly share code, notes, and snippets.

@kmundnic
Created February 19, 2018 04:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kmundnic/62a9f106712ce0d0e46b32606e514b17 to your computer and use it in GitHub Desktop.
Save kmundnic/62a9f106712ce0d0e46b32606e514b17 to your computer and use it in GitHub Desktop.
function ∇tSTE(X::Array{Float64,2},
no_objects::Int64,
no_dims::Int64,
no_triplets::Int64,
triplets::Array{Int64,2},
λ::Float64,
α::Float64;
use_log = true::Bool)::Tuple{Float64,Array{Float64,2}}
P = Array{Float64}(Threads.nthreads())
C = 0.0 + λ * sum(X.^2)::Float64 # Initialize cost including l2 regularization cost
∇C = Array{Float64}(no_objects, no_dims)
sum_X = zeros(Float64, no_objects, )
K = zeros(Float64, no_objects, no_objects)
Q = zeros(Float64, no_objects, no_objects)
A_to_B = Array{Float64}(Threads.nthreads())
A_to_C = Array{Float64}(Threads.nthreads())
constant = (α + 1) / α::Float64
triplets_A = Array{Int64}(Threads.nthreads())
triplets_B = Array{Int64}(Threads.nthreads())
triplets_C = Array{Int64}(Threads.nthreads())
# Compute t-Student kernel for each point
# i,j range over points; k ranges over dims
for k in 1:no_dims, i in 1:no_objects
@inbounds sum_X[i] += X[i, k] * X[i, k] # Squared norm
end
for j in 1:no_objects, i in 1:no_objects
@inbounds K[i,j] = sum_X[i] + sum_X[j]
for k in 1:no_dims
# K[i,j] = ((sqdist(i,j)/α + 1)) ^ (-(α+1)/2),
# which is exactly the numerator of p_{i,j} in the lower right of
# t-STE paper page 3.
# The proof follows because sqdist(a,b) = (a-b)(a-b) = a^2+b^2-2ab
@inbounds K[i,j] += -2 * X[i,k] * X[j,k]
@inbounds Q[i,j] = (1 + K[i,j] / α) ^ -1
@inbounds K[i,j] = (1 + K[i,j] / α) ^ ((α + 1) / -2)
end
end
# Compute probability (or log-prob) for each triplet
Threads.@threads for t in 1:no_triplets
@inbounds triplets_A[Threads.threadid()] = triplets[t, 1]
@inbounds triplets_B[Threads.threadid()] = triplets[t, 2]
@inbounds triplets_C[Threads.threadid()] = triplets[t, 3]
# This is exactly p_{ijk}, which is the equation in the lower-right of page 3 of the t-STE paper.
@inbounds P[Threads.threadid()] = K[triplets_A[Threads.threadid()], triplets_B[Threads.threadid()]] / (K[triplets_A[Threads.threadid()], triplets_B[Threads.threadid()]] + K[triplets_A[Threads.threadid()], triplets_C[Threads.threadid()]])
C += -log(P[Threads.threadid()])
for i in 1:no_dims
# Calculate the gradient of *this triplet* on its points.
@inbounds A_to_B[Threads.threadid()] = ((1 - P[Threads.threadid()]) * Q[triplets_A[Threads.threadid()], triplets_B[Threads.threadid()]] * (X[triplets_A[Threads.threadid()], i] - X[triplets_B[Threads.threadid()], i]))
@inbounds A_to_C[Threads.threadid()] = ((1 - P[Threads.threadid()]) * Q[triplets_A[Threads.threadid()], triplets_C[Threads.threadid()]] * (X[triplets_A[Threads.threadid()], i] - X[triplets_C[Threads.threadid()], i]))
@inbounds ∇C[triplets_A[Threads.threadid()], i] += constant * (A_to_C[Threads.threadid()] - A_to_B[Threads.threadid()])
@inbounds ∇C[triplets_B[Threads.threadid()], i] += constant * A_to_B[Threads.threadid()]
@inbounds ∇C[triplets_C[Threads.threadid()], i] += - constant * A_to_C[Threads.threadid()]
end
end
for i in 1:no_dims, n in 1:no_objects
# The 2λX is for regularization: derivative of L2 norm
@inbounds ∇C[n,i] = - ∇C[n, i] + 2λ * X[n, i]
end
return C, ∇C
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment