Skip to content

Instantly share code, notes, and snippets.

@rsms
Created November 23, 2020 17:35
Show Gist options
  • Save rsms/42b5e6712a499c96995c6b7340a0a9e4 to your computer and use it in GitHub Desktop.
Save rsms/42b5e6712a499c96995c6b7340a0a9e4 to your computer and use it in GitHub Desktop.
Optimized AW-Set aka OR-SWOT — a state-based CRDT set
package crdt
import (
"fmt"
"os"
"sort"
"strings"
"testing"
)
// CRDT "Optimized AW-Set" aka "OR-SWOT".
// Inspired by a talk by Russell Brown of Riak at Erlang Factory in Stockholm, 2016.
// Slides: http://www.erlang-factory.com/static/upload/media/
// 1474729847848977russellbrownbiggersetseuc2016.pdf
type Clock struct {
ClientId uint
Version uint
}
type VectorClock []uint
func (v *VectorClock) Merge(src VectorClock) {
dst := *v
for i, n := range src {
if i < len(dst) {
if dst[i] < n {
dst[i] = n
}
} else {
dst = append(dst, n)
}
}
*v = dst
}
func (v VectorClock) String() string {
var sb strings.Builder
sb.WriteByte('[')
for i, n := range v {
if i > 0 {
sb.WriteString(", ")
}
fmt.Fprintf(&sb, "%c %d", 'A'+i, n)
}
sb.WriteByte(']')
return sb.String()
}
func (v VectorClock) Clone() VectorClock {
c := make(VectorClock, len(v))
copy(c, v)
return c
}
type AWSet struct {
ClientId uint
VectorClock
Entries map[string]Clock
}
func (s *AWSet) SortedValues() []string {
values := make([]string, len(s.Entries))
i := 0
for value := range s.Entries {
values[i] = value
i++
}
sort.Strings(values)
return values
}
func (s *AWSet) Clone() *AWSet {
c := &AWSet{ClientId: s.ClientId}
c.VectorClock = s.VectorClock.Clone()
c.Entries = make(map[string]Clock, len(s.Entries))
for k, v := range s.Entries {
c.Entries[k] = v
}
return c
}
func (s *AWSet) Has(k string) bool { _, ok := s.Entries[k]; return ok }
func (s *AWSet) Add(k string) {
s.VectorClock[s.ClientId]++
version := s.VectorClock[s.ClientId]
v := s.Entries[k]
v.ClientId = s.ClientId
v.Version = version
s.Entries[k] = v
}
func (s *AWSet) Del(k string) {
delete(s.Entries, k)
}
func (s *AWSet) MergeMinimal(src *AWSet) {
// -- WIP -- merge of subset.
// The goal is to only send over entries which have changed.
//
// 1. "src" requests "dst"'s vector clock state
// 2. "dst" replies with a copy of its vector clock state:
dstVectorClock := s.VectorClock.Clone()
//
// 3. "src" picks only the entries which have a larger version
srcEntries := make(map[string]Clock)
for k, v := range src.Entries {
if v.ClientId >= uint(len(dstVectorClock)) {
// entry from client not known by "dst". include it
srcEntries[k] = v
} else if dstVectorClock[v.ClientId] < v.Version {
// entry was modified
srcEntries[k] = v
}
}
// merge
s.merge(src.VectorClock, srcEntries)
}
func (s *AWSet) Merge(src *AWSet) {
s.merge(src.VectorClock, src.Entries)
// s.MergeMinimal(src)
}
func (s *AWSet) merge(srcVectorClock VectorClock, srcEntries map[string]Clock) {
// merge dst <- src
dst := s
for k, srcClock := range srcEntries {
if dstClock, ok := dst.Entries[k]; ok {
// dst contains this entry
if dstClock != srcClock {
fmt.Printf("** updated %q \t %v -> %v\n", k, dstClock, srcClock)
} else {
fmt.Printf("** keep %q \t %v\n", k, dstClock)
}
} else {
// dst may have seen this entry, but doesn't contain it at the moment.
// Decide if we should: add it or ignore it
if srcClock.Version > dst.VectorClock[srcClock.ClientId] {
// the src version
fmt.Printf("** add %q\t %v\n", k, srcClock)
} else {
fmt.Printf("** skip %q\t %v\n", k, srcClock)
continue
}
}
dst.Entries[k] = srcClock
}
// process deleted
for k, dstClock := range dst.Entries {
if _, ok := srcEntries[k]; !ok {
// missing in b (b=incoming)
// KEEP if dstClock IN_SET src.VectorClock
if dstClock.Version < srcVectorClock[dstClock.ClientId] {
fmt.Printf("** remove %q\t %v\n", k, dstClock)
delete(dst.Entries, k)
} else {
fmt.Printf("** keep %q\t %v\n", k, dstClock)
}
}
}
dst.VectorClock.Merge(srcVectorClock)
}
func (s *AWSet) merge1(srcVectorClock VectorClock, srcEntries map[string]Clock) {
// merge dst <- src
dst := s
for k, dstClock := range dst.Entries {
if _, ok := srcEntries[k]; ok {
// src contains this entry
fmt.Printf("> keep %q\n", k)
} else {
// src doesn't contain this entry.
// Either src hasn't seen it or the entry has been removed.
if srcVectorClock[dstClock.ClientId] >= dstClock.Version {
// src has deleted the entry
fmt.Printf("> drop %q\n", k)
delete(dst.Entries, k)
} else {
// src has not seen this entry
dstClock.Version = 0
fmt.Printf("> mark %q\n", k)
}
}
}
for k, srcClock := range srcEntries {
if dstClock, ok := dst.Entries[k]; ok {
// both src and dst has this entry.
if dstClock != srcClock {
fmt.Printf("> update %q \t %v -> %v\n", k, dstClock, srcClock)
} else {
fmt.Printf("> keep %q \t %v\n", k, dstClock)
}
} else {
// dst doesn't contain this entry.
// Either dst hasn't seen it or the entry has been removed.
if srcClock.Version > dst.VectorClock[srcClock.ClientId] {
// the src version
fmt.Printf("> add %q\t %v\n", k, srcClock)
} else {
fmt.Printf("** skip %q\t %v\n", k, srcClock)
continue
}
}
dst.Entries[k] = srcClock
}
dst.VectorClock.Merge(srcVectorClock)
}
func TestAWSet(t *testing.T) {
A, B, printstate, assertEntries := testAWSetInit(t)
// empty
assertEntries(A)
assertEntries(B)
A.Add("Shelly")
printstate()
assertEntries(A, "Shelly")
assertEntries(B)
B.Merge(A) // B <- A
printstate()
assertEntries(A, "Shelly")
assertEntries(B, "Shelly")
B.Add("Bob")
B.Add("Phil")
B.Add("Pete")
assertEntries(A, "Shelly")
assertEntries(B, "Shelly", "Bob", "Phil", "Pete")
A.Merge(B) // A <- B
printstate()
assertEntries(A, "Shelly", "Bob", "Phil", "Pete")
assertEntries(B, "Shelly", "Bob", "Phil", "Pete")
A.Del("Phil")
A.Add("Bob") // update
A.Add("Anna")
assertEntries(A, "Shelly", "Bob" /* */, "Pete", "Anna")
assertEntries(B, "Shelly", "Bob", "Phil", "Pete")
B.Merge(A) // B <- A
printstate()
assertEntries(A, "Shelly", "Bob", "Pete", "Anna")
assertEntries(B, "Shelly", "Bob", "Pete", "Anna")
// Test commutativity (merge order doesn't matter.)
// A removes "Anna" while B adds/updates "Anna".
A.Del("Anna")
B.Add("Anna")
assertEntries(A, "Shelly", "Bob", "Pete")
assertEntries(B, "Shelly", "Bob", "Pete", "Anna")
// The outcome should be that "Anna" is restored (undeleted.)
expectedAfterMerge := []string{"Shelly", "Bob", "Pete", "Anna"}
//
// We try different merge orders to ensure the results are the same.
// Merge order: A -> B -> A
if A, B := A.Clone(), B.Clone(); A != nil {
B.Merge(A) // B <- A
A.Merge(B) // A <- B
assertEntries(A, expectedAfterMerge...)
assertEntries(B, expectedAfterMerge...)
}
// Merge order: B -> A -> B
A.Merge(B) // A <- B
B.Merge(A) // B <- A
assertEntries(A, expectedAfterMerge...)
assertEntries(B, expectedAfterMerge...)
printstate()
A.Del("Bob")
A.Del("Pete")
B.Del("Bob")
B.Del("Shelly")
A.Merge(B) // A <- B
B.Merge(A) // B <- A
printstate()
assertEntries(A, "Anna")
assertEntries(B, "Anna")
A.Add("A")
A.Add("B")
A.Add("C")
A.Del("A")
A.Add("A")
B.Merge(A) // B <- A
printstate()
assertEntries(A, "Anna", "A", "B", "C")
assertEntries(B, "Anna", "A", "B", "C")
os.Exit(0)
}
func testAWSetInit(
t *testing.T,
) (A, B *AWSet, printstate func(), assertEntries func(*AWSet, ...string) bool) {
A = &AWSet{
ClientId: 0,
VectorClock: VectorClock{0, 0},
Entries: make(map[string]Clock),
}
B = &AWSet{
ClientId: 1,
VectorClock: VectorClock{0, 0},
Entries: make(map[string]Clock),
}
printstate = func() {
fmt.Printf("————————————————————————————————————————————————\n")
fmt.Printf("Replica A: %s\n", A)
fmt.Printf("Replica B: %s\n", B)
fmt.Printf("————————————————————————————————————————————————\n")
}
assertEntries = func(s *AWSet, expectedValues ...string) bool {
t.Helper()
sort.Strings(expectedValues)
actualValues := s.SortedValues()
if len(actualValues) != len(expectedValues) {
t.Errorf("expected %d values, got %d\nexpected: %v\ngot: %v",
len(expectedValues), len(actualValues),
expectedValues, actualValues)
t.FailNow()
return false
}
for i, value := range actualValues {
if value != expectedValues[i] {
t.Errorf("expected values[%d] to be %q, got %q\nexpected: %v\ngot: %v",
i, expectedValues[i], value,
expectedValues, actualValues)
t.FailNow()
return false
}
}
return true
}
return
}
// ----------------------------------------------------
func (v Clock) String() string {
return fmt.Sprintf("(%c, %d)", 'A'+v.ClientId, v.Version)
}
func (s AWSet) String() string {
var sb strings.Builder
sb.WriteString(s.VectorClock.String())
for _, value := range s.SortedValues() {
v := s.Entries[value]
fmt.Fprintf(&sb, "\n %s %q", v, value)
}
return sb.String()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment