Skip to content

Instantly share code, notes, and snippets.

@akarasz
Last active March 20, 2021 15:58
Show Gist options
  • Save akarasz/9b6ced58256ae35e00ada18b1fb7fed1 to your computer and use it in GitHub Desktop.
Save akarasz/9b6ced58256ae35e00ada18b1fb7fed1 to your computer and use it in GitHub Desktop.
PostgreSQL: Unique Constraint on Joined Tables
create table vehicle(
id int primary key,
color text
);
create table car(
vehicle_id int references vehicle(id)
);
create table truck(
vehicle_id int references vehicle(id),
trailer_color text
);
create or replace function unique_car_check() returns trigger as $$
declare
total int;
begin
select
count(*)
into
total
from
vehicle
join vehicle as this on this.color = vehicle.color
join car on car.vehicle_id = vehicle.id
where
this.id = NEW.vehicle_id and
vehicle.id != this.id;
if total > 0 then
raise exception 'already has a car with color';
end if;
return NEW;
end;
$$ language plpgsql;
create trigger
car_unique_vehicle
before insert or update
on car
for each row execute procedure unique_car_check();
create or replace function unique_truck_check() returns trigger as $$
declare
total int;
begin
select
count(*)
into
total
from
vehicle
join vehicle as this on this.color = vehicle.color
join truck on truck.vehicle_id = vehicle.id
where
this.id = NEW.vehicle_id and
vehicle.id != this.id;
if total > 0 then
raise exception 'already has a truck with color';
end if;
return NEW;
end;
$$ language plpgsql;
create trigger
truck_unique_vehicle
before insert or update
on truck
for each row execute procedure unique_truck_check();
package unique_test
import (
"context"
"fmt"
"io/ioutil"
"testing"
"github.com/jackc/pgx/v4"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
func TestUnique(t *testing.T) {
ctx := context.Background()
container, host, port := startTestContainer(ctx)
defer container.Terminate(ctx)
conn := createPgConnection(ctx, host, port)
applySchema(ctx, conn, "schema.sql")
err := insertCar(ctx, conn, 1, "red")
if err != nil {
t.Fatal("inserting the first red car should not fail:", err)
}
err = insertTruck(ctx, conn, 2, "red")
if err != nil {
t.Fatal("inserting the first red truck when a red car is present should not fail:", err)
}
err = insertCar(ctx, conn, 3, "red")
if err == nil {
t.Error("inserting the second red car should fail")
}
err = insertTruck(ctx, conn, 4, "red")
if err == nil {
t.Error("inserting the second red truck should fail")
}
}
func insertCar(ctx context.Context, conn *pgx.Conn, id int, color string) error {
return inTransaction(ctx, conn, func(tx pgx.Tx) error {
_, err := tx.Exec(ctx, "insert into vehicle(id, color) values ($1, $2)", id, color)
if err != nil {
return err
}
_, err = tx.Exec(ctx, "insert into car(vehicle_id) values ($1)", id)
if err != nil {
return err
}
return nil
})
}
func insertTruck(ctx context.Context, conn *pgx.Conn, id int, color string) error {
return inTransaction(ctx, conn, func(tx pgx.Tx) error {
_, err := tx.Exec(ctx, "insert into vehicle(id, color) values ($1, $2)", id, color)
if err != nil {
return err
}
_, err = tx.Exec(ctx, "insert into truck(vehicle_id) values ($1)", id)
if err != nil {
return err
}
return nil
})
}
func inTransaction(ctx context.Context, conn *pgx.Conn, commands func(pgx.Tx) error) error {
tx, err := conn.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
err = commands(tx);
if err != nil {
return err
}
err = tx.Commit(ctx)
if err != nil {
return err
}
return nil
}
func startTestContainer(ctx context.Context) (container testcontainers.Container, host, port string) {
container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Image: "postgres:12",
ExposedPorts: []string{"5432/tcp"},
Env: map[string]string{"POSTGRES_PASSWORD": "changeit"},
WaitingFor: wait.ForListeningPort("5432/tcp"),
},
Started: true,
})
if err != nil {
panic(err)
}
host, err = container.Host(ctx)
if err != nil {
panic(err)
}
p, err := container.MappedPort(ctx, "5432")
if err != nil {
panic(err)
}
port = p.Port()
return
}
func createPgConnection(ctx context.Context, host, port string) *pgx.Conn {
connString := fmt.Sprintf("postgresql://postgres:changeit@%s:%s/postgres", host, port)
fmt.Println("Connecting to", connString)
conn, err := pgx.Connect(ctx, connString)
if err != nil {
panic(err)
}
return conn
}
func applySchema(ctx context.Context, conn *pgx.Conn, file string) {
content, err := ioutil.ReadFile(file)
if err != nil {
panic(err)
}
_, err = conn.Exec(ctx, string(content))
if err != nil {
panic(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment