Skip to content

Instantly share code, notes, and snippets.

@glerchundi
Created July 31, 2018 14:24
Show Gist options
  • Save glerchundi/315be9ae9e4b72c467f4ef39d57ef004 to your computer and use it in GitHub Desktop.
Save glerchundi/315be9ae9e4b72c467f4ef39d57ef004 to your computer and use it in GitHub Desktop.
tenancy
package tenancy
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"github.com/lib/pq"
)
// tenantContext contains the state that must propagate across process
// boundaries.
type tenantContext struct {
id string
}
type contextKey struct{}
// fromContext returns the tenantContext stored in a context, or an error if
// there isn't one.
func fromContext(ctx context.Context) (*tenantContext, error) {
tenantCtx, ok := ctx.Value(contextKey{}).(*tenantContext)
if !ok {
return nil, errors.New("tenancy: unable to retrieve tenant context")
}
return tenantCtx, nil
}
// WithID returns a new context with the given tenant id attached.
func WithID(parent context.Context, id string) context.Context {
return context.WithValue(parent, contextKey{}, &tenantContext{id: id})
}
// Driver is the Postgres database driver for Multi-Tenancy.
type Driver struct{}
// Open opens a new connection to the database. name is a connection string.
// Most users should only use it through database/sql package from the standard
// library.
func (d *Driver) Open(name string) (driver.Conn, error) {
return open(name)
}
func init() {
sql.Register("postgres-tenancy", &Driver{})
}
type conn struct {
driver.Conn
}
func open(name string) (driver.Conn, error) {
c, err := pq.Open(name)
if err != nil {
return nil, err
}
return &conn{
Conn: c,
}, nil
}
// Prepare implements driver.Conn.Prepare.
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.Conn.Prepare(query)
}
// Close implements driver.Conn.Close.
func (c *conn) Close() error {
return c.Conn.Close()
}
// BeginTx implements driver.ConnBeginTx.BeginTx.
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return c.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
}
// Query implements driver.Queryer.Query.
func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return nil, errors.New("driver.Queryer.Query not supported")
}
// QueryContext implements driver.QueryerContext.QueryContext.
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
tenantCtx, err := fromContext(ctx)
if err != nil {
return nil, err
}
useStmt := useStatement(tenantCtx.id)
if len(args) > 0 {
if _, err := c.Conn.(driver.ExecerContext).ExecContext(ctx, useStmt, nil); err != nil {
return nil, err
}
} else {
query = useStmt + ";" + query
}
return c.Conn.(driver.QueryerContext).QueryContext(ctx, query, args)
}
// Exec implements driver.Execer.Exec.
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
return nil, errors.New("driver.Execer.Exec not supported")
}
// ExecContext implements driver.ExecerContext.ExecContext.
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
tenantCtx, err := fromContext(ctx)
if err != nil {
return nil, err
}
useStmt := useStatement(tenantCtx.id)
if len(args) > 0 {
if _, err := c.Conn.(driver.ExecerContext).ExecContext(ctx, useStmt, nil); err != nil {
return nil, err
}
} else {
query = useStmt + ";" + query
}
return c.Conn.(driver.ExecerContext).ExecContext(ctx, query, args)
}
func useStatement(tenantID string) string {
// escape quotes
pos := 0
buf := make([]byte, len(tenantID)*2)
for i := 0; i < len(tenantID); i++ {
c := tenantID[i]
if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
return fmt.Sprintf("USE '%s'", string(buf[:pos]))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment