Created
February 18, 2021 14:36
-
-
Save mfalt/3d2804354078f1287523806f8d40d30d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# To add packages in the REPL: "]add BenchmarkTools" | |
using LinearAlgebra, BenchmarkTools | |
# Simple method and syntax, no types! | |
function quadratic(x,Q,y) | |
dot(x,Q*y) | |
end | |
Q = randn(5,5) | |
x = randn(5) | |
y = randn(5) | |
v = quadratic(x,Q,y) | |
Q = randn(5,5) | |
x = [0.0, 0.0, 0.0, 1.0, 0.0] | |
y = [0.0, 0.0, 1.0, 0.0, 0.0] | |
# Similar to matlab | |
using BenchmarkTools | |
@btime quadratic(x,Q,y) | |
# Working, but could be faster | |
# Define our own type | |
# Bool <: Int | |
# Subtype/inheritance | |
struct OneHotVec <: AbstractVector{Bool} | |
i::Int | |
length::Int | |
end | |
# Basic properties of a vector | |
Base.size(y::OneHotVec) = (y.length,) | |
Base.getindex(y::OneHotVec, i::Int) = (i == y.i ? true : false) | |
xv = OneHotVec(4,5) | |
yv = OneHotVec(3,5) | |
quadratic(xv, Q, yv) | |
# But is it slow? (Not using blas) | |
@btime quadratic(x, Q, y) | |
@btime quadratic(xv, Q, yv) | |
# Fast, why!? | |
# We can improve! | |
Base.:*(Q::AbstractMatrix, y::OneHotVec) = Q[:,y.i] | |
quadratic(xv, Q, yv) | |
# What is happening here? | |
# Julia picks most specific function! (multiple dispatch) | |
# Compiles given input types | |
@btime quadratic(xv, Q, yv) | |
# Even better! How would you do this in OOP? | |
quadratic(x::OneHotVec, Q, y::OneHotVec) = Q[x.i, y.i] | |
quadratic(xv, Q, yv) | |
@btime quadratic(xv, Q, yv) | |
# But what about cases we didn't cover? | |
quadratic(x, Q, yv) | |
@btime quadratic(x, Q, yv) | |
# Bigger case | |
using BenchmarkTools | |
Q = randn(1000,1000) | |
x = zeros(1000); x[123] = 1; | |
y = zeros(1000); x[321] = 1; | |
xv = OneHotVec(123,1000) | |
yv = OneHotVec(321,1000) | |
@btime quadratic(x,Q,y) | |
@btime quadratic(x,Q,yv) | |
@btime quadratic(xv,Q,yv) | |
# What about other types? | |
using SparseArrays | |
QSparse = sprandn(ComplexF64, 1000, 1000, 0.01) | |
@btime quadratic(x, QSparse, y) | |
@btime quadratic(x, QSparse, yv) | |
@btime quadratic(xv, QSparse, yv) | |
# REVIEW: | |
# Generic implementation | |
function quadratic(x,Q,y) | |
dot(x,Q*y) | |
end | |
# Fast Matrix*OneHotVec | |
Base.:*(Q::AbstractMatrix, y::OneHotVec) = Q[:,y.i] | |
# Really fast quadratic | |
quadratic(x::OneHotVec, Q, y::OneHotVec) = Q[x.i, y.i] | |
# Automatic Differentiation slide | |
# Lets test some more | |
using ForwardDiff | |
d = ForwardDiff.Dual(1.0, 1.0) # (1 + 1*ϵ) | |
2*d | |
sin(1), cos(1) | |
sin(d) | |
(d+3)^2 | |
# Back to the quadratic | |
x = randn(5) | |
quad(Q) = quadratic(x,Q,OneHotVec(3,5)) | |
Q = randn(5,5) | |
ForwardDiff.gradient(quad, Q) | |
# It works! | |
# Even more complicated | |
quad5(Q) = quadratic(OneHotVec(3,5),Q^5,OneHotVec(3,5)) | |
ForwardDiff.gradient(quad5, Q) | |
# # Try manually | |
# epsM = zeros(5,5); epsM[1,1] = 1e-8 | |
# (quad5(Q+epsM)- quad5(Q))./1e-8 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment