Skip to content

Instantly share code, notes, and snippets.

@imjasonh
Last active October 28, 2021 06:26
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save imjasonh/7791518 to your computer and use it in GitHub Desktop.
Save imjasonh/7791518 to your computer and use it in GitHub Desktop.
Simple reusable Set minilibrary for Go. Doesn't have any type checking (lolgenerics), just using interface{}. Something like this has come in handy enough times in the past, even without generics/typing, that I think it's worth jotting down.
// Package readyset provides a simple set implementation.
//
// Elements in the set must be hashable (e.g., no slices)
package readyset
import "fmt"
// Set is a container for arbitrary data which ensures that duplicates elements will not be stored multiple times.
type Set map[interface{}]struct{}
// NewSet creates a new Set containing the given elements.
func NewSet(in ...interface{}) Set {
set := Set{}
set.Add(in...)
return set
}
// Contains returns whether the given element is in the set.
func (set Set) Contains(i interface{}) bool {
_, in := set[i]
return in
}
// Add adds the given elements to the set.
func (set Set) Add(is ...interface{}) {
for _, i := range is {
set[i] = struct{}{}
}
}
// Remove removes the given elements from the set.
func (set Set) Remove(is ...interface{}) {
for _, i := range is {
delete(set, i)
}
}
// Slice returns the elements of the set as a slice.
func (set Set) Slice() []interface{} {
s := []interface{}{}
for k, _ := range set {
s = append(s, k)
}
return s
}
// String renders a string representation of the set.
func (set Set) String() string {
if len(set) == 0 {
return "[]"
}
s := "["
for k, _ := range set {
s += fmt.Sprintf("%v ", k)
}
return s[:len(s)-1] + "]"
}
// Len returns the size of the set.
func (set Set) Len() int {
return len(set)
}
// Intersection returns a set containing all elements contained in both given sets.
func Intersection(a, b Set) Set {
i := NewSet()
for k, _ := range a {
if b.Contains(k) {
i.Add(k)
}
}
for k, _ := range b {
if a.Contains(k) {
i.Add(k)
}
}
return i
}
// Intersection returns a set containing all elements contained in either of the given sets.
func Union(a, b Set) Set {
i := a
i.Add(b.Slice()...)
return i
}
package readyset
import "testing"
func TestContains(t *testing.T) {
cs := []struct {
s Set
e []interface{}
}{{
NewSet(1, 2, 3),
[]interface{}{1, 2, 3},
}, {
NewSet(true, "foo", struct{}{}),
[]interface{}{true, "foo", struct{}{}},
}}
for _, c := range cs {
for _, x := range c.e {
if !c.s.Contains(x) {
t.Errorf("set %v should contain %v", c.s, x)
}
if !sliceContains(c.s.Slice(), x) {
t.Errorf("slice %v should contain %v", c.s.Slice(), x)
}
}
}
}
func sliceContains(s []interface{}, e interface{}) bool {
for _, se := range s {
if se == e {
return true
}
}
return false
}
func TestIntersection(t *testing.T) {
cs := []struct {
a, b Set
e []interface{}
}{{
NewSet(1, 2, 3),
NewSet(3, 4, 5),
[]interface{}{3},
}, {
NewSet(true, false, "foo"),
NewSet(false, "foo", "bar"),
[]interface{}{false, "foo"},
}}
for _, c := range cs {
i := Intersection(c.a, c.b)
for _, e := range c.e {
if !i.Contains(e) {
t.Errorf("expected %v in %v", e, i)
}
}
}
}
func TestUnion(t *testing.T) {
cs := []struct {
a, b Set
e []interface{}
}{{
NewSet(1, 2),
NewSet(2, 3),
[]interface{}{1, 2, 3},
}, {
NewSet(false, true, "foo"),
NewSet(true, "boo", struct{}{}),
[]interface{}{false, true, "foo", "boo", struct{}{}},
}}
for _, c := range cs {
u := Union(c.a, c.b)
for _, e := range c.e {
if !u.Contains(e) {
t.Errorf("expected %v in %v", e, u)
}
}
if u.Len() != len(c.e) {
t.Errorf("unexpected len, got %d, want %d", u.Len(), len(c.e))
}
}
}
func TestAddRemove(t *testing.T) {
s := NewSet(1, 2, 3)
s.Remove(1)
if s.Contains(1) {
t.Errorf("expected 1 to be removed from %v", s)
}
if s.Len() != 2 {
t.Errorf("expected %v.Len() == 2", s)
}
// Add new element
s.Add(4)
if !s.Contains(4) {
t.Errorf("expected 4 to be added to %v", s)
}
if s.Len() != 3 {
t.Errorf("expected %v.Len() == 3", s)
}
// Readd existing element
s.Add(2)
if !s.Contains(2) {
t.Errorf("expected 2 to be in %v", s)
}
if s.Len() != 3 {
t.Errorf("expected %v.Len() == 3", s)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment