Created
March 25, 2012 02:45
-
-
Save ompugao/2190890 to your computer and use it in GitHub Desktop.
NArray Extension
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'narray' | |
class NArray | |
class << self | |
# borrows other dimension lengths from the first object and relies on it to | |
# raise errors (or not) upon concatenation. | |
def cat(dim=0,*narrays) | |
raise ArgumentError, "'dim' must be an integer (did you forget your dim arg?)" unless dim.is_a?(Integer) | |
raise ArgumentError, "must have narrays to cat" if narrays.size == 0 | |
new_typecode = narrays.map(&:typecode).max | |
narrays.uniq.each {|narray| narray.newdim!(dim) if narray.shape[dim].nil? } | |
shapes = narrays.map(&:shape) | |
new_dim_size = shapes.inject(0) {|sum,v| sum + v[dim] } | |
new_shape = shapes.first.dup | |
new_shape[dim] = new_dim_size | |
narr = NArray.new(new_typecode, *new_shape) | |
range_cnt = 0 | |
narrays.zip(shapes) do |narray, shape| | |
index = shape.map {true} | |
index[dim] = (range_cnt...(range_cnt += shape[dim])) | |
narr[*index] = narray | |
end | |
narr | |
end | |
def vcat(*narrays) ; cat(1, *narrays) end | |
def hcat(*narrays) ; cat(0, *narrays) end | |
end | |
# object method interface | |
def cat(dim=0, *narrays) ; NArray.cat(dim, self, *narrays) end | |
def vcat(*narrays) ; NArray.vcat(self, *narrays) end | |
def hcat(*narrays) ; NArray.hcat(self, *narrays) end | |
class << self | |
def trace | |
sum = 0 | |
0.step(self.shape[0]-1) do |s| | |
sum += self[s,s] | |
end | |
sum | |
end | |
end | |
def trace() ; NArray.trace(self) end | |
end | |
class NMatrix | |
class << self | |
def determinant(nmat) | |
raise IndexError,"must be a square matrix" if nmat.shape[0] != nmat.shape[1] | |
return nmat[0,0] if nmat.shape == [1,1] | |
det = 0 | |
(0...nmat.shape[0]).step do |idx| | |
det += nmat[0,idx] * NMatrix.determinant(nmat.delete_at(0,idx)) * (idx%2==0 ? 1 : -1) | |
end | |
return det | |
end | |
alias det determinant | |
end | |
def determinant(); NMatrix.determinant(self) end | |
alias det determinant | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment