Skip to content

Instantly share code, notes, and snippets.

@knz
Last active July 17, 2018 12:38
Show Gist options
  • Save knz/04c44b4d609df0e073d9f53bcf30bd68 to your computer and use it in GitHub Desktop.
Save knz/04c44b4d609df0e073d9f53bcf30bd68 to your computer and use it in GitHub Desktop.
diff --git a/pkg/cli/sql.go b/pkg/cli/sql.go
index 5957fae4bb..160f5edeec 100644
--- a/pkg/cli/sql.go
+++ b/pkg/cli/sql.go
@@ -53,7 +53,7 @@ const (
`
)
-const defaultPromptPattern = "%n@%M/%?%x>"
+const defaultPromptPattern = "%n@%M/%/%x>"
// sqlShellCmd opens a sql shell.
var sqlShellCmd = &cobra.Command{
@@ -86,6 +86,8 @@ type cliState struct {
fullPrompt string
// The prompt on a continuation line in a multi-line entry.
continuePrompt string
+ // Which prompt to use to populate currentPrompt
+ useContinuePrompt bool
// The current prompt, either fullPrompt or continuePrompt.
currentPrompt string
// The string used to produce the value of fullPrompt.
@@ -137,7 +139,7 @@ const (
// Querying the server for the current transaction status
// and setting the prompt accordingly.
- cliRefreshPrompts
+ cliRefreshPrompt
// Just before reading the first line of a potentially multi-line
// statement.
@@ -243,8 +245,7 @@ var options = map[string]struct {
set func(c *cliState, val string) error
reset func(c *cliState) error
// display is used to retrieve the current value.
- display func(c *cliState) string
- nextState cliStateEnum
+ display func(c *cliState) string
}{
`display_format`: {
"the output format for tabular data (pretty, csv, tsv, html, sql, records, raw)",
@@ -262,7 +263,6 @@ var options = map[string]struct {
return nil
},
func(_ *cliState) string { return cliCtx.tableDisplayFormat.String() },
- 0,
},
`echo`: {
"show SQL queries before they are sent to the server",
@@ -271,7 +271,6 @@ var options = map[string]struct {
func(_ *cliState, _ string) error { sqlCtx.echo = true; return nil },
func(_ *cliState) error { sqlCtx.echo = false; return nil },
func(_ *cliState) string { return strconv.FormatBool(sqlCtx.echo) },
- 0,
},
`errexit`: {
"exit the shell upon a query error",
@@ -280,7 +279,6 @@ var options = map[string]struct {
func(c *cliState, _ string) error { c.errExit = true; return nil },
func(c *cliState) error { c.errExit = false; return nil },
func(c *cliState) string { return strconv.FormatBool(c.errExit) },
- 0,
},
`check_syntax`: {
"check the SQL syntax before running a query (needs SHOW SYNTAX support on the server)",
@@ -289,7 +287,6 @@ var options = map[string]struct {
func(c *cliState, _ string) error { c.checkSyntax = true; return nil },
func(c *cliState) error { c.checkSyntax = false; return nil },
func(c *cliState) string { return strconv.FormatBool(c.checkSyntax) },
- 0,
},
`show_times`: {
"display the execution time after each query",
@@ -298,10 +295,9 @@ var options = map[string]struct {
func(_ *cliState, _ string) error { cliCtx.showTimes = true; return nil },
func(_ *cliState) error { cliCtx.showTimes = false; return nil },
func(_ *cliState) string { return strconv.FormatBool(cliCtx.showTimes) },
- 0,
},
`prompt1`: {
- "prompt string to use before each command (the following are expanded: %M full host, %m host, %> port number, %n user, %? database, %x txn status)",
+ "prompt string to use before each command (the following are expanded: %M full host, %m host, %> port number, %n user, %/ database, %x txn status)",
false,
true,
func(c *cliState, val string) error {
@@ -313,7 +309,6 @@ var options = map[string]struct {
return nil
},
func(c *cliState) string { return c.customPromptPattern },
- cliRefreshPrompts,
},
}
@@ -389,10 +384,6 @@ func (c *cliState) handleSet(args []string, nextState, errState cliStateEnum) cl
return errState
}
- if opt.nextState != 0 {
- return opt.nextState
- }
-
return nextState
}
@@ -412,9 +403,6 @@ func (c *cliState) handleUnset(args []string, nextState, errState cliStateEnum)
fmt.Fprintf(stderr, "\\unset %s: %v\n", args[0], err)
return errState
}
- if opt.nextState != 0 {
- return opt.nextState
- }
return nextState
}
@@ -535,7 +523,7 @@ func (c *cliState) pipeSyscmd(line string, nextState, errState cliStateEnum) cli
}
// rePromptFmt: available keys compile with regex expression one time.
-var rePromptFmt = regexp.MustCompile("%(.)")
+var rePromptFmt = regexp.MustCompile("(%.)")
// doRefreshPrompts refreshes the prompts of the client depending on the
// status of the current transaction.
@@ -553,6 +541,19 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum {
return nextState
}
+ // Prepare variables for use during the substitution below.
+ c.refreshTransactionStatus()
+ // refreshDatabaseName() must be called *after* refreshTransactionStatus(),
+ // even when %/ appears before %x in the prompt format.
+ dbName, hasDbName := c.refreshDatabaseName()
+ if !hasDbName {
+ dbName = "?"
+ }
+ userName := ""
+ if parsedURL.User != nil {
+ userName = parsedURL.User.Username()
+ }
+
c.fullPrompt = rePromptFmt.ReplaceAllStringFunc(c.customPromptPattern, func(m string) string {
switch m {
case "%M":
@@ -561,29 +562,22 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum {
return parsedURL.Hostname() // host name.
case "%>":
return parsedURL.Port() // port.
- case "%n":
- userName := ""
- if parsedURL.User != nil { // user name.
- userName = parsedURL.User.Username()
- }
+ case "%n": // user name.
return userName
- case "%?":
- dbName, hasDbName := c.refreshDatabaseName() // database name.
- if hasDbName {
- return dbName
- }
+ case "%/": // database name.
+ return dbName
case "%x": // txn status.
- c.refreshTransactionStatus()
return c.lastKnownTxnStatus
case "%%":
return "%"
- // default:
- // err = fmt.Errorf("unrecognized format code in prompt: %q", m)
- // return err.Error()
+ default:
+ err = fmt.Errorf("unrecognized format code in prompt: %q", m)
+ return ""
}
-
- return m
})
+ if err != nil {
+ c.fullPrompt = err.Error()
+ }
c.fullPrompt += " "
@@ -594,6 +588,16 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum {
c.continuePrompt = strings.Repeat(" ", len(c.fullPrompt)-3) + "-> "
}
+ switch c.useContinuePrompt {
+ case true:
+ c.currentPrompt = c.continuePrompt
+ case false:
+ c.currentPrompt = c.fullPrompt
+ }
+
+ // Configure the editor to use the new prompt.
+ c.ins.SetLeftPrompt(c.currentPrompt)
+
return nextState
}
@@ -756,23 +760,14 @@ func (c *cliState) doStartLine(nextState cliStateEnum) cliStateEnum {
c.atEOF = false
c.partialLines = c.partialLines[:0]
c.partialStmtsLen = 0
-
- if c.hasEditor() {
- c.currentPrompt = c.fullPrompt
- c.ins.SetLeftPrompt(c.currentPrompt)
- }
+ c.useContinuePrompt = false
return nextState
}
func (c *cliState) doContinueLine(nextState cliStateEnum) cliStateEnum {
c.atEOF = false
-
- if c.hasEditor() {
- c.currentPrompt = c.continuePrompt
- c.ins.SetLeftPrompt(c.currentPrompt)
- }
-
+ c.useContinuePrompt = true
return nextState
}
@@ -1164,16 +1159,16 @@ func runInteractive(conn *sqlConn) (exitErr error) {
c.buf = bufio.NewReader(stdin)
}
- state = c.doStart(cliRefreshPrompts)
-
- case cliRefreshPrompts:
- state = c.doRefreshPrompts(cliStartLine)
+ state = c.doStart(cliStartLine)
case cliStartLine:
- state = c.doStartLine(cliReadLine)
+ state = c.doStartLine(cliRefreshPrompt)
case cliContinueLine:
- state = c.doContinueLine(cliReadLine)
+ state = c.doContinueLine(cliRefreshPrompt)
+
+ case cliRefreshPrompt:
+ state = c.doRefreshPrompts(cliReadLine)
case cliReadLine:
state = c.doReadLine(cliDecidePath)
@@ -1182,21 +1177,21 @@ func runInteractive(conn *sqlConn) (exitErr error) {
state = c.doDecidePath()
case cliProcessFirstLine:
- state = c.doProcessFirstLine(cliRefreshPrompts, cliHandleCliCmd)
+ state = c.doProcessFirstLine(cliStartLine, cliHandleCliCmd)
case cliHandleCliCmd:
- state = c.doHandleCliCmd(cliReadLine, cliPrepareStatementLine)
+ state = c.doHandleCliCmd(cliRefreshPrompt, cliPrepareStatementLine)
case cliPrepareStatementLine:
state = c.doPrepareStatementLine(
- cliRefreshPrompts, cliContinueLine, cliCheckStatement, cliRunStatement,
+ cliStartLine, cliContinueLine, cliCheckStatement, cliRunStatement,
)
case cliCheckStatement:
- state = c.doCheckStatement(cliRefreshPrompts, cliContinueLine, cliRunStatement)
+ state = c.doCheckStatement(cliStartLine, cliContinueLine, cliRunStatement)
case cliRunStatement:
- state = c.doRunStatement(cliRefreshPrompts)
+ state = c.doRunStatement(cliStartLine)
default:
panic(fmt.Sprintf("unknown state: %d", state))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment