Skip to content

Instantly share code, notes, and snippets.

@andrewpillar
Created Oct 24, 2022
Embed
What would you like to do?
// The MIT License (MIT)
//
// Copyright (c) Andrew Pillar
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the “Software”), to
// deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
// sell copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package database
import (
"context"
"github.com/andrewpillar/query"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
type Model interface {
Primary() (string, any)
Scan(fields []string, scan ScanFunc) error
Params() map[string]any
}
type ScanFunc func(dest ...any) error
func Scan(desttab map[string]any, fields []string, scan ScanFunc) error {
dest := make([]any, 0, len(fields))
for _, fld := range fields {
if p, ok := desttab[fld]; ok {
dest = append(dest, p)
}
}
return scan(dest...)
}
type Store[M Model] struct {
*pgxpool.Pool
table string
new func() M
}
func NewStore[M Model](pool *pgxpool.Pool, table string, new func() M) *Store[M] {
return &Store[M]{
Pool: pool,
table: table,
new: new,
}
}
func (s *Store[M]) fields(rows pgx.Rows) []string {
fields0 := rows.FieldDescriptions()
fields := make([]string, 0, len(fields0))
for _, fld := range fields0 {
fields = append(fields, string(fld.Name))
}
return fields
}
func (s *Store[M]) Create(ctx context.Context, m M) error {
p := m.Params()
cols := make([]string, 0, len(p)+1)
vals := make([]any, 0, len(p)+1)
for k, v := range p {
cols = append(cols, k)
vals = append(vals, v)
}
primary, _ := m.Primary()
q := query.Insert(
s.table,
query.Columns(cols...),
query.Values(vals...),
query.Returning(primary),
)
rows, err := s.Query(ctx, q.Build(), q.Args()...)
if err != nil {
return err
}
if !rows.Next() {
rows.Close()
if err := rows.Err(); err != nil {
return err
}
return nil
}
if err := m.Scan(s.fields(rows), rows.Scan); err != nil {
return err
}
return nil
}
func (s *Store[M]) Select(ctx context.Context, cols []string, opts ...query.Option) ([]M, error) {
opts = append([]query.Option{
query.From(s.table),
}, opts...)
q := query.Select(query.Columns(cols...), opts...)
rows, err := s.Query(ctx, q.Build(), q.Args()...)
if err != nil {
return nil, err
}
fields := s.fields(rows)
mm := make([]M, 0)
for rows.Next() {
m := s.new()
if err := m.Scan(fields, rows.Scan); err != nil {
return nil, err
}
}
rows.Close()
if err := rows.Err(); err != nil {
return nil, err
}
return mm, nil
}
func (s *Store[M]) Get(ctx context.Context, opts ...query.Option) (M, bool, error) {
var zero M
opts = append([]query.Option{
query.From(s.table),
}, opts...)
q := query.Select(query.Columns("*"), opts...)
rows, err := s.Query(ctx, q.Build(), q.Args()...)
if err != nil {
return zero, false, nil
}
if !rows.Next() {
rows.Close()
if err := rows.Err(); err != nil {
return zero, false, err
}
return zero, false, nil
}
m := s.new()
if err := m.Scan(s.fields(rows), rows.Scan); err != nil {
return zero, false, err
}
return m, true, nil
}
func (s *Store[M]) All(ctx context.Context, opts ...query.Option) ([]M, error) {
return s.Select(ctx, []string{"*"}, opts...)
}
func (s *Store[M]) Update(ctx context.Context, m M) error {
p := m.Params()
opts := make([]query.Option, 0, len(p))
for k, v := range p {
opts = append(opts, query.Set(k, query.Arg(v)))
}
col, id := m.Primary()
opts = append(opts, query.Where(col, "=", query.Arg(id)))
q := query.Update(s.table, opts...)
if _, err := s.Exec(ctx, q.Build(), q.Args()...); err != nil {
return err
}
return nil
}
func (s *Store[M]) Delete(ctx context.Context, m M) error {
col, id := m.Primary()
q := query.Delete(s.table, query.Where(col, "=", query.Arg(id)))
if _, err := s.Exec(ctx, q.Build(), q.Args()...); err != nil {
return err
}
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment