Skip to content

Instantly share code, notes, and snippets.

@meetme2meat
Created November 5, 2020 11:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save meetme2meat/5652d7b803f5a57760fafd08e5d0d55d to your computer and use it in GitHub Desktop.
Save meetme2meat/5652d7b803f5a57760fafd08e5d0d55d to your computer and use it in GitHub Desktop.
SFTP-proxy
package server
import (
"errors"
"fmt"
"io"
"io/ioutil"
"time"
log "github.com/sirupsen/logrus"
"net"
"golang.org/x/crypto/ssh"
)
type Config struct {
Server struct {
Bind string
PrivateKey string
Auth sshAuth
AuthorizedKeys string
}
Remote struct {
Host string
Port int
Auth sshAuth
}
}
type sshAuth struct {
Type string
User string
Password string
PrivateKey string
}
type Server struct {
c *Config
}
func Start(config *Config) {
server := newServer(config)
go server.run()
}
func newServer(config *Config) *Server {
return &Server{c: config}
}
func (s *Server) run() {
log.Debugf("Listening at %s", s.c.Server.Bind)
s.listen(s.c.Server.Bind)
}
func (server *Server) listen(addr string) {
listener, err := net.Listen("tcp", addr)
if err != nil {
log.Fatalf("could not start the server %s", err)
}
go server.Accept(listener)
}
func (s *Server) Accept(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
// The server should not die
fmt.Printf("failed to accept incoming connection %s", err.Error())
continue
}
go s.handleConnection(conn)
}
}
func (s *Server) handleConnection(conn net.Conn) {
// check the TCP connection whether they obey handshake
defer func() {
if r := recover(); r != nil {
fmt.Println("Recovered the handleConnection", r)
}
}()
log.Debugf("Received connection from %s", conn.RemoteAddr())
var rClient *ssh.Client
config := s.buildServerConfig(&rClient)
privateBytes, err := ioutil.ReadFile(s.c.Server.PrivateKey)
if err != nil {
panic("Failed to load private key")
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
panic("Failed to parse private key")
}
// add the private key
config.AddHostKey(private)
sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil {
log.Errorf("Could not establish connection with %s : %v", conn.RemoteAddr().String(), err)
}
defer sconn.Close()
defer rClient.Close()
// we will discarding all out-of-band request (essentially global request as they are not channel request)
// like SSHKEEPALIVE which does not require a reply
go ssh.DiscardRequests(reqs)
for inputChannel := range chans {
log.Debug("Received a new channel again")
s.handleChannel(inputChannel, rClient)
}
log.Debugf("Lost connection with %s", conn.RemoteAddr())
}
func (s *Server) handleChannel(newChannel ssh.NewChannel, rClient *ssh.Client) {
// check if the channel is a terminal channel
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type: "+newChannel.ChannelType())
return
}
// accept the request
// this is meat of the code
// The inputChannel where SSH_MSG_CHANNEL_DATA is received
// and inputReq where SSH_MSG_CHANNEL_REQUEST is received
inputChannel, inputReq, err := newChannel.Accept()
if err != nil {
panic("could not accept channel")
}
// this would open the remote client channel
// which would be use copy the stream back and forth i.e SSH_MSG_CHANNEL_DATA
// reference used here https://github.com/jpillora/go-and-ssh/blob/master/channels/client.go#L41
outputChannel, outputReq, err := rClient.OpenChannel(newChannel.ChannelType(), nil)
if err != nil {
panic("could not open channel")
}
go s.bypass(inputChannel, outputChannel, inputReq, outputReq)
time.Sleep(5 * time.Second)
go s.copyStream(outputChannel, inputChannel)
go s.copyStream(inputChannel, outputChannel)
}
func (s *Server) bypass(chan1, chan2 ssh.Channel, req1, req2 <-chan *ssh.Request) {
defer func() {
if r := recover(); r != nil {
fmt.Println("Recovered in bypass", r)
}
}()
defer chan2.Close()
defer chan1.Close()
for {
select {
case req, ok := <-req1:
if !ok {
// if the channel is closed
return
}
if err := s.forwardRequest(req, chan2); err != nil {
fmt.Println("forward Error: " + err.Error())
continue
}
case req, ok := <-req2:
if !ok {
return
}
if err := s.forwardRequest(req, chan1); err != nil {
fmt.Println("forward Error: " + err.Error())
continue
}
}
}
}
func (s *Server) forwardRequest(req *ssh.Request, channel ssh.Channel) error {
reply, err := channel.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
return err
}
if req.WantReply {
req.Reply(reply, nil)
}
return nil
}
func (s *Server) copyStream(writer, reader ssh.Channel) {
_, err := io.Copy(writer, reader)
if err != nil {
log.Errorf("Copy stream error %s", err)
}
writer.CloseWrite()
log.Debug("Closing writer")
}
func (s *Server) buildServerConfig(rClient **ssh.Client) *ssh.ServerConfig {
config := &ssh.ServerConfig{
ServerVersion: "SSH-2.0-ProxyServer",
PasswordCallback: s.passwordHook(rClient),
PublicKeyCallback: s.publicKeyHook(rClient),
}
return config
}
func (s *Server) publicKeyHook(rclient **ssh.Client) func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
// https://github.com/golang/crypto/blob/master/ssh/example_test.go
// extracted from here.
authorizedKeysBytes, err := ioutil.ReadFile(s.c.Server.AuthorizedKeys)
if err != nil {
log.Fatalf("Failed to load authorized_keys, err: %v", err)
}
authorizedKeysMap := map[string]bool{}
for len(authorizedKeysBytes) > 0 {
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
if err != nil {
log.Fatalf("ParseAuthorized Key : %v", err)
}
authorizedKeysMap[string(pubKey.Marshal())] = true
authorizedKeysBytes = rest
}
return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
if c.User() == s.c.Server.Auth.User {
if authorizedKeysMap[string(pubKey.Marshal())] {
*rclient, err = s.remoteSSHClient()
if err == nil {
return &ssh.Permissions{
// Record the public key used for authentication.
Extensions: map[string]string{
"pubkey-fp": ssh.FingerprintSHA256(pubKey),
},
}, nil
}
}
}
return nil, fmt.Errorf("unknown public key for %q", c.User())
}
}
func (s *Server) passwordHook(rclient **ssh.Client) func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
return func(c ssh.ConnMetadata, passwd []byte) (*ssh.Permissions, error) {
// verify host
host := s.c.Remote.Host
// just a fail safe since we have already done the check at bootstrap
if host == "" {
return nil, fmt.Errorf("unknown user %s", c.User())
}
if c.User() != s.c.Server.Auth.User {
return nil, fmt.Errorf("unknown user %s", c.User())
}
if s.c.Server.Auth.Password != "" && s.c.Server.Auth.Password == string(passwd) {
var err error
*rclient, err = s.remoteSSHClient()
if err != nil {
log.Errorf("Could not authorize %s on %s: %s", c.User(), c.RemoteAddr().String(), err)
return nil, fmt.Errorf("Could not authorize %s on %s: %s",
c.User(), c.RemoteAddr().String(), err)
}
log.Debugf("User %s authenticated", s.c.Server.Auth.User)
return nil, nil
}
return nil, errors.New("passwords do not match")
}
}
func (s *Server) remoteSSHClient() (*ssh.Client, error) {
key, err := ioutil.ReadFile(s.c.Remote.Auth.PrivateKey)
if err != nil {
return nil, fmt.Errorf("unable to connect: not a valid key")
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil, fmt.Errorf("Parsing err: %v", err)
}
config := &ssh.ClientConfig{
User: s.c.Remote.Auth.User,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
}
conn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", s.c.Remote.Host, s.c.Remote.Port), config)
if err != nil {
return nil, fmt.Errorf("remote server connect error %v", err)
}
return conn, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment