Last active
July 26, 2022 09:44
-
-
Save transacid/78ff18ad25daadcb4b62c123b35dfc5e to your computer and use it in GitHub Desktop.
concurrent distributed ssh. I really like pdsh and on my journey of learning go I thought I'd give it a try :)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"encoding/json" | |
"flag" | |
"fmt" | |
"log" | |
"net" | |
"os" | |
"os/user" | |
"sync" | |
"golang.org/x/crypto/ssh" | |
"golang.org/x/crypto/ssh/agent" | |
) | |
func main() { | |
var jFlag = flag.Bool("j", false, "json output") | |
var cFlag = flag.String("c", "echo", "command") | |
var lFlag = flag.String("l", "", "dsh list") | |
var pFlag = flag.String("p", "22", "port") | |
flag.Parse() | |
h := getHostList(*lFlag) | |
execute(h, *cFlag, jFlag, pFlag) | |
} | |
func getHostList(hosts string) []string { | |
hd, err := os.UserHomeDir() | |
if err != nil { | |
log.Fatal(err.Error()) | |
} | |
f, err := os.ReadFile(fmt.Sprintf("%s/.dsh/group/%s", hd, hosts)) | |
if err != nil { | |
log.Fatal(err.Error()) | |
} | |
t := bytes.TrimRight(f, "\n") | |
l := bytes.Split(t, []byte("\n")) | |
var list []string | |
for _, h := range l { | |
list = append(list, string(h)) | |
} | |
return list | |
} | |
func getSshConfig() *ssh.ClientConfig { | |
socket := os.Getenv("SSH_AUTH_SOCK") | |
conn, err := net.Dial("unix", socket) | |
if err != nil { | |
log.Fatalf("Failed to open SSH_AUTH_SOCK: %v", err) | |
} | |
u, err := user.Current() | |
if err != nil { | |
log.Fatalf(err.Error()) | |
} | |
agentClient := agent.NewClient(conn) | |
config := &ssh.ClientConfig{ | |
User: u.Username, | |
Auth: []ssh.AuthMethod{ | |
ssh.PublicKeysCallback(agentClient.Signers), | |
}, | |
HostKeyCallback: ssh.InsecureIgnoreHostKey(), | |
} | |
return config | |
} | |
type result struct { | |
Host string `json:"Host"` | |
Res string `json:"Output"` | |
} | |
func sshCommand(host, cmd, port string) result { | |
client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%s", host, port), getSshConfig()) | |
if err != nil { | |
return result{host, fmt.Sprintf("Failed to dial %s: %s", host, err.Error())} | |
} | |
defer client.Close() | |
session, err := client.NewSession() | |
if err != nil { | |
return result{host, fmt.Sprintf("Failed to create session for %s: %s", host, err.Error())} | |
} | |
defer session.Close() | |
var b bytes.Buffer | |
session.Stdout = &b | |
if err := session.Run(cmd); err != nil { | |
return result{host, fmt.Sprintf("Failed to run on %s: %s", host, err.Error())} | |
} | |
r := bytes.TrimRight(b.Bytes(), "\n") | |
return result{host, string(r)} | |
} | |
func execute(target []string, cmd string, jFlag *bool, pFlag *string) { | |
var wg sync.WaitGroup | |
ch := make(chan result) | |
defer close(ch) | |
for _, h := range target { | |
wg.Add(1) | |
go func(h, cmd string) { | |
defer wg.Done() | |
ch <- sshCommand(h, cmd, *pFlag) | |
}(h, cmd) | |
} | |
for i := 0; i < len(target); i++ { | |
res := <-ch | |
if *jFlag { | |
r, err := json.Marshal(res) | |
if err != nil { | |
log.Fatal(err.Error()) | |
} | |
fmt.Println(string(r)) | |
} else { | |
fmt.Printf("%s: %s\n", res.Host, res.Res) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment