Skip to content

Instantly share code, notes, and snippets.

@hbjydev
Created April 10, 2024 21:34
Show Gist options
  • Save hbjydev/f42de0b6446dc886a3e7f6aee945615c to your computer and use it in GitHub Desktop.
Save hbjydev/f42de0b6446dc886a3e7f6aee945615c to your computer and use it in GitHub Desktop.
package tmp
import (
"context"
"encoding/json"
"fmt"
"os"
"reflect"
"github.com/go-kit/log"
)
func errorHandler(e error) func(context.Context, interface{}) (interface{}, error) {
return func(context.Context, interface{}) (interface{}, error) {
return nil, e
}
}
func WrapHandler(handlerFunc interface{}) interface{} {
if handlerFunc == nil {
return errorHandler(fmt.Errorf("handler is nil"))
}
handlerType := reflect.TypeOf(handlerFunc)
if handlerType.Kind() != reflect.Func {
return errorHandler(fmt.Errorf("handler kind %s is not %s", handlerType.Kind(), reflect.Func))
}
takesContext, err := validateArguments(handlerType)
if err != nil {
return errorHandler(err)
}
if handlerType.NumIn() == 0 || handlerType.NumIn() == 1 && takesContext {
return func(ctx context.Context) (interface{}, error) {
var temp *interface{}
event := reflect.ValueOf(temp)
return wrap(ctx, handlerFunc, []byte{}, event, takesContext)
}
}
return func(ctx context.Context, payload interface{}) (interface{}, error) {
event := reflect.New(handlerType.In(handlerType.NumIn() - 1))
remarshalledPayload, err := json.Marshal(payload)
if err != nil {
return nil, err
}
if err := json.Unmarshal(remarshalledPayload, event.Interface()); err != nil {
return nil, err
}
return wrap(ctx, handlerFunc, remarshalledPayload, event.Elem(), takesContext)
}
}
func validateArguments(handler reflect.Type) (bool, error) {
handlerTakesContext := false
if handler.NumIn() > 2 {
return false, fmt.Errorf("handlers may not take more than two arguments, but handler takes %d", handler.NumIn())
} else if handler.NumIn() > 0 {
contextType := reflect.TypeOf((*context.Context)(nil)).Elem()
argumentType := handler.In(0)
handlerTakesContext = argumentType.Implements(contextType)
if handler.NumIn() > 1 && !handlerTakesContext {
return false, fmt.Errorf("handler takes two arguments, but the first is not Context. got %s", argumentType.Kind())
}
}
return handlerTakesContext, nil
}
func eventExists(event interface{}) bool {
if event == nil {
return false
}
// reflect.Value.isNil() can only be called on Values of certain Kinds.
// Unsupported Kinds will panic rather than return false
switch reflect.TypeOf(event).Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return !reflect.ValueOf(event).IsNil()
}
return true
}
func wrapper(handlerFunc interface{}) func(ctx context.Context, eventJSON []byte, event interface{}, takesContext bool) []reflect.Value {
return func(ctx context.Context, eventJSON []byte, event interface{}, takesContext bool) []reflect.Value {
handler := reflect.ValueOf(handlerFunc)
var args []reflect.Value
if takesContext {
args = append(args, reflect.ValueOf(ctx))
}
if eventExists(event) {
args = append(args, reflect.ValueOf(event))
}
response := handler.Call(args)
return response
}
}
func wrap(ctx context.Context, handlerFunc interface{}, eventJSON []byte, event reflect.Value, takesContext bool) (interface{}, error) {
wrappedLambdaHandler := reflect.ValueOf(wrapper(handlerFunc))
ctx = SetLogger(ctx, log.NewLogfmtLogger(os.Stderr))
argsWrapped := []reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(eventJSON), event, reflect.ValueOf(takesContext)}
response := wrappedLambdaHandler.Call(argsWrapped)[0].Interface().([]reflect.Value)
var err error
if len(response) > 0 {
if errVal, ok := response[len(response)-1].Interface().(error); ok {
err = errVal
}
}
var val interface{}
if len(response) > 1 {
val = response[0].Interface()
}
return val, err
}
package tmp_test
import (
"context"
"fmt"
"os"
"reflect"
"testing"
"tmp"
"github.com/aws/aws-lambda-go/lambdacontext"
"github.com/stretchr/testify/assert"
)
var mockContext = lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{})
func TestLambdaHandlerSignatures(t *testing.T) {
setEnvVars()
testCases := []struct {
name string
handler interface{}
expected error
args []reflect.Value
}{
{
name: "simple handler",
expected: nil,
handler: func(ctx context.Context, in string) (string, error) {
l := tmp.GetLogger(ctx)
l.Log("hello", "world")
return fmt.Sprintf("Hello, %s", in), nil
},
args: []reflect.Value{reflect.ValueOf(mockContext), reflect.ValueOf("")},
},
}
for i, testCase := range testCases {
testCase := testCase
t.Run(fmt.Sprintf("testCase[%d] %s", i, testCase.name), func (t *testing.T) {
lambdaHandler := tmp.WrapHandler(testCase.handler)
handler := reflect.ValueOf(lambdaHandler)
resp := handler.Call(testCase.args)
assert.Equal(t, 2, len(resp))
assert.Equal(t, testCase.expected, resp[1].Interface())
})
}
}
func setEnvVars() {
_ = os.Setenv("AWS_LAMBDA_FUNCTION_NAME", "testFunction")
_ = os.Setenv("AWS_REGION", "us-texas-1")
_ = os.Setenv("AWS_LAMBDA_FUNCTION_VERSION", "$LATEST")
_ = os.Setenv("AWS_LAMBDA_LOG_STREAM_NAME", "2023/01/01/[$LATEST]5d1edb9e525d486696cf01a3503487bc")
_ = os.Setenv("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128")
_ = os.Setenv("_X_AMZN_TRACE_ID", "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1")
}
package tmp
import (
"context"
"os"
"github.com/go-kit/log"
)
type loggerContextKeyType int
const loggerContextKey loggerContextKeyType = iota
func SetLogger(ctx context.Context, logger log.Logger) context.Context {
return context.WithValue(ctx, loggerContextKey, logger)
}
func GetLogger(ctx context.Context) log.Logger {
logger, ok := ctx.Value(loggerContextKey).(log.Logger)
if !ok {
logger = log.NewLogfmtLogger(os.Stderr)
}
return logger
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment