Skip to content

Instantly share code, notes, and snippets.

@samth
Created April 21, 2021 02:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samth/92d24b44184fbcdace3215b193848e64 to your computer and use it in GitHub Desktop.
Save samth/92d24b44184fbcdace3215b193848e64 to your computer and use it in GitHub Desktop.
#lang racket
(require racket/fixnum)
(define (bignum? x) (integer? x))
(define ($bignum-length x) (quotient (integer-length x) 10))
(define (integer* x y) (* x y))
(require racket/trace)
(define-syntax-rule (assert b . args)
(unless b
(error "assertion failure" 'b . args)))
(define constant (lambda (e) e))
(define bigit-bits 1)
(define ash arithmetic-shift)
(define integer-ash ash)
;; _Modern Computer Arithmetic_, Brent and Zimmermann
(define (karatsuba x y)
(define xl (if (bignum? x) ($bignum-length x) 0))
(define yl (if (bignum? y) ($bignum-length y) 0))
(cond
[(and (fx< xl 32) (fx< yl 32))
(integer* x y)]
[else
(let* ([k (fx* (fxquotient (fxmax xl yl) 2) (constant bigit-bits))]
[x-hi (ash x (fx- k))]
[y-hi (ash y (fx- k))]
(x-lo (bitwise-bit-field x 0 k))
(y-lo (bitwise-bit-field y 0 k))
[_ (assert (or (not (= x-lo x))
(not (= y-lo y)))
x-lo x
y-lo y)]
[c0 (karatsuba x-lo y-lo)]
[c1 (karatsuba x-hi y-hi)]
[c1-c2 (cond
[(< x-lo x-hi)
(cond
[(< y-lo y-hi)
(- c1 (karatsuba (- x-hi x-lo) (- y-hi y-lo)))]
[else
(+ c1 (karatsuba (- x-hi x-lo) (- y-lo y-hi)))])]
[else
(cond
[(< y-lo y-hi)
(+ c1 (karatsuba (- x-lo x-hi) (- y-hi y-lo)))]
[else
(- c1 (karatsuba (- x-lo x-hi) (- y-lo y-hi)))])])])
(+ c0 (integer-ash (+ c0 c1-c2) k) (integer-ash c1 (fx* 2 k))))]))
(define >> (lambda (x y) (arithmetic-shift x (- y))))
(define (<< x y) (arithmetic-shift x y))
(define (// x y) (quotient x y))
(define ZERO 0)
(define ONE 1)
(define TWO 2)
(define SIX 6)
(define FOUR 4)
;; _Modern Computer Arithmetic_, Brent and Zimmermann
(define (toom3 x y)
(define xl (if (bignum? x) ($bignum-length x) 0))
(define yl (if (bignum? y) ($bignum-length y) 0))
(cond
[(and (fx< xl 32) (fx< yl 32))
(integer* x y)]
[(and (fx< xl 100) (fx< yl 100))
(karatsuba x y)]
[else
(let* ([k (fx* (fxquotient (fxmax xl yl) 3) (constant bigit-bits))]
[x-hi (ash x (* -2 k))]
[y-hi (ash y (* -2 k))]
(x-mid (bitwise-bit-field x (fx* 1 k) (fx* 2 k)))
(y-mid (bitwise-bit-field y (fx* 1 k) (fx* 2 k)))
(x-lo (bitwise-bit-field x 0 k))
(y-lo (bitwise-bit-field y 0 k))
#;[_
(begin
(assert (or (not (= x x-lo))
(not (= y y-lo)))
x-lo y-lo)
;(displayln (list x-hi x-mid x-lo))
;(displayln (list y-hi y-mid y-lo))
(unless (= (+ (ash x-hi (* k 2))
(ash x-mid k)
x-lo)
x)
(printf "~s\n" x)
(error "assertion" x x-hi x-mid x-lo k)))]
[z0 (toom3 x-hi y-hi)]
[z4 (toom3 x-lo y-lo)]
[t1 (toom3 (+ x-hi x-mid x-lo) (+ y-hi y-mid y-lo))]
[t2 (toom3 (+ (- x-hi x-mid) x-lo) (+ (- y-hi y-mid) y-lo))]
[t3 (* (+ x-hi (<< x-mid ONE) (<< x-lo TWO)) (+ y-hi (<< y-mid ONE) (<< y-lo TWO)))]
[z2 (- (>> (+ t1 t2) ONE) z0 z4)]
[t4 (- t3 z0 (<< z2 TWO) (<< z4 FOUR))]
[z3 (// (+ (- t4 t1) t2) SIX)]
[z1 (- (>> (- t1 t2) ONE) z3)])
(+ (ash z0 (* k 4))
(ash z1 (* k 3))
(ash z2 (* k 2))
(ash z3 (* k 1))
(ash z4 (* k 0))))]))
(define (toom4 x y)
(define xl (if (bignum? x) ($bignum-length x) 0))
(define yl (if (bignum? y) ($bignum-length y) 0))
(cond
[(and (fx< xl 32) (fx< yl 32))
(integer* x y)]
[(and (fx< xl 100) (fx< yl 100))
(karatsuba x y)]
[(and (fx< xl 256) (fx< yl 256))
(toom3 x y)]
[else
(let* ((k (fx* (fxquotient (fxmax xl yl) 4) (constant bigit-bits)))
(x0 (ash x (fx* -3 k)))
(y0 (ash y (fx* -3 k)))
(x1 (bitwise-bit-field x (fx* 2 k) (fx* 3 k)))
(y1 (bitwise-bit-field y (fx* 2 k) (fx* 3 k)))
(x2 (bitwise-bit-field x (fx* 1 k) (fx* 2 k)))
(y2 (bitwise-bit-field y (fx* 1 k) (fx* 2 k)))
(x3 (bitwise-bit-field x 0 k))
(y3 (bitwise-bit-field y 0 k))
(z0 (toom4 x0 y0))
(z6 (toom4 x3 y3))
(t0 (+ z0 z6))
(xeven (+ x0 x2))
(xodd (+ x1 x3))
(yeven (+ y0 y2))
(yodd (+ y1 y3))
(t1 (- (toom4 (+ xeven xodd) (+ yeven yodd)) t0))
(t2 (- (toom4 (- xeven xodd) (- yeven yodd)) t0))
(xeven (+ x0 (ash x2 2)))
(xodd (+ (ash x1 1) (ash x3 3)))
(yeven (+ y0 (ash y2 2)))
(yodd (+ (ash y1 1) (ash y3 3)))
(t0 (+ z0 (ash z6 6)))
(t3 (- (toom4 (+ xeven xodd) (+ yeven yodd)) t0))
(t4 (- (toom4 (- xeven xodd) (- yeven yodd)) t0))
(t5 (- (* (+ x0 (* 3 x1) (* 9 x2) (* 27 x3))
(+ y0 (* 3 y1) (* 9 y2) (* 27 y3)))
(+ z0 (* 729 z6))))
(t6 (+ t1 t2))
(t7 (+ t3 t4))
(z4 (quotient (- t7 (ash t6 2)) 24))
(z2 (- (ash t6 -1) z4))
(t8 (- t1 z2 z4))
(t9 (- t3 (ash z2 2) (ash z4 4)))
(t10 (- t5 (* 9 z2) (* 81 z4)))
(t11 (- t10 (* 3 t8)))
(t12 (- t9 (ash t8 1)))
(z5 (quotient (- t11 (ash t12 2)) 120))
(z3 (quotient (- (ash t12 3) t11) 24))
(z1 (- t8 z3 z5)))
(define r
(+ (ash z0 (* k 6))
(ash z1 (* k 5))
(ash z2 (* k 4))
(ash z3 (* k 3))
(ash z4 (* k 2))
(ash z5 (* k 1))
(ash z6 (* k 0))))
#;#;(define rr (* x y))
(unless (= r rr)
(printf "failed step: (* ~s ~s) = \n~s \n~s\n~s" x y r rr (- r rr))
(error 'fail))
r)]))
(require rackunit)
(define (check-mul toom3 _a _b #:chg [chg? #t])
(define chg (if chg? (- (random 1000) 500) 0))
(define chg2 (if chg? (- (random 1000) 500) 0))
(define a (+ _a chg))
(define b (+ _b chg2))
(printf "trying (check-mul ~s ~s ~s #:chg #f)\n" (object-name toom3) a b)
(define-values (x y) (values (* a b) (toom3 a b)))
(check-equal? x y)
(define-values (x2 y2) (values (* (- a) b) (toom3 (- a) b)))
(check-equal? x2 y2)
(define-values (x3 y3) (values (* a (- b)) (toom3 a (- b))))
(check-equal? x3 y3)
(define-values (x4 y4) (values (* (- a) (- b)) (toom3 (- a) (- b))))
(check-equal? x4 y4)
#;
(list x y))
;(check-mul (expt 2 1000) (+ 1 (expt 3 1000)))
(define (run-tests)
(let loop ([n 1000])
(check-mul toom4 (expt 2 n) (+ 1 (expt 2 n)))
(printf "~s worked\n" n)
(loop (+ n 10))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment