Skip to content

Instantly share code, notes, and snippets.

@caelifer caelifer/ssh_client.go
Last active Dec 15, 2016

Embed
What would you like to do?
package main
import (
"errors"
"flag"
"io"
"io/ioutil"
"net"
"os"
"os/user"
"path/filepath"
"strings"
"sync"
"time"
"github.com/mattn/go-isatty"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
log "github.com/golang/glog"
)
var (
remoteCommand string
username string
hosts []string
)
func init() {
// Init flags
flag.StringVar(&remoteCommand, "exec", "", "Command(s) to execute on the remote server(s)")
flag.StringVar(&username, "user", "username", "Use this name for remote SSH connection instead of current user.")
}
func main() {
flag.Parse()
if remoteCommand == "" {
remoteCommand = `/bin/bash -s 2>&1 | sed "s/^/$(hostname): /"`
} else {
remoteCommand = "(" + remoteCommand + `) 2>&1 | sed "s/^/$(hostname): /"`
}
if hosts = flag.Args(); len(hosts) == 0 {
// usage
os.Exit(1)
}
user, err := user.Current()
if err != nil {
log.Fatalf("Unable to get user info: %v", err)
}
if username == "" {
// Get user name
username = user.Username
log.Warningf("Set username: %v", username)
}
// Private key patterns
keyPat := filepath.Join(user.HomeDir, ".ssh", "id_*")
log.Infof("PK file pattern: %s", keyPat)
// Search for pk files on Unix-like file system
pkFiles, err := getPKFiles(keyPat)
if err != nil {
// Try the Windows paths
keyPat = filepath.Join("F:\\", "_ssh", "id_*")
log.Infof("PK file pattern on Windows: %s", keyPat)
// Search for pk files
pkFiles, err = getPKFiles(keyPat)
if err != nil {
log.Fatalf("Unable to parse private key files: %v", err)
}
}
// Add public keys signers
signers, err := getSignersFromPK(pkFiles)
if err != nil {
log.Fatalf("Unable to read signers: %v", err)
}
// Add SSH Agent signers
sock := os.Getenv("SSH_AUTH_SOCK")
if sock != "" {
log.V(2).Info("Found set SSH_AUTH_SOCK environment viriable")
log.V(2).Infof("Loading creds from SSH Agent socket: %s", sock)
if ss, err := getSignersFromAgent(sock); err == nil {
signers = append(signers, ss...)
}
}
if len(signers) == 0 {
log.Fatal("Did not find valid authentication methods for non-interactive SSH session")
}
// Create SSH client config
sshConfig := &ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signers...),
},
Timeout: 10 * time.Second,
}
var wg sync.WaitGroup
stdins := make([]io.Writer, 0, len(hosts))
port := "22" // XXX hardcode port number for now.
for _, host := range hosts {
// Create connection to the remote host
log.V(2).Infof("Connecting via SSH: host - %s:%s, user - %s", host, port, username)
conn, err := ssh.Dial("tcp", host+":"+port, sshConfig)
if err != nil {
log.Errorf("Failed to connect to %s:%s: %v", host, port, err)
continue
}
defer conn.Close()
// Forward Agent if available
if agt.Agent != nil {
if err := agent.ForwardToAgent(conn, agt.Agent); err != nil {
log.Errorf("Failed to forward to agent: host - %s:%s: %v", host, port, err)
continue
}
}
// Create SSH session
sess, err := conn.NewSession()
if err != nil {
log.Errorf("Failed to open SSH session to %s:%s: %v", host, port, err)
continue
}
defer sess.Close()
log.V(2).Infof("SSH Session state for %s:%s: openned", host, port)
// Enable Agent Forwarding
if agt.Agent != nil {
if err := agent.RequestAgentForwarding(sess); err != nil {
log.Errorf("Failed to agent forward: host - %s:%s: %v", host, port, err)
continue
}
}
// Request PTY
// err = sess.RequestPty("xterm", 80, 24, ssh.TerminalModes{})
// Connect stdout and stderr
sess.Stdout = os.Stdout
sess.Stderr = os.Stderr
// Add stdin pipe
stdin, err := sess.StdinPipe()
if err != nil {
log.Warningf("Unable to open pipe for SSH session %s:%s: %v", host, port, err)
continue
}
stdins = append(stdins, stdin)
// Set remote environment
// err = sess.Setenv("LC_JAVA_HOME", "/usr/java/latest")
wg.Add(1)
// Start login shell
go func(host string) {
defer wg.Done()
log.V(2).Infof("Running remote command on %s:%s: %s", host, port, remoteCommand)
// Execute command on multiple boxes at the same time
err := sess.Start(remoteCommand)
if err != nil {
log.Warningf("SSH Session state for %s:%s: error: %v", host, port, err)
return
}
err = sess.Wait()
if err != nil {
log.Warningf("SSH Session state for %s:%s: error: %v", host, port, err)
return
}
log.V(2).Infof("SSH Session state for %s:%s: closed", host, port)
}(host)
}
// Close all open pipes
if !isatty.IsTerminal(os.Stdin.Fd()) {
io.Copy(io.MultiWriter(stdins...), os.Stdin)
for _, in := range stdins {
in.(io.Closer).Close()
}
}
wg.Wait()
}
func getPKFiles(pat string) ([]string, error) {
// Search for pk files
pkFileCandidates, err := filepath.Glob(pat)
if err != nil {
return nil, err
}
if len(pkFileCandidates) == 0 {
return nil, errors.New("no files found")
}
pkValidaedFiles := pkFileCandidates[:0]
for _, pkc := range pkFileCandidates {
if !strings.HasSuffix(pkc, ".pub") {
pkValidaedFiles = append(pkValidaedFiles, pkc)
}
}
return pkValidaedFiles, nil
}
func getSignersFromPK(pkFiles []string) ([]ssh.Signer, error) {
signers := make([]ssh.Signer, 0, len(pkFiles))
for _, pkFile := range pkFiles {
log.V(2).Infof("Processing private key file: %v", pkFile)
buf, err := ioutil.ReadFile(pkFile)
if err != nil {
log.Warningf("Unable to read PK file: %v", pkFile)
continue
}
signer, err := ssh.ParsePrivateKey(buf)
if err != nil {
log.Warningf("Unable to parse PK file: %v", pkFile)
continue
}
signers = append(signers, signer)
}
return signers, nil
}
var agt struct {
Agent agent.Agent
CleanFn func() error
}
func getSignersFromAgent(sock string) ([]ssh.Signer, error) {
conn, err := net.Dial("unix", sock)
if err != nil {
log.Warningf("Failed to connect to SSH Agent")
return nil, err
}
agt.Agent = agent.NewClient(conn)
agt.CleanFn = conn.Close
return agt.Agent.Signers()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.