Skip to content

Instantly share code, notes, and snippets.

@RyanJarv
Last active July 4, 2021 03:44
Show Gist options
  • Save RyanJarv/03354c9997854772f52c51dc18e7d890 to your computer and use it in GitHub Desktop.
Save RyanJarv/03354c9997854772f52c51dc18e7d890 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/smithy-go/middleware"
awshttp "github.com/aws/smithy-go/transport/http"
"time"
)
// TmpFix is a workaround for the client resetting the tcp connection after each request.
// Unsure what is going on here but this fixes it for whatever reason...
type TmpFix struct{}
func (r *TmpFix) ID() string {
return "TmpFix"
}
func (r *TmpFix) HandleDeserialize(
ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
) (
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
) {
out, metadata, err = next.HandleDeserialize(ctx, in)
resp, ok := out.RawResponse.(*awshttp.Response)
if !ok {
return out, metadata, fmt.Errorf("unknown transport type %T", out.RawResponse)
}
var buf bytes.Buffer
if _, err = buf.ReadFrom(resp.Body); err != nil {
//save = nil
return middleware.DeserializeOutput{}, middleware.Metadata{}, err
}
resp.Body = io.NopCloser(bytes.NewReader(buf.Bytes()))
return out, metadata, err
}
func main() {
fmt.Println("SDK Version:", aws.SDKVersion)
err := run()
if err != nil {
fmt.Print(err)
}
}
func run() error {
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
return err
}
s3client := s3.NewFromConfig(cfg)
fmt.Println("running default GetObjectAcl")
start := time.Now()
for i := 0; i < 30; i++ {
_, err := s3client.GetObjectAcl(context.Background(), &s3.GetObjectAclInput{
Bucket: aws.String("golang-sdkv2-tcp-reset-bug"),
Key: aws.String("test"),
})
if err != nil {
return err
}
}
end := time.Now()
fmt.Println("Seconds elapsed: ", end.Second() - start.Second())
fmt.Println("running GetObjectAcl with fix")
start = time.Now()
for i := 0; i < 30; i++ {
_, err := s3client.GetObjectAcl(context.Background(), &s3.GetObjectAclInput{
Bucket: aws.String("golang-sdkv2-tcp-reset-bug"),
Key: aws.String("test"),
}, func(opts *s3.Options) {
opts.APIOptions = []func(*middleware.Stack) error{
func(stack *middleware.Stack) error {
return stack.Deserialize.Add(&TmpFix{}, middleware.After)
},
}
})
if err != nil {
return err
}
}
end = time.Now()
fmt.Println("Seconds elapsed: ", end.Second() - start.Second())
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment