Skip to content

Instantly share code, notes, and snippets.

@keymon
Created October 4, 2023 14:50
Show Gist options
  • Save keymon/e9b91d3644254bbc8dc587ca57d9eece to your computer and use it in GitHub Desktop.
Save keymon/e9b91d3644254bbc8dc587ca57d9eece to your computer and use it in GitHub Desktop.
A funciton in golang that updates a temporary pgfile and sets PGPASSFILE env var. Can be updated to use pgpassfile
package pgfile
import (
"context"
"fmt"
"io/ioutil"
"log"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/google/renameio"
)
// This function does export the variable PGPASSFILE and starts a goroutine
// that will update it every 5 minutes with new creds from RDS IAM of AWS, using
// IRSA (webidentitytoken) to auth in AWS
//
// WARNING: set an env var, and can be only use one per program
func StartPgPassFileRdsIamUpdater(
ctx context.Context,
rdsHost string,
rdsPort string,
rdsRegion string,
rdsUser string,
rdsRoleArn string,
rdsWebIdentityTokenFile string,
) error {
file, err := ioutil.TempFile("tmp", "prefix")
if err != nil {
return err
}
cfg := aws.Config{
Region: rdsRegion,
}
stsClient := sts.NewFromConfig(cfg)
credsProvider := stscreds.NewWebIdentityRoleProvider(
stsClient, rdsRoleArn, stscreds.IdentityTokenFile(rdsWebIdentityTokenFile),
)
updateCreds := func() error {
password, err := auth.BuildAuthToken(
ctx,
fmt.Sprintf("%s:%s", rdsHost, rdsPort), // database endpoint (with port)
rdsRegion,
rdsUser,
credsProvider,
)
if err != nil {
return err
}
escapedPassword := strings.Replace(password, ":", "\\:", -1)
pgPassFileContent := fmt.Sprintf("%s:%s:%s:%s:%s", rdsHost, rdsPort, "*", rdsUser, escapedPassword)
// Atomically write the file using https://pkg.go.dev/github.com/google/renameio#WriteFile
// that uses rename
err = renameio.WriteFile(file.Name(), []byte(pgPassFileContent), 0600)
if err != nil {
return err
}
fmt.Println(pgPassFileContent)
return err
}
// First update
fmt.Println("first cred update")
err = updateCreds()
if err != nil {
return err
}
// Export the PGPASSFILE variable
// WARNING!!! this is global!!!
os.Setenv("PGPASSFILE", file.Name())
// Update the file every 5 minutes
ticker := time.NewTicker(300 * time.Second)
go func() {
for {
select {
case <-ctx.Done():
os.Remove(file.Name())
return
case _ = <-ticker.C:
err := updateCreds()
// TODO: Would be cool to add better logging and also some retry logic
if err != nil {
log.Printf("Warning: %s", err)
}
}
}
}()
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment