Skip to content

Instantly share code, notes, and snippets.

@valsteen
Created November 6, 2022 12:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save valsteen/60b39247457d11325cc433afe938ba94 to your computer and use it in GitHub Desktop.
Save valsteen/60b39247457d11325cc433afe938ba94 to your computer and use it in GitHub Desktop.
Chaining calls until something fails, with error mapping
package main
import (
"errors"
"fmt"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func identity[T any](t T) T { return t }
func Chain[I, R, R1 any](
input I,
prelude func(I) (R, error),
success1 func(R) (R1, error),
) (R1, error) {
return ChainMapErr(input, prelude, identity[error], success1, identity[error])
}
func Chain2[I, R, R1, R2 any](
input I,
prelude func(I) (R, error),
success1 func(R) (R1, error),
success2 func(R1) (R2, error),
) (R2, error) {
return ChainMapErr2(input, prelude, identity[error], success1, identity[error], success2, identity[error])
}
func ChainMapErr[I, R, R1 any](
input I,
prelude func(I) (R, error),
mapError1 func(error) error,
success1 func(R) (R1, error),
mapError2 func(error) error,
) (R1, error) {
var zero R1
first, err := prelude(input)
if err != nil {
return zero, mapError1(err)
}
ret, err := success1(first)
if err != nil {
return zero, mapError2(err)
}
return ret, nil
}
func ChainMapErr2[I, R, R1, R2 any](
input I,
prelude func(I) (R, error),
mapError1 func(error) error,
success1 func(R) (R1, error),
mapError2 func(error) error,
success2 func(R1) (R2, error),
mapError3 func(error) error,
) (R2, error) {
first, err := ChainMapErr(input, prelude, mapError1, success1, mapError2)
var zero R2
if err != nil {
return zero, err
}
ret2, err := success2(first)
if err != nil {
return zero, mapError3(err)
}
return ret2, nil
}
func process(num int) (int, error) {
if num < 0 {
return 0, errors.New("cannot be negative")
}
return num * 3, nil
}
func render(num int) (string, error) {
if num > 1000 {
return "", errors.New("too large")
}
return strconv.Itoa(num), nil
}
func Test(t *testing.T) {
type TestCase struct {
input string
result string
errString string
mappedErrString string
}
for _, testcase := range []TestCase{
{
"1",
"3",
"",
"",
},
{
"invalidint",
"",
"strconv.Atoi: parsing \"invalidint\": invalid syntax",
"error while parsing: strconv.Atoi: parsing \"invalidint\": invalid syntax",
},
{
"-3",
"",
"cannot be negative",
"error while processing: cannot be negative",
},
{
"340",
"",
"too large",
"error while rendering: too large",
},
} {
result, err := Chain2(testcase.input, strconv.Atoi, process, render)
require.Equal(t, result, testcase.result)
if err != nil {
require.Equal(t, testcase.errString, err.Error())
} else {
require.Empty(t, testcase.errString)
}
result, err = ChainMapErr2(
testcase.input,
strconv.Atoi,
func(e error) error { return fmt.Errorf("error while parsing: %w", e) },
process,
func(e error) error { return fmt.Errorf("error while processing: %w", e) },
render,
func(e error) error { return fmt.Errorf("error while rendering: %w", e) },
)
require.Equal(t, result, testcase.result)
if err != nil {
require.Equal(t, testcase.mappedErrString, err.Error())
} else {
require.Empty(t, testcase.mappedErrString)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment