Skip to content

Instantly share code, notes, and snippets.

@atimin
Created July 8, 2012 03:38
Show Gist options
  • Save atimin/3069186 to your computer and use it in GitHub Desktop.
Save atimin/3069186 to your computer and use it in GitHub Desktop.
Slice N-dimension matrix
#!/usr/bin/env ruby
require 'pp'
require 'test/unit/assertions.rb'
class NMatrix
attr_reader :shape, :elements, :count, :strides
def initialize(shape, elements = nil)
@shape = shape
@count = 0
if elements
@elements = elements
else
@elements = init_el(shape).flatten
end
@strides = init_s(shape)
end
def[](*slices)
offsets = slices.map(&:min)
lens = slices.map { |s| s.max - s.min + 1 }
# Calc start point
start = 0
offsets.count.times { |i| start += offsets[i]*strides[i] }
els = slice(lens, strides, start)
NMatrix.new(shape, els)
end
private
def init_s(shape)
res = [1]*shape.count
res.count.times do |i|
shape[i+1..-1].each { |s| res[i] *= s } if i+1 <= shape.count
end
res
end
def init_el(shape)
if shape.empty?
return
elsif shape.count == 1
Array.new(shape.first) { @count+= 1}
else
Array.new(shape.first) { init_el(shape[1..-1]) }
end
end
def slice(lens, strides, start)
res = []
if lens.count > 1
lens[0].times do |i|
res += slice(lens[1..-1], strides[1..-1], start + strides[0]*i)
end
else
res = elements[start, lens[0]]
end
res
end
end
include MiniTest::Assertions
n = NMatrix.new([3,3,3])
pp n
assert_equal(n[0..2, 0..1, 1..2].elements, [2, 3, 5, 6, 11, 12, 14, 15, 20, 21, 23, 24])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment