Skip to content

Instantly share code, notes, and snippets.

@zdunecki
Last active December 12, 2021 22:10
Show Gist options
  • Save zdunecki/a728e152a3b5bd5885444a350cde8612 to your computer and use it in GitHub Desktop.
Save zdunecki/a728e152a3b5bd5885444a350cde8612 to your computer and use it in GitHub Desktop.
Example persistence tests for Golang - Blog post
package main
import (
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
"github.com/google/go-cmp/cmp"
"github.com/ory/dockertest/v3"
"testing"
)
// RootUser the default name of MySQL root user. We run tests in root because we don't have to worry about privileges.
// Disclaimer: We run on root because of testing simplicity.
const RootUser = "root"
// DockerRepository is reference to Docker Hub's Repository.
const DockerRepository = "mysql"
// Version is a version of MySQL. Available versions are located on for DockerRepository Docker Hub.
const Version = "8.0"
// AllowEmptyPassword is a simple setup for valid connection for root user.
// Disclaimer: Do not use it on production but for tests it looks fine if you don't have specific requirements.
const AllowEmptyPassword = "MYSQL_ALLOW_EMPTY_PASSWORD=yes"
// Environments are passed into MySQL instance.
var Environments = []string{
AllowEmptyPassword,
}
// pool holds reference to Docker.
var pool *dockertest.Pool
// init initialize docker pool.
func init() {
var p, err = dockertest.NewPool("")
if err != nil {
panic(err)
}
pool = p
}
// mysqlContainer creates Docker container with MySQL and wait until MySQL is ready.
func mysqlContainer() (*sql.DB, *dockertest.Resource, error) {
rootUser := RootUser
container, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: DockerRepository,
Tag: Version,
Env: Environments,
})
if err != nil {
return nil, nil, err
}
// get mapped port
mysqlHostPort := container.GetPort("3306/tcp")
var db *sql.DB
// wait until MySQL is ready
if err := pool.Retry(func() error {
var err error
// connect with MySQL using port available in host
db, err = sql.Open("mysql", fmt.Sprintf("%s@(localhost:%s)/mysql", rootUser, mysqlHostPort))
if err != nil {
return err
}
if err := db.Ping(); err != nil {
return err
}
return nil
}); err != nil {
return nil, container, err
}
return db, container, nil
}
// getMessages returns messages from database.
func getMessages(db *sql.DB) ([]string, error) {
messages := make([]string, 0)
rows, err := db.Query("SELECT message FROM demo")
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var msg string
if err := rows.Scan(&msg); err != nil {
continue
}
messages = append(messages, msg)
}
err = rows.Err()
if err != nil {
return messages, err
}
return messages, nil
}
// seed is useful to prepare database with data before queries.
func seed(db *sql.DB, message string) error {
if _, err := db.Query(`
create table demo(
id INT NOT NULL AUTO_INCREMENT,
message VARCHAR(100) NOT NULL,
PRIMARY KEY ( id )
);
`); err != nil {
return err
}
query, err := db.Prepare(`
insert into demo(message) values (?);
`)
if err != nil {
return err
}
if _, err := query.Exec(message); err != nil {
return err
}
return nil
}
// beforeEach is a small helper function to run ephemeral environment for each test run.
func beforeEach(tx *testing.T) (db *sql.DB, done func(), err error) {
var container *dockertest.Resource
db, container, err = mysqlContainer()
if err != nil {
tx.Fatal(err)
return
}
done = func() {
if err := container.Close(); err != nil {
tx.Fatal(err)
}
}
return
}
// diff returns error message if expected is different from current.
func diff(expected interface{}, current interface{}) error {
if d := cmp.Diff(expected, current); d != "" {
return fmt.Errorf("%s mismatch (-want +got):\n%s", "messages are not equal", d)
}
return nil
}
func TestMySQLInsert(tx *testing.T) {
tx.Run("Insert and query messages from database", func(t *testing.T) {
db, done, err := beforeEach(tx)
if done != nil {
defer done()
}
if err != nil {
tx.Fatal(err)
return
}
if err := seed(db, "hello world"); err != nil {
t.Fatal(err)
return
}
messages, err := getMessages(db)
if err != nil {
t.Log(err)
}
expected := []string{"hello world"}
if err := diff(expected, messages); err != nil {
t.Fatal(err)
}
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment