Skip to content

Instantly share code, notes, and snippets.

@Timmmm
Last active Aug 16, 2016
Embed
What would you like to do?
// An SSH and SFTP server. It doesn't support 'exec' commands yet (WIP).
package main
import (
"crypto/rand"
"crypto/rsa"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"os/exec"
"sync"
"syscall"
"unsafe"
"github.com/kr/pty"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
func main() {
config := &ssh.ServerConfig{
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
// Obviously not production code. Needs constant-time comparison and stronger password and hashing.
if c.User() == "root" && string(pass) == "sueme" {
return nil, nil
}
return nil, fmt.Errorf("password rejected for %q", c.User())
},
// You may also explicitly allow anonymous client authentication.
// NoClientAuth: true,
}
// You can generate a keypair with 'ssh-keygen -t rsa'
// Otherwise one is generated, and we try to save it to the current directory.
var private ssh.Signer
privateBytes, err := ioutil.ReadFile("id_rsa")
if err == nil {
private, err = ssh.ParsePrivateKey(privateBytes)
if err != nil {
log.Fatalf("Failed to parse private key (%s)", err)
}
} else {
log.Print("Failed to load private key (./id_rsa).")
}
if private == nil {
log.Print("Generating random key")
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatalf("Failed to generate random key (%s)", err)
}
private, err = ssh.NewSignerFromKey(privateKey)
if err != nil {
log.Fatalf("Failed to convert random key (%s)", err)
}
// Try to save it (this is optional).
// if _, err := os.Stat("id_rsa"); os.IsNotExist(err) {
// // Save it to the file.
// log.Print("Saving new key to ./id_rsa")
// }
}
config.AddHostKey(private)
// Once a ServerConfig has been configured, connections can be accepted.
listener, err := net.Listen("tcp", "0.0.0.0:22")
if err != nil {
log.Fatalf("Failed to listen on 22 (%s)", err)
}
// Accept all connections
log.Print("Listening on 22...")
for {
tcpConn, err := listener.Accept()
if err != nil {
log.Printf("Failed to accept incoming connection (%s)", err)
continue
}
// Before use, a handshake must be performed on the incoming net.Conn.
sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, config)
if err != nil {
log.Printf("Failed to handshake (%s)", err)
continue
}
log.Printf("New SSH connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion())
// Discard all global out-of-band Requests
go ssh.DiscardRequests(reqs)
// Accept all channels
go handleChannels(chans)
}
}
func handleChannels(chans <-chan ssh.NewChannel) {
// Service the incoming Channel channel in go routine
for newChannel := range chans {
go handleChannel(newChannel)
}
}
func handleChannel(newChannel ssh.NewChannel) {
// Since we're handling a shell, we expect a
// channel type of "session". The also describes
// "x11", "direct-tcpip" and "forwarded-tcpip"
// channel types.
if t := newChannel.ChannelType(); t != "session" {
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("Unknown channel type: %s", t))
return
}
// At this point, we have the opportunity to reject the client's
// request for another logical connection
channel, requests, err := newChannel.Accept()
if err != nil {
log.Printf("Could not accept channel (%s)", err)
return
}
// Only one shell, exec or subsystem is allowed per channel.
// We keep track of that here.
commandStarted := false
// The shell (or exec command)
var shell *exec.Cmd
var shellFile *os.File
// Default pty size.
shellWidth := uint32(80)
shellHeight := uint32(24)
// Prepare teardown function
close := func() {
channel.Close()
if shell != nil {
if shell.Process != nil {
_, err := shell.Process.Wait()
if err != nil {
log.Printf("Failed to exit bash (%s)", err)
}
log.Printf("Session closed")
}
}
}
// Sessions have out-of-band requests such as "shell", "pty-req" and "env"
for req := range requests {
log.Print("Request:", req.Type)
switch req.Type {
case "shell":
if commandStarted {
req.Reply(false, nil)
break
}
commandStarted = true
// Allocate a terminal for this channel
log.Print("Creating pty...")
shell = exec.Command("/bin/sh")
var err error
shellFile, err = pty.Start(shell)
if err != nil {
log.Printf("Could not start pty (%s)", err)
close()
return
}
SetWinsize(shellFile.Fd(), shellWidth, shellHeight)
//pipe session to bash and visa-versa
var once sync.Once
go func() {
io.Copy(channel, shellFile)
once.Do(close)
}()
go func() {
io.Copy(shellFile, channel)
once.Do(close)
}()
req.Reply(true, nil)
case "pty-req":
_, shellWidth, shellHeight, _, _, _ := parsePtyReq(req.Payload)
if shellFile != nil {
SetWinsize(shellFile.Fd(), shellWidth, shellHeight)
}
req.Reply(true, nil)
case "window-change":
shellWidth, shellHeight, _, _ := parseWindowChange(req.Payload)
if shellFile != nil {
SetWinsize(shellFile.Fd(), shellWidth, shellHeight)
}
req.Reply(true, nil)
case "env":
envName, envVal := parseEnv(req.Payload)
log.Printf("Env %s: %s", envName, envVal)
req.Reply(false, nil)
case "exec":
// TODO
req.Reply(false, nil)
// Only one shell or exec is allowed per channel.
// if commandStarted {
// req.Reply(false, nil)
// break
// }
// commandStarted = true
// // The first 4 bytes are the length which we don't need.
// // Need at least 1 byte.
// if len(req.Payload) < 5 {
// req.Reply(false, nil)
// break
// }
// commandLine := string(req.Payload[4:])
// // Allocate a terminal for this channel
// log.Print("Exec command:", commandLine)
// cmdParts := strings.Split(commandLine, " ")
// if len(cmdParts) == 0 {
// req.Reply(false, nil)
// break
// }
// shell = exec.Command(cmdParts[0], cmdParts[1:]...)
// shell.Start()
// shellFile, err := pty.Start(shell)
// if err != nil {
// log.Printf("Could not start pty (%s)", err)
// close()
// return
// }
// //pipe session to bash and visa-versa
// var once sync.Once
// go func() {
// io.Copy(channel, shellFile)
// once.Do(close)
// }()
// go func() {
// io.Copy(shellFile, channel)
// once.Do(close)
// }()
// req.Reply(true, nil)
case "subsystem":
if commandStarted {
req.Reply(false, nil)
break
}
commandStarted = true
subsys := string(req.Payload[4:])
if subsys != "sftp" {
log.Print("Unsupported subsystem: ", subsys)
req.Reply(false, nil)
break
}
serverOptions := []sftp.ServerOption{
sftp.WithDebug(os.Stderr),
}
// if readOnly {
// serverOptions = append(serverOptions, sftp.ReadOnly())
// }
server, err := sftp.NewServer(
channel,
serverOptions...,
)
if err != nil {
log.Print("Error starting sftp server:", err)
close()
}
go func() {
err := server.Serve()
if err != nil {
log.Print("sftp server finished:", err)
}
close()
}()
req.Reply(true, nil)
}
}
}
// TODO: Refactor into generic parsing code and limit string lengths to avoid DoS.
func parsePtyReq(b []byte) (term string, cols, rows, pixelWidth, pixelHeight uint32, modes string) {
// Offset
off := 0
// Make sure there is enough data for the TERM env var length.
if len(b) < 4 {
return
}
// Read the TERM length
termLen := binary.BigEndian.Uint32(b[off : off+4])
off += 4
// Make sure there is enough data for TERM
if len(b) < off+int(termLen) {
return
}
// Read term
term = string(b[off : off+int(termLen)])
off += int(termLen)
// Make sure there is enough data for width/heights
if len(b) < off+4*4 {
return
}
cols = binary.BigEndian.Uint32(b[off : off+4])
off += 4
rows = binary.BigEndian.Uint32(b[off : off+4])
off += 4
pixelWidth = binary.BigEndian.Uint32(b[off : off+4])
off += 4
pixelHeight = binary.BigEndian.Uint32(b[off : off+4])
off += 4
// Make sure there is enough data for the modes length.
if len(b) < off+4 {
return
}
// Read the modes length.
modesLen := binary.BigEndian.Uint32(b[off : off+4])
off += 4
// Read the value
modes = string(b[off : off+int(modesLen)])
return
}
func parseWindowChange(b []byte) (cols, rows, pixelWidth, pixelHeight uint32) {
// Make sure there is enough data
if len(b) < 4*4+1 {
return
}
off := 1
cols = binary.BigEndian.Uint32(b[off : off+4])
off += 4
rows = binary.BigEndian.Uint32(b[off : off+4])
off += 4
pixelWidth = binary.BigEndian.Uint32(b[off : off+4])
off += 4
pixelHeight = binary.BigEndian.Uint32(b[off : off+4])
off += 4
return
}
func parseEnv(b []byte) (name, value string) {
// Offset
off := 0
// Make sure there is enough data for the name length.
if len(b) < 4 {
return
}
// Read the name length
nameLen := binary.BigEndian.Uint32(b[off : off+4])
off += 4
// Make sure there is enough data for the name.
if len(b) < off+int(nameLen) {
return
}
// Read the name
name = string(b[off : off+int(nameLen)])
off += int(nameLen)
// Make sure there is enough data for the value length.
if len(b) < off+4 {
return
}
// Read the value length.
valueLen := binary.BigEndian.Uint32(b[off : off+4])
off += 4
// Make sure there is enough data for the value
if len(b) < off+int(valueLen) {
return
}
// Read the value
value = string(b[off : off+int(valueLen)])
return
}
// ======================
// Winsize stores the Height and Width of a terminal.
type Winsize struct {
Height uint16
Width uint16
x uint16 // unused
y uint16 // unused
}
// SetWinsize sets the size of the given pty.
func SetWinsize(fd uintptr, w, h uint32) {
ws := &Winsize{Width: uint16(w), Height: uint16(h)}
syscall.Syscall(syscall.SYS_IOCTL, fd, uintptr(syscall.TIOCSWINSZ), uintptr(unsafe.Pointer(ws)))
}
// Borrowed from https://github.com/creack/termios/blob/master/win/win.go
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment