Skip to content

Instantly share code, notes, and snippets.

@lixin9311
Created May 12, 2022 09:23
Show Gist options
  • Save lixin9311/c111a19e4e39e5a1341a71adc2b9fc57 to your computer and use it in GitHub Desktop.
Save lixin9311/c111a19e4e39e5a1341a71adc2b9fc57 to your computer and use it in GitHub Desktop.
Go generic datastore db driver
package dbdriver
import (
"context"
"errors"
"reflect"
"cloud.google.com/go/datastore"
)
var _ Indexable = UnimplementedIndexable{}
// CallOption configures a Call before it starts or extracts information from
// a Call after it completes.
type TXOption[T Indexable] interface {
before(*txinfo, T) error
after(*txinfo, T) error
}
var _ TXOption[UnimplementedIndexable] = emptyTXOption[UnimplementedIndexable]{}
type emptyTXOption[T Indexable] struct{}
func (emptyTXOption[T]) before(*txinfo, T) error { return nil }
func (emptyTXOption[T]) after(*txinfo, T) error { return nil }
type txinfo struct {
client *datastore.Client
tx *datastore.Transaction
op string
}
type beforePutOption[T Indexable] struct {
emptyTXOption[T]
beforeFunc func(*datastore.Transaction, T) error
}
func (o *beforePutOption[T]) before(txinfo *txinfo, t T) error {
if txinfo.op == "PUT" {
return o.beforeFunc(txinfo.tx, t)
}
return nil
}
func BeforePut[T Indexable](f func(*datastore.Transaction, T) error) TXOption[T] {
return &beforePutOption[T]{beforeFunc: f}
}
// Should be a pointer
type Indexable interface {
// Required*
BuildKey() *datastore.Key
// Required* only called once upon initialization
Kind() string
// (Optional) will be called before Insert, Update(if not exist, and a constructor is provided)
// be sure it dose not conflict with the update constructor
OnCreate() error
// (Optional) will be called before Update
OnUpdate() error
// (Optional) will be called before Put
OnPut() error
// (Optional) will be called after every operation except delete
AfterGet() error
mustEmbedUnimplementedIndexable()
}
type UnimplementedIndexable struct{}
func (UnimplementedIndexable) BuildKey() *datastore.Key {
panic("UnimplementedIndexable.BuildKey")
}
func (UnimplementedIndexable) Kind() string {
panic("UnimplementedIndexable.Kind")
}
func (UnimplementedIndexable) OnCreate() error { return nil }
func (UnimplementedIndexable) OnUpdate() error { return nil }
func (UnimplementedIndexable) AfterGet() error { return nil }
func (UnimplementedIndexable) OnPut() error { return nil }
func (UnimplementedIndexable) mustEmbedUnimplementedIndexable() {}
type DatastoreDriver[T Indexable] struct {
*datastore.Client
namespace string
kind string
// constructor is used to create empty new instances.
// use the arg to populate default values, such as
// wallet address for an empty profile, 0 balance for an empty balance.
constructor func(T) T
}
func NewDatastoreDriver[T Indexable](
client *datastore.Client,
namespace string,
constructor func(T) T,
) *DatastoreDriver[T] {
var obj T
if reflect.TypeOf(obj).Kind() != reflect.Ptr {
panic("Type must be a pointer")
} else if reflect.TypeOf(obj).Elem().Kind() != reflect.Struct {
panic("Type must be a pointer to a struct")
}
return &DatastoreDriver[T]{
Client: client,
namespace: namespace,
kind: obj.Kind(),
constructor: constructor,
}
}
func (db *DatastoreDriver[_]) GetNamespace() string {
return db.namespace
}
func (db *DatastoreDriver[T]) buildKey(obj T) *datastore.Key {
key := obj.BuildKey()
key.Namespace = db.namespace
parentKey := key.Parent
for parentKey != nil {
parentKey.Namespace = db.namespace
parentKey = parentKey.Parent
}
return key
}
// Non-transactional Get
func (db *DatastoreDriver[T]) Get(ctx context.Context, obj T) (T, error) {
key := db.buildKey(obj)
result := db.constructor(obj)
if err := db.Client.Get(ctx, key, result); err != nil {
return result, err
}
if err := result.AfterGet(); err != nil {
return result, err
}
return result, nil
}
// Use it inside a transaction
func (db *DatastoreDriver[T]) GetTX(tx *datastore.Transaction, obj T) (T, error) {
key := db.buildKey(obj)
result := db.constructor(obj)
if err := tx.Get(key, result); err != nil {
return result, err
}
if err := result.AfterGet(); err != nil {
return result, err
}
return result, nil
}
func (db *DatastoreDriver[T]) BatchGet(ctx context.Context, objs []T) ([]T, error) {
keys := make([]*datastore.Key, len(objs))
for i, obj := range objs {
keys[i] = db.buildKey(obj)
}
if err := db.GetMulti(ctx, keys, objs); err != nil {
return objs, err
}
for _, obj := range objs {
if err := obj.AfterGet(); err != nil {
return objs, err
}
}
return objs, nil
}
// will reuse the given array
func (db *DatastoreDriver[T]) BatchGetTX(tx *datastore.Transaction, objs []T) ([]T, error) {
keys := make([]*datastore.Key, len(objs))
for i, obj := range objs {
keys[i] = db.buildKey(obj)
}
if err := tx.GetMulti(keys, objs); err != nil {
return objs, err
}
for _, obj := range objs {
if err := obj.AfterGet(); err != nil {
return objs, err
}
}
return objs, nil
}
// Non-transactional Insert
func (db *DatastoreDriver[T]) Insert(ctx context.Context, obj T) (T, error) {
key := db.buildKey(obj)
if err := obj.OnCreate(); err != nil {
return obj, err
}
mut := datastore.NewInsert(key, obj)
if _, err := db.Client.Mutate(ctx, mut); err != nil {
return obj, err
}
if err := obj.AfterGet(); err != nil {
return obj, err
}
return obj, nil
}
// Use it inside a transaction
func (db *DatastoreDriver[T]) InsertTX(tx *datastore.Transaction, obj T) (T, error) {
key := db.buildKey(obj)
if err := obj.OnCreate(); err != nil {
return obj, err
}
mut := datastore.NewInsert(key, obj)
if _, err := tx.Mutate(mut); err != nil {
return obj, err
}
if err := obj.AfterGet(); err != nil {
return obj, err
}
return obj, nil
}
// Wrapped BatchInsertTX
func (db *DatastoreDriver[T]) BatchInsert(ctx context.Context, objs []T) error {
if _, err := db.Client.RunInTransaction(ctx, func(tx *datastore.Transaction) error {
if err := db.BatchInsertTX(tx, objs); err != nil {
return err
}
return nil
}); err != nil {
return err
}
return nil
}
// Transactional batch Insert
func (db *DatastoreDriver[T]) BatchInsertTX(tx *datastore.Transaction, objs []T) error {
muts := make([]*datastore.Mutation, len(objs))
for i, obj := range objs {
obj.OnCreate()
muts[i] = datastore.NewInsert(db.buildKey(obj), obj)
}
if _, err := tx.Mutate(muts...); err != nil {
return err
}
return nil
}
// Non-transactional Put
func (db *DatastoreDriver[T]) Put(ctx context.Context, obj T) (T, error) {
key := db.buildKey(obj)
if err := obj.OnPut(); err != nil {
return obj, err
}
if _, err := db.Client.Put(ctx, key, obj); err != nil {
return obj, err
}
if err := obj.AfterGet(); err != nil {
return obj, err
}
return obj, nil
}
// Use it inside a transaction
func (db *DatastoreDriver[T]) PutTX(tx *datastore.Transaction, obj T) (T, error) {
key := db.buildKey(obj)
if err := obj.OnPut(); err != nil {
return obj, err
}
if _, err := tx.Put(key, obj); err != nil {
return obj, err
}
if err := obj.AfterGet(); err != nil {
return obj, err
}
return obj, nil
}
func (db *DatastoreDriver[T]) BatchPut(ctx context.Context, objs []T, dbopts ...TXOption[T]) error {
if _, err := db.Client.RunInTransaction(ctx, func(tx *datastore.Transaction) error {
return db.BatchPutTX(tx, objs, dbopts...)
}); err != nil {
return err
}
return nil
}
// be careful, this will override the existing objects
func (db *DatastoreDriver[T]) BatchPutTX(tx *datastore.Transaction, objs []T, dbopts ...TXOption[T]) error {
txinfo := &txinfo{
client: db.Client,
tx: tx,
op: "PUT",
}
muts := make([]*datastore.Mutation, len(objs))
for i, obj := range objs {
for _, opt := range dbopts {
if err := opt.before(txinfo, obj); err != nil {
return err
}
}
if err := obj.OnPut(); err != nil {
return err
}
muts[i] = datastore.NewUpsert(db.buildKey(obj), obj)
}
_, err := tx.Mutate(muts...)
return err
}
// Wrapped UpdateTX
func (db *DatastoreDriver[T]) Update(ctx context.Context, obj T, update func(T) (T, error), constructor func(T) T) (result T, err error) {
if _, err := db.Client.RunInTransaction(ctx, func(tx *datastore.Transaction) error {
result, err = db.UpdateTX(tx, obj, update, constructor)
return err
}); err != nil {
return result, err
}
return result, nil
}
// If a constructor is provided, and the object is not existed, it will be created.
// Otherwise, NotFound error will be returned.
func (db *DatastoreDriver[T]) UpdateTX(tx *datastore.Transaction, obj T, update func(T) (T, error), constructor func(T) T) (result T, err error) {
key := db.buildKey(obj)
result = db.constructor(obj)
if err := tx.Get(key, result); err == datastore.ErrNoSuchEntity {
if constructor == nil {
return result, err
}
result = constructor(obj)
if err := result.OnCreate(); err != nil {
return result, err
}
} else if err != nil {
return result, err
}
if result, err = update(result); errors.Is(err, ErrAborted) {
return result, nil
} else if err != nil {
return result, err
}
if err := result.OnUpdate(); err != nil {
return result, err
}
if _, err := tx.Put(key, result); err != nil {
return result, err
}
if err := result.AfterGet(); err != nil {
return result, err
}
return result, nil
}
// Wrapped BatchDeleteTX
func (db *DatastoreDriver[T]) BatchDelete(ctx context.Context, objs []T) error {
if _, err := db.Client.RunInTransaction(ctx, func(tx *datastore.Transaction) error {
return db.BatchDeleteTX(tx, objs)
}); err != nil {
return err
}
return nil
}
func (db *DatastoreDriver[T]) BatchDeleteTX(tx *datastore.Transaction, objs []T) error {
keys := make([]*datastore.Key, len(objs))
for i, obj := range objs {
keys[i] = db.buildKey(obj)
}
return tx.DeleteMulti(keys)
}
package dbdriver
import (
"context"
"fmt"
"testing"
"cloud.google.com/go/datastore"
)
type Test struct {
UnimplementedIndexable
ID string `datastore:"id"`
Name string `datastore:"name"`
}
func (t *Test) BuildKey() *datastore.Key {
return datastore.NameKey("Test", t.ID, nil)
}
func (t *Test) OnCreate() error {
return nil
}
func (t *Test) OnUpdate() error {
return nil
}
func (t *Test) Kind() string {
return "Test"
}
func TestGenericDriver(t *testing.T) {
ctx := context.Background()
client, err := datastore.NewClient(ctx, "project-id")
if err != nil {
t.Fatal(err)
}
db := NewDatastoreDriver(
client,
"test",
func(t *Test) *Test { return new(Test) },
)
result, err := db.Get(ctx, &Test{ID: "test"})
if err != nil {
t.Fatal(err)
}
fmt.Println(result)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment