Skip to content

Instantly share code, notes, and snippets.

@moonsub-kim
Last active January 6, 2021 08:55
Show Gist options
  • Save moonsub-kim/1e75666ae1a1678ecb68467276d24060 to your computer and use it in GitHub Desktop.
Save moonsub-kim/1e75666ae1a1678ecb68467276d24060 to your computer and use it in GitHub Desktop.
refresh credentials profvider
package main
import (
"context"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
const (
tickDuration = time.Minute * 5
)
// Expirer expires credentials.Value
type Expirer interface {
Expire()
}
// RefreshProvider definition
// refer: https://github.com/aws/aws-sdk-go/issues/561#issuecomment-185974563
type RefreshProvider struct {
credentials.ProviderWithContext
Ticker *time.Ticker
creds credentials.Value
err error
mux sync.RWMutex
initRunner sync.Once
Expirer Expirer
}
// Retrieve returns credentials
func (p *RefreshProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(context.Background())
}
// RetrieveWithContext returns credentials
func (p *RefreshProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
p.initRunner.Do(func() {
p.creds, p.err = p.ProviderWithContext.RetrieveWithContext(ctx)
go func() {
defer func() {
if v := recover(); v != nil {
// p.logger.Error(
// "periodicRefresh recovered",
// zap.Any("obj", v),
// )
}
p.Ticker.Stop()
}()
p.periodicRefresh()
}()
})
p.mux.RLock()
defer p.mux.RUnlock()
return p.creds, p.err
}
// IsExpired returns whether the credentials are no longer valid
func (p *RefreshProvider) IsExpired() bool {
p.mux.RLock()
defer p.mux.RUnlock()
return p.ProviderWithContext.IsExpired()
}
func (p *RefreshProvider) periodicRefresh() {
for {
_, ok := <-p.Ticker.C
if !ok {
break
}
if p.refresh() {
// Expire() must be called on mutex unlocked.
// Calling Expire() with locking mutex cause of deadlock
// because credentials call Retrieve() after locking its mutex,
p.Expirer.Expire()
}
}
}
func (p *RefreshProvider) refresh() bool {
p.mux.Lock()
defer p.mux.Unlock()
// Probably want to log the returned error
creds, err := p.ProviderWithContext.Retrieve()
if err != nil {
if p.ProviderWithContext.IsExpired() {
p.err = err
return false
}
}
p.err = nil
p.creds = creds
return true
}
// NewRefreshCredentials returns automatically refreshed credentials
func NewRefreshCredentials(
provider credentials.ProviderWithContext,
expirer credentials.Expirer,
) *credentials.Credentials {
ticker := time.NewTicker(tickDuration)
rp := &RefreshProvider{
ProviderWithContext: provider,
Ticker: ticker,
}
rp.Retrieve() // To run goroutine
creds := credentials.NewCredentials(rp)
rp.Expirer = creds
return creds
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment