Skip to content

Instantly share code, notes, and snippets.

@tanmaykm
Last active August 29, 2015 14:02
Show Gist options
  • Save tanmaykm/6871d2d014df4a92f40f to your computer and use it in GitHub Desktop.
Save tanmaykm/6871d2d014df4a92f40f to your computer and use it in GitHub Desktop.
sort performance
elapsed time: 0.103189028 seconds (0 bytes allocated)
elapsed time: 0.000888531 seconds (0 bytes allocated)
elapsed time: 0.001684563 seconds (0 bytes allocated)
elapsed time: 0.103114518 seconds (0 bytes allocated)
elapsed time: 0.001014195 seconds (0 bytes allocated)
elapsed time: 0.001589316 seconds (0 bytes allocated)
elapsed time: 0.103195119 seconds (0 bytes allocated)
elapsed time: 0.000961692 seconds (0 bytes allocated)
elapsed time: 0.001577562 seconds (0 bytes allocated)
using Base.Order.Forward
using Base.Sort
# binary (method call)
function f1(rowvalA)
r2 = length(rowvalA)
r1 = 1
for i in 1:10^6
idx = searchsortedfirst(rowvalA, i, r1, r2, Forward)
(idx > 0) && (r1 = idx)
end
end
# binary
function f2(rowvalA)
last = length(rowvalA)
for i in 1:10^6
ridx = 1
@inbounds while ridx <= last
mid = (ridx + last) >> 1
midval = int(rowvalA[mid])
if midval > i
last = mid - 1
elseif midval == i
ridx = mid
break
else
ridx = mid + 1
end
end
end
end
# linear (with start pointer being advanced)
function f3(rowvalA)
last = length(rowvalA)
ridx = 1
for i in 1:10^6
@inbounds while (ridx <= last)
(rowvalA[ridx] >= i) && break
ridx += 1
end
end
end
function f()
rowvalA = sort(randperm(10^8)[1:10^6])
@time f1(rowvalA)
@time f2(rowvalA)
@time f3(rowvalA)
end
for i in 1:3
f()
end
@mauro3
Copy link

mauro3 commented Jun 6, 2014

I had a look at this because of the huge performance difference. I think there were a couple of bugs, with them corrected I get slightly better performance of searchsortedfirst, probably because it does only one comparison per loop iteration:

elapsed time: 0.00099055 seconds (0 bytes allocated)
elapsed time: 0.001322109 seconds (0 bytes allocated)
elapsed time: 0.000988963 seconds (0 bytes allocated)
elapsed time: 0.001316261 seconds (0 bytes allocated)
...

Corrected code:

using Base.Order.Forward
using Base.Sort
function f1(rowvalA, needles)
    r2 = length(rowvalA)
    y = 0
    for i in needles
        y+=searchsortedfirst(rowvalA, i, 1, r2, Forward)
    end
    y
end

function f2(rowvalA, needles)

    y=0
    for i in needles
        last = length(rowvalA)
        ridx = 1
        @inbounds while ridx <= last
            mid = (ridx + last) >>> 1
            midval = int(rowvalA[mid])
            if midval > i
                last = mid - 1
            elseif midval == i
                ridx = mid
                break
            else
                ridx = mid + 1
            end
        end
        y += ridx
    end
    y
end

function f()
    rowvalA = [1:10000] # this needs to be sorted
    needle = randperm(10000)
    @time y1 = f1(rowvalA, needle)
    @time y2 = f2(rowvalA, needle)
    @assert y1 == y2 
end

for i in 1:5
    f()
end

@tanmaykm
Copy link
Author

tanmaykm commented Jun 7, 2014

Thanks. I had corrected the sorting of rowvalA a little while after I initially posted the results. But the specific case I was comparing is where the start pointer for the search range is advanced with every result (as was the scenario in sparsematrix.jl). If I update your modifications for that, I get results similar to what I got initially.

But in the end, there is not much difference (for the specific case being tested here) and even a linear search is almost as fast (f3 that was added later).

Here's what it looked like:

julia> using Base.Order.Forward

julia> using Base.Sort

julia> function f1(rowvalA, needles)
           r2 = length(rowvalA)
           r1 = 1
           y = 0
           for i in needles
               z = searchsortedfirst(rowvalA, i, r1, r2, Forward)
               (z > 0) && (r1 = z)
               y += z
           end
           y
       end
f1 (generic function with 2 methods)

julia> function f2(rowvalA, needles)

           y=0
           ridx = 1
           for i in needles
               last = length(rowvalA)
               @inbounds while ridx <= last
                   mid = (ridx + last) >>> 1
                   midval = int(rowvalA[mid])
                   if midval > i
                       last = mid - 1
                   elseif midval == i
                       ridx = mid
                       break
                   else
                       ridx = mid + 1
                   end
               end
               y += ridx
           end
           y
       end
f2 (generic function with 2 methods)

julia> function f()
           rowvalA = [1:10000] # this needs to be sorted
           needle = randperm(10000)
           @time y1 = f1(rowvalA, needle)
           @time y2 = f2(rowvalA, needle)
           @assert y1 == y2 
       end
f (generic function with 1 method)

julia> for i in 1:5
           f()
       end
elapsed time: 0.000127998 seconds (0 bytes allocated)
elapsed time: 8.0174e-5 seconds (0 bytes allocated)
elapsed time: 0.000108596 seconds (0 bytes allocated)
elapsed time: 5.3706e-5 seconds (0 bytes allocated)
elapsed time: 0.000120776 seconds (0 bytes allocated)
elapsed time: 5.9641e-5 seconds (0 bytes allocated)
elapsed time: 0.000112553 seconds (0 bytes allocated)
elapsed time: 5.5417e-5 seconds (0 bytes allocated)
elapsed time: 0.000114833 seconds (0 bytes allocated)
elapsed time: 5.5032e-5 seconds (0 bytes allocated)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment