Created
August 27, 2019 22:01
-
-
Save dreampuf/b224fbb26167c275a32f847d77c3dc33 to your computer and use it in GitHub Desktop.
Golang SSH Interactively shell show case
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"context" | |
"golang.org/x/crypto/ssh" | |
"golang.org/x/crypto/ssh/agent" | |
"io/ioutil" | |
"net" | |
"os" | |
"sync" | |
"testing" | |
"time" | |
) | |
type singleWriter struct { | |
b bytes.Buffer | |
mu sync.Mutex | |
} | |
func (w *singleWriter) Write(p []byte) (int, error) { | |
w.mu.Lock() | |
defer w.mu.Unlock() | |
return w.b.Write(p) | |
} | |
func (w *singleWriter) Reset() { | |
w.mu.Lock() | |
defer w.mu.Unlock() | |
w.b.Reset() | |
} | |
func (w *singleWriter) String() string { | |
w.mu.Lock() | |
defer w.mu.Unlock() | |
return w.b.String() | |
} | |
func (w *singleWriter) Len() int { | |
w.mu.Lock() | |
defer w.mu.Unlock() | |
return w.b.Len() | |
} | |
func (w *singleWriter) WaitUntilMessage(timeoutCtx context.Context) string { | |
if w.b.Len() > 0 { | |
defer w.b.Reset() | |
return w.b.String() | |
} | |
t := time.NewTicker(time.Millisecond * 100) | |
defer t.Stop() | |
outter: | |
for { | |
if w.b.Len() > 0 { | |
defer w.b.Reset() | |
return w.b.String() | |
} | |
select { | |
case <-timeoutCtx.Done(): | |
break outter | |
case <-t.C: | |
continue | |
} | |
} | |
return "" | |
} | |
func BuildClientConfig(username string, auths []ssh.AuthMethod) *ssh.ClientConfig { | |
cfg := &ssh.ClientConfig{ | |
User: username, | |
Auth: auths, | |
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { | |
return nil | |
}, | |
Timeout: 0, | |
} | |
cfg.SetDefaults() | |
return cfg | |
} | |
func BuildSSHAuths(password, identityFile string) ([]ssh.AuthMethod, error) { | |
auths := []ssh.AuthMethod{} | |
if password != "" { | |
auths = append(auths, ssh.Password(password)) | |
} | |
if identityFile != "" { | |
if f, err := os.Open(identityFile); err != nil { | |
return auths, err | |
} else { | |
bts, _ := ioutil.ReadAll(f) | |
_ = f.Close() | |
if signer, err := ssh.ParsePrivateKey(bts); err != nil { | |
return auths, err | |
} else { | |
auths = append(auths, ssh.PublicKeys(signer)) | |
} | |
} | |
} | |
SSH_AUTH_SOCK := os.Getenv("SSH_AUTH_SOCK") | |
if SSH_AUTH_SOCK != "" { | |
c, _ := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) | |
defer c.Close() | |
a := agent.NewClient(c) | |
// get all the pubkeys from the agent | |
agentSigners, _ := a.Signers() | |
auths = append(auths, ssh.PublicKeys(agentSigners...)) | |
} | |
return auths, nil | |
} | |
func Test_Pty(t *testing.T) { | |
ctx, _ := context.WithCancel(context.Background()) | |
dialer := &net.Dialer{ | |
Timeout: time.Duration(10) * time.Second, | |
} | |
var ( | |
conn net.Conn | |
err error | |
) | |
if conn, err = dialer.DialContext(ctx, "tcp", "localhost:22"); err != nil { | |
t.Error(err) | |
return | |
} | |
auths, _ := BuildSSHAuths("", "~/.ssh/id_rsa") | |
sshConn, chans, reqs, err := ssh.NewClientConn(conn, "localhost:22", BuildClientConfig("dreampuf", auths)) | |
if err != nil { | |
_ = conn.Close() | |
t.Error(err) | |
return | |
} | |
client := ssh.NewClient(sshConn, chans, reqs) | |
session, err := client.NewSession() | |
if err != nil { | |
t.Error(err) | |
return | |
} | |
//t.Error(session.Setenv("LC_ALL", "tmp")) | |
modes := ssh.TerminalModes{ | |
ssh.ECHO: 0, // disable echoing | |
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud | |
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud | |
} | |
term := os.Getenv("TERM") | |
if err := session.RequestPty(term, 40, 80, modes); err != nil { | |
t.Error(err) | |
return | |
} | |
var buffer singleWriter | |
stdin, _ := session.StdinPipe() | |
session.Stdout = &buffer | |
session.Stderr = &buffer | |
if err = session.Shell(); err != nil { | |
return | |
} | |
_, err = stdin.Write([]byte("export PS1=\"\"\n")) | |
beforePrepare := buffer.WaitUntilMessage(ctx) | |
t.Log("reset\n", beforePrepare) | |
go func() { | |
for n:=0; n < 3; n ++ { | |
_, err = stdin.Write([]byte("sudo date\n")) | |
if err != nil { | |
t.Error(err) | |
return | |
} | |
time.Sleep(time.Second) | |
} | |
}() | |
for { | |
ctxTO, _ := context.WithTimeout(ctx, time.Second * 2) | |
result := buffer.WaitUntilMessage(ctxTO) | |
if result == "" { | |
break | |
} | |
t.Log(result) | |
} | |
t.Error(session.Close()) | |
t.Error(client.Close()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment