Created
November 22, 2020 01:27
-
-
Save philipmuir/0b3f6b4a3cc769cb8fbf870ee70663fd to your computer and use it in GitHub Desktop.
Vault Transit package with AWS login [WIP]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package vault | |
import ( | |
"encoding/base64" | |
"encoding/json" | |
"fmt" | |
"io/ioutil" | |
"github.com/aws/aws-sdk-go/aws/session" | |
"github.com/aws/aws-sdk-go/service/sts" | |
vaultapi "github.com/hashicorp/vault/api" | |
) | |
const ( | |
vaultAuthHeaderName = "X-Vault-AWS-IAM-Server-ID" | |
iamHTTPRequestMethod = "iam_http_request_method" | |
iamRequestURL = "iam_request_url" | |
iamRequestHeaders = "iam_request_headers" | |
iamRequestBody = "iam_request_body" | |
role = "role" | |
) | |
// AuthMethod handles renewing a client token | |
type AuthMethod interface { | |
VaultAuth() (*RenewableToken, error) | |
} | |
// AuthMechanismFromConfig returns a configured AuthMethod | |
func AuthMechanismFromConfig(c *Config) AuthMethod { | |
switch c.AuthMethod { | |
case "aws": | |
return NewAWSAuth(c.AwsConfig) | |
case "static": | |
return NewStaticAuth(c.StaticConfig) | |
default: | |
panic(fmt.Errorf("invalid auth mechanism defined in config: %s", c.AuthMethod)) | |
} | |
} | |
// AWSAuth is a thing | |
type AWSAuth struct { | |
client vaultapi.Client | |
aws *AWSConfig | |
} | |
// NewAWSAuth creates a blah | |
func NewAWSAuth(c *AWSConfig) *AWSAuth { | |
return &AWSAuth{aws: c} | |
} | |
// VaultAuth This code was adapted from Hashicorp Vault: | |
// https://github.com/hashicorp/vault/blob/e2bb2ec3b93a242a167f763684f93df867bb253d/builtin/credential/aws/cli.go#L78 | |
func (a *AWSAuth) VaultAuth() (*RenewableToken, error) { | |
if a.aws.AuthProviderName == "" || a.aws.Role == "" { | |
return nil, fmt.Errorf("you must set the Host, AuthProviderName, and Role config values") | |
} | |
/* | |
We are relying on ENV vars for access key and secret access key existing: | |
credAccessEnvKey = []string{ | |
"AWS_ACCESS_KEY_ID", | |
"AWS_ACCESS_KEY", | |
} | |
credSecretEnvKey = []string{ | |
"AWS_SECRET_ACCESS_KEY", | |
"AWS_SECRET_KEY", | |
*/ | |
stsSvc := sts.New(session.New()) | |
req, _ := stsSvc.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) | |
if a.aws.VaultAuthHeader != "" { | |
// if supplied, and then sign the request including that header | |
req.HTTPRequest.Header.Add(vaultAuthHeaderName, a.aws.VaultAuthHeader) | |
} | |
_ = req.Sign() | |
headers, err := json.Marshal(req.HTTPRequest.Header) | |
if err != nil { | |
return nil, err | |
} | |
body, err := ioutil.ReadAll(req.HTTPRequest.Body) | |
if err != nil { | |
return nil, err | |
} | |
d := make(map[string]interface{}) | |
d[iamHTTPRequestMethod] = req.HTTPRequest.Method | |
d[iamRequestURL] = base64.StdEncoding.EncodeToString([]byte(req.HTTPRequest.URL.String())) | |
d[iamRequestHeaders] = base64.StdEncoding.EncodeToString(headers) | |
d[iamRequestBody] = base64.StdEncoding.EncodeToString(body) | |
d[role] = a.aws.Role | |
resp, err := a.client.Logical().Write(fmt.Sprintf("auth/%s/login", a.aws.AuthProviderName), d) | |
if err != nil { | |
return nil, err | |
} | |
if resp == nil { | |
return nil, fmt.Errorf("got no response from the %s authentication provider", a.aws.AuthProviderName) | |
} | |
return parseToken(resp) | |
} | |
// StaticAuth the static auth mechanism is good for environments where the root token, or a long lived dev token is in use. | |
type StaticAuth struct { | |
token string | |
} | |
// NewStaticAuth is an auth mechanism whose VaultAuth() function always returns a static token whose expiry is 1 year in the future. | |
func NewStaticAuth(c *StaticConfig) *StaticAuth { | |
return &StaticAuth{token: c.Token} | |
} | |
// VaultAuth returns a RenewableToken whose expiry is 1 year in the future | |
func (a *StaticAuth) VaultAuth() (*RenewableToken, error) { | |
return NewStaticToken(a.token), nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package vault | |
import ( | |
vaultapi "github.com/hashicorp/vault/api" | |
) | |
// Client encapsulates the vault api client with needed configuration and authentication mechanism (aws) in use. | |
type Client struct { | |
client *vaultapi.Client | |
config *Config | |
auth AuthMethod | |
token ClientToken | |
} | |
// newClient returns a configured Client given the auth mechanism and initial token. | |
func newClient(t ClientToken, a AuthMethod, c *Config) (*Client, error) { | |
vaultClient, err := vaultapi.NewClient(nil) | |
if err != nil { | |
return nil, err | |
} | |
return &Client{ | |
client: vaultClient, | |
config: c, | |
auth: a, | |
token: t, | |
}, nil | |
} | |
// NewClientFromConfig returns a configured & authorised Client | |
func NewClientFromConfig(c *Config) (*Client, error) { | |
a := AuthMechanismFromConfig(c) | |
t, err := a.VaultAuth() | |
if err != nil { | |
return nil, err | |
} | |
return newClient(t, a, c) | |
} | |
// Client returns the underlying Vault API client, in an authenticated state. | |
func (c *Client) Client() (*vaultapi.Client, error) { | |
if c.token.IsExpired() { | |
return c.client, c.vaultAuth() | |
} | |
if c.token.ShouldRenew() { | |
return c.client, c.renewToken() | |
} | |
return c.client, nil | |
} | |
func (c *Client) vaultAuth() error { | |
t, err := c.auth.VaultAuth() | |
if err != nil { | |
return err | |
} | |
c.token = t | |
c.client.SetToken(c.token.String()) | |
return nil | |
} | |
// RenewToken Renew the token if it is renewable. If it isn't, or if it's expired, refresh | |
// authentication instead. This is typically called internally. | |
func (c *Client) renewToken() error { | |
if c.token.IsExpired() || !c.token.IsRenewable() { | |
return c.vaultAuth() | |
} | |
resp, err := c.client.Auth().Token().RenewSelf(int(c.token.TTL())) | |
if err != nil { | |
return err | |
} | |
token, err := parseToken(resp) | |
if err != nil { | |
return err | |
} | |
c.token = token | |
c.client.SetToken(token.String()) | |
return nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package vault | |
import ( | |
"reflect" | |
"testing" | |
"time" | |
vaultapi "github.com/hashicorp/vault/api" | |
) | |
func TestNewClient(t *testing.T) { | |
now := time.Now() | |
token := NewRenewableToken(now, "xyz123", false, 3600) | |
type args struct { | |
t ClientToken | |
auth AuthMethod | |
} | |
tests := []struct { | |
name string | |
args args | |
want reflect.Type | |
}{ | |
{ | |
name: "returns instance of client", | |
args: args{ | |
t: token, | |
auth: &AWSAuth{}, | |
}, | |
want: reflect.TypeOf(&Client{}), | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
c, err := NewClient(tt.args.t, tt.args.auth) | |
if err != nil { | |
t.Errorf("NewClient() error not nil %v", err) | |
} | |
if got := reflect.TypeOf(c); !reflect.DeepEqual(got, tt.want) { | |
t.Errorf("NewClient() = %v, want %v", got, tt.want) | |
} | |
}) | |
} | |
} | |
func TestClient_Client(t *testing.T) { | |
type fields struct { | |
client *vaultapi.Client | |
auth AuthMethod | |
token ClientToken | |
} | |
tests := []struct { | |
name string | |
fields fields | |
want *vaultapi.Client | |
wantErr bool | |
}{ | |
// TODO: Add test cases. | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
c := &Client{ | |
client: tt.fields.client, | |
auth: tt.fields.auth, | |
token: tt.fields.token, | |
} | |
got, err := c.Client() | |
if (err != nil) != tt.wantErr { | |
t.Errorf("Client.Client() error = %v, wantErr %v", err, tt.wantErr) | |
return | |
} | |
if !reflect.DeepEqual(got, tt.want) { | |
t.Errorf("Client.Client() = %v, want %v", got, tt.want) | |
} | |
}) | |
} | |
} | |
func TestClient_vaultAuth(t *testing.T) { | |
type fields struct { | |
client *vaultapi.Client | |
auth AuthMethod | |
token ClientToken | |
} | |
tests := []struct { | |
name string | |
fields fields | |
wantErr bool | |
}{ | |
// TODO: Add test cases. | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
c := &Client{ | |
client: tt.fields.client, | |
auth: tt.fields.auth, | |
token: tt.fields.token, | |
} | |
if err := c.vaultAuth(); (err != nil) != tt.wantErr { | |
t.Errorf("Client.vaultAuth() error = %v, wantErr %v", err, tt.wantErr) | |
} | |
}) | |
} | |
} | |
func TestClient_renewToken(t *testing.T) { | |
type fields struct { | |
client *vaultapi.Client | |
auth AuthMethod | |
token ClientToken | |
} | |
tests := []struct { | |
name string | |
fields fields | |
wantErr bool | |
}{ | |
// TODO: Add test cases. | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
c := &Client{ | |
client: tt.fields.client, | |
auth: tt.fields.auth, | |
token: tt.fields.token, | |
} | |
if err := c.renewToken(); (err != nil) != tt.wantErr { | |
t.Errorf("Client.renewToken() error = %v, wantErr %v", err, tt.wantErr) | |
} | |
}) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package vault | |
// Config contains general Vault configuration the package needs to operate. | |
type Config struct { | |
AuthMethod string `mapstructure:"auth_method"` | |
Host string `mapstructure:"host"` | |
TransitPath string `mapstructure:"transit_path"` | |
DefaultKey string `mapstructure:"default_key"` | |
AwsConfig *AWSConfig `mapstructure:"aws"` | |
StaticConfig *StaticConfig `mapstructure:"static"` | |
} | |
// StaticConfig contains the Token which should be used to auth with. | |
type StaticConfig struct { | |
Token string `mapstructure:"token"` | |
} | |
// AWSConfig contains variables needed to contact AWS STS and Vault to authenticate. | |
type AWSConfig struct { | |
Role string `mapstructure:"role"` | |
AuthProviderName string `mapstructure:"auth_provider_name"` | |
VaultAuthHeader string `mapstructure:"vault_auth_header"` | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package vault | |
import "time" | |
// NewStaticToken returns a non renewable ClientToken set to expire in 1 day. | |
func NewStaticToken(token string) *RenewableToken { | |
return NewRenewableToken(time.Now(), token, false, 86400) | |
} | |
// NewRenewableToken returns a renewable instance of a ClientToken | |
func NewRenewableToken(currentTime time.Time, token string, tokenIsRenewable bool, tokenTTL time.Duration) *RenewableToken { | |
t := &RenewableToken{ | |
token: token, | |
tokenIsRenewable: tokenIsRenewable, | |
tokenTTL: tokenTTL, | |
tokenExpiration: currentTime.Add(tokenTTL), | |
} | |
return t | |
} | |
// RenewableToken struct | |
type RenewableToken struct { | |
token string | |
tokenIsRenewable bool | |
tokenExpiration time.Time // actual expiration | |
tokenTTL time.Duration // lifetime of the auth token received | |
expirationWindow time.Duration // time to allow to process a token renewal | |
renewalWindow time.Duration // time before expiration when token should be actively renewed | |
} | |
// TTL returns the tokens time to live value | |
func (t *RenewableToken) TTL() time.Duration { | |
return t.tokenTTL | |
} | |
// IsRenewable returns true if the token can be renewed via Vault. | |
func (t *RenewableToken) IsRenewable() bool { | |
return t.tokenIsRenewable | |
} | |
// IsExpired returns true if the token has expired. | |
func (t *RenewableToken) IsExpired() bool { | |
return time.Now().Add(t.expirationWindow).After(t.tokenExpiration) | |
} | |
// ShouldRenew returns true if we are inside the renewable window but after the token expiration time. | |
func (t *RenewableToken) ShouldRenew() bool { | |
return time.Now().Add(t.renewalWindow).After(t.tokenExpiration) | |
} | |
// String returns the underlying token string. | |
func (t *RenewableToken) String() string { | |
return t.token | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package vault | |
import ( | |
"reflect" | |
"testing" | |
"time" | |
) | |
func TestNewStaticToken(t *testing.T) { | |
tokenValue := "abc-xyz" | |
now := time.Now() | |
type args struct { | |
token string | |
} | |
tests := []struct { | |
name string | |
args args | |
want *RenewableToken | |
}{ | |
{ | |
name: "creates a renewable token", | |
args: args{token: tokenValue}, | |
want: NewRenewableToken(now, tokenValue, false, 86400), | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
got := NewStaticToken(tt.args.token) | |
if reflect.TypeOf(got) != reflect.TypeOf(&RenewableToken{}) { | |
t.Errorf("NewStaticToken() = %v, want %v", got, reflect.TypeOf(tt.want)) | |
} | |
if got.String() != tokenValue { | |
t.Errorf("String() %v, want %v", got.String(), tokenValue) | |
} | |
if got.IsRenewable() != false { | |
t.Errorf("NewStaticToken() should not be renewable") | |
} | |
if got.TTL() != 86400 { | |
t.Errorf("NewStaticToken() TTL should be 1 day") | |
} | |
}) | |
} | |
} | |
func TestNewRenewableToken(t *testing.T) { | |
now := time.Now() | |
ttl := time.Duration(8600) | |
type args struct { | |
token string | |
tokenIsRenewable bool | |
tokenTTL time.Duration | |
} | |
tests := []struct { | |
name string | |
args args | |
want *RenewableToken | |
}{ | |
{ | |
name: "test happy path", | |
args: args{ | |
token: "abc-xyz", | |
tokenIsRenewable: false, | |
tokenTTL: ttl, | |
}, | |
want: &RenewableToken{ | |
token: "abc-xyz", | |
tokenIsRenewable: false, | |
tokenExpiration: now.Add(ttl), | |
tokenTTL: ttl, | |
expirationWindow: 0, | |
renewalWindow: 0, | |
}, | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
if got := NewRenewableToken(now, tt.args.token, tt.args.tokenIsRenewable, tt.args.tokenTTL); !reflect.DeepEqual(got, tt.want) { | |
t.Errorf("NewRenewableToken() = %v, want %v", got, tt.want) | |
} | |
}) | |
} | |
} | |
func TestRenewableToken_TTL(t *testing.T) { | |
now := time.Now() | |
type fields struct { | |
token string | |
tokenIsRenewable bool | |
tokenExpiration time.Time | |
tokenTTL time.Duration | |
expirationWindow time.Duration | |
renewalWindow time.Duration | |
} | |
tests := []struct { | |
name string | |
fields fields | |
want time.Duration | |
}{ | |
{ | |
name: "test happy path", | |
fields: fields{ | |
tokenExpiration: now, | |
tokenTTL: time.Duration(8600), | |
}, | |
want: time.Duration(8600), | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
token := &RenewableToken{ | |
token: tt.fields.token, | |
tokenIsRenewable: tt.fields.tokenIsRenewable, | |
tokenExpiration: tt.fields.tokenExpiration, | |
tokenTTL: tt.fields.tokenTTL, | |
expirationWindow: tt.fields.expirationWindow, | |
renewalWindow: tt.fields.renewalWindow, | |
} | |
if got := token.TTL(); got != tt.want { | |
t.Errorf("RenewableToken.TTL() = %v, want %v", got, tt.want) | |
} | |
}) | |
} | |
} | |
func TestRenewableToken_IsRenewable(t *testing.T) { | |
type fields struct { | |
token string | |
tokenIsRenewable bool | |
tokenExpiration time.Time | |
tokenTTL time.Duration | |
expirationWindow time.Duration | |
renewalWindow time.Duration | |
} | |
tests := []struct { | |
name string | |
fields fields | |
want bool | |
}{ | |
{ | |
name: "token is renewable", | |
fields: fields{ | |
tokenIsRenewable: true, | |
}, | |
want: true, | |
}, | |
{ | |
name: "token is not renewable", | |
fields: fields{ | |
tokenIsRenewable: false, | |
}, | |
want: false, | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
token := &RenewableToken{ | |
token: tt.fields.token, | |
tokenIsRenewable: tt.fields.tokenIsRenewable, | |
tokenExpiration: tt.fields.tokenExpiration, | |
tokenTTL: tt.fields.tokenTTL, | |
expirationWindow: tt.fields.expirationWindow, | |
renewalWindow: tt.fields.renewalWindow, | |
} | |
if got := token.IsRenewable(); got != tt.want { | |
t.Errorf("RenewableToken.IsRenewable() = %v, want %v", got, tt.want) | |
} | |
}) | |
} | |
} | |
func TestRenewableToken_IsExpired(t *testing.T) { | |
type fields struct { | |
token string | |
tokenIsRenewable bool | |
tokenExpiration time.Time | |
tokenTTL time.Duration | |
expirationWindow time.Duration | |
renewalWindow time.Duration | |
} | |
tests := []struct { | |
name string | |
fields fields | |
want bool | |
}{ | |
{ | |
name: "token is expired", | |
fields: fields{ | |
tokenExpiration: time.Date(1999, 1, 1, 1, 1, 1, 1, &time.Location{}), | |
}, | |
want: true, | |
}, | |
{ | |
name: "token is not expired", | |
fields: fields{ | |
expirationWindow: time.Hour, | |
tokenExpiration: time.Now().Add(6 * time.Hour), | |
}, | |
want: false, | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
token := &RenewableToken{ | |
token: tt.fields.token, | |
tokenIsRenewable: tt.fields.tokenIsRenewable, | |
tokenExpiration: tt.fields.tokenExpiration, | |
tokenTTL: tt.fields.tokenTTL, | |
expirationWindow: tt.fields.expirationWindow, | |
renewalWindow: tt.fields.renewalWindow, | |
} | |
if got := token.IsExpired(); got != tt.want { | |
t.Errorf("RenewableToken.IsExpired() = %v, want %v", got, tt.want) | |
} | |
}) | |
} | |
} | |
func TestRenewableToken_ShouldRenew(t *testing.T) { | |
type fields struct { | |
token string | |
tokenIsRenewable bool | |
tokenExpiration time.Time | |
tokenTTL time.Duration | |
expirationWindow time.Duration | |
renewalWindow time.Duration | |
} | |
tests := []struct { | |
name string | |
fields fields | |
want bool | |
}{ | |
{ | |
name: "token is expired and should renew", | |
fields: fields{ | |
tokenExpiration: time.Date(1999, 1, 1, 1, 1, 1, 1, &time.Location{}), | |
renewalWindow: 1 * time.Hour, | |
}, | |
want: true, | |
}, | |
{ | |
name: "token should not renew", | |
fields: fields{ | |
expirationWindow: 1 * time.Hour, | |
tokenExpiration: time.Now().Add(6 * time.Hour), | |
}, | |
want: false, | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
token := &RenewableToken{ | |
token: tt.fields.token, | |
tokenIsRenewable: tt.fields.tokenIsRenewable, | |
tokenExpiration: tt.fields.tokenExpiration, | |
tokenTTL: tt.fields.tokenTTL, | |
expirationWindow: tt.fields.expirationWindow, | |
renewalWindow: tt.fields.renewalWindow, | |
} | |
if got := token.ShouldRenew(); got != tt.want { | |
t.Errorf("RenewableToken.ShouldRenew() = %v, want %v", got, tt.want) | |
} | |
}) | |
} | |
} | |
func TestRenewableToken_String(t *testing.T) { | |
type fields struct { | |
token string | |
tokenIsRenewable bool | |
tokenExpiration time.Time | |
tokenTTL time.Duration | |
expirationWindow time.Duration | |
renewalWindow time.Duration | |
} | |
tests := []struct { | |
name string | |
fields fields | |
want string | |
}{ | |
{ | |
name: "returns token value", | |
fields: fields{ | |
token: "abc-xyz", | |
}, | |
want: "abc-xyz", | |
}, | |
} | |
for _, tt := range tests { | |
t.Run(tt.name, func(t *testing.T) { | |
token := &RenewableToken{ | |
token: tt.fields.token, | |
tokenIsRenewable: tt.fields.tokenIsRenewable, | |
tokenExpiration: tt.fields.tokenExpiration, | |
tokenTTL: tt.fields.tokenTTL, | |
expirationWindow: tt.fields.expirationWindow, | |
renewalWindow: tt.fields.renewalWindow, | |
} | |
if got := token.String(); got != tt.want { | |
t.Errorf("RenewableToken.String() = %v, want %v", got, tt.want) | |
} | |
}) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package vault | |
// EncryptDataResponse contains the ciphertext from a successful vault transit encrypt request. | |
type EncryptDataResponse struct { | |
Ciphertext string `mapstructure:"ciphertext"` | |
} | |
// DecryptDataResponse represents a valid response from a successful vault transit decrypt request. | |
type DecryptDataResponse struct { | |
Plaintext string `mapstructure:"plaintext"` | |
} | |
// EncryptRequest contains the plaintext data and encryption key for a vault transit encrypt request. | |
type EncryptRequest struct { | |
Plaintext string `mapstructure:"plaintext"` | |
Key string `mapstructure:"key"` | |
KeyType string `mapstructure:"type"` | |
} | |
// DecryptRequest represents | |
type DecryptRequest struct { | |
Ciphertext string `mapstructure:"ciphertext"` | |
Key string `mapstructure:"key"` | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package vault | |
import ( | |
"encoding/base64" | |
"errors" | |
"fmt" | |
"github.com/mitchellh/mapstructure" | |
"github.com/deputyapp/go-lib/logger" | |
) | |
// TransitBackend wraps the vault API client with request/response objects for Vault transit backend calls | |
type TransitBackend struct { | |
config *Config | |
client *Client | |
logger *logger.Logger | |
} | |
// NewTransitBackend returns a configured transit backend client given app Vault configuration | |
func NewTransitBackend(c *Config) (*TransitBackend, error) { | |
client, err := NewClientFromConfig(c) | |
if err != nil { | |
return nil, err | |
} | |
return &TransitBackend{ | |
config: c, | |
client: client, | |
logger: logger.New("transit"), | |
}, nil | |
} | |
// EncryptData calls the configured transit backend to encrypt the request plaintext. Returning the ciphertext or error. | |
func (t *TransitBackend) EncryptData(req *EncryptRequest) (*EncryptDataResponse, error) { | |
c, e := t.client.Client() | |
if e != nil { | |
return nil, e | |
} | |
data := make(map[string]interface{}) | |
err := mapstructure.Decode(req, &data) | |
if err != nil { | |
t.logger.Errorf("mapstructure.Decode err %#v", err) | |
return nil, err | |
} | |
path := fmt.Sprintf("/%s/encrypt/%s", t.config.TransitPath, req.Key) | |
secret, err := c.Logical().Write(path, data) | |
if err != nil { | |
t.logger.Errorf("vault transit write err %#v", err) | |
return nil, err | |
} | |
ct := secret.Data["ciphertext"].(string) | |
return &EncryptDataResponse{ | |
Ciphertext: ct, | |
}, nil | |
} | |
// DecryptData calls the configured transit backend to decrypt the request ciphertext. Returning the plaintext or error. | |
func (t *TransitBackend) DecryptData(req *DecryptRequest) (*DecryptDataResponse, error) { | |
c, e := t.client.Client() | |
if e != nil { | |
return nil, e | |
} | |
data := make(map[string]interface{}) | |
err := mapstructure.Decode(req, &data) | |
if err != nil { | |
return nil, err | |
} | |
path := fmt.Sprintf("/%s/decrypt/%s", t.config.TransitPath, req.Key) | |
secret, err := c.Logical().Write(path, data) | |
if err != nil { | |
t.logger.Errorf("vault client write err %#v", err) | |
return nil, err | |
} | |
return t.decodeTransitResponse(secret.Data) | |
} | |
func (t *TransitBackend) decodeTransitResponse(data map[string]interface{}) (*DecryptDataResponse, error) { | |
if _, ok := data["plaintext"]; !ok { | |
return nil, errors.New("transit response missing plaintext key") | |
} | |
bytes, err := base64.StdEncoding.DecodeString(data["plaintext"].(string)) | |
if err != nil { | |
t.logger.Errorf("error base64 decoding vault response err= %#v", err) | |
return nil, err | |
} | |
return &DecryptDataResponse{ | |
Plaintext: string(bytes), | |
}, nil | |
} | |
// NewEncryptDataRequest method takes care of instanciating a valid EncryptRequest object. | |
// Vault doco - https://www.vaultproject.io/api/secret/transit/index.html#encrypt-data | |
// plaintext - the data to be encrypted by vault. | |
// key - The key name to use to encrypt, if the key does not exist it will be created as the type specified. | |
// keyType - This parameter is required when encryption key is expected to be created. | |
// When performing an upsert operation, the type of key to create. | |
// We default to aes256-gcm96 when empty. | |
func (t *TransitBackend) NewEncryptDataRequest(plaintext, key, keyType string) (*EncryptRequest, error) { | |
if keyType == "" { | |
keyType = "aes256-gcm96" | |
} | |
if keyType != "aes256-gcm96" { | |
return nil, errors.New("we currently only support aes256-gcm96 for the key type") | |
} | |
if key == "" { | |
key = t.config.DefaultKey | |
} | |
if len(plaintext) == 0 { | |
return nil, errors.New("plaintext can't be empty") | |
} | |
plaintext = base64.StdEncoding.EncodeToString([]byte(plaintext)) | |
return &EncryptRequest{Plaintext: plaintext, Key: key, KeyType: keyType}, nil | |
} | |
// NewDecryptDataRequest method takes care of instantiating a valid DecryptRequest object. | |
// Vault doco - https://www.vaultproject.io/api/secret/transit/index.html#decrypt-data | |
// ciphertext - the data to be encrypted by vault. | |
// key - The key name to use to decrypt, if the key is empty default key from config will be assumed. | |
func (t *TransitBackend) NewDecryptDataRequest(ciphertext, key string) (*DecryptRequest, error) { | |
if key == "" { | |
key = t.config.DefaultKey | |
} | |
if len(ciphertext) == 0 { | |
return nil, errors.New("ciphertext can't be empty") | |
} | |
return &DecryptRequest{Ciphertext: ciphertext, Key: key}, nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package vault | |
import ( | |
"time" | |
vaultapi "github.com/hashicorp/vault/api" | |
) | |
// ClientToken interface | |
type ClientToken interface { | |
String() string | |
IsRenewable() bool | |
ShouldRenew() bool | |
IsExpired() bool | |
TTL() time.Duration | |
} | |
func parseToken(resp *vaultapi.Secret) (*RenewableToken, error) { | |
var err error | |
var tokenString string | |
var tokenIsRenewable bool | |
var tokenTTL time.Duration | |
if tokenString, err = resp.TokenID(); err != nil { | |
return nil, err | |
} | |
if tokenIsRenewable, err = resp.TokenIsRenewable(); err != nil { | |
return nil, err | |
} | |
if tokenTTL, err = resp.TokenTTL(); err != nil { | |
return nil, err | |
} | |
return NewRenewableToken(time.Now(), tokenString, tokenIsRenewable, tokenTTL), nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment