Skip to content

Instantly share code, notes, and snippets.

@17twenty
Last active February 17, 2024 05:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 17twenty/4aea5eeb7bcb35ee098293f470114972 to your computer and use it in GitHub Desktop.
Save 17twenty/4aea5eeb7bcb35ee098293f470114972 to your computer and use it in GitHub Desktop.
jsonfile: Use `jsonfile` to persist a Go value to a JSON file.
// Copyright (c) David Crawshaw
// SPDX-License-Identifier: BSD-3-Clause
// Package jsonfile persists a Go value to a JSON file.
package jsonfile
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
)
// JSONFile holds a Go value of type Data and persists it to a JSON file.
// Data is accessed and modified using the Read and Write methods.
// Create a JSONFile using the New or Load functions.
type JSONFile[Data any] struct {
path string
mu sync.RWMutex
bytes []byte
data *Data
}
// New creates a new empty JSONFile at the given path.
func New[Data any](path string) (*JSONFile[Data], error) {
p := &JSONFile[Data]{path: path, bytes: []byte("{}"), data: new(Data)}
if err := p.Write(func(*Data) error { return nil }); err != nil {
return nil, fmt.Errorf("jsonfile.New: %w", err)
}
return p, nil
}
// Load loads an existing JSONFileData from the given path.
//
// If the file does not exist, Load returns an error that can be
// checked with os.IsNotExist.
//
// Load and New are separate to avoid creating a new file when
// starting a service, which could lead to data loss. To both load an
// existing file or create it (which you may want to do in a development
// environment), combine Load with New, like this:
//
// db, err := jsonfile.Load[Data](path)
// if os.IsNotExist(err) {
// db, err = jsonfile.New[Data](path)
// }
func Load[Data any](path string) (*JSONFile[Data], error) {
p := &JSONFile[Data]{path: path, data: new(Data)}
var err error
p.bytes, err = os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("jsonfile.Load: %w", err)
}
if err := json.Unmarshal(p.bytes, p.data); err != nil {
return nil, fmt.Errorf("jsonfile.Load: %w", err)
}
return p, nil
}
// Read calls fn with the current copy of the data.
func (p *JSONFile[Data]) Read(fn func(data *Data)) {
p.mu.RLock()
defer p.mu.RUnlock()
fn(p.data)
}
// Write calls fn with a copy of the data, then writes the changes to the file.
// If fn returns an error, Write does not change the file and returns the error.
func (p *JSONFile[Data]) Write(fn func(*Data) error) error {
p.mu.Lock()
defer p.mu.Unlock()
data := new(Data) // operate on copy to allow concurrent reads and rollback
if err := json.Unmarshal(p.bytes, data); err != nil {
return fmt.Errorf("JSONFile.Write: %w", err)
}
if err := fn(data); err != nil {
return err
}
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("JSONFile.Write: %w", err)
}
if bytes.Equal(b, p.bytes) {
return nil // no change
}
f, err := os.CreateTemp(filepath.Dir(p.path), filepath.Base(p.path)+".tmp")
if err != nil {
return fmt.Errorf("JSONFile.Write: temp: %w", err)
}
_, err = f.Write(b)
if err1 := f.Close(); err1 != nil && err == nil {
err = err1
}
if err != nil {
return fmt.Errorf("JSONFile.Write: %w", err)
}
if err := os.Rename(f.Name(), p.path); err != nil {
return fmt.Errorf("JSONFile.Write: rename: %w", err)
}
data = new(Data) // avoid any aliased memory
if err := json.Unmarshal(b, data); err != nil {
return fmt.Errorf("JSONFile.Write: %w", err)
}
p.data = data
p.bytes = b
return nil
}
// Copyright (c) David Crawshaw
// SPDX-License-Identifier: BSD-3-Clause
package jsonfile
import (
"errors"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
)
func mustWrite[Data any](t *testing.T, data *JSONFile[Data], fn func(db *Data)) {
t.Helper()
if err := data.Write(func(db *Data) error { fn(db); return nil }); err != nil {
t.Fatal(err)
}
}
func TestBasic(t *testing.T) {
t.Parallel()
type Data struct {
Name string
Friends []string
Ages map[string]int
}
want := Data{
Name: "Alice",
Friends: []string{"Bob", "Carol", "Dave"},
Ages: map[string]int{"Bob": 25, "Carol": 30, "Dave": 35},
}
path := filepath.Join(t.TempDir(), "testbasic.json")
data, err := New[Data](path)
if err != nil {
t.Fatal(err)
}
mustWrite(t, data, func(data *Data) {
data.Name = want.Name
data.Friends = append([]string{}, want.Friends...)
data.Ages = make(map[string]int, len(want.Ages))
for k, v := range want.Ages {
data.Ages[k] = v
}
})
mustWrite(t, data, func(*Data) {}) // noop
data.Read(func(data *Data) {
if !reflect.DeepEqual(*data, want) {
t.Errorf("got %+v, want %+v", *data, want)
}
})
data, err = Load[Data](path)
if err != nil {
t.Fatal(err)
}
data.Read(func(data *Data) {
if !reflect.DeepEqual(*data, want) {
t.Errorf("got %+v, want %+v", *data, want)
}
})
}
func TestRollbackOnProgramError(t *testing.T) {
t.Parallel()
type DB struct{ Val int }
path := filepath.Join(t.TempDir(), "testrollback.json")
db, err := New[DB](path)
if err != nil {
t.Fatal(err)
}
mustWrite(t, db, func(db *DB) { db.Val = 3 })
mustWrite(t, db, func(db *DB) { db.Val = 1 })
var rollbackErr = fmt.Errorf("rollback")
if err := db.Write(func(db *DB) error {
db.Val = 2
return rollbackErr
}); err == nil || !errors.Is(err, rollbackErr) {
t.Fatalf("Write err=%v, want %v", err, rollbackErr)
}
db.Read(func(db *DB) {
if db.Val != 1 {
t.Fatalf("Val = %d after rollback, want 1", db.Val)
}
})
}
func TestFileError(t *testing.T) {
t.Parallel()
type DB struct{ Val int }
path := filepath.Join(t.TempDir(), "tstdir", "testfserr.json")
os.MkdirAll(filepath.Dir(path), 0777)
db, err := New[DB](path)
if err != nil {
t.Fatal(err)
}
mustWrite(t, db, func(db *DB) { db.Val = 1 })
if err := os.Chmod(filepath.Dir(path), 0500); err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
os.Chmod(filepath.Dir(path), 0700)
})
if err := db.Write(func(db *DB) error {
db.Val = 2
return nil
}); err == nil || !errors.Is(err, os.ErrPermission) {
t.Fatalf("Write err=%v, want %v", err, os.ErrPermission)
}
db.Read(func(db *DB) {
if db.Val != 1 {
t.Fatalf("Val = %d after rollback, want 1", db.Val)
}
})
}
func TestNoAlias(t *testing.T) {
t.Parallel()
type DB struct{ Vals []int }
path := filepath.Join(t.TempDir(), "testrollback.json")
db, err := New[DB](path)
if err != nil {
t.Fatal(err)
}
someVals := []int{1, 2, 3}
mustWrite(t, db, func(db *DB) { db.Vals = someVals })
checkVals := func(db *DB) {
if !reflect.DeepEqual(db.Vals, []int{1, 2, 3}) {
t.Fatalf("Vals = %v want 1, 2, 3", db.Vals)
}
}
db.Read(checkVals)
someVals[0] = 10
db.Read(checkVals) // db.Vals not aliasing someVals
}
func TestBadLoad(t *testing.T) {
t.Parallel()
type DB struct{ Val int }
path := filepath.Join(t.TempDir(), "testbadload.json")
_, err := Load[DB](path)
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("Load err=%v, want %v", err, os.ErrNotExist)
}
if err := os.WriteFile(path, []byte("not json"), 0666); err != nil {
t.Fatal(err)
}
_, err = Load[DB](path)
if err == nil || !strings.Contains(err.Error(), "invalid character") {
t.Fatalf("Load err=%v, want error", err)
}
}
func TestBadNew(t *testing.T) {
t.Parallel()
type DB struct{ Val int }
_, err := New[DB](t.TempDir())
if !errors.Is(err, os.ErrExist) {
t.Fatalf("New err=%v, want %v", err, os.ErrExist)
}
}
type JSONFile
func Load[Data any](path string) (*JSONFile[Data], error)
func New[Data any](path string) (*JSONFile[Data], error)
func (p *JSONFile[Data]) Read(fn func(data *Data))
func (p *JSONFile[Data]) Write(fn func(*Data) error) error

There is a bit more thought put into the few lines of code in this repository than you might expect. For more details, see the blog post.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment