Created
February 10, 2022 11:39
-
-
Save pwood/3cbfd86030339ea7d47bb3ff6a92e00f to your computer and use it in GitHub Desktop.
Go AWS Lambda Generic Prototype
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package lambdawrap | |
import ( | |
"context" | |
"encoding/json" | |
"fmt" | |
"github.com/aws/aws-lambda-go/events" | |
"github.com/stretchr/testify/assert" | |
"io" | |
"io/ioutil" | |
"net/http" | |
"net/http/httptest" | |
"strings" | |
"testing" | |
"testing/iotest" | |
) | |
func SNS[O any](next func(context.Context, O) ([]byte, error)) func(ctx context.Context, event events.SNSEvent) ([]byte, error) { | |
return func(ctx context.Context, event events.SNSEvent) ([]byte, error) { | |
var res []byte | |
for _, r := range event.Records { | |
v, err := optionalUnmarshal[O]([]byte(r.SNS.Message)) | |
d, err := next(ctx, v) | |
if err != nil { | |
return nil, fmt.Errorf("child failure: %w", err) | |
} | |
res = append(res, d...) | |
} | |
return res, nil | |
} | |
} | |
func SQS[O any](next func(context.Context, O) ([]byte, error)) func(ctx context.Context, event events.SQSEvent) ([]byte, error) { | |
return func(ctx context.Context, event events.SQSEvent) ([]byte, error) { | |
var res []byte | |
for _, r := range event.Records { | |
v, err := optionalUnmarshal[O]([]byte(r.Body)) | |
d, err := next(ctx, v) | |
if err != nil { | |
return nil, fmt.Errorf("child failure: %w", err) | |
} | |
res = append(res, d...) | |
} | |
return res, nil | |
} | |
} | |
func optionalUnmarshal[O any](data []byte) (O, error) { | |
v := new(O) | |
switch p := any(v).(type) { | |
case *[]byte: | |
*p = data | |
return *v, nil | |
} | |
err := json.Unmarshal(data, v) | |
if err != nil { | |
return *new(O), fmt.Errorf("couldn't unmarshal: %w", err) | |
} | |
return *v, nil | |
} | |
type Codec interface { | |
Marshal(v any) ([]byte, error) | |
Unmarshal(data []byte, v any) error | |
} | |
type jsonCodec struct { | |
} | |
var JSONCodec jsonCodec | |
func (j jsonCodec) Marshal(v any) ([]byte, error) { | |
return json.Marshal(v) | |
} | |
func (j jsonCodec) Unmarshal(data []byte, v any) error { | |
return json.Unmarshal(data, v) | |
} | |
func DomainObject[I any, O any](next func(context.Context, I) (O, error), c Codec) func(ctx context.Context, n []byte) ([]byte, error) { | |
return func(ctx context.Context, n []byte) ([]byte, error) { | |
in := new(I) | |
err := c.Unmarshal(n, in) | |
if err != nil { | |
return nil, fmt.Errorf("codec unmarshal failure: %w", err) | |
} | |
ret, err := next(ctx, *in) | |
if err != nil { | |
return nil, fmt.Errorf("child failure: %w", err) | |
} | |
data, err := c.Marshal(ret) | |
if err != nil { | |
return nil, fmt.Errorf("codec marshal failure: %w", err) | |
} | |
return data, nil | |
} | |
} | |
func DynamoDB(next func(context.Context, events.DynamoDBEventRecord) ([]byte, error)) func(ctx context.Context, event events.DynamoDBEvent) ([]byte, error) { | |
return func(ctx context.Context, event events.DynamoDBEvent) ([]byte, error) { | |
var res []byte | |
for _, r := range event.Records { | |
d, err := next(ctx, r) | |
if err != nil { | |
return nil, fmt.Errorf("child failure: %w", err) | |
} | |
res = append(res, d...) | |
} | |
return res, nil | |
} | |
} | |
func S3Notification(next func(context.Context, events.S3EventRecord) ([]byte, error)) func(ctx context.Context, event events.S3Event) ([]byte, error) { | |
return func(ctx context.Context, event events.S3Event) ([]byte, error) { | |
var res []byte | |
for _, r := range event.Records { | |
d, err := next(ctx, r) | |
if err != nil { | |
return nil, fmt.Errorf("child failure: %w", err) | |
} | |
res = append(res, d...) | |
} | |
return res, nil | |
} | |
} | |
func S3FetchReader(next func(context.Context, io.Reader) ([]byte, error)) func(ctx context.Context, event events.S3EventRecord) ([]byte, error) { | |
return func(ctx context.Context, event events.S3EventRecord) ([]byte, error) { | |
// Fetch from S3, and get reader. | |
s3Reader := iotest.ErrReader(io.EOF) | |
return next(ctx, s3Reader) | |
} | |
} | |
func ReaderToBytes(next func(context.Context, []byte) ([]byte, error)) func(ctx context.Context, r io.Reader) ([]byte, error) { | |
return func(ctx context.Context, r io.Reader) ([]byte, error) { | |
data, err := ioutil.ReadAll(r) | |
if err != nil { | |
return nil, fmt.Errorf("reader to bytes: failure to read all: %w", err) | |
} | |
return next(ctx, data) | |
} | |
} | |
func APIGatewayProxyToHandler(handler func(http.ResponseWriter, *http.Request)) func(ctx context.Context, event events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { | |
return func(ctx context.Context, event events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { | |
r := convertAPIGatewayProxyRequestToGoHttpRequest(ctx, event) | |
w := httptest.NewRecorder() | |
handler(w, r) | |
return convertGoHttpResponseToAPIGatewayProxyResponse(w) | |
} | |
} | |
func convertAPIGatewayProxyRequestToGoHttpRequest(ctx context.Context, event events.APIGatewayProxyRequest) *http.Request { | |
r := new(http.Request) | |
r.Body = io.NopCloser(strings.NewReader(event.Body)) | |
r.Method = event.HTTPMethod | |
// need to canonocialise these | |
r.Header = event.MultiValueHeaders | |
r.ContentLength = int64(len(event.Body)) | |
// etc | |
r = r.WithContext(ctx) | |
return r | |
} | |
func convertGoHttpResponseToAPIGatewayProxyResponse(w *httptest.ResponseRecorder) (events.APIGatewayProxyResponse, error) { | |
r := w.Result() | |
var data []byte | |
if r.Body != nil { | |
d, err := io.ReadAll(r.Body) | |
if err != nil { | |
return events.APIGatewayProxyResponse{}, err | |
} | |
if err := r.Body.Close(); err != nil { | |
return events.APIGatewayProxyResponse{}, err | |
} | |
data = d | |
if len(d) > 6*1024*1024 { | |
return events.APIGatewayProxyResponse{}, fmt.Errorf("response too large: %d > %d", len(d), 6*1024*1024) | |
} | |
} | |
return events.APIGatewayProxyResponse{ | |
StatusCode: r.StatusCode, | |
MultiValueHeaders: r.Header, | |
Body: string(data), | |
}, nil | |
} | |
type DomainInput struct { | |
Account int | |
Delta float64 | |
} | |
type DomainOutput struct { | |
Success bool | |
} | |
func Test(t *testing.T) { | |
myBusinessLogic := func(_ context.Context, _ DomainInput) (DomainOutput, error) { | |
return DomainOutput{Success: true}, nil | |
} | |
t.Run("example usage", func(t *testing.T) { | |
// Set up example record that would be passed by lambda.Start() | |
data, _ := json.Marshal(DomainInput{Account: 1, Delta: -10}) | |
data, _ = json.Marshal(events.SQSEvent{Records: []events.SQSMessage{{Body: string(data)}}}) | |
snsEvent := events.SNSEvent{Records: []events.SNSEventRecord{{SNS: events.SNSEntity{Message: string(data)}}}} | |
// Define our chain of expected AWS events | |
wrapChain := SNS(SQS(DomainObject(myBusinessLogic, JSONCodec))) | |
// Call it like lambda.Start() would | |
data, err := wrapChain(context.Background(), snsEvent) | |
assert.NoError(t, err) | |
// Print response. | |
fmt.Println(string(data)) | |
// lambda.StartWithContext(context.Background(), SNS(SQS(DomainObject(myBusinessLogic)))) | |
_ = SQS(S3Notification(S3FetchReader(ReaderToBytes(DomainObject(myBusinessLogic, JSONCodec))))) | |
_ = APIGatewayProxyToHandler(func(writer http.ResponseWriter, request *http.Request) { | |
}) | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment