Last active
May 11, 2017 23:02
-
-
Save nbecker/1c02aedd99ed5bfa72bf9284b108f57b to your computer and use it in GitHub Desktop.
Histogram in julia
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Histogram | |
struct range_t{flt_t<:Real} | |
start::flt_t | |
step::flt_t | |
stop::flt_t | |
range_t{flt_t}(start::flt_t, step::flt_t, stop::flt_t) where flt_t<:Real = new(start, step, stop) | |
end | |
range_t(start::flt_t, step::flt_t, stop::flt_t) where flt_t<:Real = range_t{flt_t}(start, step, stop) | |
convert(::Type{range_t{flt_t}}, start::flt_t, step::flt_t, stop::flt_t) where flt_t<:Real = range_t{flt_t}(start, step, stop) | |
struct histogram{flt_t<:Real,cnt_t<:Integer,N} | |
ranges::NTuple{N, range_t{flt_t}} | |
clip::Bool | |
buckets::Array{cnt_t,N} | |
# function histogram{T,N}(rng::Tuple{StepRange}, clip::Bool=true) where T<:Real | |
# new(rng, clip, zeros(T, [Int((r.max - r.min)/r.delta + 1) for r in rng])) | |
# end | |
end | |
#construct a histogram from a sequence of range_t | |
#rng is ((start, step, stop),...) | |
function histogram{flt_t,cnt_t}(rng::range_t{flt_t}...; clip::Bool=true) where flt_t<:Real where cnt_t<:Integer | |
N = nfields(rng) | |
histogram{flt_t,cnt_t,N}(rng, clip, zeros(cnt_t, tuple([Int((r.stop - r.start)/r.step + 1) for r in rng]...))) | |
end | |
#construct with default Float64,Int64 | |
histogram(rng::range_t{Float64}...; clip::Bool=true) = histogram{Float64,Int64}(rng...; clip=clip) | |
#construct a histogram from a sequence of tuples | |
function histogram{flt_t,cnt_t}(r::Tuple{flt_t,flt_t,flt_t}...; clip::Bool=true) where flt_t <: Real where cnt_t <: Integer | |
N = nfields(r) | |
histogram{flt_t,cnt_t}([range_t(e...) for e in r]...; clip=clip) | |
end | |
#construct with default Int64 | |
histogram(r::Tuple{flt_t,flt_t,flt_t}...; clip::Bool=true) where flt_t <: Real = histogram{flt_t,Int64}(r...; clip=clip) | |
# histogram(min::T, max::T, delta::T, clip::Bool=true) where {T<:Real} = histogram{T}(min, max, delta, clip); | |
# histogram(min::Real, max::Real, delta::Real, clip::Bool=true) = histogram(promote(min, max, delta)..., clip) | |
function apply(rng::range_t, x::flt_t, clip::Bool) where flt_t<:Real | |
if (x > rng.stop) | |
if (clip) | |
return rng.stop; | |
else | |
throw(DomainError()) | |
end | |
elseif (x < rng.start) | |
if (clip) | |
return rng.start | |
else | |
throw(DomainError()) | |
end | |
else | |
return x | |
end | |
end | |
# Update a histogram, xs is an N-dim point | |
function fit!(h::histogram{flt_t,cnt_t,N}, xs::flt_t...; weight::Integer=1) where flt_t <: Real where cnt_t <: Integer where N | |
nfields(xs) == N || throw(DomainError()) | |
indexes = [Int64(round((apply(r, x, h.clip) - r.start) / r.step)) for (r,x) in zip(h.ranges, xs)] | |
h.buckets[indexes+1]+=weight; | |
return h | |
end | |
# for internal use, this just switches argument order, for use by the following function | |
function _fit!(h::histogram{flt_t,cnt_t,N}, w::Integer, xs::flt_t...) where flt_t <: Real where cnt_t <: Integer where N | |
fit!(h, xs..., weight=w) | |
end | |
# update a histogram from a sequence of points. Each point is N-dim. weight is either a scalar or a sequence of the same length as the point sequence | |
function fit!(h::histogram{flt_t,cnt_t,N}, z::AbstractArray{flt_t}...; weight=1) where flt_t <: Real where cnt_t <: Integer where N | |
_fit!.(h, weight, z...) | |
return h | |
end | |
# alias for fit! | |
import Base.+ | |
function +(h::histogram{flt_t,cnt_t,N}, z) where flt_t <: Real where cnt_t <: Integer where N | |
fit!(h, z) | |
return h | |
end | |
function axes(h::histogram{flt_t,cnt_t,N}) where flt_t <: Real where cnt_t <: Integer where N | |
return [r.start:r.step:r.stop for r in h.ranges] | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Histogram: histogram, fit!, range_t, convert | |
h = histogram(range_t(0., 0.01, 1.)) | |
h = histogram((0., 0.01, 1.)) | |
h += 2.0 | |
fit!(h, 2.0) | |
fit!(h, zeros(10)) | |
h3 = histogram((0., 0.01, 1.), (0., 0.01, 1.)) | |
fit!(h3, 2.0, 2.0) | |
u = reshape(Array{Float64}(1:6), 2, 3) | |
fit!(h3, u[1,:], u[2,:]) | |
fit!(h3, u[1,:], u[2,:];weight= ones(Int64,3)*2) |
Thanks for the suggestions. I didn't use range, because I need (start, step, stop), while range stores (start, step, length). Seems inefficient to be converting back and forth.
Yes, buckets is supposed to be N-dim. The previous version of this code didn't actually work for N != 1. This version does.
You should give ranges a type parameter, eg
struct histogram{flt_t<:Real,cnt_t<:Integer,N,T<:Tuple}
ranges::T
clip::Bool
buckets::Array{cnt_t,N}
# function histogram{T,N}(rng::Tuple{StepRange}, clip::Bool=true) where T<:Real
# new(rng, clip, zeros(T, [Int((r.max - r.min)/r.delta + 1) for r in rng]))
# end
end
See this post for more info: https://discourse.julialang.org/t/undefreferror-access-to-undefined-reference/3526/6
Updated version has it's own range_t type to represent ranges. Maybe this is better?
Looks better. Try @code_warntype fit!(h3, u[1,:], u[2,:])
to look for any type instability.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Why not just use ranges for ranges? eg
rng::T
whereT
in one case could betypeof((1:3,1:2.0:27))
Also it's common in julia to use a single capital letter for a type parameter.
Do you really want
buckets
to be anN
dimensional array?You should probably include some test/example code in the same gist.