Skip to content

Instantly share code, notes, and snippets.

@kmpm
Last active July 27, 2022 05:31
Show Gist options
  • Save kmpm/22450dfd53a2adeb11a2deccca1e7d06 to your computer and use it in GitHub Desktop.
Save kmpm/22450dfd53a2adeb11a2deccca1e7d06 to your computer and use it in GitHub Desktop.
bun-check for has-many relations with composite keys
package main
import (
"context"
"database/sql"
"fmt"
"os"
_ "github.com/denisenkom/go-mssqldb"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/mssqldialect"
"github.com/uptrace/bun/dialect/sqlitedialect"
"github.com/uptrace/bun/driver/sqliteshim"
"github.com/uptrace/bun/extra/bundebug"
)
type Company struct {
bun.BaseModel
No string `bun:",pk"`
Name string
Departments []*Department `bun:"rel:has-many,join:no=company_no"`
}
type Department struct {
CompanyNo string `bun:"company_no,pk"`
No string `bun:"no,pk"`
Name string
Contacts []*Contact `bun:"rel:has-many,join:company_no=company_no,join:no=department_no"`
}
type Contact struct {
bun.BaseModel
CompanyNo string `bun:"company_no,pk"`
DepartmentNo string `bun:"department_no,pk"`
Name string `bun:"name,pk"`
}
func getSqlite() *bun.DB {
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
if err != nil {
panic(err)
}
db := bun.NewDB(sqldb, sqlitedialect.New())
return db
}
func getMssql() *bun.DB {
sqldb, err := sql.Open("sqlserver", "sqlserver://sa:passWORD1@localhost?database=test")
if err != nil {
panic(err)
}
db := bun.NewDB(sqldb, mssqldialect.New())
return db
}
func main() {
var db *bun.DB
var dn string
if len(os.Args) > 1 {
dn = os.Args[1]
}
switch dn {
case "mssql":
db = getMssql()
case "sqlite":
db = getSqlite()
default:
fmt.Printf("argument '%s' is not a valid dialect name\n", dn)
os.Exit(1)
}
db.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
ctx := context.TODO()
migrate(ctx, db)
err := data(ctx, db)
if err != nil {
panic(err)
}
c := new(Company)
//Getting with Departments relation that uses single field as key
err = db.NewSelect().Model(c).Relation("Departments").Where("no=?", "ABC").Scan(ctx)
if err != nil {
panic(err)
}
if c.Name != "Test Company" {
panic("wrong company name")
}
if len(c.Departments) != 2 {
panic(fmt.Sprintf("incorrect number of departments. got %d, want 2", len(c.Departments)))
}
dep := new(Department)
db.NewSelect().Model(dep).
Relation("Contacts").
Where("company_no=? AND no=?", c.No, "AA").
Scan(ctx)
if len(dep.Contacts) != 2 {
panic(fmt.Sprintf("incorrect number of contacts. got %d, want 2", len(dep.Contacts)))
}
fmt.Println("done")
}
func migrate(ctx context.Context, db *bun.DB) {
db.NewCreateTable().Model((*Company)(nil)).Exec(ctx)
db.NewCreateTable().Model((*Department)(nil)).Exec(ctx)
db.NewCreateTable().Model((*Contact)(nil)).Exec(ctx)
fmt.Println("migrated")
}
func data(ctx context.Context, db *bun.DB) error {
c := &Company{No: "ABC", Name: "Test Company"}
_, err := db.NewInsert().Model(c).Exec(ctx)
if err != nil {
return err
}
d := &Department{CompanyNo: c.No, No: "AA", Name: "Accounting"}
_, err = db.NewInsert().Model(d).Exec(ctx)
if err != nil {
return err
}
addConcact(ctx, db, d, "alfa")
addConcact(ctx, db, d, "bravo")
d = &Department{CompanyNo: c.No, No: "AB", Name: "HR"}
_, err = db.NewInsert().Model(d).Exec(ctx)
if err != nil {
return err
}
addConcact(ctx, db, d, "charlie")
addConcact(ctx, db, d, "delta")
fmt.Println("data")
return nil
}
func addConcact(ctx context.Context, db *bun.DB, d *Department, name string) *Contact {
con := &Contact{DepartmentNo: d.No, CompanyNo: d.CompanyNo, Name: name}
_, err := db.NewInsert().Model(con).Exec(ctx)
if err != nil {
panic(err)
}
return con
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment