Created
August 22, 2012 05:44
-
-
Save takikawa/3422588 to your computer and use it in GitHub Desktop.
Persistent union-find
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#lang racket | |
;; Persistent union-find from http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.79.8494 | |
(define-signature persistent-array^ | |
(init get set)) | |
(define-signature union-find^ | |
(create find union)) | |
(define-unit persistent-union-find@ | |
(import persistent-array^) | |
(export union-find^) | |
(struct uf (rank [parent #:mutable])) | |
(define (create n) | |
(uf (init n (λ (_) 0)) | |
(init n (λ (i) i)))) | |
(define (find-aux f i) | |
(define fi (get f i)) | |
(cond [(= fi i) (values f i)] | |
[else | |
(define-values (f* r) (find-aux f fi)) | |
(define f** (set f* i r)) | |
(values f** r)])) | |
(define (find h x) | |
(define-values (f cx) (find-aux (uf-parent h) x)) | |
(set-uf-parent! h f) | |
cx) | |
(define (union h x y) | |
(define cx (find h x)) | |
(define cy (find h y)) | |
(cond [(not (= cx cy)) | |
(define rx (get (uf-rank h) cx)) | |
(define ry (get (uf-rank h) cy)) | |
(cond [(> rx ry) | |
(uf (uf-rank h) (set (uf-parent h) cy cx))] | |
[(< rx ry) | |
(uf (uf-rank h) (set (uf-parent h) cx cy))] | |
[else | |
(uf (set (uf-rank h) cx (add1 rx)) | |
(set (uf-parent h) cy cx))])] | |
[else h]))) | |
;; mutable vector version | |
(define-unit vector@ | |
(import) | |
(export persistent-array^) | |
(define (init n f) | |
(build-vector n f)) | |
(define (get v i) | |
(vector-ref v i)) | |
(define (set v i elem) | |
(define v* (vector-copy v)) | |
(vector-set! v* i elem) | |
v*)) | |
;; better, but slow on old versions | |
(define-unit persistent-vector-slow@ | |
(import) | |
(export persistent-array^) | |
;; PVector is a | |
;; - (pvector Data) | |
;; | |
;; Data is one of | |
;; - vector? | |
;; - (diff int elem PVector) | |
(struct pvector (data) #:mutable) | |
(struct diff (idx val vec)) | |
(define (init n f) | |
(pvector (build-vector n f))) | |
(define (get pv i) | |
(match (pvector-data pv) | |
[(? vector? vec) (vector-ref vec i)] | |
[(diff j val vec) | |
(if (= i j) | |
val | |
(get vec i))])) | |
(define (set pv i v) | |
(match (pvector-data pv) | |
[(? vector? vec) | |
(define old (vector-ref vec i)) | |
(vector-set! vec i v) | |
(define res (pvector vec)) | |
(set-pvector-data! pv (diff i old res)) | |
res] | |
[(diff _1 _2 _3) | |
(pvector (diff i v pv))]))) | |
(define-unit persistent-vector@ | |
(import) | |
(export persistent-array^) | |
;; as before | |
(struct pvector (data) #:mutable) | |
(struct diff (idx val vec)) | |
;; reroot a persistent array | |
(define (reroot pv) | |
(match (pvector-data pv) | |
[(? vector? _) (void)] | |
[(diff i v pv*) | |
(reroot pv*) | |
(match (pvector-data pv*) | |
[(? vector? vec) | |
(define v* (vector-ref vec i)) | |
(vector-set! vec i v) | |
(set-pvector-data! pv vec) | |
(set-pvector-data! pv* (diff i v* pv))] | |
[(diff _1 _2 _3) (error "Internal error")])])) | |
(define (init n f) | |
(pvector (build-vector n f))) | |
(define (get pv i) | |
(match (pvector-data pv) | |
[(? vector? vec) (vector-ref vec i)] | |
[(diff j val vec) | |
(reroot pv) | |
(match (pvector-data pv) | |
[(? vector? vec) (vector-ref vec i)] | |
[(diff _1 _2 _3) (error "Internal error")])])) | |
(define (set pv i v) | |
(reroot pv) | |
(match (pvector-data pv) | |
[(? vector? vec) | |
(define old (vector-ref vec i)) | |
(vector-set! vec i v) | |
(define res (pvector vec)) | |
(set-pvector-data! pv (diff i old res)) | |
res] | |
[(diff _1 _2 _3) (error "Internal error")]))) | |
;(define-values/invoke-unit/infer (link persistent-union-find@ vector@)) | |
(define-values/invoke-unit/infer (link persistent-union-find@ persistent-vector@)) | |
(module+ test | |
(require rackunit) | |
(check-equal? (find (create 10) 5) 5) | |
(check-equal? (find (union (create 10) 5 3) 5) 5) | |
(check-equal? (find (union (create 10) 5 3) 3) 5) | |
(check-equal? (find (union (union (create 10) 5 3) 1 5) 1) 5)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment