Skip to content

Instantly share code, notes, and snippets.

@timakin
Created September 28, 2019 05:06
Show Gist options
  • Save timakin/641d11e0765a00a45dadeee7ce2036df to your computer and use it in GitHub Desktop.
Save timakin/641d11e0765a00a45dadeee7ce2036df to your computer and use it in GitHub Desktop.
Testutil for e2e test in Go API
package testutils
import (
"encoding/json"
"fmt"
"io/ioutil"
"path"
"path/filepath"
"runtime"
"testing"
)
// GetInputJSON ... テストのためのFixtureをJSONで取得するためのヘルパー
func GetInputJSON(t *testing.T, dst interface{}) {
t.Helper()
_, pwd, _, _ := runtime.Caller(0)
input := filepath.Join(path.Dir(pwd), "..", "..", "testdata", fmt.Sprintf("%s.input", t.Name()))
i, err := ioutil.ReadFile(input)
if err != nil {
t.Fatalf("failed reading .input: %s", err)
}
if err := json.Unmarshal(i, dst); err != nil {
t.Errorf("failed to parse input file: %s, err: %v", input, err)
return
}
}
// GetGoldenJSON ... テスト結果をJSONで取得するためのヘルパー
func GetGoldenJSON(t *testing.T, dst interface{}) {
t.Helper()
_, pwd, _, _ := runtime.Caller(0)
golden := filepath.Join(path.Dir(pwd), "..", "..", "testdata", fmt.Sprintf("%s.golden", t.Name()))
g, err := ioutil.ReadFile(golden)
if err != nil {
t.Fatalf("failed reading .golden: %s", err)
}
if err := json.Unmarshal(g, dst); err != nil {
t.Errorf("failed to parse golden file: %s, err: %v", golden, err)
return
}
}
package dbsetup
import (
"io/ioutil"
"path/filepath"
)
// TestDatabaseHandler ... テスト用のデータベースにアクセス・セットアップをするためのインターフェース
type TestDatabaseHandler interface {
SetupDefaultFixture()
SetupOptionalFixture(names []string)
Cleanup()
}
package testutils
import (
"encoding/json"
"fmt"
"io/ioutil"
"path"
"path/filepath"
"runtime"
"testing"
"github.com/golang/mock/gomock"
)
// SetupMockCtrl ... interfaceのモック実行用のコントローラー作成
func SetupMockCtrl(t *testing.T) (*gomock.Controller, func()) {
ctrl := gomock.NewController(t)
return ctrl, func() { ctrl.Finish() }
}
// GetMockJSON ... モックで返却するJSONで取得するためのヘルパー
func GetMockJSON(t *testing.T, dst interface{}) {
t.Helper()
_, pwd, _, _ := runtime.Caller(0)
golden := filepath.Join(path.Dir(pwd), "..", "..", "testdata", fmt.Sprintf("%s.mock", t.Name()))
g, err := ioutil.ReadFile(golden)
if err != nil {
t.Fatalf("failed reading .mock: %s", err)
}
if err := json.Unmarshal(g, dst); err != nil {
t.Errorf("failed to parse mock file: %s, err: %v", golden, err)
return
}
}
package testutils
import (
"time"
"github.com/xxx/internal/library"
)
// MockTimeNow ... 現在時刻をモックする
func MockTimeNow() func() {
now := time.Date(2020, 1, 1, 0, 0, 0, 0, time.Local)
return MockTime(now)
}
// MockTime ... 時刻をモックする
func MockTime(mt time.Time) func() {
library.SetFakeTime(mt)
return func() { library.ResetFake() }
}
package testutils
import (
"database/sql"
"fmt"
"io/ioutil"
"log"
"os"
"path"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
)
var dbConn *sql.DB
const schemaDirRelativePathFormat = "%s/../../schema/%s"
const fixturesDirRelativePathFormat = "%s/../../schema/fixtures/%s"
func SetupMySQLConn() func() {
c := mysql.Config{
DBName: os.Getenv("TEST_MYSQL_DATABASE"),
User: os.Getenv("TEST_MYSQL_USER"),
Passwd: os.Getenv("TEST_MYSQL_PASSWORD"),
Addr: os.Getenv("TEST_MYSQL_ADDRESS"),
Net: "tcp",
Loc: time.UTC,
ParseTime: true,
AllowNativePasswords: true,
}
db, err := sql.Open("mysql", c.FormatDSN())
if err != nil {
log.Fatalf("Could not connect to mysql: %s", err)
}
db.SetMaxIdleConns(0)
dbConn = db
return func() { dbConn.Close() }
}
// GetTestMySQLConn ... プールしてあるテスト用のDBコネクションを返す
func GetTestMySQLConn() (*sql.DB, func()) {
if dbConn == nil {
panic("mysql connection is not initialized yet")
}
return dbConn, func() { truncateTables() }
}
// setupDefaultFixtures ... 全テストに共通するFixtureのInsert
func setupDefaultFixtures() {
_, pwd, _, _ := runtime.Caller(0)
defaultFixtureDir := fmt.Sprintf(fixturesDirRelativePathFormat, path.Dir(pwd), "default")
defaultFixturePathes := walkSchema(defaultFixtureDir)
for _, dpath := range defaultFixturePathes {
execSchema(dpath)
}
}
// SetupOptionalFixtures ... テストケースごとに任意に設定するFixtureのInsert
func SetupOptionalFixtures(names []string) {
_, pwd, _, _ := runtime.Caller(0)
optionalFixtureDir := fmt.Sprintf(fixturesDirRelativePathFormat, path.Dir(pwd), "optional")
for _, n := range names {
opath := filepath.Join(optionalFixtureDir, fmt.Sprintf("%s.sql", n))
execSchema(opath)
}
}
func walkSchema(dir string) []string {
files, err := ioutil.ReadDir(dir)
if err != nil {
panic(err)
}
var paths []string
for _, file := range files {
paths = append(paths, filepath.Join(dir, file.Name()))
}
return paths
}
func execSchema(fpath string) {
b, err := ioutil.ReadFile(fpath)
if err != nil {
log.Fatalf("schema reading error: %v", err)
}
queries := strings.Split(string(b), ";")
for _, query := range queries[:len(queries)-1] {
query := query
_, err = dbConn.Exec(query)
if err != nil {
log.Printf("exec schema error: %v, query: %s", err, query)
continue
}
}
}
func createTablesIfNotExist() {
_, pwd, _, _ := runtime.Caller(0)
schemaPath := fmt.Sprintf(schemaDirRelativePathFormat, path.Dir(pwd), "schema.sql")
execSchema(schemaPath)
}
func truncateTables() {
rows, err := dbConn.Query("SHOW TABLES")
if err != nil {
log.Fatalf("show tables error: %#v", err)
}
defer rows.Close()
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
log.Fatalf("show table error: %#v", err)
continue
}
cmds := []string{
"SET FOREIGN_KEY_CHECKS = 0",
fmt.Sprintf("TRUNCATE %s", tableName),
"SET FOREIGN_KEY_CHECKS = 1",
}
for _, cmd := range cmds {
if _, err := dbConn.Exec(cmd); err != nil {
log.Fatalf("truncate error: %#v", err)
continue
}
}
}
}
func dropTables() {
rows, err := dbConn.Query("SHOW TABLES")
if err != nil {
log.Fatalf("show tables error: %#v", err)
}
defer rows.Close()
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
log.Fatalf("show table error: %#v", err)
continue
}
cmds := []string{
fmt.Sprintf("DROP TABLE %s", tableName),
}
for _, cmd := range cmds {
if _, err := dbConn.Exec(cmd); err != nil {
log.Fatalf("drop table error: %#v", err)
continue
}
}
}
}
package dbsetup
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"os"
"path"
"path/filepath"
"runtime"
"github.com/gomodule/redigo/redis"
"github.com/xxx/config"
validator "gopkg.in/go-playground/validator.v9"
)
const fixturesDirRelativePathFormat = "%s/../../../schema/fixtures/redis/%s"
var pool *redis.Pool
var validate *validator.Validate
type fixtureType = string
const (
fixtureTypeString = "string"
fixtureTypeObject = "object"
fixtureTypeStringArray = "string_array"
fixtureTypeObjectArray = "object_array"
)
// RedisFixture ... Redisのフィクスチャの構造体
type fixture struct {
Key string `json:"key"`
Value json.RawMessage `json:"value,omitempty"`
Values json.RawMessage `json:"values,omitempty"`
Type fixtureType `json:"type"`
}
func fixtureStructLevelValidation(sl validator.StructLevel) {
f := sl.Current().Interface().(fixture)
switch f.Type {
case fixtureTypeString, fixtureTypeObject:
if f.Value == nil {
sl.ReportError(f.Value, "Value", "value", "typeandvalue", "")
}
case fixtureTypeStringArray, fixtureTypeObjectArray:
if f.Values == nil {
sl.ReportError(f.Value, "Values", "values", "typeandvalues", "")
}
default:
sl.ReportError(f.Type, "Type", "type", "invalidtype", "")
}
}
// SetupRedisPool ... Redisのコネクションプールの初期化
func SetupRedisPool() func() {
p := config.NewRedisPool(os.Getenv("TEST_REDIS_SERVER"))
pool = p
return func() { config.CleanRedisPool(p) }
}
// GetRedisPool ... プールしてあるテスト用のDBコネクションを返す
func GetRedisPool() *redis.Pool {
if pool == nil {
panic("redis connection pool is not initialized yet")
}
return pool
}
// SetupDefaultRedisFixture ... 全テストに共通するFixtureのInsert
func SetupDefaultRedisFixture() {
_, pwd, _, _ := runtime.Caller(0)
defaultFixtureDir := fmt.Sprintf(fixturesDirRelativePathFormat, path.Dir(pwd), "default")
defaultFixturePathes := walkSchema(defaultFixtureDir)
for _, dpath := range defaultFixturePathes {
execRedisCommand(dpath)
}
}
// SetupOptionalRedisFixtures ... テストケースごとに任意に設定するFixtureのInsert
func SetupOptionalRedisFixtures(names []string) {
_, pwd, _, _ := runtime.Caller(0)
optionalFixtureDir := fmt.Sprintf(fixturesDirRelativePathFormat, path.Dir(pwd), "optional")
for _, n := range names {
opath := filepath.Join(optionalFixtureDir, fmt.Sprintf("%s.json", n))
execRedisCommand(opath)
}
}
func execRedisCommand(fpath string) {
validate = validator.New()
validate.RegisterStructValidation(fixtureStructLevelValidation, fixture{})
b, err := ioutil.ReadFile(fpath)
if err != nil {
log.Fatalf("schema reading error: %v", err)
}
var data fixture
err = json.Unmarshal(b, &data)
if err != nil {
panic(err)
}
err = validate.Struct(data)
if err != nil {
panic(err)
}
conn := pool.Get()
defer conn.Close()
switch data.Type {
case fixtureTypeString:
var val string
json.Unmarshal(data.Value, &val)
conn.Send("SET", data.Key, val)
case fixtureTypeObject:
var val map[string]interface{}
json.Unmarshal(data.Value, &val)
conn.Send("SET", data.Key, val)
case fixtureTypeStringArray:
var vals []string
json.Unmarshal(data.Values, &vals)
for _, v := range vals {
conn.Send("RPUSH", data.Key, v)
}
case fixtureTypeObjectArray:
var vals []map[string]interface{}
json.Unmarshal(data.Values, &vals)
for _, v := range vals {
b, _ := json.Marshal(v)
conn.Send("RPUSH", data.Key, b)
}
}
conn.Flush()
}
// CleanupRedis ... データベースをまっさらな状態にする
func CleanupRedis() {
conn := pool.Get()
defer conn.Close()
_, err := conn.Do("FLUSHALL")
if err != nil {
panic(err)
}
}
package testutils
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp"
)
// TryRequest ... リクエスト検証
func TryRequest(t *testing.T, desc, method, path, payload string, mux *chi.Mux, wantCode int, wantBody []byte) {
srv := httptest.NewServer(mux)
defer srv.Close()
req, err := http.NewRequest(method, srv.URL+path, strings.NewReader(payload))
if err != nil {
t.Errorf("%s: generate request: %v", desc, err)
return
}
req.Header.Set("Content-Type", "application/json")
c := http.DefaultClient
resp, err := c.Do(req)
if err != nil {
t.Errorf("%s: http.Get: %v", desc, err)
return
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("%s: reading body: %v", desc, err)
return
}
if resp.StatusCode != wantCode {
t.Errorf("%s: got HTTP %d, want %d", desc, resp.StatusCode, wantCode)
t.Errorf("response body: %s", string(body))
return
}
if wantBody != nil {
eqj, df, _ := areEqualJSON(body, wantBody)
if !eqj {
t.Errorf("%s: got HTTP body %q, \n want %q, \n diff: %s", desc, body, wantBody, df)
return
}
}
}
func areEqualJSON(s1, s2 []byte) (bool, string, error) {
var o1 interface{}
var o2 interface{}
var err error
err = json.Unmarshal(s1, &o1)
if err != nil {
return false, cmp.Diff(o1, o2), fmt.Errorf("Error mashalling string 1 :: %s", err.Error())
}
err = json.Unmarshal(s2, &o2)
if err != nil {
return false, cmp.Diff(o1, o2), fmt.Errorf("Error mashalling string 2 :: %s", err.Error())
}
return reflect.DeepEqual(o1, o2), cmp.Diff(o1, o2), nil
}
package testutils
import (
"os"
"testing"
"github.com/xxx/internal/testutils/dbsetup"
)
// IntegrationTestRunner ... インテグレーションテストで走る共通処理(BeforeAll)
func IntegrationTestRunner(m *testing.M) {
shutdown := SetupMySQLConn()
shundownRedis := dbsetup.SetupRedisPool()
createTablesIfNotExist()
truncateTables()
resetTimer := MockTimeNow()
code := m.Run()
shutdown()
shundownRedis()
resetTimer()
os.Exit(code)
}
// IntegrationTestFuncCall ... インテグレーションテストの各ケースで行う共通処理
func IntegrationTestFuncCall(t *testing.T, name string, f func(t *testing.T)) {
BeforeEachIntegrationTest()
defer AfterEachIntegrationTest()
t.Run(name, f)
}
// BeforeEachIntegrationTest ... インテグレーションテストのBeforeEach的共通処理
func BeforeEachIntegrationTest() {
setupDefaultFixtures()
// TODO まだ初期値がないのでコメントアウト
// dbsetup.SetupDefaultRedisFixture()
}
// AfterEachIntegrationTest ... インテグレーションテストのAfterEach的共通処理
func AfterEachIntegrationTest() {
truncateTables()
dbsetup.CleanupRedis()
}
// RepositoryUnitTestRunner ... リポジトリのテストで走る共通処理(BeforeAll)
// (スキーマ適用など、インテグレーションテストと処理が共通している)
func RepositoryUnitTestRunner(m *testing.M) {
IntegrationTestRunner(m)
}
// RepositoryUnitTestFuncCall ... リポジトリのテストの各ケースで行う共通処理
func RepositoryUnitTestFuncCall(t *testing.T, name string, f func(t *testing.T)) {
BeforeEachRepositoryUnitTest()
defer AfterEachRepositoryUnitTest()
t.Run(name, f)
}
// BeforeEachRepositoryUnitTest ... リポジトリのテストのBeforeEach的共通処理
func BeforeEachRepositoryUnitTest() {
BeforeEachIntegrationTest()
}
// AfterEachRepositoryUnitTest ... リポジトリのテストのAfterEach的共通処理
func AfterEachRepositoryUnitTest() {
AfterEachIntegrationTest()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment