Skip to content

Instantly share code, notes, and snippets.

@tobstarr
Created July 22, 2013 10:15
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save tobstarr/6052811 to your computer and use it in GitHub Desktop.
Save tobstarr/6052811 to your computer and use it in GitHub Desktop.
Golang SSH client

Golang SSH client

Requirements

This is a proof of concept which uses public key authentication through a running SSH Agent. (needs to be found in the SSH_AUTH_SOCK env variable)

Usage

$ make
$ ./bin/test_ssh_client -h 192.168.0.13 -l root -c "ps aux | head -n 5"
package main
import (
"fmt"
"time"
)
var LogColors = map[int]int {
DEBUG: 102,
INFO: 28,
WARN: 214,
ERROR: 196,
}
const TIME_FORMAT = "2006-01-02T15:04:05.000000"
func colorize(c int, s string) (r string) {
return fmt.Sprintf("\033[38;5;%dm%s\033[0m", c, s)
}
const (
DEBUG = iota
INFO
WARN
ERROR
)
var LogPrefixes = map[int]string {
DEBUG: "DEBUG",
INFO: "INFO ",
WARN: "WARN ",
ERROR: "ERROR",
}
type Logger struct {
LogLevel int
Prefix string
}
func (l* Logger) Debugf(format string, n ...interface{}) {
l.Logf(DEBUG, format, n...)
}
func (l* Logger) Infof(format string, n ...interface{}) {
l.Logf(INFO, format, n...)
}
func (l* Logger) Warnf(format string, n ...interface{}) {
l.Logf(WARN, format, n...)
}
func (l* Logger) Errorf(format string, n ...interface{}) {
l.Logf(ERROR, format, n...)
}
func (l* Logger) Logf(level int, s string, n ...interface{}) {
if level >= l.LogLevel {
fmt.Println(l.LogPrefix(level), fmt.Sprintf(s, n...))
}
}
func (l* Logger) Debug(n ...interface{}) {
l.Log(DEBUG, n...)
}
func (l* Logger) Info(n ...interface{}) {
l.Log(INFO, n...)
}
func (l* Logger) Warn(n ...interface{}) {
l.Log(WARN, n...)
}
func (l* Logger) Error(n ...interface{}) {
l.Log(ERROR, n...)
}
func (l* Logger) LogPrefix(i int) (s string ) {
s = time.Now().Format(TIME_FORMAT)
if l.Prefix != "" {
s = s + " [" + l.Prefix + "]"
}
s = s + " " + l.LogLevelPrefix(i)
return
}
func (l* Logger) LogLevelPrefix(level int) (s string) {
color := LogColors[level]
prefix := LogPrefixes[level]
return colorize(color, prefix)
}
func (l* Logger) Log(level int, n ...interface{}) {
if level >= l.LogLevel {
all := append([]interface{} { l.LogPrefix(level) }, n...)
fmt.Println(all...)
}
}
default:
@go get code.google.com/p/go.crypto/ssh
go build -o bin/test_ssh_client *.go
package main
import (
"strings"
"time"
"code.google.com/p/go.crypto/ssh"
"os"
"net"
"bytes"
)
func NewSSHClient(host, user string) (c* SSHClient) {
c = &SSHClient{
User: user,
Host: host,
Logger: &Logger{Prefix: host},
}
return
}
type SSHClient struct {
User string
Host string
Agent net.Conn
Conn *ssh.ClientConn
Logger* Logger
}
func (c* SSHClient) Close() {
if c.Conn != nil {
c.Conn.Close()
}
if c.Agent != nil{
c.Agent.Close()
}
}
func (c* SSHClient) ConnectWhenNotConnected() (e error) {
if c.Conn != nil {
c.Logger.Debug("already connected")
return
}
return c.Connect()
}
func (c* SSHClient) Connect() (e error) {
c.Logger.Debug("connecting " + c.Host)
var auths []ssh.ClientAuth
if c.Agent, e = net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); e == nil {
auths = append(auths, ssh.ClientAuthAgent(ssh.NewAgentClient(c.Agent)))
}
config := &ssh.ClientConfig {
User: "root",
Auth: auths,
}
c.Conn, e = ssh.Dial("tcp", c.Host + ":22", config)
if e != nil {
return
}
return
}
type LogWriter struct {
LogTo func(n... interface{})
Buffer bytes.Buffer
}
func (w* LogWriter) String() (s string) {
return w.Buffer.String()
}
func (w* LogWriter) Write(b []byte) (i int, e error) {
if w.LogTo != nil {
for _, s := range strings.Split(string(b), "\n") {
trimmed := strings.TrimSpace(s)
if len(trimmed) > 0 {
w.LogTo(trimmed)
}
}
}
w.Buffer.Write(b)
return len(b), nil
}
func (c* SSHClient) Execute(s string) (r* SSHResult, e error) {
started := time.Now()
c.ConnectWhenNotConnected()
ses, e := c.Conn.NewSession()
if e != nil {
return
}
r = &SSHResult{
StdoutBuffer: &LogWriter{LogTo: c.Logger.Debug},
StderrBuffer: &LogWriter{LogTo: c.Logger.Error},
}
ses.Stdout = r.StdoutBuffer
ses.Stderr = r.StderrBuffer
c.Logger.Infof(`executing command "%s"`, s)
r.Error = ses.Run(s)
c.Logger.Debugf("executed in %.06f", time.Now().Sub(started).Seconds())
ses.Close()
r.Runtime = time.Now().Sub(started)
return
}
package main
import (
"time"
"fmt"
)
type SSHResult struct {
StdoutBuffer, StderrBuffer *LogWriter
Runtime time.Duration
Error error
}
func (r* SSHResult) Stdout() string {
return r.StdoutBuffer.String()
}
func (r* SSHResult) String() (out string) {
m := map[string]string {
"stdout": fmt.Sprintf("%d bytes", len(r.StdoutBuffer.String())),
"stderr": fmt.Sprintf("%d bytes", len(r.StderrBuffer.String())),
"runtime": fmt.Sprintf("%0.6f", r.Runtime.Seconds()),
}
return fmt.Sprintf("%+v", m)
}
package main
import (
"fmt"
"flag"
"os"
)
func ExitWith(reason string) {
fmt.Println("ERROR:", reason)
flag.PrintDefaults()
os.Exit(1)
}
func main() {
var host* string
var login* string
var cmd* string
host = flag.String("h", "", "Host")
login = flag.String("l", "", "Login")
cmd = flag.String("c", "", "Command to execute")
flag.Parse()
if *host == "" {
ExitWith("host must be provided")
}
if *login == "" {
ExitWith("login must be provided")
}
if *cmd == "" {
ExitWith("cmd must be provided")
}
client := NewSSHClient(*host, *login)
defer client.Close()
rsp, e := client.Execute(*cmd)
if e != nil {
return
}
client.Logger.Info("response:", rsp.String())
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment