Skip to content

Instantly share code, notes, and snippets.

@antoine-levitt
Created October 25, 2017 05:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save antoine-levitt/413f3251e50c18f9544ae26ad5ea651f to your computer and use it in GitHub Desktop.
Save antoine-levitt/413f3251e50c18f9544ae26ad5ea651f to your computer and use it in GitHub Desktop.
using MacroTools:postwalk,@capture
macro with_precision(flt_type, ex)
flt_type = eval(flt_type)
function repl(x)
# literals get converted
if isa(x,AbstractFloat)
:(parse($flt_type, $(string(x))))
elseif isa(x, Irrational)
flt_type(x)
# known symbols get converted
elseif x in (:pi,:e,:γ,:catalan,:φ,:Inf,:NaN)
flt_type(eval(x))
elseif isa(x, Expr) && x.head == :call
# 1/2 -> flt_type(1)/flt_type(2)
if x.args[1] == :/ && isa(x.args[2], Integer) && isa(x.args[3], Integer)
flt_type(x.args[2])/flt_type(x.args[3])
#known array-builders get a type in first position
elseif x.args[1] in (:ones,:zeros,:rand,:randn,:eye)
# ignore if the first argument is a type
@capture(x, f_(xs__)) ? :(isa($(x.args[2]),DataType) ? $x : $f($flt_type, $(xs...))) : x
else
x
end
else
x
end
end
esc(postwalk(repl,ex))
end
@with_precision BigFloat begin
a = 1e-2
println(typeof(a))
a = 3.2
println(typeof(a))
a = 2/3
println(typeof(a))
a = pi+1
println(typeof(a))
a = 2/3+1
println(typeof(a))
x = [1., 2.]
println(x)
x = ones(2,2)
println(x)
x = ones(Float64,2,2)
println(x)
x = Inf
println(typeof(x))
x = NaN
println(typeof(x))
x = sin(1)
println(typeof(x))
cust_type = Float32
println(randn(cust_type,2))
function f{T}(x::T)
ones(T,2,2)
end
println(typeof(f(Float32(2.))))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment