Skip to content

Instantly share code, notes, and snippets.

@pwood
Created February 10, 2022 11:39
Show Gist options
  • Save pwood/3cbfd86030339ea7d47bb3ff6a92e00f to your computer and use it in GitHub Desktop.
Save pwood/3cbfd86030339ea7d47bb3ff6a92e00f to your computer and use it in GitHub Desktop.
Go AWS Lambda Generic Prototype
package lambdawrap
import (
"context"
"encoding/json"
"fmt"
"github.com/aws/aws-lambda-go/events"
"github.com/stretchr/testify/assert"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"testing/iotest"
)
func SNS[O any](next func(context.Context, O) ([]byte, error)) func(ctx context.Context, event events.SNSEvent) ([]byte, error) {
return func(ctx context.Context, event events.SNSEvent) ([]byte, error) {
var res []byte
for _, r := range event.Records {
v, err := optionalUnmarshal[O]([]byte(r.SNS.Message))
d, err := next(ctx, v)
if err != nil {
return nil, fmt.Errorf("child failure: %w", err)
}
res = append(res, d...)
}
return res, nil
}
}
func SQS[O any](next func(context.Context, O) ([]byte, error)) func(ctx context.Context, event events.SQSEvent) ([]byte, error) {
return func(ctx context.Context, event events.SQSEvent) ([]byte, error) {
var res []byte
for _, r := range event.Records {
v, err := optionalUnmarshal[O]([]byte(r.Body))
d, err := next(ctx, v)
if err != nil {
return nil, fmt.Errorf("child failure: %w", err)
}
res = append(res, d...)
}
return res, nil
}
}
func optionalUnmarshal[O any](data []byte) (O, error) {
v := new(O)
switch p := any(v).(type) {
case *[]byte:
*p = data
return *v, nil
}
err := json.Unmarshal(data, v)
if err != nil {
return *new(O), fmt.Errorf("couldn't unmarshal: %w", err)
}
return *v, nil
}
type Codec interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
}
type jsonCodec struct {
}
var JSONCodec jsonCodec
func (j jsonCodec) Marshal(v any) ([]byte, error) {
return json.Marshal(v)
}
func (j jsonCodec) Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}
func DomainObject[I any, O any](next func(context.Context, I) (O, error), c Codec) func(ctx context.Context, n []byte) ([]byte, error) {
return func(ctx context.Context, n []byte) ([]byte, error) {
in := new(I)
err := c.Unmarshal(n, in)
if err != nil {
return nil, fmt.Errorf("codec unmarshal failure: %w", err)
}
ret, err := next(ctx, *in)
if err != nil {
return nil, fmt.Errorf("child failure: %w", err)
}
data, err := c.Marshal(ret)
if err != nil {
return nil, fmt.Errorf("codec marshal failure: %w", err)
}
return data, nil
}
}
func DynamoDB(next func(context.Context, events.DynamoDBEventRecord) ([]byte, error)) func(ctx context.Context, event events.DynamoDBEvent) ([]byte, error) {
return func(ctx context.Context, event events.DynamoDBEvent) ([]byte, error) {
var res []byte
for _, r := range event.Records {
d, err := next(ctx, r)
if err != nil {
return nil, fmt.Errorf("child failure: %w", err)
}
res = append(res, d...)
}
return res, nil
}
}
func S3Notification(next func(context.Context, events.S3EventRecord) ([]byte, error)) func(ctx context.Context, event events.S3Event) ([]byte, error) {
return func(ctx context.Context, event events.S3Event) ([]byte, error) {
var res []byte
for _, r := range event.Records {
d, err := next(ctx, r)
if err != nil {
return nil, fmt.Errorf("child failure: %w", err)
}
res = append(res, d...)
}
return res, nil
}
}
func S3FetchReader(next func(context.Context, io.Reader) ([]byte, error)) func(ctx context.Context, event events.S3EventRecord) ([]byte, error) {
return func(ctx context.Context, event events.S3EventRecord) ([]byte, error) {
// Fetch from S3, and get reader.
s3Reader := iotest.ErrReader(io.EOF)
return next(ctx, s3Reader)
}
}
func ReaderToBytes(next func(context.Context, []byte) ([]byte, error)) func(ctx context.Context, r io.Reader) ([]byte, error) {
return func(ctx context.Context, r io.Reader) ([]byte, error) {
data, err := ioutil.ReadAll(r)
if err != nil {
return nil, fmt.Errorf("reader to bytes: failure to read all: %w", err)
}
return next(ctx, data)
}
}
func APIGatewayProxyToHandler(handler func(http.ResponseWriter, *http.Request)) func(ctx context.Context, event events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
return func(ctx context.Context, event events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
r := convertAPIGatewayProxyRequestToGoHttpRequest(ctx, event)
w := httptest.NewRecorder()
handler(w, r)
return convertGoHttpResponseToAPIGatewayProxyResponse(w)
}
}
func convertAPIGatewayProxyRequestToGoHttpRequest(ctx context.Context, event events.APIGatewayProxyRequest) *http.Request {
r := new(http.Request)
r.Body = io.NopCloser(strings.NewReader(event.Body))
r.Method = event.HTTPMethod
// need to canonocialise these
r.Header = event.MultiValueHeaders
r.ContentLength = int64(len(event.Body))
// etc
r = r.WithContext(ctx)
return r
}
func convertGoHttpResponseToAPIGatewayProxyResponse(w *httptest.ResponseRecorder) (events.APIGatewayProxyResponse, error) {
r := w.Result()
var data []byte
if r.Body != nil {
d, err := io.ReadAll(r.Body)
if err != nil {
return events.APIGatewayProxyResponse{}, err
}
if err := r.Body.Close(); err != nil {
return events.APIGatewayProxyResponse{}, err
}
data = d
if len(d) > 6*1024*1024 {
return events.APIGatewayProxyResponse{}, fmt.Errorf("response too large: %d > %d", len(d), 6*1024*1024)
}
}
return events.APIGatewayProxyResponse{
StatusCode: r.StatusCode,
MultiValueHeaders: r.Header,
Body: string(data),
}, nil
}
type DomainInput struct {
Account int
Delta float64
}
type DomainOutput struct {
Success bool
}
func Test(t *testing.T) {
myBusinessLogic := func(_ context.Context, _ DomainInput) (DomainOutput, error) {
return DomainOutput{Success: true}, nil
}
t.Run("example usage", func(t *testing.T) {
// Set up example record that would be passed by lambda.Start()
data, _ := json.Marshal(DomainInput{Account: 1, Delta: -10})
data, _ = json.Marshal(events.SQSEvent{Records: []events.SQSMessage{{Body: string(data)}}})
snsEvent := events.SNSEvent{Records: []events.SNSEventRecord{{SNS: events.SNSEntity{Message: string(data)}}}}
// Define our chain of expected AWS events
wrapChain := SNS(SQS(DomainObject(myBusinessLogic, JSONCodec)))
// Call it like lambda.Start() would
data, err := wrapChain(context.Background(), snsEvent)
assert.NoError(t, err)
// Print response.
fmt.Println(string(data))
// lambda.StartWithContext(context.Background(), SNS(SQS(DomainObject(myBusinessLogic))))
_ = SQS(S3Notification(S3FetchReader(ReaderToBytes(DomainObject(myBusinessLogic, JSONCodec)))))
_ = APIGatewayProxyToHandler(func(writer http.ResponseWriter, request *http.Request) {
})
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment