Skip to content

Instantly share code, notes, and snippets.

@rjnienaber
Last active April 15, 2024 17:36
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rjnienaber/e7a833542d00a430f1c784ee842fa7fb to your computer and use it in GitHub Desktop.
Save rjnienaber/e7a833542d00a430f1c784ee842fa7fb to your computer and use it in GitHub Desktop.
repository pattern in Go
package controllers
import (
"fmt"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/rjnienaber/server-sent-events/internal/errorcodes"
"github.com/rjnienaber/server-sent-events/internal/repositories"
"github.com/rjnienaber/server-sent-events/internal/views"
)
type examController struct {
scoreMesageFinder repositories.ScoreMessageFinder
}
func RegisterExamRoutes(engine *gin.Engine, repo repositories.ScoreMessageFinder) {
controller := examController{scoreMesageFinder: repo}
exams := engine.Group("/exams")
exams.GET("", controller.getAll)
exams.GET("/:id", controller.getById)
}
func (c examController) getById(g *gin.Context) {
idStr := g.Params.ByName("id")
params := map[string]string{"examId": idStr}
id, err := strconv.Atoi(idStr)
if err != nil {
msg := fmt.Sprintf("invalid id received '%s'", idStr)
writeError(g, http.StatusBadRequest, errorcodes.InvalidExamId, msg, &params)
return
}
scores, err := c.scoreMesageFinder.FindByExamId(id)
if err != nil {
msg := fmt.Sprintf("could not retrieve scores for exam id '%s'", idStr)
writeError(g, http.StatusInternalServerError, errorcodes.InternalError, msg, &params)
return
}
if len(scores) == 0 {
msg := fmt.Sprintf("no scores found for exam id '%s'", idStr)
writeError(g, http.StatusNotFound, errorcodes.NoExamScoresFound, msg, &params)
return
}
g.JSON(http.StatusOK, views.ExamViewBuilder(id, scores))
}
func (c examController) getAll(g *gin.Context) {
exams, err := c.scoreMesageFinder.GroupByExam()
if err != nil {
g.String(http.StatusBadRequest, err.Error())
return
}
examViews := views.ExamsViewBuilder(exams)
g.JSON(http.StatusOK, examViews)
}
package controllers
import (
"bytes"
"encoding/json"
"errors"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/rjnienaber/server-sent-events/internal/models"
"github.com/stretchr/testify/assert"
)
type mockFinder struct {
findByExamId func(examId int) (scores []models.ScoreMessage, err error)
}
func (m mockFinder) FindByExamId(examId int) (scores []models.ScoreMessage, err error) {
if m.findByExamId != nil {
f := m.findByExamId
return f(examId)
}
panic("implement me")
}
func (m mockFinder) GroupByExam() (examScoreGroup []models.ExamScores, err error) {
panic("implement me")
}
func assertBasicResponseDetails(t *testing.T, recorder *httptest.ResponseRecorder, statusCode int) {
assert.Equal(t, statusCode, recorder.Code)
assert.Equal(t, []string{"application/json; charset=utf-8"}, recorder.Result().Header["Content-Type"])
}
func createTestContext(params map[string]string) (recorder *httptest.ResponseRecorder, context *gin.Context) {
recorder = httptest.NewRecorder()
gin.SetMode(gin.ReleaseMode)
context, _ = gin.CreateTestContext(recorder)
for key, value := range params {
context.Params = []gin.Param{{Key: key, Value: value}}
}
return
}
func TestExamControllerReturns400WhenInvalidID(t *testing.T) {
recorder, context := createTestContext(map[string]string{"id": "abcd"})
controller := examController{}
controller.getById(context)
assertBasicResponseDetails(t, recorder, 400)
expected := `{"statusCode":400,"code":"invalid_exam_id","message":"invalid id received 'abcd'","params":{"examId":"abcd"}}`
assert.Equal(t, expected, recorder.Body.String())
}
func TestExamControllerReturns500WhenDatabaseAccessFails(t *testing.T) {
recorder, context := createTestContext(map[string]string{"id": "999"})
controller := examController{mockFinder{
findByExamId: func(examId int) (scores []models.ScoreMessage, err error) {
return nil, errors.New("database access failed")
},
}}
controller.getById(context)
assertBasicResponseDetails(t, recorder, 500)
expected := `{"statusCode":500,"code":"internal_error","message":"could not retrieve scores for exam id '999'","params":{"examId":"999"}}`
assert.Equal(t, expected, recorder.Body.String())
}
func TestExamControllerReturns404WhenExamIsNotFound(t *testing.T) {
recorder, context := createTestContext(map[string]string{"id": "123"})
controller := examController{mockFinder{
findByExamId: func(examId int) (scores []models.ScoreMessage, err error) {
return
},
}}
controller.getById(context)
assertBasicResponseDetails(t, recorder, 404)
expected := `{"statusCode":404,"code":"no_exam_scores_found","message":"no scores found for exam id '123'","params":{"examId":"123"}}`
assert.Equal(t, expected, recorder.Body.String())
}
func TestExamControllerReturnsJSONWhenExamIsFound(t *testing.T) {
recorder, context := createTestContext(map[string]string{"id": "456"})
controller := examController{mockFinder{
findByExamId: func(examId int) (scores []models.ScoreMessage, err error) {
return []models.ScoreMessage{{StudentId: "John.Doe", Exam: 456, Score: 0.75}}, nil
},
}}
controller.getById(context)
assertBasicResponseDetails(t, recorder, 200)
expectedBytes := []byte(`{
"exam": {
"id": 456,
"scores": [{
"id": "John.Doe",
"score": "0.750000",
"_links": {
"exam": {
"href": "/exams/456"
},
"self": {
"href": "/students/John.Doe"
}
}
}],
"average": "0.750000",
"_links": {
"self": {
"href": "/exams/456"
}
}
}
}`)
expectedBuffer := new(bytes.Buffer)
err := json.Compact(expectedBuffer, expectedBytes)
assert.NoError(t, err)
assert.Equal(t, expectedBuffer.String(), recorder.Body.String())
}
package repositories
import (
"log"
"os"
"time"
"github.com/rjnienaber/server-sent-events/internal/models"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
type Repositories struct {
ScoreMessages ScoreMessageRepository
logLevel logger.LogLevel
dialector *gorm.Dialector
}
type Option func(svc *Repositories)
// create logger here so we can control the log level from the beginning
// should be the same as logger.Default
func createLogger(logLevel logger.LogLevel) logger.Interface {
return logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logLevel,
Colorful: true,
})
}
func openDatabaseConnection(dialector gorm.Dialector, logger logger.Interface) (*gorm.DB, error) {
db, err := gorm.Open(dialector, &gorm.Config{Logger: logger})
if err != nil {
return nil, err
}
err = db.AutoMigrate(&models.ScoreMessage{})
if err != nil {
return nil, err
}
return db, nil
}
func NewRepositories(opts ...Option) (Repositories, error) {
repos := Repositories{}
for _, opt := range opts {
opt(&repos)
}
dbLogger := createLogger(repos.logLevel)
// no connector given, use in memory sqlite db
if repos.dialector == nil {
WithSqlite("file::memory:")(&repos)
}
db, err := openDatabaseConnection(*repos.dialector, dbLogger)
if err != nil {
return Repositories{}, err
}
repos.ScoreMessages = ScoreMessageRepository{db: db}
return repos, nil
}
func WithLogLevel(logLevel logger.LogLevel) Option {
return func(repos *Repositories) {
repos.logLevel = logLevel
}
}
func WithSqlite(dsn string) Option {
return func(repos *Repositories) {
conn := sqlite.Open(dsn)
repos.dialector = &conn
}
}
package repositories
import (
"testing"
"github.com/stretchr/testify/assert"
"gorm.io/gorm/logger"
)
func TestDatabaseReturnsErrorOnInvalidDsn(t *testing.T) {
_, err := NewRepositories(WithLogLevel(logger.Silent), WithSqlite("asdfasfd/asdfasfda"))
assert.EqualError(t, err, "unable to open database file: no such file or directory")
}
package repositories
import (
"sort"
"github.com/rjnienaber/server-sent-events/internal/models"
"gorm.io/gorm"
)
type ScoreMessageFinder interface {
FindByExamId(examId int) (scores []models.ScoreMessage, err error)
GroupByExam() (examScores []models.ExamScores, err error)
}
type ScoreMessageSaver interface {
Save(score models.ScoreMessage) error
}
type ScoreMessageRepository struct {
db *gorm.DB
}
func (r ScoreMessageRepository) FindAll() (scores []models.ScoreMessage, err error) {
tx := r.db.Find(&scores)
return scores, tx.Error
}
func (r ScoreMessageRepository) Save(score models.ScoreMessage) error {
return r.db.Save(&score).Error
}
func (r ScoreMessageRepository) SaveAll(scores []models.ScoreMessage) error {
return r.db.Save(&scores).Error
}
func (r ScoreMessageRepository) FindByExamId(examId int) (scores []models.ScoreMessage, err error) {
err = r.db.Where("exam = ?", examId).Find(&scores).Error
return
}
func (r ScoreMessageRepository) FindByStudentId(studentId string) (scores []models.ScoreMessage, err error) {
err = r.db.Where("student_id = ?", studentId).Find(&scores).Error
return
}
func (r ScoreMessageRepository) GroupByExam() (examScores []models.ExamScores, err error) {
scores, err := r.FindAll()
if err != nil {
return
}
examIds := []int{}
groups := map[int][]models.ScoreMessage{}
for _, score := range scores {
if examScores, ok := groups[score.Exam]; ok {
groups[score.Exam] = append(examScores, score)
} else {
groups[score.Exam] = []models.ScoreMessage{score}
examIds = append(examIds, score.Exam)
}
}
sort.Ints(examIds)
for _, examId := range examIds {
scores := groups[examId]
examScores = append(examScores, models.ExamScores{Exam: examId, Scores: scores})
}
return
}
package repositories
import (
"testing"
"github.com/rjnienaber/server-sent-events/internal/models"
"github.com/stretchr/testify/assert"
)
func createRepository(t *testing.T) ScoreMessageRepository {
repos, err := NewRepositories()
assert.NoError(t, err)
return repos.ScoreMessages
}
func TestScoreMessagesSaveAndFind(t *testing.T) {
repo := createRepository(t)
msg := models.ScoreMessage{StudentId: "John.Doe", Exam: 123, Score: 0.75}
err := repo.Save(msg)
assert.NoError(t, err)
records, err := repo.FindAll()
assert.NoError(t, err)
assert.Len(t, records, 1)
assert.Equal(t, msg.StudentId, records[0].StudentId)
assert.Equal(t, msg.Exam, records[0].Exam)
assert.Equal(t, msg.Score, records[0].Score)
}
func TestScoreMessagesFindByExamId(t *testing.T) {
repo := createRepository(t)
msg := models.ScoreMessage{StudentId: "John.Doe", Exam: 123, Score: 0.75}
err := repo.Save(msg)
assert.NoError(t, err)
records, err := repo.FindByExamId(123)
assert.NoError(t, err)
assert.Len(t, records, 1)
assert.Equal(t, msg.StudentId, records[0].StudentId)
assert.Equal(t, msg.Exam, records[0].Exam)
assert.Equal(t, msg.Score, records[0].Score)
}
func TestScoreMessagesFindByStudentId(t *testing.T) {
repo := createRepository(t)
msg := models.ScoreMessage{StudentId: "John.Doe", Exam: 123, Score: 0.75}
err := repo.Save(msg)
assert.NoError(t, err)
records, err := repo.FindByStudentId("John.Doe")
assert.NoError(t, err)
assert.Len(t, records, 1)
assert.Equal(t, msg.StudentId, records[0].StudentId)
assert.Equal(t, msg.Exam, records[0].Exam)
assert.Equal(t, msg.Score, records[0].Score)
}
func TestScoreMessagesGroupByExam(t *testing.T) {
repo := createRepository(t)
msgs := []models.ScoreMessage{
{StudentId: "John.Doe", Exam: 123, Score: 0.75},
{StudentId: "Jane.Doe", Exam: 456, Score: 0.8},
{StudentId: "Josiah.Doe", Exam: 123, Score: 0.55},
}
err := repo.SaveAll(msgs)
assert.NoError(t, err)
examScores, err := repo.GroupByExam()
assert.NoError(t, err)
assert.Len(t, examScores, 2)
examOne := examScores[0]
assert.Equal(t, 123, examOne.Exam)
assert.Len(t, examOne.Scores, 2)
assert.Equal(t, "John.Doe", examOne.Scores[0].StudentId)
assert.Equal(t, "Josiah.Doe", examOne.Scores[1].StudentId)
examTwo := examScores[1]
assert.Equal(t, 456, examTwo.Exam)
assert.Len(t, examTwo.Scores, 1)
assert.Equal(t, "Jane.Doe", examTwo.Scores[0].StudentId)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment