Skip to content

Instantly share code, notes, and snippets.

@mfalt
Created February 18, 2021 14:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mfalt/3d2804354078f1287523806f8d40d30d to your computer and use it in GitHub Desktop.
Save mfalt/3d2804354078f1287523806f8d40d30d to your computer and use it in GitHub Desktop.
# To add packages in the REPL: "]add BenchmarkTools"
using LinearAlgebra, BenchmarkTools
# Simple method and syntax, no types!
function quadratic(x,Q,y)
dot(x,Q*y)
end
Q = randn(5,5)
x = randn(5)
y = randn(5)
v = quadratic(x,Q,y)
Q = randn(5,5)
x = [0.0, 0.0, 0.0, 1.0, 0.0]
y = [0.0, 0.0, 1.0, 0.0, 0.0]
# Similar to matlab
using BenchmarkTools
@btime quadratic(x,Q,y)
# Working, but could be faster
# Define our own type
# Bool <: Int
# Subtype/inheritance
struct OneHotVec <: AbstractVector{Bool}
i::Int
length::Int
end
# Basic properties of a vector
Base.size(y::OneHotVec) = (y.length,)
Base.getindex(y::OneHotVec, i::Int) = (i == y.i ? true : false)
xv = OneHotVec(4,5)
yv = OneHotVec(3,5)
quadratic(xv, Q, yv)
# But is it slow? (Not using blas)
@btime quadratic(x, Q, y)
@btime quadratic(xv, Q, yv)
# Fast, why!?
# We can improve!
Base.:*(Q::AbstractMatrix, y::OneHotVec) = Q[:,y.i]
quadratic(xv, Q, yv)
# What is happening here?
# Julia picks most specific function! (multiple dispatch)
# Compiles given input types
@btime quadratic(xv, Q, yv)
# Even better! How would you do this in OOP?
quadratic(x::OneHotVec, Q, y::OneHotVec) = Q[x.i, y.i]
quadratic(xv, Q, yv)
@btime quadratic(xv, Q, yv)
# But what about cases we didn't cover?
quadratic(x, Q, yv)
@btime quadratic(x, Q, yv)
# Bigger case
using BenchmarkTools
Q = randn(1000,1000)
x = zeros(1000); x[123] = 1;
y = zeros(1000); x[321] = 1;
xv = OneHotVec(123,1000)
yv = OneHotVec(321,1000)
@btime quadratic(x,Q,y)
@btime quadratic(x,Q,yv)
@btime quadratic(xv,Q,yv)
# What about other types?
using SparseArrays
QSparse = sprandn(ComplexF64, 1000, 1000, 0.01)
@btime quadratic(x, QSparse, y)
@btime quadratic(x, QSparse, yv)
@btime quadratic(xv, QSparse, yv)
# REVIEW:
# Generic implementation
function quadratic(x,Q,y)
dot(x,Q*y)
end
# Fast Matrix*OneHotVec
Base.:*(Q::AbstractMatrix, y::OneHotVec) = Q[:,y.i]
# Really fast quadratic
quadratic(x::OneHotVec, Q, y::OneHotVec) = Q[x.i, y.i]
# Automatic Differentiation slide
# Lets test some more
using ForwardDiff
d = ForwardDiff.Dual(1.0, 1.0) # (1 + 1*ϵ)
2*d
sin(1), cos(1)
sin(d)
(d+3)^2
# Back to the quadratic
x = randn(5)
quad(Q) = quadratic(x,Q,OneHotVec(3,5))
Q = randn(5,5)
ForwardDiff.gradient(quad, Q)
# It works!
# Even more complicated
quad5(Q) = quadratic(OneHotVec(3,5),Q^5,OneHotVec(3,5))
ForwardDiff.gradient(quad5, Q)
# # Try manually
# epsM = zeros(5,5); epsM[1,1] = 1e-8
# (quad5(Q+epsM)- quad5(Q))./1e-8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment