Skip to content

Instantly share code, notes, and snippets.

@dieggsy
Last active October 10, 2018 02:15
Show Gist options
  • Save dieggsy/04f1bd7074c84674573a1366428d5695 to your computer and use it in GitHub Desktop.
Save dieggsy/04f1bd7074c84674573a1366428d5695 to your computer and use it in GitHub Desktop.
(functor (generic-vector (M (valloc
vset!
vlength
vref
vfill!
vcopy!
vswap!
vreverse!
vnull?
vpositive?
vnegative?
vnonnegative?
vequal?
vreal
vimag
v+
v+c
v-
v*
v/
vmax
vmin
vargmax
vargmin
vbasis!
vsubvector)))
(vector?
list->vector
vector->list
vector
make-vector
vector-length
vector-map
vector-map!
vector-ref
vector-set!
vector-fill!
vector-copy
vector-swap!
vector-reverse
vector-reverse!
vector-real-part
vector-imag-part
vector+
vector+!
vector-
vector-!
vector*
vector*!
vector/
vector/!
vector-scale
vector-scale!
vector-add-constant
vector-add-constant!
vector-zero?
vector-positive?
vector-negative?
vector-nonnegative?
vector=
vector-max
vector-min
vector-argmax
vector-argmin
vector-basis!
subvector)
(import (except scheme
vector
vector?
list->vector
vector->list
make-vector
vector-length
vector-ref
vector-fill!
vector-set!)
(except (chicken base) subvector)
M)
(define-record-type vector
(ptr->vector ptr)
vector?
(ptr vector->ptr))
(define (list->vector lst #!optional complex)
(let* ((len (length lst))
(v (valloc (length lst))))
(do ((i 0 (+ i 1))
(lst lst (cdr lst)))
((= i len) (ptr->vector v))
(vset! v i (car lst)))))
(define (vector->list v)
(let* ((ptr (vector->ptr v))
(len (vlength ptr)))
(do ((i (- len 1) (- i 1))
(res '() (cons (vref ptr i) res)))
((= i -1) res))))
(define (vector . args)
(list->vector args))
(define (make-vector n #!optional fill)
(let ((v (valloc n)))
(when fill
(vfill! v fill))
(ptr->vector v)))
(define (vector-length v)
(vlength (vector->ptr v)))
(define (vector-map f . v)
(let* ((len (apply min (map vector-length v)))
(d (map vector->ptr v))
(r (valloc len)))
(do ((i 0 (+ i 1)))
((= i len) (ptr->vector r))
(vset! r i (apply f (map (cut vref <> i) d))))))
(define (vector-map! f . v)
(let* ((len (apply min (map vector-length v)))
(d (map vector->ptr v)))
(do ((i 0 (+ i 1)))
((= i len))
(vset! (car d) i (apply f (map (cut vref <> i) d))))))
(define (vector-ref v i)
(vref (vector->ptr v) i))
(define (subvector v a b #!optional (stride 1))
(ptr->vector (vsubvector (vector->ptr v)
a
stride
(inexact->exact
(ceiling (/ (- b a) stride))))))
(define (vector-set! v i n)
(vset! (vector->ptr v) i n))
(define (vector-fill! v n)
(vfill! (vector->ptr v) n))
(define (vector-copy v)
(let ((c (valloc (vector-length v)))
(d (vector->ptr v)))
(vcopy! c d)
(ptr->vector c)))
(define (vector-swap! v n1 n2)
(vswap! (vector->ptr v) n1 n2))
(define (vector-reverse v)
(let ((r (valloc (vector-length v))))
(vcopy! r (vector->ptr v))
(vreverse! r)
(ptr->vector r)))
(define (vector-reverse! v)
(vreverse! (vector->ptr v)))
(define (vector-real-part v)
(ptr->vector (vreal (vector->ptr v))))
(define (vector-imag-part v)
(ptr->vector (vimag (vector->ptr v))))
(define (vector+ v1 v2)
(let ((c (vector-copy v1)))
(v+ (vector->ptr c) (vector->ptr v2))
c))
(define (vector+! v1 v2)
(v+ (vector->ptr v1) (vector->ptr v2))
(void))
(define (vector- v1 v2)
(let ((c (vector-copy v1)))
(v- (vector->ptr c) (vector->ptr v2))
c))
(define (vector-! v1 v2)
(v- (vector->ptr v1) (vector->ptr v2))
(void))
(define (vector* v1 v2)
(let ((c (vector-copy v1)))
(v* (vector->ptr c) (vector->ptr v2))
c))
(define (vector*! v1 v2)
(v* (vector->ptr v1) (vector->ptr v2))
(void))
(define (vector/ v1 v2)
(let ((c (vector-copy v1)))
(v/ (vector->ptr c) (vector->ptr v2))
c))
(define (vector/! v1 v2)
(v/ (vector->ptr v1) (vector->ptr v2))
(void))
(define (vector-scale v n)
(let ((c (vector-copy v)))
(v*c (vector->ptr c) n)
c))
(define (vector-scale! v n)
(v*c (vector->ptr v) n)
(void))
(define (vector-add-constant v n)
(let ((c (vector-copy v)))
(v+c (vector->ptr c) n)
c))
(define (vector-add-constant! v n)
(v+c (vector->ptr v) n)
(void))
(define (vector-max v)
(vmax (vector->ptr v)))
(define (vector-min v)
(vmin (vector->ptr v)))
(define (vector-argmax v)
(vargmax (vector->ptr v)))
(define (vector-argmin v)
(vargmin (vector->ptr v)))
(define (vector-zero? v)
(vnull? (vector->ptr v)))
(define (vector-positive? v)
(vpositive? (vector->ptr v)))
(define (vector-negative? v)
(vnegative? (vector->ptr v)))
(define (vector-nonnegative? v)
(vnonnegative? (vector->ptr v)))
(define (vector= v1 v2)
(vequal? (vector->ptr v1) (vector->ptr v2)))
(define (vector-basis! v n)
(vbasis! (vector->ptr v) n)
(void)))
(module _csl.zdvector *
(import scheme (chicken base) (chicken gc) (chicken foreign) foreigners)
(import (chicken format))
(foreign-declare "#include <gsl/gsl_errno.h>")
(define gsl_set_error_handler
(foreign-lambda void "gsl_set_error_handler" (c-pointer void)))
(define-external (csl_err (c-string reason) (c-string file) (int line) (int gsl_errno)) void
(error (format "gsl: ~a:~a: ERROR: ~a" file line reason)))
(gsl_set_error_handler (location csl_err))
(import-for-syntax (srfi 1)
(srfi 13)
(chicken format)
(chicken foreign)
(chicken string))
(import (only (rename scheme (make-rectangular %make-rectangular)) %make-rectangular))
(define-syntax complex-foreign-lambda
(ir-macro-transformer
(lambda (e i c)
(define (conc* . args)
(string->symbol
(apply conc args)))
(let* ((ret-type (strip-syntax (cadr e)))
(fn (caddr e))
(args (cdddr e))
(named-args
(let loop ((a args)
(i 0))
(if (null? a)
'()
(let ((arg (car a)))
(cons
(cons (if (and (pair? arg) (eq? (strip-syntax (car arg)) 'complex))
(conc* 'z i)
(conc* 'x i))
(if (pair? arg)
(map strip-syntax arg)
(strip-syntax arg)))
(loop (cdr a) (+ i 1))))))))
(define (make-letvars named-args)
(let ((only-complex (filter (lambda (x)
(and (pair? (cdr x))
(eq? 'complex (cadr x))))
named-args)))
(append-map (lambda (x)
`((,(conc* 'r (car x)) (real-part ,(car x)))
(,(conc* 'i (car x)) (imag-part ,(car x)))))
only-complex)))
(define (make-foreign-args named-args)
(append-map (lambda (x)
(if (and (pair? (cdr x)) (eq? (cadr x) 'complex))
`((,(caddr x) ,(conc* 'r (car x)))
(,(caddr x) ,(conc* 'i (car x))))
`((,(cdr x) ,(car x)))))
named-args))
(define (make-inits named-args)
(let* ((only-complex (filter (lambda (x)
(and (pair? (cdr x))
(eq? 'complex (cadr x))))
named-args)))
(fold (lambda (x y)
(cons
(let ((name (car x))
(type (caddr x)))
(format "gsl_complex~a ~a = gsl_complex_rect(~a,~a);"
(if (eq? type 'double)
""
(conc "_" type))
(symbol->string name)
(conc "r" name)
(conc "i" name)))
y))
'()
only-complex)))
(define (make-return ret-type fn named-args)
(let ((strargs (string-join (map (lambda (x)
(symbol->string (car x)))
named-args) ",")))
(cond ((and (pair? ret-type)
(eq? 'complex (car ret-type)))
`(,(format "gsl_complex~a out = ~a(~a);"
(if (eq? (cadr ret-type) 'double)
""
(conc "_" (cadr ret-type)))
fn strargs)
,(format "C_return(scheme_make_rect(GSL_REAL(out),GSL_IMAG(out)));")))
((eq? ret-type 'void)
`(,(format "~a(~a);" fn strargs)))
(else `(,(format "C_return(~a(~a));" fn strargs))))))
(let* ((foreign-args (make-foreign-args named-args))
(letvars (make-letvars named-args))
(inits (make-inits named-args))
(return (make-return ret-type fn named-args))
(return-type (if (and (pair? ret-type)
(eq? (car ret-type) 'complex))
'scheme-object
(strip-syntax (cadr e)))))
(if (null? letvars)
`(foreign-safe-lambda* ,return-type ,foreign-args
,@inits
,@return)
`(lambda ,(map car named-args)
(let ,letvars
((foreign-safe-lambda* ,return-type ,foreign-args
,@inits
,@return)
,@(map cadr foreign-args))))))))))
(define-external
(scheme_make_rect (double r) (double i))
scheme-object
(%make-rectangular r i))
(foreign-declare "#include <gsl/gsl_block_complex_double.h>")
(define-foreign-record-type
(gsl_block "gsl_block_complex")
(unsigned-int size gsl_block.size)
((c-pointer double) data gsl_block.data))
(foreign-declare "#include <gsl/gsl_vector_complex_double.h>")
(foreign-declare "#include <gsl/gsl_complex.h>")
(foreign-declare "#include <gsl/gsl_complex_math.h>")
(define-foreign-record-type
(gsl_vector "gsl_vector_complex")
(unsigned-int size vlength)
(unsigned-int stride gsl_vector.stride)
((c-pointer double) data gsl_vector.data)
(gsl_block block gsl_vector.block)
(int owner gsl_vector.owner))
(define gsl_vector_free
(foreign-safe-lambda void "gsl_vector_complex_free" gsl_vector))
(define gsl_vector_alloc
(foreign-safe-lambda gsl_vector "gsl_vector_complex_alloc" unsigned-int))
(define (valloc n) (set-finalizer! (gsl_vector_alloc n) gsl_vector_free))
(define vref
(complex-foreign-lambda
(complex double)
"gsl_vector_complex_get"
(const gsl_vector)
(const unsigned-int)))
(define vset!
(complex-foreign-lambda
void
"gsl_vector_complex_set"
gsl_vector
(const unsigned-int)
(complex double)))
(define vfill!
(complex-foreign-lambda
void
"gsl_vector_complex_set_all"
gsl_vector
(complex double)))
(define vbasis!
(foreign-safe-lambda
void
"gsl_vector_complex_set_basis"
gsl_vector
unsigned-int))
(define vsubvector
(foreign-safe-lambda*
gsl_vector
((gsl_vector v)
(unsigned-int offset)
(unsigned-int stride)
(unsigned-int n))
"gsl_vector_complex *p0 = gsl_vector_complex_alloc(v->size);"
"gsl_vector_complex_view p1 = gsl_vector_complex_subvector_with_stride(v,offset,stride,n);"
"memcpy(p0, &p1.vector, sizeof(gsl_vector_complex));"
"C_return(p0);"))
(define vcopy!
(foreign-safe-lambda
int
"gsl_vector_complex_memcpy"
gsl_vector
gsl_vector))
(define vswap!
(foreign-safe-lambda
int
"gsl_vector_complex_swap_elements"
gsl_vector
unsigned-int
unsigned-int))
(define vreverse!
(foreign-safe-lambda int "gsl_vector_complex_reverse" gsl_vector))
(define v+
(foreign-safe-lambda int "gsl_vector_complex_add" gsl_vector gsl_vector))
(define v-
(foreign-safe-lambda int "gsl_vector_complex_sub" gsl_vector gsl_vector))
(define v*
(foreign-safe-lambda int "gsl_vector_complex_mul" gsl_vector gsl_vector))
(define v/
(foreign-safe-lambda int "gsl_vector_complex_div" gsl_vector gsl_vector))
(define v*c
(complex-foreign-lambda
int
"gsl_vector_complex_scale"
gsl_vector
(complex double)))
(define v+c
(complex-foreign-lambda
int
"gsl_vector_complex_add_constant"
gsl_vector
(complex double)))
(define (vmax v) (error "Complex numbers have no ordering."))
(define (vmin v) (error "Complex numbers have no ordering."))
(define (vargmax v) (error "Complex numbers have no ordering."))
(define (vargmin v) (error "Complex numbers have no ordering."))
(define vnull?
(foreign-safe-lambda bool "gsl_vector_complex_isnull" gsl_vector))
(define vpositive?
(foreign-safe-lambda bool "gsl_vector_complex_ispos" gsl_vector))
(define vnegative?
(foreign-safe-lambda bool "gsl_vector_complex_isneg" gsl_vector))
(define vnonnegative?
(foreign-safe-lambda bool "gsl_vector_complex_isnonneg" gsl_vector))
(define vequal?
(foreign-safe-lambda
bool
"gsl_vector_complex_equal"
gsl_vector
gsl_vector))
(define vimag
(foreign-safe-lambda*
gsl_vector
((gsl_vector v))
"gsl_vector_complex *p0 = gsl_vector_complex_alloc(v->size);"
"gsl_vector_view p1 = gsl_vector_complex_imag(v);"
"for (int i = 0; i < v->size; i++) { "
"double iz = gsl_vector_get(&p1.vector, i);"
"gsl_vector_complex_set(p0,i,gsl_complex_rect(0, iz));"
"}"
"C_return(p0);"))
(define vreal
(foreign-safe-lambda*
gsl_vector
((gsl_vector v))
"gsl_vector_complex *p0 = gsl_vector_complex_alloc(v->size);"
"gsl_vector_view p1 = gsl_vector_complex_real(v);"
"for (int i = 0; i < v->size; ++i) { "
"double rz = gsl_vector_get(&p1.vector, i);"
"gsl_vector_complex_set(p0,i,gsl_complex_rect(rz, 0));"
"}"
"C_return(p0);")))
(module csl.zdvector = (generic-vector _csl.zdvector))
;; Local Variables:
;; compile-command: "csc -s csl.zdvector.scm -j csl.zdvector -L -lgsl -L -lgslcblas"
;; End:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment