Skip to content

Instantly share code, notes, and snippets.

@dieggsy
Created October 9, 2018 23:08
Show Gist options
  • Save dieggsy/e06d5dcf1c3844a9abe8255f88ade556 to your computer and use it in GitHub Desktop.
Save dieggsy/e06d5dcf1c3844a9abe8255f88ade556 to your computer and use it in GitHub Desktop.
vectors issue
(import-for-syntax (srfi 1)
(srfi 13)
(chicken format)
(chicken foreign)
(chicken string))
(import (only (rename scheme (make-rectangular %make-rectangular)) %make-rectangular))
(define-syntax complex-foreign-lambda
(ir-macro-transformer
(lambda (e i c)
(define (conc* . args)
(string->symbol
(apply conc args)))
(let* ((ret-type (strip-syntax (cadr e)))
(fn (caddr e))
(args (cdddr e))
(named-args
(let loop ((a args)
(i 0))
(if (null? a)
'()
(let ((arg (car a)))
(cons
(cons (if (and (pair? arg) (eq? (strip-syntax (car arg)) 'complex))
(conc* 'z i)
(conc* 'x i))
(if (pair? arg)
(map strip-syntax arg)
(strip-syntax arg)))
(loop (cdr a) (+ i 1))))))))
(define (make-letvars named-args)
(let ((only-complex (filter (lambda (x)
(and (pair? (cdr x))
(eq? 'complex (cadr x))))
named-args)))
(append-map (lambda (x)
`((,(conc* 'r (car x)) (real-part ,(car x)))
(,(conc* 'i (car x)) (imag-part ,(car x)))))
only-complex)))
(define (make-foreign-args named-args)
(append-map (lambda (x)
(if (and (pair? (cdr x)) (eq? (cadr x) 'complex))
`((,(caddr x) ,(conc* 'r (car x)))
(,(caddr x) ,(conc* 'i (car x))))
`((,(cdr x) ,(car x)))))
named-args))
(define (make-inits named-args)
(let* ((only-complex (filter (lambda (x)
(and (pair? (cdr x))
(eq? 'complex (cadr x))))
named-args)))
(fold (lambda (x y)
(cons
(let ((name (car x))
(type (caddr x)))
(format "gsl_complex~a ~a = gsl_complex_rect(~a,~a);"
(if (eq? type 'double)
""
(conc "_" type))
(symbol->string name)
(conc "r" name)
(conc "i" name)))
y))
'()
only-complex)))
(define (make-return ret-type fn named-args)
(let ((strargs (string-join (map (lambda (x)
(symbol->string (car x)))
named-args) ",")))
(cond ((and (pair? ret-type)
(eq? 'complex (car ret-type)))
`(,(format "gsl_complex~a out = ~a(~a);"
(if (eq? (cadr ret-type) 'double)
""
(conc "_" (cadr ret-type)))
fn strargs)
,(format "C_return(scheme_make_rect(GSL_REAL(out),GSL_IMAG(out)));")))
((eq? ret-type 'void)
`(,(format "~a(~a);" fn strargs)))
(else `(,(format "C_return(~a(~a));" fn strargs))))))
(let* ((foreign-args (make-foreign-args named-args))
(letvars (make-letvars named-args))
(inits (make-inits named-args))
(return (make-return ret-type fn named-args))
(return-type (if (and (pair? ret-type)
(eq? (car ret-type) 'complex))
'scheme-object
(strip-syntax (cadr e)))))
(if (null? letvars)
`(foreign-safe-lambda* ,return-type ,foreign-args
,@inits
,@return)
`(lambda ,(map car named-args)
(let ,letvars
((foreign-safe-lambda* ,return-type ,foreign-args
,@inits
,@return)
,@(map cadr foreign-args))))))))))
(import (chicken format))
(foreign-declare "#include <gsl/gsl_errno.h>")
(define gsl_set_error_handler
(foreign-lambda void "gsl_set_error_handler" (c-pointer void)))
(define-external (csl_err (c-string reason) (c-string file) (int line) (int gsl_errno)) void
(error (format "gsl: ~a:~a: ERROR: ~a" file line reason)))
(gsl_set_error_handler (location csl_err))
(##core#begin
(module
_csl.dvector
*
(import scheme (chicken base) (chicken gc) (chicken foreign) foreigners)
(include "../csl-error.scm")
(foreign-declare "#include <gsl/gsl_block.h>")
(define-foreign-record-type
(gsl_block "gsl_block")
(unsigned-int size gsl_block.size)
((c-pointer double) data gsl_block.data))
(foreign-declare "#include <gsl/gsl_vector_double.h>")
(define-foreign-record-type
(gsl_vector "gsl_vector")
(unsigned-int size vlength)
(unsigned-int stride gsl_vector.stride)
((c-pointer double) data gsl_vector.data)
(gsl_block block gsl_vector.block)
(int owner gsl_vector.owner))
(define gsl_vector_free
(foreign-safe-lambda void "gsl_vector_free" gsl_vector))
(define gsl_vector_alloc
(foreign-safe-lambda gsl_vector "gsl_vector_alloc" unsigned-int))
(define (valloc n) (set-finalizer! (gsl_vector_alloc n) gsl_vector_free))
(define vref
(foreign-safe-lambda
double
"gsl_vector_get"
(const gsl_vector)
(const unsigned-int)))
(define vset!
(foreign-safe-lambda
void
"gsl_vector_set"
gsl_vector
(const unsigned-int)
(const double)))
(define vfill!
(foreign-safe-lambda void "gsl_vector_set_all" gsl_vector double))
(define vbasis!
(foreign-safe-lambda
void
"gsl_vector_set_basis"
gsl_vector
unsigned-int))
(define vsubvector
(foreign-safe-lambda*
gsl_vector
((gsl_vector v)
(unsigned-int offset)
(unsigned-int stride)
(unsigned-int n))
"gsl_vector *p0 = malloc(sizeof(gsl_vector));"
"gsl_vector_view p1 = gsl_vector_subvector_with_stride(v,offset,stride,n);"
"memcpy(p0, &p1.vector, sizeof(gsl_vector));"
"C_return(p0);"))
(define vcopy!
(foreign-safe-lambda int "gsl_vector_memcpy" gsl_vector gsl_vector))
(define vswap!
(foreign-safe-lambda
int
"gsl_vector_swap_elements"
gsl_vector
unsigned-int
unsigned-int))
(define vreverse!
(foreign-safe-lambda int "gsl_vector_reverse" gsl_vector))
(define v+
(foreign-safe-lambda int "gsl_vector_add" gsl_vector gsl_vector))
(define v-
(foreign-safe-lambda int "gsl_vector_sub" gsl_vector gsl_vector))
(define v*
(foreign-safe-lambda int "gsl_vector_mul" gsl_vector gsl_vector))
(define v/
(foreign-safe-lambda int "gsl_vector_div" gsl_vector gsl_vector))
(define v*c (foreign-safe-lambda int "gsl_vector_scale" gsl_vector double))
(define v+c
(foreign-safe-lambda int "gsl_vector_add_constant" gsl_vector double))
(define vmax (foreign-safe-lambda double "gsl_vector_max" gsl_vector))
(define vmin (foreign-safe-lambda double "gsl_vector_min" gsl_vector))
(define vargmax
(foreign-safe-lambda int "gsl_vector_max_index" gsl_vector))
(define vargmin
(foreign-safe-lambda int "gsl_vector_min_index" gsl_vector))
(define vnull? (foreign-safe-lambda bool "gsl_vector_isnull" gsl_vector))
(define vpositive?
(foreign-safe-lambda bool "gsl_vector_ispos" gsl_vector))
(define vnegative?
(foreign-safe-lambda bool "gsl_vector_isneg" gsl_vector))
(define vnonnegative?
(foreign-safe-lambda bool "gsl_vector_isnonneg" gsl_vector))
(define vequal?
(foreign-safe-lambda bool "gsl_vector_equal" gsl_vector gsl_vector))
(define (vimag v) (let ((v (valloc (vlength v)))) (vfill! v 0) v))
(define (vreal v) v))
(module csl.dvector = (generic-vector _csl.dvector)))
(functor (generic-vector (M (valloc
vset!
vlength
vref
vfill!
vcopy!
vswap!
vreverse!
vnull?
vpositive?
vnegative?
vnonnegative?
vequal?
vreal
vimag
v+
v+c
v-
v*
v/
vmax
vmin
vargmax
vargmin
vbasis!
vsubvector)))
(vector?
list->vector
vector->list
vector
make-vector
vector-length
vector-map
vector-map!
vector-ref
vector-set!
vector-fill!
vector-copy
vector-swap!
vector-reverse
vector-reverse!
vector-real-part
vector-imag-part
vector+
vector+!
vector-
vector-!
vector*
vector*!
vector/
vector/!
vector-scale
vector-scale!
vector-add-constant
vector-add-constant!
vector-zero?
vector-positive?
vector-negative?
vector-nonnegative?
vector=
vector-max
vector-min
vector-argmax
vector-argmin
vector-basis!
subvector)
(import (except scheme
vector
vector?
list->vector
vector->list
make-vector
vector-length
vector-ref
vector-fill!
vector-set!)
(except (chicken base) subvector)
M)
(define-record-type vector
(ptr->vector ptr)
vector?
(ptr vector->ptr))
(define (list->vector lst #!optional complex)
(let* ((len (length lst))
(v (valloc (length lst))))
(do ((i 0 (+ i 1))
(lst lst (cdr lst)))
((= i len) (ptr->vector v))
(vset! v i (car lst)))))
(define (vector->list v)
(let* ((ptr (vector->ptr v))
(len (vlength ptr)))
(do ((i (- len 1) (- i 1))
(res '() (cons (vref ptr i) res)))
((= i -1) res))))
(define (vector . args)
(list->vector args))
(define (make-vector n #!optional fill)
(let ((v (valloc n)))
(when fill
(vfill! v fill))
(ptr->vector v)))
(define (vector-length v)
(vlength (vector->ptr v)))
(define (vector-map f . v)
(let* ((len (apply min (map vector-length v)))
(d (map vector->ptr v))
(r (valloc len)))
(do ((i 0 (+ i 1)))
((= i len) (ptr->vector r))
(vset! r i (apply f (map (cut vref <> i) d))))))
(define (vector-map! f . v)
(let* ((len (apply min (map vector-length v)))
(d (map vector->ptr v)))
(do ((i 0 (+ i 1)))
((= i len))
(vset! (car d) i (apply f (map (cut vref <> i) d))))))
(define (vector-ref v i)
(vref (vector->ptr v) i))
(define (subvector v a b #!optional (stride 1))
(ptr->vector (vsubvector (vector->ptr v)
a
stride
(inexact->exact
(ceiling (/ (- b a) stride))))))
(define (vector-set! v i n)
(vset! (vector->ptr v) i n))
(define (vector-fill! v n)
(vfill! (vector->ptr v) n))
(define (vector-copy v)
(let ((c (valloc (vector-length v)))
(d (vector->ptr v)))
(vcopy! c d)
(ptr->vector c)))
(define (vector-swap! v n1 n2)
(vswap! (vector->ptr v) n1 n2))
(define (vector-reverse v)
(let ((r (valloc (vector-length v))))
(vcopy! r (vector->ptr v))
(vreverse! r)
(ptr->vector r)))
(define (vector-reverse! v)
(vreverse! (vector->ptr v)))
(define (vector-real-part v)
(ptr->vector (vreal (vector->ptr v))))
(define (vector-imag-part v)
(ptr->vector (vimag (vector->ptr v))))
(define (vector+ v1 v2)
(let ((c (vector-copy v1)))
(v+ (vector->ptr c) (vector->ptr v2))
c))
(define (vector+! v1 v2)
(v+ (vector->ptr v1) (vector->ptr v2))
(void))
(define (vector- v1 v2)
(let ((c (vector-copy v1)))
(v- (vector->ptr c) (vector->ptr v2))
c))
(define (vector-! v1 v2)
(v- (vector->ptr v1) (vector->ptr v2))
(void))
(define (vector* v1 v2)
(let ((c (vector-copy v1)))
(v* (vector->ptr c) (vector->ptr v2))
c))
(define (vector*! v1 v2)
(v* (vector->ptr v1) (vector->ptr v2))
(void))
(define (vector/ v1 v2)
(let ((c (vector-copy v1)))
(v/ (vector->ptr c) (vector->ptr v2))
c))
(define (vector/! v1 v2)
(v/ (vector->ptr v1) (vector->ptr v2))
(void))
(define (vector-scale v n)
(let ((c (vector-copy v)))
(v*c (vector->ptr c) n)
c))
(define (vector-scale! v n)
(v*c (vector->ptr v) n)
(void))
(define (vector-add-constant v n)
(let ((c (vector-copy v)))
(v+c (vector->ptr c) n)
c))
(define (vector-add-constant! v n)
(v+c (vector->ptr v) n)
(void))
(define (vector-max v)
(vmax (vector->ptr v)))
(define (vector-min v)
(vmin (vector->ptr v)))
(define (vector-argmax v)
(vargmax (vector->ptr v)))
(define (vector-argmin v)
(vargmin (vector->ptr v)))
(define (vector-zero? v)
(vnull? (vector->ptr v)))
(define (vector-positive? v)
(vpositive? (vector->ptr v)))
(define (vector-negative? v)
(vnegative? (vector->ptr v)))
(define (vector-nonnegative? v)
(vnonnegative? (vector->ptr v)))
(define (vector= v1 v2)
(vequal? (vector->ptr v1) (vector->ptr v2)))
(define (vector-basis! v n)
(vbasis! (vector->ptr v) n)
(void)))
(##core#begin
(module
_csl.zdvector
*
(import scheme (chicken base) (chicken gc) (chicken foreign) foreigners)
(include "../csl-error")
(include "../complex-foreign-lambda.scm")
(define-external
(scheme_make_rect (double r) (double i))
scheme-object
(%make-rectangular r i))
(foreign-declare "#include <gsl/gsl_block_complex_double.h>")
(define-foreign-record-type
(gsl_block "gsl_block_complex")
(unsigned-int size gsl_block.size)
((c-pointer double) data gsl_block.data))
(foreign-declare "#include <gsl/gsl_vector_complex_double.h>")
(foreign-declare "#include <gsl/gsl_complex.h>")
(foreign-declare "#include <gsl/gsl_complex_math.h>")
(define-foreign-record-type
(gsl_vector "gsl_vector_complex")
(unsigned-int size vlength)
(unsigned-int stride gsl_vector.stride)
((c-pointer double) data gsl_vector.data)
(gsl_block block gsl_vector.block)
(int owner gsl_vector.owner))
(define gsl_vector_free
(foreign-safe-lambda void "gsl_vector_complex_free" gsl_vector))
(define gsl_vector_alloc
(foreign-safe-lambda gsl_vector "gsl_vector_complex_alloc" unsigned-int))
(define (valloc n) (set-finalizer! (gsl_vector_alloc n) gsl_vector_free))
(define vref
(complex-foreign-lambda
(complex double)
"gsl_vector_complex_get"
(const gsl_vector)
(const unsigned-int)))
(define vset!
(complex-foreign-lambda
void
"gsl_vector_complex_set"
gsl_vector
(const unsigned-int)
(complex double)))
(define vfill!
(complex-foreign-lambda
void
"gsl_vector_complex_set_all"
gsl_vector
(complex double)))
(define vbasis!
(foreign-safe-lambda
void
"gsl_vector_complex_set_basis"
gsl_vector
unsigned-int))
(define vsubvector
(foreign-safe-lambda*
gsl_vector
((gsl_vector v)
(unsigned-int offset)
(unsigned-int stride)
(unsigned-int n))
"gsl_vector_complex *p0 = gsl_vector_complex_alloc(v->size);"
"gsl_vector_complex_view p1 = gsl_vector_complex_subvector_with_stride(v,offset,stride,n);"
"memcpy(p0, &p1.vector, sizeof(gsl_vector_complex));"
"C_return(p0);"))
(define vcopy!
(foreign-safe-lambda
int
"gsl_vector_complex_memcpy"
gsl_vector
gsl_vector))
(define vswap!
(foreign-safe-lambda
int
"gsl_vector_complex_swap_elements"
gsl_vector
unsigned-int
unsigned-int))
(define vreverse!
(foreign-safe-lambda int "gsl_vector_complex_reverse" gsl_vector))
(define v+
(foreign-safe-lambda int "gsl_vector_complex_add" gsl_vector gsl_vector))
(define v-
(foreign-safe-lambda int "gsl_vector_complex_sub" gsl_vector gsl_vector))
(define v*
(foreign-safe-lambda int "gsl_vector_complex_mul" gsl_vector gsl_vector))
(define v/
(foreign-safe-lambda int "gsl_vector_complex_div" gsl_vector gsl_vector))
(define v*c
(complex-foreign-lambda
int
"gsl_vector_complex_scale"
gsl_vector
(complex double)))
(define v+c
(complex-foreign-lambda
int
"gsl_vector_complex_add_constant"
gsl_vector
(complex double)))
(define (vmax v) (error "Complex numbers have no ordering."))
(define (vmin v) (error "Complex numbers have no ordering."))
(define (vargmax v) (error "Complex numbers have no ordering."))
(define (vargmin v) (error "Complex numbers have no ordering."))
(define vnull?
(foreign-safe-lambda bool "gsl_vector_complex_isnull" gsl_vector))
(define vpositive?
(foreign-safe-lambda bool "gsl_vector_complex_ispos" gsl_vector))
(define vnegative?
(foreign-safe-lambda bool "gsl_vector_complex_isneg" gsl_vector))
(define vnonnegative?
(foreign-safe-lambda bool "gsl_vector_complex_isnonneg" gsl_vector))
(define vequal?
(foreign-safe-lambda
bool
"gsl_vector_complex_equal"
gsl_vector
gsl_vector))
(define vimag
(foreign-safe-lambda*
gsl_vector
((gsl_vector v))
"gsl_vector_complex *p0 = gsl_vector_complex_alloc(v->size);"
"gsl_vector_view p1 = gsl_vector_complex_imag(v);"
"for (int i = 0; i < v->size; i++) { "
"double iz = gsl_vector_get(&p1.vector, i);"
"gsl_vector_complex_set(p0,i,gsl_complex_rect(0, iz));"
"}"
"C_return(p0);"))
(define vreal
(foreign-safe-lambda*
gsl_vector
((gsl_vector v))
"gsl_vector_complex *p0 = gsl_vector_complex_alloc(v->size);"
"gsl_vector_view p1 = gsl_vector_complex_real(v);"
"for (int i = 0; i < v->size; ++i) { "
"double rz = gsl_vector_get(&p1.vector, i);"
"gsl_vector_complex_set(p0,i,gsl_complex_rect(rz, 0));"
"}"
"C_return(p0);")))
(module csl.zdvector = (generic-vector _csl.zdvector)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment