Skip to content

Instantly share code, notes, and snippets.

@maraino
Created November 4, 2020 01:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maraino/ea3a8015f655eaf106e42e8f0c48c5da to your computer and use it in GitHub Desktop.
Save maraino/ea3a8015f655eaf106e42e8f0c48c5da to your computer and use it in GitHub Desktop.
package sshkms
import (
"bytes"
"context"
"crypto"
"io"
"net"
"os"
"golang.org/x/crypto/ssh"
"github.com/pkg/errors"
"github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/sshutil"
"golang.org/x/crypto/ssh/agent"
)
type algorithmAttributes struct {
Type string
Curve string
}
// DefaultRSAKeySize is the default size for RSA keys.
const DefaultRSAKeySize = 3072
var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]algorithmAttributes{
apiv1.UnspecifiedSignAlgorithm: {"EC", "P-256"},
apiv1.SHA256WithRSA: {"RSA", ""},
apiv1.SHA384WithRSA: {"RSA", ""},
apiv1.SHA512WithRSA: {"RSA", ""},
apiv1.SHA256WithRSAPSS: {"RSA", ""},
apiv1.SHA384WithRSAPSS: {"RSA", ""},
apiv1.SHA512WithRSAPSS: {"RSA", ""},
apiv1.ECDSAWithSHA256: {"EC", "P-256"},
apiv1.ECDSAWithSHA384: {"EC", "P-384"},
apiv1.ECDSAWithSHA512: {"EC", "P-521"},
apiv1.PureEd25519: {"OKP", "Ed25519"},
}
// generateKey is used for testing purposes.
var generateKey = func(kty, crv string, size int) (interface{}, interface{}, error) {
if kty == "RSA" && size == 0 {
size = DefaultRSAKeySize
}
return keys.GenerateKeyPair(kty, crv, size)
}
type AgentKMS struct {
agent agent.ExtendedAgent
conn net.Conn
}
// New returns a new AgentKMS.
func New(ctx context.Context, opts apiv1.Options) (*AgentKMS, error) {
socket := os.Getenv("SSH_AUTH_SOCK")
conn, err := net.Dial("unix", socket)
if err != nil {
return nil, errors.Wrap(err, "error connecting with ssh-agent")
}
return &AgentKMS{
agent: agent.NewClient(conn),
conn: conn,
}, nil
}
func (k *AgentKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) {
keys, err := k.agent.List()
if err != nil {
return nil, errors.Wrap(err, "error listing ssh-agent keys")
}
for _, key := range keys {
if key.Comment == req.Name {
return sshutil.PublicKey(key)
}
}
return nil, errors.Errorf("key with comment '%s' was not found", req.Name)
}
func (k *AgentKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) {
v, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
if !ok {
return nil, errors.Errorf("AgentKMS does not support signature algorithm '%s'", req.SignatureAlgorithm)
}
pub, priv, err := generateKey(v.Type, v.Curve, req.Bits)
if err != nil {
return nil, err
}
err = k.agent.Add(agent.AddedKey{
PrivateKey: priv,
Comment: req.Name,
LifetimeSecs: uint32(0),
})
if err != nil {
return nil, errors.Wrap(err, "error adding key to ssh-agent")
}
return &apiv1.CreateKeyResponse{
Name: req.Name,
PublicKey: pub,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: req.Name,
},
}, nil
}
func (k *AgentKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
keys, err := k.agent.List()
if err != nil {
return nil, errors.Wrap(err, "error listing ssh-agent keys")
}
var key *agent.Key
for _, k := range keys {
if k.Comment == req.SigningKey {
key = k
}
}
if key == nil {
return nil, errors.Errorf("key with comment '%s' was not found", req.SigningKey)
}
signers, err := k.agent.Signers()
if err != nil {
return nil, errors.Wrap(err, "error listing ssh-agent keys")
}
keyBytes := key.Marshal()
for _, sig := range signers {
if bytes.Equal(keyBytes, sig.PublicKey().Marshal()) {
return &Signer{signer: sig}, nil
}
}
return nil, errors.Errorf("signer with comment '%s' was not found", req.SigningKey)
}
func (k *AgentKMS) Close() error {
return errors.Wrap(k.conn.Close(), "error closing ssh-agent connection")
}
type Signer struct {
signer ssh.Signer
}
func (s *Signer) Public() crypto.PublicKey {
key, err := sshutil.PublicKey(s.signer.PublicKey())
if err != nil {
return err
}
return key
}
func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
sig, err := s.signer.Sign(rand, digest)
if err != nil {
return nil, errors.Wrap(err, "error signing data")
}
return sig.Blob, nil
// type asn1Signature struct {
// R, S *big.Int
// }
// var asn1Sig asn1Signature
// if err := ssh.Unmarshal(sig.Blob, &asn1Sig); err != nil {
// return nil, err
// }
// return asn1.Marshal(asn1Sig)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment