Skip to content

Instantly share code, notes, and snippets.

@Ellrion
Created August 8, 2018 12:17
Show Gist options
  • Save Ellrion/fe3ef119ca35ae335d520a06d4b0247c to your computer and use it in GitHub Desktop.
Save Ellrion/fe3ef119ca35ae335d520a06d4b0247c to your computer and use it in GitHub Desktop.
package tunnel
import (
"golang.org/x/crypto/ssh"
"errors"
"io"
"io/ioutil"
"log"
"net"
"net/url"
"os"
)
const defaultSshPort = "22"
type Tunnel struct {
localAddr string
remoteAddr string
serverAddr string
sshUser string
sshPass string
authByKey bool
conn *ssh.Client
remote net.Conn
local net.Listener
}
func New(port, toPort, serverDsn string) (*Tunnel, error) {
dsn, err := url.Parse(serverDsn)
if err != nil {
return nil, err
}
sshPort := dsn.Port()
if sshPort == "" {
sshPort = defaultSshPort
}
addr := net.JoinHostPort(dsn.Hostname(), sshPort)
pass, hasPass := dsn.User.Password()
byKey := !hasPass
if s := dsn.Query().Get("by_key"); s != "" {
byKey = true
}
t := &Tunnel{
remoteAddr: "127.0.0.1:" + port,
localAddr: "127.0.0.1:" + toPort,
serverAddr: addr,
sshUser: dsn.User.Username(),
sshPass: pass,
authByKey: byKey,
}
err = t.open()
return t, err
}
func (t *Tunnel) Close() error {
var errStr string
for _, c := range []io.Closer{t.conn, t.remote, t.local} {
if c == nil {
continue
}
err := c.Close()
if err != nil {
errStr += err.Error()
}
}
if errStr != "" {
return errors.New(errStr)
}
return nil
}
func (t *Tunnel) open() error {
var err error
t.conn, err = ssh.Dial("tcp", t.serverAddr, &ssh.ClientConfig{
User: t.sshUser,
Auth: []ssh.AuthMethod{
t.authBy(),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
return err
}
// local connection
t.remote, err = t.conn.Dial("tcp", t.remoteAddr)
if err != nil {
t.Close()
return err
}
t.local, err = net.Listen("tcp", t.localAddr)
if err != nil {
t.Close()
return err
}
go t.transfer()
return nil
}
func (t *Tunnel) authBy() ssh.AuthMethod {
if !t.authByKey {
return ssh.Password(t.sshPass)
}
return ssh.PublicKeysCallback(func() (signers []ssh.Signer, err error) {
buffer, err := ioutil.ReadFile(os.Getenv("HOME") + "/.ssh/id_rsa")
if err != nil {
return nil, err
}
var key ssh.Signer
if t.sshPass != "" {
key, err = ssh.ParsePrivateKeyWithPassphrase(buffer, []byte(t.sshPass))
} else {
key, err = ssh.ParsePrivateKey(buffer)
}
return []ssh.Signer{key}, err
})
}
func (t *Tunnel) transfer() {
for {
l, err := t.local.Accept()
if err != nil {
log.Fatalf("listen Accept failed %s", err)
}
go func() {
_, err := io.Copy(l, t.remote)
if err != nil {
log.Fatalf("io.Copy failed: %v", err)
}
}()
go func() {
_, err := io.Copy(t.remote, l)
if err != nil {
log.Fatalf("io.Copy failed: %v", err)
}
}()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment