Last active
March 3, 2023 18:26
-
-
Save rwhelan/795dde0d22c069c0f31fdf4cb2ab74da to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"context" | |
"errors" | |
"fmt" | |
"os" | |
"sync" | |
"github.com/bramvdbogaerde/go-scp" | |
"golang.org/x/crypto/ssh" | |
) | |
type SSHServer struct { | |
host string | |
sshc *ssh.Client | |
} | |
func (s *SSHServer) UploadFile(localPath, remotePath, permissions string) error { | |
localfd, err := os.Open(localPath) | |
if err != nil { | |
return fmt.Errorf("unable to read local file %s: %w", localPath, err) | |
} | |
defer localfd.Close() | |
scpc, err := scp.NewClientBySSH(s.sshc) | |
if err != nil { | |
return fmt.Errorf("unable to create scp client for host %s: %w", s.host, err) | |
} | |
defer scpc.Close() | |
return scpc.CopyFromFile(context.Background(), *localfd, remotePath, permissions) | |
} | |
func (s *SSHServer) UploadFileContent(content []byte, remotePath, permissions string) error { | |
buf := bytes.NewReader(content) | |
scpc, err := scp.NewClientBySSH(s.sshc) | |
if err != nil { | |
return fmt.Errorf("unable to create scp client for host %s: %w", s.host, err) | |
} | |
defer scpc.Close() | |
return scpc.Copy(context.Background(), buf, remotePath, permissions, int64(buf.Len())) | |
} | |
func (s *SSHServer) ExecCmd(cmd string) (string, string, error) { | |
session, err := s.sshc.NewSession() | |
if err != nil { | |
return "", "", fmt.Errorf("unable to create ssh session: %w", err) | |
} | |
defer session.Close() | |
var stdout bytes.Buffer | |
var stderr bytes.Buffer | |
session.Stdout = &stdout | |
session.Stderr = &stderr | |
err = session.Run(cmd) | |
return stdout.String(), stderr.String(), err | |
} | |
func NewSSHServer(host string, port int) (*SSHServer, error) { | |
// TODO: not this | |
privKeyPath, ok := os.LookupEnv("TF_VAR_private_key") | |
if !ok { | |
return nil, errors.New("missing required env var: 'TF_VAR_private_key'") | |
} | |
// TODO: ...or this | |
sshUser, ok := os.LookupEnv("TF_VAR_ssh_user") | |
if !ok { | |
return nil, errors.New("missing required env var: 'TF_VAR_ssh_user'") | |
} | |
privKeyContent, err := os.ReadFile(privKeyPath) | |
if err != nil { | |
return nil, errors.New(fmt.Sprintf("unable to read ssh key file %s", privKeyPath)) | |
} | |
signer, err := ssh.ParsePrivateKey(privKeyContent) | |
if err != nil { | |
return nil, fmt.Errorf("unable to parse private ssh key %s: %w", privKeyPath, err) | |
} | |
sshConfig := &ssh.ClientConfig{ | |
User: sshUser, | |
Auth: []ssh.AuthMethod{ | |
ssh.PublicKeys(signer), | |
}, | |
HostKeyCallback: ssh.InsecureIgnoreHostKey(), | |
} | |
sshc, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", host, port), sshConfig) | |
if err != nil { | |
return nil, fmt.Errorf("unable to connect to host %s: %w", host, err) | |
} | |
return &SSHServer{ | |
host: host, | |
sshc: sshc, | |
}, nil | |
} | |
// ============================================================================ | |
// TODO: server_list.go ? | |
type SSHServerList []SSHServer | |
type SSHServerListUploadResult struct { | |
Host string | |
Err error | |
} | |
type SSHSeverListExecCmdResult struct { | |
Host string | |
Stdout string | |
Stderr string | |
Err error | |
} | |
// TODO: support non-standard ssh port? | |
func NewSSHServerList(servers []string) (SSHServerList, error) { | |
sshServers := make([]SSHServer, 0, len(servers)) | |
for _, s := range servers { | |
serv, err := NewSSHServer(s, 22) | |
if err != nil { | |
return nil, err | |
} | |
sshServers = append(sshServers, *serv) | |
} | |
return sshServers, nil | |
} | |
func (s *SSHServerList) UploadFileContent(content []byte, remotePath, permissions string) []SSHServerListUploadResult { | |
results := make([]SSHServerListUploadResult, len(*s)) | |
wg := sync.WaitGroup{} | |
wg.Add(len(*s)) | |
for i, server := range *s { | |
i := i // local scoping | |
// possible thundering hurd issue; fix later | |
go func(server SSHServer) { | |
results[i] = SSHServerListUploadResult{ | |
Host: server.host, | |
Err: server.UploadFileContent(content, remotePath, permissions), | |
} | |
wg.Done() | |
}(server) | |
} | |
wg.Wait() | |
return results | |
} | |
func (s *SSHServerList) UploadFile(localPath, remotePath, permissions string) ([]SSHServerListUploadResult, error) { | |
fileContent, err := os.ReadFile(localPath) | |
if err != nil { | |
return nil, fmt.Errorf("unable to read local file %s: %w", localPath, err) | |
} | |
return s.UploadFileContent(fileContent, remotePath, permissions), nil | |
} | |
func (s *SSHServerList) ExecCmd(cmd string) []SSHSeverListExecCmdResult { | |
results := make([]SSHSeverListExecCmdResult, len(*s)) | |
wg := sync.WaitGroup{} | |
wg.Add(len(*s)) | |
for i, server := range *s { | |
i := i // keepin' it local | |
go func(server SSHServer) { | |
stdout, stderr, err := server.ExecCmd(cmd) | |
results[i] = SSHSeverListExecCmdResult{ | |
Host: server.host, | |
Stdout: stdout, | |
Stderr: stderr, | |
Err: err, | |
} | |
wg.Done() | |
}(server) | |
} | |
wg.Wait() | |
return results | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment