Skip to content

Instantly share code, notes, and snippets.

@rwhelan
Last active March 3, 2023 18:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rwhelan/795dde0d22c069c0f31fdf4cb2ab74da to your computer and use it in GitHub Desktop.
Save rwhelan/795dde0d22c069c0f31fdf4cb2ab74da to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"sync"
"github.com/bramvdbogaerde/go-scp"
"golang.org/x/crypto/ssh"
)
type SSHServer struct {
host string
sshc *ssh.Client
}
func (s *SSHServer) UploadFile(localPath, remotePath, permissions string) error {
localfd, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("unable to read local file %s: %w", localPath, err)
}
defer localfd.Close()
scpc, err := scp.NewClientBySSH(s.sshc)
if err != nil {
return fmt.Errorf("unable to create scp client for host %s: %w", s.host, err)
}
defer scpc.Close()
return scpc.CopyFromFile(context.Background(), *localfd, remotePath, permissions)
}
func (s *SSHServer) UploadFileContent(content []byte, remotePath, permissions string) error {
buf := bytes.NewReader(content)
scpc, err := scp.NewClientBySSH(s.sshc)
if err != nil {
return fmt.Errorf("unable to create scp client for host %s: %w", s.host, err)
}
defer scpc.Close()
return scpc.Copy(context.Background(), buf, remotePath, permissions, int64(buf.Len()))
}
func (s *SSHServer) ExecCmd(cmd string) (string, string, error) {
session, err := s.sshc.NewSession()
if err != nil {
return "", "", fmt.Errorf("unable to create ssh session: %w", err)
}
defer session.Close()
var stdout bytes.Buffer
var stderr bytes.Buffer
session.Stdout = &stdout
session.Stderr = &stderr
err = session.Run(cmd)
return stdout.String(), stderr.String(), err
}
func NewSSHServer(host string, port int) (*SSHServer, error) {
// TODO: not this
privKeyPath, ok := os.LookupEnv("TF_VAR_private_key")
if !ok {
return nil, errors.New("missing required env var: 'TF_VAR_private_key'")
}
// TODO: ...or this
sshUser, ok := os.LookupEnv("TF_VAR_ssh_user")
if !ok {
return nil, errors.New("missing required env var: 'TF_VAR_ssh_user'")
}
privKeyContent, err := os.ReadFile(privKeyPath)
if err != nil {
return nil, errors.New(fmt.Sprintf("unable to read ssh key file %s", privKeyPath))
}
signer, err := ssh.ParsePrivateKey(privKeyContent)
if err != nil {
return nil, fmt.Errorf("unable to parse private ssh key %s: %w", privKeyPath, err)
}
sshConfig := &ssh.ClientConfig{
User: sshUser,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
sshc, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", host, port), sshConfig)
if err != nil {
return nil, fmt.Errorf("unable to connect to host %s: %w", host, err)
}
return &SSHServer{
host: host,
sshc: sshc,
}, nil
}
// ============================================================================
// TODO: server_list.go ?
type SSHServerList []SSHServer
type SSHServerListUploadResult struct {
Host string
Err error
}
type SSHSeverListExecCmdResult struct {
Host string
Stdout string
Stderr string
Err error
}
// TODO: support non-standard ssh port?
func NewSSHServerList(servers []string) (SSHServerList, error) {
sshServers := make([]SSHServer, 0, len(servers))
for _, s := range servers {
serv, err := NewSSHServer(s, 22)
if err != nil {
return nil, err
}
sshServers = append(sshServers, *serv)
}
return sshServers, nil
}
func (s *SSHServerList) UploadFileContent(content []byte, remotePath, permissions string) []SSHServerListUploadResult {
results := make([]SSHServerListUploadResult, len(*s))
wg := sync.WaitGroup{}
wg.Add(len(*s))
for i, server := range *s {
i := i // local scoping
// possible thundering hurd issue; fix later
go func(server SSHServer) {
results[i] = SSHServerListUploadResult{
Host: server.host,
Err: server.UploadFileContent(content, remotePath, permissions),
}
wg.Done()
}(server)
}
wg.Wait()
return results
}
func (s *SSHServerList) UploadFile(localPath, remotePath, permissions string) ([]SSHServerListUploadResult, error) {
fileContent, err := os.ReadFile(localPath)
if err != nil {
return nil, fmt.Errorf("unable to read local file %s: %w", localPath, err)
}
return s.UploadFileContent(fileContent, remotePath, permissions), nil
}
func (s *SSHServerList) ExecCmd(cmd string) []SSHSeverListExecCmdResult {
results := make([]SSHSeverListExecCmdResult, len(*s))
wg := sync.WaitGroup{}
wg.Add(len(*s))
for i, server := range *s {
i := i // keepin' it local
go func(server SSHServer) {
stdout, stderr, err := server.ExecCmd(cmd)
results[i] = SSHSeverListExecCmdResult{
Host: server.host,
Stdout: stdout,
Stderr: stderr,
Err: err,
}
wg.Done()
}(server)
}
wg.Wait()
return results
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment