Skip to content

Instantly share code, notes, and snippets.

@saschatimme
Created February 22, 2018 08:31
Show Gist options
  • Save saschatimme/3178ece0d141e1a65c5c852dd0874c1a to your computer and use it in GitHub Desktop.
Save saschatimme/3178ece0d141e1a65c5c852dd0874c1a to your computer and use it in GitHub Desktop.
function ∇tSTE_threaded(X::Array{Float64,2},
no_objects::Int64,
no_dims::Int64,
no_triplets::Int64,
triplets::Array{Int64,2},
λ::Float64,
α::Float64)
use_log = true
P = 0.0::Float64
C = 0.0 + λ * sum(X.^2)::Float64 # Initialize cost including l2 regularization cost
sum_X = zeros(Float64, no_objects, )
K = zeros(Float64, no_objects, no_objects)
Q = zeros(Float64, no_objects, no_objects)
A_to_B = 0.0::Float64
A_to_C = 0.0::Float64
constant = (α + 1) / α::Float64
triplets_A = 0::Int64
triplets_B = 0::Int64
triplets_C = 0::Int64
# Compute t-Student kernel for each point
# i,j range over points; k ranges over dims
for k in 1:no_dims, i in 1:no_objects
@inbounds sum_X[i] += X[i, k] * X[i, k] # Squared norm
end
for j in 1:no_objects, i in 1:no_objects
@inbounds K[i,j] = sum_X[i] + sum_X[j]
for k in 1:no_dims
# K[i,j] = ((sqdist(i,j)/α + 1)) ^ (-(α+1)/2),
# which is exactly the numerator of p_{i,j} in the lower right of
# t-STE paper page 3.
# The proof follows because sqdist(a,b) = (a-b)(a-b) = a^2+b^2-2ab
@inbounds K[i,j] += -2 * X[i,k] * X[j,k]
@inbounds Q[i,j] = (1 + K[i,j] / α) ^ -1
@inbounds K[i,j] = (1 + K[i,j] / α) ^ ((α + 1) / -2)
end
end
# Compute probability (or log-prob) for each triplet
nthreads::Int = Threads.nthreads()
∇Cs = [zeros(Float64, no_objects, no_dims) for _=1:nthreads]
Cs = Vector{Float64}(nthreads)
work_ranges = partition_work(no_triplets)
Threads.@threads for tid in 1:nthreads
Cs[tid] = thread_kernel(work_ranges[tid], triplets, K, Q, X, ∇Cs[tid], constant, no_dims)
end
C += sum(Cs)
∇C = ∇Cs[1]
for i in 2:length(∇Cs)
∇C .+= ∇Cs[i]
end
for i in 1:no_dims, n in 1:no_objects
# The 2λX is for regularization: derivative of L2 norm
@inbounds ∇C[n,i] = - ∇C[n, i] + 2λ * X[n, i]
end
return C, ∇C
end
function partition_work(N)
k = Threads.nthreads()
ls = linspace(1, N, k+1)
map(1:k) do i
a = round(Int, ls[i])
if i > 1
a += 1
end
b = round(Int, ls[i+1])
a:b
end
end
function thread_kernel(range, triplets, K, Q, X, ∇C, constant, no_dims)
C = 0.0
for t in range
@inbounds triplets_A = triplets[t, 1]
@inbounds triplets_B = triplets[t, 2]
@inbounds triplets_C = triplets[t, 3]
# This is exactly p_{ijk}, which is the equation in the lower-right of page 3 of the t-STE paper.
@inbounds P = K[triplets_A, triplets_B] / (K[triplets_A, triplets_B] + K[triplets_A, triplets_C])
C += -log(P)
for i in 1:no_dims
# Calculate the gradient of *this triplet* on its points.
@inbounds A_to_B = ((1 - P) * Q[triplets_A, triplets_B] * (X[triplets_A, i] - X[triplets_B, i]))
@inbounds A_to_C = ((1 - P) * Q[triplets_A, triplets_C] * (X[triplets_A, i] - X[triplets_C, i]))
@inbounds ∇C[triplets_A, i] += constant * (A_to_C - A_to_B)
@inbounds ∇C[triplets_B, i] += constant * A_to_B
@inbounds ∇C[triplets_C, i] += - constant * A_to_C
end
end
return C
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment