Skip to content

Instantly share code, notes, and snippets.

@korc
Last active April 21, 2020 02:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save korc/66fe38ad3f285c691648af3d91044cef to your computer and use it in GitHub Desktop.
Save korc/66fe38ad3f285c691648af3d91044cef to your computer and use it in GitHub Desktop.
S3 proxy test
package main
import (
"bytes"
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/http/httptrace"
"net/url"
"os"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/signer/v4"
)
const (
envKey = "AWS_ACCESS_KEY_ID"
envSecret = "AWS_SECRET_ACCESS_KEY"
preSignExpireDuration = time.Minute * 5
amzDateFormat = "20060102T150405Z"
)
func main() {
listenAddr := flag.String("listen", ":8080", "Listen address")
remoteAddr := flag.String("remote", "", "Remote URL")
awsId := flag.String("key", os.Getenv(envKey), "override $"+envKey)
awsSecret := flag.String("secret", os.Getenv(envSecret), "override $"+envSecret)
caFile := flag.String("cafile", "", "CA bundle")
region := flag.String("region", "us-east-1", "Default region")
service := flag.String("service", "s3", "default service")
fixedSignTimeFlag := flag.String("sign-time", "", fmt.Sprintf(
"Fixed signing time in UTC / "+amzDateFormat+" format (now: %s)", time.Now().UTC().Format(amzDateFormat)))
flag.Parse()
if *remoteAddr == "" {
log.Fatal("Need -remote <url> option")
}
if *awsId == "" || *awsSecret == "" {
log.Fatal("Need to set -key or $" + envKey + ", and -secret or $" + envSecret)
}
if err := os.Setenv(envKey, *awsId); err != nil {
log.Fatal("Cannot set environment: ", err)
}
if err := os.Setenv(envSecret, *awsSecret); err != nil {
log.Fatal("Cannot set environment: ", err)
}
remoteUrl, err := url.Parse(*remoteAddr)
if err != nil {
log.Fatal("Cannot parse remote URL: ", err)
}
var fixedSignTime time.Time
if *fixedSignTimeFlag != "" {
var err error
if fixedSignTime, err = time.Parse(amzDateFormat, *fixedSignTimeFlag); err != nil {
log.Fatalf("Cannot parse time %#v as "+amzDateFormat, *fixedSignTimeFlag)
}
fixedSignTime = fixedSignTime.UTC()
}
signer := v4.NewSigner(credentials.NewEnvCredentials())
tr := http.DefaultTransport.(*http.Transport)
if *caFile != "" {
caData, err := ioutil.ReadFile(*caFile)
if err != nil {
log.Fatal("Cannot read CA certs: ", err)
}
if tr.TLSClientConfig == nil {
tr.TLSClientConfig = &tls.Config{}
}
tr.TLSClientConfig.RootCAs = x509.NewCertPool()
if !tr.TLSClientConfig.RootCAs.AppendCertsFromPEM(caData) {
log.Fatalf("Could not add CA cert data from %#v", *caFile)
}
}
reportError := func(w http.ResponseWriter, statusCode int, message string, err error) {
log.Printf("Error: %s: %s", message, err)
w.WriteHeader(statusCode)
if _, err := w.Write([]byte(message)); err != nil {
log.Printf("Could not send error to client: %s", err)
}
}
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
log.Printf("Got request: %s", r.RequestURI)
r.RequestURI = ""
var bodyReader io.ReadSeeker
if r.Body != nil {
if b, ok := r.Body.(io.ReadSeeker); ok {
bodyReader = b
} else {
data, err := ioutil.ReadAll(r.Body)
if err != nil {
reportError(w, http.StatusBadRequest, "Could not read body", err)
return
}
bodyReader = bytes.NewReader(data)
}
}
r.URL.Scheme = remoteUrl.Scheme
r.URL.Host = remoteUrl.Host
r.Host = remoteUrl.Host
var signTime time.Time
if fixedSignTime.IsZero() {
signTime = time.Now().UTC()
} else {
signTime = fixedSignTime
}
if timeHeader := r.Header.Get("X-Amz-Date"); timeHeader != "" {
signTime, err = time.Parse(amzDateFormat, timeHeader)
if err != nil {
reportError(w, http.StatusBadRequest, "Bad time format", err)
return
}
}
pr := r.Clone(r.Context())
preSignHdr, err := signer.Presign(pr, nil, *service, *region, preSignExpireDuration, signTime)
if err != nil {
log.Printf("Cannot do presign: %s", err)
} else {
log.Printf("Pre-Signed header/request: %#v, %s", preSignHdr, pr.URL)
}
signHdr, err := signer.Sign(r, bodyReader, *service, *region, signTime)
if err != nil {
reportError(w, http.StatusInternalServerError, "Could not sign", err)
return
}
log.Printf("Signed headers: %#v", signHdr)
buf := bytes.NewBufferString("")
if err := r.Write(buf); err != nil {
log.Printf("Could not write to buffer: %s", err)
}
if bodyReader != nil {
if _, err := bodyReader.Seek(0, io.SeekStart); err != nil {
log.Printf("Could not seek body: %s", err)
}
r.Body = ioutil.NopCloser(bodyReader)
}
log.Printf("Sending request to %s:\n%s\n", r.URL.Host, buf.String())
resp, err := http.DefaultClient.Do(r.Clone(httptrace.WithClientTrace(r.Context(), &httptrace.ClientTrace{
WroteRequest: func(info httptrace.WroteRequestInfo) {
log.Printf("Wrote request (error: %v)", info.Err)
},
DNSDone: func(info httptrace.DNSDoneInfo) {
log.Printf("DNS done: %s", info.Addrs)
},
WroteHeaderField: func(key string, value []string) {
for _, v := range value {
log.Printf("Wrote header %s: %s", key, v)
}
},
})))
if err != nil {
reportError(w, http.StatusBadGateway, "Could not connect to remote", err)
return
}
log.Printf("Response status %d %s", resp.StatusCode, resp.Status)
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
reportError(w, http.StatusInternalServerError, "Cannot read response body", err)
return
}
wHdr := w.Header()
for hdr := range resp.Header {
wHdr[hdr] = resp.Header[hdr]
log.Printf("Response header %s: %s", hdr, resp.Header[hdr])
}
w.WriteHeader(resp.StatusCode)
log.Printf("Response body: %#v", string(respBody))
if _, err := w.Write(respBody); err != nil {
log.Printf("Could not send response body to client: %s", err)
return
}
})
if err := http.ListenAndServe(*listenAddr, nil); err != nil {
log.Fatal("Could not listen: ", err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment