Created
July 10, 2020 19:53
-
-
Save bendrucker/5efdbde8a389aaa0c909895091dca76c to your computer and use it in GitHub Desktop.
Diff between Vault PostgreSQL and Redshift plugins
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- plugins/database/postgresql/postgresql.go 2020-07-08 15:54:28.000000000 -0700 | |
+++ plugins/database/redshift/redshift.go 2020-07-08 15:54:28.000000000 -0700 | |
@@ -1,15 +1,15 @@ | |
-package postgresql | |
+package redshift | |
import ( | |
"context" | |
"database/sql" | |
"errors" | |
"fmt" | |
- "regexp" | |
"strings" | |
"time" | |
"github.com/hashicorp/errwrap" | |
+ "github.com/hashicorp/go-multierror" | |
"github.com/hashicorp/vault/api" | |
"github.com/hashicorp/vault/sdk/database/dbplugin" | |
"github.com/hashicorp/vault/sdk/database/helper/connutil" | |
@@ -21,52 +21,46 @@ | |
) | |
const ( | |
- postgreSQLTypeName = "postgres" | |
- defaultPostgresRenewSQL = ` | |
-ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; | |
-` | |
- defaultPostgresRotateRootCredentialsSQL = ` | |
-ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}'; | |
-` | |
-) | |
+ // This is how this plugin will be reflected in middleware | |
+ // such as metrics. | |
+ middlewareTypeName = "redshift" | |
-var ( | |
- _ dbplugin.Database = &PostgreSQL{} | |
+ // This allows us to use the postgres database driver. | |
+ sqlTypeName = "postgres" | |
- // postgresEndStatement is basically the word "END" but | |
- // surrounded by a word boundary to differentiate it from | |
- // other words like "APPEND". | |
- postgresEndStatement = regexp.MustCompile(`\bEND\b`) | |
- | |
- // doubleQuotedPhrases finds substrings like "hello" | |
- // and pulls them out with the quotes included. | |
- doubleQuotedPhrases = regexp.MustCompile(`(".*?")`) | |
- | |
- // singleQuotedPhrases finds substrings like 'hello' | |
- // and pulls them out with the quotes included. | |
- singleQuotedPhrases = regexp.MustCompile(`('.*?')`) | |
+ defaultRenewSQL = ` | |
+ALTER USER "{{name}}" VALID UNTIL '{{expiration}}'; | |
+` | |
+ defaultRotateRootCredentialsSQL = ` | |
+ALTER USER "{{username}}" WITH PASSWORD '{{password}}'; | |
+` | |
) | |
-// New implements builtinplugins.BuiltinFactory | |
-func New() (interface{}, error) { | |
- db := new() | |
+// lowercaseUsername is the reason we wrote this plugin. Redshift implements (mostly) | |
+// a postgres 8 interface, and part of that is under the hood, it's lowercasing the | |
+// usernames. | |
+func New(lowercaseUsername bool) func() (interface{}, error) { | |
+ return func() (interface{}, error) { | |
+ db := newRedshift(lowercaseUsername) | |
// Wrap the plugin with middleware to sanitize errors | |
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) | |
return dbType, nil | |
} | |
+} | |
-func new() *PostgreSQL { | |
+func newRedshift(lowercaseUsername bool) *RedShift { | |
connProducer := &connutil.SQLConnectionProducer{} | |
- connProducer.Type = postgreSQLTypeName | |
+ connProducer.Type = sqlTypeName | |
credsProducer := &credsutil.SQLCredentialsProducer{ | |
DisplayNameLen: 8, | |
RoleNameLen: 8, | |
UsernameLen: 63, | |
Separator: "-", | |
+ LowercaseUsername: lowercaseUsername, | |
} | |
- db := &PostgreSQL{ | |
+ db := &RedShift{ | |
SQLConnectionProducer: connProducer, | |
CredentialsProducer: credsProducer, | |
} | |
@@ -74,9 +68,9 @@ | |
return db | |
} | |
-// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin | |
+// Run instantiates a RedShift object, and runs the RPC server for the plugin | |
func Run(apiTLSConfig *api.TLSConfig) error { | |
- dbType, err := New() | |
+ dbType, err := New(true)() | |
if err != nil { | |
return err | |
} | |
@@ -86,21 +80,22 @@ | |
return nil | |
} | |
-type PostgreSQL struct { | |
+type RedShift struct { | |
*connutil.SQLConnectionProducer | |
credsutil.CredentialsProducer | |
} | |
-func (p *PostgreSQL) Type() (string, error) { | |
- return postgreSQLTypeName, nil | |
+func (r *RedShift) Type() (string, error) { | |
+ return middlewareTypeName, nil | |
} | |
-func (p *PostgreSQL) getConnection(ctx context.Context) (*sql.DB, error) { | |
- db, err := p.Connection(ctx) | |
+// getConnection accepts a context and retuns a new pointer to a sql.DB object. | |
+// It's up to the caller to close the connection or handle reuse logic. | |
+func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) { | |
+ db, err := r.Connection(ctx) | |
if err != nil { | |
return nil, err | |
} | |
- | |
return db.(*sql.DB), nil | |
} | |
@@ -110,9 +105,9 @@ | |
// and setting the password of static accounts, as well as rolling back | |
// passwords in the database in the event an updated database fails to save in | |
// Vault's storage. | |
-func (p *PostgreSQL) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { | |
+func (r *RedShift) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { | |
if len(statements.Rotation) == 0 { | |
- statements.Rotation = []string{defaultPostgresRotateRootCredentialsSQL} | |
+ statements.Rotation = []string{defaultRotateRootCredentialsSQL} | |
} | |
username = staticUser.Username | |
@@ -122,18 +117,19 @@ | |
} | |
// Grab the lock | |
- p.Lock() | |
- defer p.Unlock() | |
+ r.Lock() | |
+ defer r.Unlock() | |
// Get the connection | |
- db, err := p.getConnection(ctx) | |
+ db, err := r.getConnection(ctx) | |
if err != nil { | |
return "", "", err | |
} | |
+ defer db.Close() | |
// Check if the role exists | |
var exists bool | |
- err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) | |
+ err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists) | |
if err != nil && err != sql.ErrNoRows { | |
return "", "", err | |
} | |
@@ -148,7 +144,7 @@ | |
return "", "", err | |
} | |
defer func() { | |
- _ = tx.Rollback() | |
+ tx.Rollback() | |
}() | |
// Execute each query | |
@@ -161,7 +157,6 @@ | |
m := map[string]string{ | |
"name": staticUser.Username, | |
- "username": staticUser.Username, | |
"password": password, | |
} | |
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | |
@@ -174,11 +169,10 @@ | |
if err := tx.Commit(); err != nil { | |
return "", "", err | |
} | |
- | |
return username, password, nil | |
} | |
-func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | |
+func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { | |
statements = dbutil.StatementCompatibilityHelper(statements) | |
if len(statements.Creation) == 0 { | |
@@ -186,29 +180,30 @@ | |
} | |
// Grab the lock | |
- p.Lock() | |
- defer p.Unlock() | |
+ r.Lock() | |
+ defer r.Unlock() | |
- username, err = p.GenerateUsername(usernameConfig) | |
+ username, err = r.GenerateUsername(usernameConfig) | |
if err != nil { | |
return "", "", err | |
} | |
- password, err = p.GeneratePassword() | |
+ password, err = r.GeneratePassword() | |
if err != nil { | |
return "", "", err | |
} | |
- expirationStr, err := p.GenerateExpiration(expiration) | |
+ expirationStr, err := r.GenerateExpiration(expiration) | |
if err != nil { | |
return "", "", err | |
} | |
// Get the connection | |
- db, err := p.getConnection(ctx) | |
+ db, err := r.getConnection(ctx) | |
if err != nil { | |
return "", "", err | |
} | |
+ defer db.Close() | |
// Start a transaction | |
tx, err := db.BeginTx(ctx, nil) | |
@@ -222,20 +217,6 @@ | |
// Execute each query | |
for _, stmt := range statements.Creation { | |
- if containsMultilineStatement(stmt) { | |
- // Execute it as-is. | |
- m := map[string]string{ | |
- "name": username, | |
- "username": username, | |
- "password": password, | |
- "expiration": expirationStr, | |
- } | |
- if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil { | |
- return "", "", err | |
- } | |
- continue | |
- } | |
- // Otherwise, it's fine to split the statements on the semicolon. | |
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { | |
query = strings.TrimSpace(query) | |
if len(query) == 0 { | |
@@ -244,7 +225,6 @@ | |
m := map[string]string{ | |
"name": username, | |
- "username": username, | |
"password": password, | |
"expiration": expirationStr, | |
} | |
@@ -258,25 +238,25 @@ | |
if err := tx.Commit(); err != nil { | |
return "", "", err | |
} | |
- | |
return username, password, nil | |
} | |
-func (p *PostgreSQL) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | |
- p.Lock() | |
- defer p.Unlock() | |
+func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { | |
+ r.Lock() | |
+ defer r.Unlock() | |
statements = dbutil.StatementCompatibilityHelper(statements) | |
renewStmts := statements.Renewal | |
if len(renewStmts) == 0 { | |
- renewStmts = []string{defaultPostgresRenewSQL} | |
+ renewStmts = []string{defaultRenewSQL} | |
} | |
- db, err := p.getConnection(ctx) | |
+ db, err := r.getConnection(ctx) | |
if err != nil { | |
return err | |
} | |
+ defer db.Close() | |
tx, err := db.BeginTx(ctx, nil) | |
if err != nil { | |
@@ -286,7 +266,7 @@ | |
tx.Rollback() | |
}() | |
- expirationStr, err := p.GenerateExpiration(expiration) | |
+ expirationStr, err := r.GenerateExpiration(expiration) | |
if err != nil { | |
return err | |
} | |
@@ -300,7 +280,6 @@ | |
m := map[string]string{ | |
"name": username, | |
- "username": username, | |
"expiration": expirationStr, | |
} | |
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | |
@@ -312,25 +291,26 @@ | |
return tx.Commit() | |
} | |
-func (p *PostgreSQL) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | |
+func (r *RedShift) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { | |
// Grab the lock | |
- p.Lock() | |
- defer p.Unlock() | |
+ r.Lock() | |
+ defer r.Unlock() | |
statements = dbutil.StatementCompatibilityHelper(statements) | |
if len(statements.Revocation) == 0 { | |
- return p.defaultRevokeUser(ctx, username) | |
+ return r.defaultRevokeUser(ctx, username) | |
} | |
- return p.customRevokeUser(ctx, username, statements.Revocation) | |
+ return r.customRevokeUser(ctx, username, statements.Revocation) | |
} | |
-func (p *PostgreSQL) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { | |
- db, err := p.getConnection(ctx) | |
+func (r *RedShift) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { | |
+ db, err := r.getConnection(ctx) | |
if err != nil { | |
return err | |
} | |
+ defer db.Close() | |
tx, err := db.BeginTx(ctx, nil) | |
if err != nil { | |
@@ -349,7 +329,6 @@ | |
m := map[string]string{ | |
"name": username, | |
- "username": username, | |
} | |
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | |
return err | |
@@ -360,15 +339,16 @@ | |
return tx.Commit() | |
} | |
-func (p *PostgreSQL) defaultRevokeUser(ctx context.Context, username string) error { | |
- db, err := p.getConnection(ctx) | |
+func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error { | |
+ db, err := r.getConnection(ctx) | |
if err != nil { | |
return err | |
} | |
+ defer db.Close() | |
// Check if the role exists | |
var exists bool | |
- err = db.QueryRowContext(ctx, "SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) | |
+ err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists) | |
if err != nil && err != sql.ErrNoRows { | |
return err | |
} | |
@@ -419,10 +399,6 @@ | |
pq.QuoteIdentifier(username))) | |
revocationStmts = append(revocationStmts, fmt.Sprintf( | |
- "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", | |
- pq.QuoteIdentifier(username))) | |
- | |
- revocationStmts = append(revocationStmts, fmt.Sprintf( | |
"REVOKE USAGE ON SCHEMA public FROM %s;", | |
pq.QuoteIdentifier(username))) | |
@@ -434,18 +410,36 @@ | |
} | |
if dbname.Valid { | |
- revocationStmts = append(revocationStmts, fmt.Sprintf( | |
- `REVOKE CONNECT ON DATABASE %s FROM %s;`, | |
- pq.QuoteIdentifier(dbname.String), | |
- pq.QuoteIdentifier(username))) | |
+ /* | |
+ We create this stored procedure to ensure we can durably revoke users on Redshift. We do not | |
+ clean up since that can cause race conditions with other instances of Vault attempting to use | |
+ this SP at the same time. | |
+ */ | |
+ revocationStmts = append(revocationStmts, `CREATE OR REPLACE PROCEDURE terminateloop(dbusername varchar(100)) | |
+LANGUAGE plpgsql | |
+AS $$ | |
+DECLARE | |
+ currentpid int; | |
+ loopvar int; | |
+ qtyconns int; | |
+BEGIN | |
+SELECT COUNT(process) INTO qtyconns FROM stv_sessions WHERE user_name=dbusername; | |
+ FOR loopvar IN 1..qtyconns LOOP | |
+ SELECT INTO currentpid process FROM stv_sessions WHERE user_name=dbusername ORDER BY process ASC LIMIT 1; | |
+ SELECT pg_terminate_backend(currentpid); | |
+ END LOOP; | |
+END | |
+$$;`) | |
+ | |
+ revocationStmts = append(revocationStmts, fmt.Sprintf(`call terminateloop('%s');`, username)) | |
} | |
// again, here, we do not stop on error, as we want to remove as | |
// many permissions as possible right now | |
- var lastStmtError error | |
+ var lastStmtError *multierror.Error //error | |
for _, query := range revocationStmts { | |
if err := dbtxn.ExecuteDBQuery(ctx, db, nil, query); err != nil { | |
- lastStmtError = err | |
+ lastStmtError = multierror.Append(lastStmtError, err) | |
} | |
} | |
@@ -459,7 +453,7 @@ | |
// Drop this user | |
stmt, err = db.PrepareContext(ctx, fmt.Sprintf( | |
- `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) | |
+ `DROP USER IF EXISTS %s;`, pq.QuoteIdentifier(username))) | |
if err != nil { | |
return err | |
} | |
@@ -471,23 +465,24 @@ | |
return nil | |
} | |
-func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { | |
- p.Lock() | |
- defer p.Unlock() | |
+func (r *RedShift) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { | |
+ r.Lock() | |
+ defer r.Unlock() | |
- if len(p.Username) == 0 || len(p.Password) == 0 { | |
+ if len(r.Username) == 0 || len(r.Password) == 0 { | |
return nil, errors.New("username and password are required to rotate") | |
} | |
rotateStatements := statements | |
if len(rotateStatements) == 0 { | |
- rotateStatements = []string{defaultPostgresRotateRootCredentialsSQL} | |
+ rotateStatements = []string{defaultRotateRootCredentialsSQL} | |
} | |
- db, err := p.getConnection(ctx) | |
+ db, err := r.getConnection(ctx) | |
if err != nil { | |
return nil, err | |
} | |
+ defer db.Close() | |
tx, err := db.BeginTx(ctx, nil) | |
if err != nil { | |
@@ -497,7 +492,7 @@ | |
tx.Rollback() | |
}() | |
- password, err := p.GeneratePassword() | |
+ password, err := r.GeneratePassword() | |
if err != nil { | |
return nil, err | |
} | |
@@ -509,8 +504,7 @@ | |
continue | |
} | |
m := map[string]string{ | |
- "name": p.Username, | |
- "username": p.Username, | |
+ "username": r.Username, | |
"password": password, | |
} | |
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { | |
@@ -523,48 +517,6 @@ | |
return nil, err | |
} | |
- // Close the database connection to ensure no new connections come in | |
- if err := db.Close(); err != nil { | |
- return nil, err | |
- } | |
- | |
- p.RawConfig["password"] = password | |
- return p.RawConfig, nil | |
-} | |
- | |
-// containsMultilineStatement is a best effort to determine whether | |
-// a particular statement is multiline, and therefore should not be | |
-// split upon semicolons. If it's unsure, it defaults to false. | |
-func containsMultilineStatement(stmt string) bool { | |
- // We're going to look for the word "END", but first let's ignore | |
- // anything the user provided within single or double quotes since | |
- // we're looking for an "END" within the Postgres syntax. | |
- literals, err := extractQuotedStrings(stmt) | |
- if err != nil { | |
- return false | |
- } | |
- stmtWithoutLiterals := stmt | |
- for _, literal := range literals { | |
- stmtWithoutLiterals = strings.Replace(stmt, literal, "", -1) | |
- } | |
- // Now look for the word "END" specifically. This will miss any | |
- // representations of END that aren't surrounded by spaces, but | |
- // it should be easy to change on the user's side. | |
- return postgresEndStatement.MatchString(stmtWithoutLiterals) | |
-} | |
- | |
-// extractQuotedStrings extracts 0 or many substrings | |
-// that have been single- or double-quoted. Ex: | |
-// `"Hello", silly 'elephant' from the "zoo".` | |
-// returns [ `Hello`, `'elephant'`, `"zoo"` ] | |
-func extractQuotedStrings(s string) ([]string, error) { | |
- var found []string | |
- toFind := []*regexp.Regexp{ | |
- doubleQuotedPhrases, | |
- singleQuotedPhrases, | |
- } | |
- for _, typeOfPhrase := range toFind { | |
- found = append(found, typeOfPhrase.FindAllString(s, -1)...) | |
- } | |
- return found, nil | |
+ r.RawConfig["password"] = password | |
+ return r.RawConfig, nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment