Last active
July 17, 2018 12:38
-
-
Save knz/04c44b4d609df0e073d9f53bcf30bd68 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/pkg/cli/sql.go b/pkg/cli/sql.go | |
index 5957fae4bb..160f5edeec 100644 | |
--- a/pkg/cli/sql.go | |
+++ b/pkg/cli/sql.go | |
@@ -53,7 +53,7 @@ const ( | |
` | |
) | |
-const defaultPromptPattern = "%n@%M/%?%x>" | |
+const defaultPromptPattern = "%n@%M/%/%x>" | |
// sqlShellCmd opens a sql shell. | |
var sqlShellCmd = &cobra.Command{ | |
@@ -86,6 +86,8 @@ type cliState struct { | |
fullPrompt string | |
// The prompt on a continuation line in a multi-line entry. | |
continuePrompt string | |
+ // Which prompt to use to populate currentPrompt | |
+ useContinuePrompt bool | |
// The current prompt, either fullPrompt or continuePrompt. | |
currentPrompt string | |
// The string used to produce the value of fullPrompt. | |
@@ -137,7 +139,7 @@ const ( | |
// Querying the server for the current transaction status | |
// and setting the prompt accordingly. | |
- cliRefreshPrompts | |
+ cliRefreshPrompt | |
// Just before reading the first line of a potentially multi-line | |
// statement. | |
@@ -243,8 +245,7 @@ var options = map[string]struct { | |
set func(c *cliState, val string) error | |
reset func(c *cliState) error | |
// display is used to retrieve the current value. | |
- display func(c *cliState) string | |
- nextState cliStateEnum | |
+ display func(c *cliState) string | |
}{ | |
`display_format`: { | |
"the output format for tabular data (pretty, csv, tsv, html, sql, records, raw)", | |
@@ -262,7 +263,6 @@ var options = map[string]struct { | |
return nil | |
}, | |
func(_ *cliState) string { return cliCtx.tableDisplayFormat.String() }, | |
- 0, | |
}, | |
`echo`: { | |
"show SQL queries before they are sent to the server", | |
@@ -271,7 +271,6 @@ var options = map[string]struct { | |
func(_ *cliState, _ string) error { sqlCtx.echo = true; return nil }, | |
func(_ *cliState) error { sqlCtx.echo = false; return nil }, | |
func(_ *cliState) string { return strconv.FormatBool(sqlCtx.echo) }, | |
- 0, | |
}, | |
`errexit`: { | |
"exit the shell upon a query error", | |
@@ -280,7 +279,6 @@ var options = map[string]struct { | |
func(c *cliState, _ string) error { c.errExit = true; return nil }, | |
func(c *cliState) error { c.errExit = false; return nil }, | |
func(c *cliState) string { return strconv.FormatBool(c.errExit) }, | |
- 0, | |
}, | |
`check_syntax`: { | |
"check the SQL syntax before running a query (needs SHOW SYNTAX support on the server)", | |
@@ -289,7 +287,6 @@ var options = map[string]struct { | |
func(c *cliState, _ string) error { c.checkSyntax = true; return nil }, | |
func(c *cliState) error { c.checkSyntax = false; return nil }, | |
func(c *cliState) string { return strconv.FormatBool(c.checkSyntax) }, | |
- 0, | |
}, | |
`show_times`: { | |
"display the execution time after each query", | |
@@ -298,10 +295,9 @@ var options = map[string]struct { | |
func(_ *cliState, _ string) error { cliCtx.showTimes = true; return nil }, | |
func(_ *cliState) error { cliCtx.showTimes = false; return nil }, | |
func(_ *cliState) string { return strconv.FormatBool(cliCtx.showTimes) }, | |
- 0, | |
}, | |
`prompt1`: { | |
- "prompt string to use before each command (the following are expanded: %M full host, %m host, %> port number, %n user, %? database, %x txn status)", | |
+ "prompt string to use before each command (the following are expanded: %M full host, %m host, %> port number, %n user, %/ database, %x txn status)", | |
false, | |
true, | |
func(c *cliState, val string) error { | |
@@ -313,7 +309,6 @@ var options = map[string]struct { | |
return nil | |
}, | |
func(c *cliState) string { return c.customPromptPattern }, | |
- cliRefreshPrompts, | |
}, | |
} | |
@@ -389,10 +384,6 @@ func (c *cliState) handleSet(args []string, nextState, errState cliStateEnum) cl | |
return errState | |
} | |
- if opt.nextState != 0 { | |
- return opt.nextState | |
- } | |
- | |
return nextState | |
} | |
@@ -412,9 +403,6 @@ func (c *cliState) handleUnset(args []string, nextState, errState cliStateEnum) | |
fmt.Fprintf(stderr, "\\unset %s: %v\n", args[0], err) | |
return errState | |
} | |
- if opt.nextState != 0 { | |
- return opt.nextState | |
- } | |
return nextState | |
} | |
@@ -535,7 +523,7 @@ func (c *cliState) pipeSyscmd(line string, nextState, errState cliStateEnum) cli | |
} | |
// rePromptFmt: available keys compile with regex expression one time. | |
-var rePromptFmt = regexp.MustCompile("%(.)") | |
+var rePromptFmt = regexp.MustCompile("(%.)") | |
// doRefreshPrompts refreshes the prompts of the client depending on the | |
// status of the current transaction. | |
@@ -553,6 +541,19 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum { | |
return nextState | |
} | |
+ // Prepare variables for use during the substitution below. | |
+ c.refreshTransactionStatus() | |
+ // refreshDatabaseName() must be called *after* refreshTransactionStatus(), | |
+ // even when %/ appears before %x in the prompt format. | |
+ dbName, hasDbName := c.refreshDatabaseName() | |
+ if !hasDbName { | |
+ dbName = "?" | |
+ } | |
+ userName := "" | |
+ if parsedURL.User != nil { | |
+ userName = parsedURL.User.Username() | |
+ } | |
+ | |
c.fullPrompt = rePromptFmt.ReplaceAllStringFunc(c.customPromptPattern, func(m string) string { | |
switch m { | |
case "%M": | |
@@ -561,29 +562,22 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum { | |
return parsedURL.Hostname() // host name. | |
case "%>": | |
return parsedURL.Port() // port. | |
- case "%n": | |
- userName := "" | |
- if parsedURL.User != nil { // user name. | |
- userName = parsedURL.User.Username() | |
- } | |
+ case "%n": // user name. | |
return userName | |
- case "%?": | |
- dbName, hasDbName := c.refreshDatabaseName() // database name. | |
- if hasDbName { | |
- return dbName | |
- } | |
+ case "%/": // database name. | |
+ return dbName | |
case "%x": // txn status. | |
- c.refreshTransactionStatus() | |
return c.lastKnownTxnStatus | |
case "%%": | |
return "%" | |
- // default: | |
- // err = fmt.Errorf("unrecognized format code in prompt: %q", m) | |
- // return err.Error() | |
+ default: | |
+ err = fmt.Errorf("unrecognized format code in prompt: %q", m) | |
+ return "" | |
} | |
- | |
- return m | |
}) | |
+ if err != nil { | |
+ c.fullPrompt = err.Error() | |
+ } | |
c.fullPrompt += " " | |
@@ -594,6 +588,16 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum { | |
c.continuePrompt = strings.Repeat(" ", len(c.fullPrompt)-3) + "-> " | |
} | |
+ switch c.useContinuePrompt { | |
+ case true: | |
+ c.currentPrompt = c.continuePrompt | |
+ case false: | |
+ c.currentPrompt = c.fullPrompt | |
+ } | |
+ | |
+ // Configure the editor to use the new prompt. | |
+ c.ins.SetLeftPrompt(c.currentPrompt) | |
+ | |
return nextState | |
} | |
@@ -756,23 +760,14 @@ func (c *cliState) doStartLine(nextState cliStateEnum) cliStateEnum { | |
c.atEOF = false | |
c.partialLines = c.partialLines[:0] | |
c.partialStmtsLen = 0 | |
- | |
- if c.hasEditor() { | |
- c.currentPrompt = c.fullPrompt | |
- c.ins.SetLeftPrompt(c.currentPrompt) | |
- } | |
+ c.useContinuePrompt = false | |
return nextState | |
} | |
func (c *cliState) doContinueLine(nextState cliStateEnum) cliStateEnum { | |
c.atEOF = false | |
- | |
- if c.hasEditor() { | |
- c.currentPrompt = c.continuePrompt | |
- c.ins.SetLeftPrompt(c.currentPrompt) | |
- } | |
- | |
+ c.useContinuePrompt = true | |
return nextState | |
} | |
@@ -1164,16 +1159,16 @@ func runInteractive(conn *sqlConn) (exitErr error) { | |
c.buf = bufio.NewReader(stdin) | |
} | |
- state = c.doStart(cliRefreshPrompts) | |
- | |
- case cliRefreshPrompts: | |
- state = c.doRefreshPrompts(cliStartLine) | |
+ state = c.doStart(cliStartLine) | |
case cliStartLine: | |
- state = c.doStartLine(cliReadLine) | |
+ state = c.doStartLine(cliRefreshPrompt) | |
case cliContinueLine: | |
- state = c.doContinueLine(cliReadLine) | |
+ state = c.doContinueLine(cliRefreshPrompt) | |
+ | |
+ case cliRefreshPrompt: | |
+ state = c.doRefreshPrompts(cliReadLine) | |
case cliReadLine: | |
state = c.doReadLine(cliDecidePath) | |
@@ -1182,21 +1177,21 @@ func runInteractive(conn *sqlConn) (exitErr error) { | |
state = c.doDecidePath() | |
case cliProcessFirstLine: | |
- state = c.doProcessFirstLine(cliRefreshPrompts, cliHandleCliCmd) | |
+ state = c.doProcessFirstLine(cliStartLine, cliHandleCliCmd) | |
case cliHandleCliCmd: | |
- state = c.doHandleCliCmd(cliReadLine, cliPrepareStatementLine) | |
+ state = c.doHandleCliCmd(cliRefreshPrompt, cliPrepareStatementLine) | |
case cliPrepareStatementLine: | |
state = c.doPrepareStatementLine( | |
- cliRefreshPrompts, cliContinueLine, cliCheckStatement, cliRunStatement, | |
+ cliStartLine, cliContinueLine, cliCheckStatement, cliRunStatement, | |
) | |
case cliCheckStatement: | |
- state = c.doCheckStatement(cliRefreshPrompts, cliContinueLine, cliRunStatement) | |
+ state = c.doCheckStatement(cliStartLine, cliContinueLine, cliRunStatement) | |
case cliRunStatement: | |
- state = c.doRunStatement(cliRefreshPrompts) | |
+ state = c.doRunStatement(cliStartLine) | |
default: | |
panic(fmt.Sprintf("unknown state: %d", state)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment