Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save meatballhat/854b13af8305f3ea4468d0bf9a1f5f7b to your computer and use it in GitHub Desktop.
Save meatballhat/854b13af8305f3ea4468d0bf9a1f5f7b to your computer and use it in GitHub Desktop.
diff --git a/acme/acme.go b/acme/acme.go
index 8619508..a7b6ce4 100644
--- a/acme/acme.go
+++ b/acme/acme.go
@@ -15,6 +15,7 @@ package acme
import (
"bytes"
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
@@ -36,9 +37,6 @@ import (
"strings"
"sync"
"time"
-
- "golang.org/x/net/context"
- "golang.org/x/net/context/ctxhttp"
)
// LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
@@ -133,7 +131,7 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
if dirURL == "" {
dirURL = LetsEncryptURL
}
- res, err := ctxhttp.Get(ctx, c.HTTPClient, dirURL)
+ res, err := c.get(ctx, dirURL)
if err != nil {
return Directory{}, err
}
@@ -154,7 +152,7 @@ func (c *Client) Discover(ctx context.Context) (Directory, error) {
CAA []string `json:"caa-identities"`
}
}
- if json.NewDecoder(res.Body).Decode(&v); err != nil {
+ if err := json.NewDecoder(res.Body).Decode(&v); err != nil {
return Directory{}, err
}
c.dir = &Directory{
@@ -200,7 +198,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
req.NotAfter = now.Add(exp).Format(time.RFC3339)
}
- res, err := c.postJWS(ctx, c.Key, c.dir.CertURL, req)
+ res, err := c.retryPostJWS(ctx, c.Key, c.dir.CertURL, req)
if err != nil {
return nil, "", err
}
@@ -216,7 +214,7 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
return cert, curl, err
}
// slurp issued cert and CA chain, if requested
- cert, err := responseCert(ctx, c.HTTPClient, res, bundle)
+ cert, err := c.responseCert(ctx, res, bundle)
return cert, curl, err
}
@@ -231,13 +229,13 @@ func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration,
// and has expected features.
func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
for {
- res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+ res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode == http.StatusOK {
- return responseCert(ctx, c.HTTPClient, res, bundle)
+ return c.responseCert(ctx, res, bundle)
}
if res.StatusCode > 299 {
return nil, responseError(res)
@@ -275,7 +273,7 @@ func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte,
if key == nil {
key = c.Key
}
- res, err := c.postJWS(ctx, key, c.dir.RevokeURL, body)
+ res, err := c.retryPostJWS(ctx, key, c.dir.RevokeURL, body)
if err != nil {
return err
}
@@ -363,7 +361,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
Resource: "new-authz",
Identifier: authzID{Type: "dns", Value: domain},
}
- res, err := c.postJWS(ctx, c.Key, c.dir.AuthzURL, req)
+ res, err := c.retryPostJWS(ctx, c.Key, c.dir.AuthzURL, req)
if err != nil {
return nil, err
}
@@ -387,7 +385,7 @@ func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization,
// If a caller needs to poll an authorization until its status is final,
// see the WaitAuthorization method.
func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
- res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+ res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
@@ -421,7 +419,7 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
Status: "deactivated",
Delete: true,
}
- res, err := c.postJWS(ctx, c.Key, url, req)
+ res, err := c.retryPostJWS(ctx, c.Key, url, req)
if err != nil {
return err
}
@@ -438,25 +436,11 @@ func (c *Client) RevokeAuthorization(ctx context.Context, url string) error {
//
// It returns a non-nil Authorization only if its Status is StatusValid.
// In all other cases WaitAuthorization returns an error.
-// If the Status is StatusInvalid, the returned error is ErrAuthorizationFailed.
+// If the Status is StatusInvalid, the returned error is of type *AuthorizationError.
func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) {
- var count int
- sleep := func(v string, inc int) error {
- count += inc
- d := backoff(count, 10*time.Second)
- d = retryAfter(v, d)
- wakeup := time.NewTimer(d)
- defer wakeup.Stop()
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-wakeup.C:
- return nil
- }
- }
-
+ sleep := sleeper(ctx)
for {
- res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+ res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
@@ -481,7 +465,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
return raw.authorization(url), nil
}
if raw.Status == StatusInvalid {
- return nil, ErrAuthorizationFailed
+ return nil, raw.error(url)
}
if err := sleep(retry, 0); err != nil {
return nil, err
@@ -493,7 +477,7 @@ func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorizat
//
// A client typically polls a challenge status using this method.
func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
- res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+ res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
@@ -527,7 +511,7 @@ func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error
Type: chal.Type,
Auth: auth,
}
- res, err := c.postJWS(ctx, c.Key, chal.URI, req)
+ res, err := c.retryPostJWS(ctx, c.Key, chal.URI, req)
if err != nil {
return nil, err
}
@@ -660,7 +644,7 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
req.Contact = acct.Contact
req.Agreement = acct.AgreedTerms
}
- res, err := c.postJWS(ctx, c.Key, url, req)
+ res, err := c.retryPostJWS(ctx, c.Key, url, req)
if err != nil {
return nil, err
}
@@ -697,6 +681,40 @@ func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Accoun
}, nil
}
+// retryPostJWS will retry calls to postJWS if there is a badNonce error,
+// clearing the stored nonces after each error.
+// If the response was 4XX-5XX, then responseError is called on the body,
+// the body is closed, and the error returned.
+func (c *Client) retryPostJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
+ sleep := sleeper(ctx)
+ for {
+ res, err := c.postJWS(ctx, key, url, body)
+ if err != nil {
+ return nil, err
+ }
+ // handle errors 4XX-5XX with responseError
+ if res.StatusCode >= 400 && res.StatusCode <= 599 {
+ err := responseError(res)
+ res.Body.Close()
+ // according to spec badNonce is urn:ietf:params:acme:error:badNonce
+ // however, acme servers in the wild return their version of the error
+ // https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4
+ if ae, ok := err.(*Error); ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") {
+ // clear any nonces that we might've stored that might now be
+ // considered bad
+ c.clearNonces()
+ retry := res.Header.Get("retry-after")
+ if err := sleep(retry, 1); err != nil {
+ return nil, err
+ }
+ continue
+ }
+ return nil, err
+ }
+ return res, nil
+ }
+}
+
// postJWS signs the body with the given key and POSTs it to the provided url.
// The body argument must be JSON-serializable.
func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
@@ -708,7 +726,7 @@ func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, bod
if err != nil {
return nil, err
}
- res, err := ctxhttp.Post(ctx, c.HTTPClient, url, "application/jose+json", bytes.NewReader(b))
+ res, err := c.post(ctx, url, "application/jose+json", bytes.NewReader(b))
if err != nil {
return nil, err
}
@@ -722,7 +740,7 @@ func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
c.noncesMu.Lock()
defer c.noncesMu.Unlock()
if len(c.nonces) == 0 {
- return fetchNonce(ctx, c.HTTPClient, url)
+ return c.fetchNonce(ctx, url)
}
var nonce string
for nonce = range c.nonces {
@@ -732,6 +750,13 @@ func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
return nonce, nil
}
+// clearNonces clears any stored nonces
+func (c *Client) clearNonces() {
+ c.noncesMu.Lock()
+ defer c.noncesMu.Unlock()
+ c.nonces = make(map[string]struct{})
+}
+
// addNonce stores a nonce value found in h (if any) for future use.
func (c *Client) addNonce(h http.Header) {
v := nonceFromHeader(h)
@@ -749,8 +774,58 @@ func (c *Client) addNonce(h http.Header) {
c.nonces[v] = struct{}{}
}
-func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
- resp, err := ctxhttp.Head(ctx, client, url)
+func (c *Client) httpClient() *http.Client {
+ if c.HTTPClient != nil {
+ return c.HTTPClient
+ }
+ return http.DefaultClient
+}
+
+func (c *Client) get(ctx context.Context, urlStr string) (*http.Response, error) {
+ req, err := http.NewRequest("GET", urlStr, nil)
+ if err != nil {
+ return nil, err
+ }
+ return c.do(ctx, req)
+}
+
+func (c *Client) head(ctx context.Context, urlStr string) (*http.Response, error) {
+ req, err := http.NewRequest("HEAD", urlStr, nil)
+ if err != nil {
+ return nil, err
+ }
+ return c.do(ctx, req)
+}
+
+func (c *Client) post(ctx context.Context, urlStr, contentType string, body io.Reader) (*http.Response, error) {
+ req, err := http.NewRequest("POST", urlStr, body)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", contentType)
+ return c.do(ctx, req)
+}
+
+func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
+ res, err := c.httpClient().Do(req.WithContext(ctx))
+ if err != nil {
+ select {
+ case <-ctx.Done():
+ // Prefer the unadorned context error.
+ // (The acme package had tests assuming this, previously from ctxhttp's
+ // behavior, predating net/http supporting contexts natively)
+ // TODO(bradfitz): reconsider this in the future. But for now this
+ // requires no test updates.
+ return nil, ctx.Err()
+ default:
+ return nil, err
+ }
+ }
+ return res, nil
+}
+
+func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
+ resp, err := c.head(ctx, url)
if err != nil {
return "", err
}
@@ -769,7 +844,7 @@ func nonceFromHeader(h http.Header) string {
return h.Get("Replay-Nonce")
}
-func responseCert(ctx context.Context, client *http.Client, res *http.Response, bundle bool) ([][]byte, error) {
+func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bool) ([][]byte, error) {
b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
if err != nil {
return nil, fmt.Errorf("acme: response stream: %v", err)
@@ -793,7 +868,7 @@ func responseCert(ctx context.Context, client *http.Client, res *http.Response,
return nil, errors.New("acme: rel=up link is too large")
}
for _, url := range up {
- cc, err := chainCert(ctx, client, url, 0)
+ cc, err := c.chainCert(ctx, url, 0)
if err != nil {
return nil, err
}
@@ -807,14 +882,8 @@ func responseError(resp *http.Response) error {
// don't care if ReadAll returns an error:
// json.Unmarshal will fail in that case anyway
b, _ := ioutil.ReadAll(resp.Body)
- e := struct {
- Status int
- Type string
- Detail string
- }{
- Status: resp.StatusCode,
- }
- if err := json.Unmarshal(b, &e); err != nil {
+ e := &wireError{Status: resp.StatusCode}
+ if err := json.Unmarshal(b, e); err != nil {
// this is not a regular error response:
// populate detail with anything we received,
// e.Status will already contain HTTP response code value
@@ -823,12 +892,7 @@ func responseError(resp *http.Response) error {
e.Detail = resp.Status
}
}
- return &Error{
- StatusCode: e.Status,
- ProblemType: e.Type,
- Detail: e.Detail,
- Header: resp.Header,
- }
+ return e.error(resp.Header)
}
// chainCert fetches CA certificate chain recursively by following "up" links.
@@ -836,12 +900,12 @@ func responseError(resp *http.Response) error {
// if the recursion level reaches maxChainLen.
//
// First chainCert call starts with depth of 0.
-func chainCert(ctx context.Context, client *http.Client, url string, depth int) ([][]byte, error) {
+func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte, error) {
if depth >= maxChainLen {
return nil, errors.New("acme: certificate chain is too deep")
}
- res, err := ctxhttp.Get(ctx, client, url)
+ res, err := c.get(ctx, url)
if err != nil {
return nil, err
}
@@ -863,7 +927,7 @@ func chainCert(ctx context.Context, client *http.Client, url string, depth int)
return nil, errors.New("acme: certificate chain is too large")
}
for _, up := range uplink {
- cc, err := chainCert(ctx, client, up, depth+1)
+ cc, err := c.chainCert(ctx, up, depth+1)
if err != nil {
return nil, err
}
@@ -893,6 +957,28 @@ func linkHeader(h http.Header, rel string) []string {
return links
}
+// sleeper returns a function that accepts the Retry-After HTTP header value
+// and an increment that's used with backoff to increasingly sleep on
+// consecutive calls until the context is done. If the Retry-After header
+// cannot be parsed, then backoff is used with a maximum sleep time of 10
+// seconds.
+func sleeper(ctx context.Context) func(ra string, inc int) error {
+ var count int
+ return func(ra string, inc int) error {
+ count += inc
+ d := backoff(count, 10*time.Second)
+ d = retryAfter(ra, d)
+ wakeup := time.NewTimer(d)
+ defer wakeup.Stop()
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-wakeup.C:
+ return nil
+ }
+ }
+}
+
// retryAfter parses a Retry-After HTTP header value,
// trying to convert v into an int (seconds) or use http.ParseTime otherwise.
// It returns d if v cannot be parsed.
@@ -974,7 +1060,8 @@ func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) {
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
BasicConstraintsValid: true,
- KeyUsage: x509.KeyUsageKeyEncipherment,
+ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
}
tmpl.DNSNames = san
diff --git a/acme/acme_test.go b/acme/acme_test.go
index 1205dbb..a4d276d 100644
--- a/acme/acme_test.go
+++ b/acme/acme_test.go
@@ -6,6 +6,7 @@ package acme
import (
"bytes"
+ "context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@@ -23,8 +24,6 @@ import (
"strings"
"testing"
"time"
-
- "golang.org/x/net/context"
)
// Decodes a JWS-encoded request and unmarshals the decoded JSON into a provided
@@ -544,6 +543,9 @@ func TestWaitAuthorizationInvalid(t *testing.T) {
if err == nil {
t.Error("err is nil")
}
+ if _, ok := err.(*AuthorizationError); !ok {
+ t.Errorf("err is %T; want *AuthorizationError", err)
+ }
}
}
@@ -981,7 +983,8 @@ func TestNonce_fetch(t *testing.T) {
defer ts.Close()
for ; i < len(tests); i++ {
test := tests[i]
- n, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
+ c := &Client{}
+ n, err := c.fetchNonce(context.Background(), ts.URL)
if n != test.nonce {
t.Errorf("%d: n=%q; want %q", i, n, test.nonce)
}
@@ -999,7 +1002,8 @@ func TestNonce_fetchError(t *testing.T) {
w.WriteHeader(http.StatusTooManyRequests)
}))
defer ts.Close()
- _, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
+ c := &Client{}
+ _, err := c.fetchNonce(context.Background(), ts.URL)
e, ok := err.(*Error)
if !ok {
t.Fatalf("err is %T; want *Error", err)
@@ -1064,6 +1068,44 @@ func TestNonce_postJWS(t *testing.T) {
}
}
+func TestRetryPostJWS(t *testing.T) {
+ var count int
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ count++
+ w.Header().Set("replay-nonce", fmt.Sprintf("nonce%d", count))
+ if r.Method == "HEAD" {
+ // We expect the client to do 2 head requests to fetch
+ // nonces, one to start and another after getting badNonce
+ return
+ }
+
+ head, err := decodeJWSHead(r)
+ if err != nil {
+ t.Errorf("decodeJWSHead: %v", err)
+ } else if head.Nonce == "" {
+ t.Error("head.Nonce is empty")
+ } else if head.Nonce == "nonce1" {
+ // return a badNonce error to force the call to retry
+ w.WriteHeader(http.StatusBadRequest)
+ w.Write([]byte(`{"type":"urn:ietf:params:acme:error:badNonce"}`))
+ return
+ }
+ // Make client.Authorize happy; we're not testing its result.
+ w.WriteHeader(http.StatusCreated)
+ w.Write([]byte(`{"status":"valid"}`))
+ }))
+ defer ts.Close()
+
+ client := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
+ // This call will fail with badNonce, causing a retry
+ if _, err := client.Authorize(context.Background(), "example.com"); err != nil {
+ t.Errorf("client.Authorize 1: %v", err)
+ }
+ if count != 4 {
+ t.Errorf("total requests count: %d; want 4", count)
+ }
+}
+
func TestLinkHeader(t *testing.T) {
h := http.Header{"Link": {
`<https://example.com/acme/new-authz>;rel="next"`,
diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go
index 4b15816..a478eff 100644
--- a/acme/autocert/autocert.go
+++ b/acme/autocert/autocert.go
@@ -10,6 +10,7 @@ package autocert
import (
"bytes"
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
@@ -30,9 +31,14 @@ import (
"time"
"golang.org/x/crypto/acme"
- "golang.org/x/net/context"
)
+// createCertRetryAfter is how much time to wait before removing a failed state
+// entry due to an unsuccessful createCert call.
+// This is a variable instead of a const for testing.
+// TODO: Consider making it configurable or an exp backoff?
+var createCertRetryAfter = time.Minute
+
// pseudoRand is safe for concurrent use.
var pseudoRand *lockedMathRand
@@ -41,8 +47,9 @@ func init() {
pseudoRand = &lockedMathRand{rnd: mathrand.New(src)}
}
-// AcceptTOS always returns true to indicate the acceptance of a CA Terms of Service
-// during account registration.
+// AcceptTOS is a Manager.Prompt function that always returns true to
+// indicate acceptance of the CA's Terms of Service during account
+// registration.
func AcceptTOS(tosURL string) bool { return true }
// HostPolicy specifies which host names the Manager is allowed to respond to.
@@ -76,18 +83,6 @@ func defaultHostPolicy(context.Context, string) error {
// It obtains and refreshes certificates automatically,
// as well as providing them to a TLS server via tls.Config.
//
-// A simple usage example:
-//
-// m := autocert.Manager{
-// Prompt: autocert.AcceptTOS,
-// HostPolicy: autocert.HostWhitelist("example.org"),
-// }
-// s := &http.Server{
-// Addr: ":https",
-// TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
-// }
-// s.ListenAndServeTLS("", "")
-//
// To preserve issued certificates and improve overall performance,
// use a cache implementation of Cache. For instance, DirCache.
type Manager struct {
@@ -123,7 +118,7 @@ type Manager struct {
// RenewBefore optionally specifies how early certificates should
// be renewed before they expire.
//
- // If zero, they're renewed 1 week before expiration.
+ // If zero, they're renewed 30 days before expiration.
RenewBefore time.Duration
// Client is used to perform low-level operations, such as account registration
@@ -173,10 +168,23 @@ type Manager struct {
// The error is propagated back to the caller of GetCertificate and is user-visible.
// This does not affect cached certs. See HostPolicy field description for more details.
func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ if m.Prompt == nil {
+ return nil, errors.New("acme/autocert: Manager.Prompt not set")
+ }
+
name := hello.ServerName
if name == "" {
return nil, errors.New("acme/autocert: missing server name")
}
+ if !strings.Contains(strings.Trim(name, "."), ".") {
+ return nil, errors.New("acme/autocert: server name component count invalid")
+ }
+ if strings.ContainsAny(name, `/\`) {
+ return nil, errors.New("acme/autocert: server name contains invalid character")
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+ defer cancel()
// check whether this is a token cert requested for TLS-SNI challenge
if strings.HasSuffix(name, ".acme.invalid") {
@@ -185,7 +193,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if cert := m.tokenCert[name]; cert != nil {
return cert, nil
}
- if cert, err := m.cacheGet(name); err == nil {
+ if cert, err := m.cacheGet(ctx, name); err == nil {
return cert, nil
}
// TODO: cache error results?
@@ -194,7 +202,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
// regular domain
name = strings.TrimSuffix(name, ".") // golang.org/issue/18114
- cert, err := m.cert(name)
+ cert, err := m.cert(ctx, name)
if err == nil {
return cert, nil
}
@@ -203,7 +211,6 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
}
// first-time
- ctx := context.Background() // TODO: use a deadline?
if err := m.hostPolicy()(ctx, name); err != nil {
return nil, err
}
@@ -211,14 +218,14 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
if err != nil {
return nil, err
}
- m.cachePut(name, cert)
+ m.cachePut(ctx, name, cert)
return cert, nil
}
// cert returns an existing certificate either from m.state or cache.
// If a certificate is found in cache but not in m.state, the latter will be filled
// with the cached value.
-func (m *Manager) cert(name string) (*tls.Certificate, error) {
+func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, error) {
m.stateMu.Lock()
if s, ok := m.state[name]; ok {
m.stateMu.Unlock()
@@ -227,7 +234,7 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) {
return s.tlscert()
}
defer m.stateMu.Unlock()
- cert, err := m.cacheGet(name)
+ cert, err := m.cacheGet(ctx, name)
if err != nil {
return nil, err
}
@@ -249,12 +256,11 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) {
}
// cacheGet always returns a valid certificate, or an error otherwise.
-func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
+// If a cached certficate exists but is not valid, ErrCacheMiss is returned.
+func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate, error) {
if m.Cache == nil {
return nil, ErrCacheMiss
}
- // TODO: might want to define a cache timeout on m
- ctx := context.Background()
data, err := m.Cache.Get(ctx, domain)
if err != nil {
return nil, err
@@ -263,7 +269,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
// private
priv, pub := pem.Decode(data)
if priv == nil || !strings.Contains(priv.Type, "PRIVATE") {
- return nil, errors.New("acme/autocert: no private key found in cache")
+ return nil, ErrCacheMiss
}
privKey, err := parsePrivateKey(priv.Bytes)
if err != nil {
@@ -281,13 +287,14 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
pubDER = append(pubDER, b.Bytes)
}
if len(pub) > 0 {
- return nil, errors.New("acme/autocert: invalid public key")
+ // Leftover content not consumed by pem.Decode. Corrupt. Ignore.
+ return nil, ErrCacheMiss
}
// verify and create TLS cert
leaf, err := validCert(domain, pubDER, privKey)
if err != nil {
- return nil, err
+ return nil, ErrCacheMiss
}
tlscert := &tls.Certificate{
Certificate: pubDER,
@@ -297,7 +304,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
return tlscert, nil
}
-func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error {
+func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Certificate) error {
if m.Cache == nil {
return nil
}
@@ -329,8 +336,6 @@ func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error {
}
}
- // TODO: might want to define a cache timeout on m
- ctx := context.Background()
return m.Cache.Put(ctx, domain, buf.Bytes())
}
@@ -370,6 +375,23 @@ func (m *Manager) createCert(ctx context.Context, domain string) (*tls.Certifica
der, leaf, err := m.authorizedCert(ctx, state.key, domain)
if err != nil {
+ // Remove the failed state after some time,
+ // making the manager call createCert again on the following TLS hello.
+ time.AfterFunc(createCertRetryAfter, func() {
+ defer testDidRemoveState(domain)
+ m.stateMu.Lock()
+ defer m.stateMu.Unlock()
+ // Verify the state hasn't changed and it's still invalid
+ // before deleting.
+ s, ok := m.state[domain]
+ if !ok {
+ return
+ }
+ if _, err := validCert(domain, s.cert, s.key); err == nil {
+ return
+ }
+ delete(m.state, domain)
+ })
return nil, err
}
state.cert = der
@@ -418,7 +440,6 @@ func (m *Manager) certState(domain string) (*certState, error) {
// authorizedCert starts domain ownership verification process and requests a new cert upon success.
// The key argument is the certificate private key.
func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) {
- // TODO: make m.verify retry or retry m.verify calls here
if err := m.verify(ctx, domain); err != nil {
return nil, nil, err
}
@@ -494,7 +515,7 @@ func (m *Manager) verify(ctx context.Context, domain string) error {
if err != nil {
return err
}
- m.putTokenCert(name, &cert)
+ m.putTokenCert(ctx, name, &cert)
defer func() {
// verification has ended at this point
// don't need token cert anymore
@@ -512,14 +533,14 @@ func (m *Manager) verify(ctx context.Context, domain string) error {
// putTokenCert stores the cert under the named key in both m.tokenCert map
// and m.Cache.
-func (m *Manager) putTokenCert(name string, cert *tls.Certificate) {
+func (m *Manager) putTokenCert(ctx context.Context, name string, cert *tls.Certificate) {
m.tokenCertMu.Lock()
defer m.tokenCertMu.Unlock()
if m.tokenCert == nil {
m.tokenCert = make(map[string]*tls.Certificate)
}
m.tokenCert[name] = cert
- m.cachePut(name, cert)
+ m.cachePut(ctx, name, cert)
}
// deleteTokenCert removes the token certificate for the specified domain name
@@ -644,10 +665,10 @@ func (m *Manager) hostPolicy() HostPolicy {
}
func (m *Manager) renewBefore() time.Duration {
- if m.RenewBefore > maxRandRenew {
+ if m.RenewBefore > renewJitter {
return m.RenewBefore
}
- return 7 * 24 * time.Hour // 1 week
+ return 720 * time.Hour // 30 days
}
// certState is ready when its mutex is unlocked for reading.
@@ -789,5 +810,10 @@ func (r *lockedMathRand) int63n(max int64) int64 {
return n
}
-// for easier testing
-var timeNow = time.Now
+// For easier testing.
+var (
+ timeNow = time.Now
+
+ // Called when a state is removed.
+ testDidRemoveState = func(domain string) {}
+)
diff --git a/acme/autocert/autocert_test.go b/acme/autocert/autocert_test.go
index 7afb213..0352e34 100644
--- a/acme/autocert/autocert_test.go
+++ b/acme/autocert/autocert_test.go
@@ -5,6 +5,7 @@
package autocert
import (
+ "context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
@@ -27,7 +28,6 @@ import (
"time"
"golang.org/x/crypto/acme"
- "golang.org/x/net/context"
)
var discoTmpl = template.Must(template.New("disco").Parse(`{
@@ -150,7 +150,7 @@ func TestGetCertificate_ForceRSA(t *testing.T) {
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
testGetCertificate(t, man, "example.org", hello)
- cert, err := man.cacheGet("example.org")
+ cert, err := man.cacheGet(context.Background(), "example.org")
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
@@ -159,9 +159,110 @@ func TestGetCertificate_ForceRSA(t *testing.T) {
}
}
-// tests man.GetCertificate flow using the provided hello argument.
+func TestGetCertificate_nilPrompt(t *testing.T) {
+ man := &Manager{}
+ defer man.stopRenew()
+ url, finish := startACMEServerStub(t, man, "example.org")
+ defer finish()
+ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+ man.Client = &acme.Client{
+ Key: key,
+ DirectoryURL: url,
+ }
+ hello := &tls.ClientHelloInfo{ServerName: "example.org"}
+ if _, err := man.GetCertificate(hello); err == nil {
+ t.Error("got certificate for example.org; wanted error")
+ }
+}
+
+func TestGetCertificate_expiredCache(t *testing.T) {
+ // Make an expired cert and cache it.
+ pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tmpl := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{CommonName: "example.org"},
+ NotAfter: time.Now(),
+ }
+ pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &pk.PublicKey, pk)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tlscert := &tls.Certificate{
+ Certificate: [][]byte{pub},
+ PrivateKey: pk,
+ }
+
+ man := &Manager{Prompt: AcceptTOS, Cache: newMemCache()}
+ defer man.stopRenew()
+ if err := man.cachePut(context.Background(), "example.org", tlscert); err != nil {
+ t.Fatalf("man.cachePut: %v", err)
+ }
+
+ // The expired cached cert should trigger a new cert issuance
+ // and return without an error.
+ hello := &tls.ClientHelloInfo{ServerName: "example.org"}
+ testGetCertificate(t, man, "example.org", hello)
+}
+
+func TestGetCertificate_failedAttempt(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusBadRequest)
+ }))
+ defer ts.Close()
+
+ const example = "example.org"
+ d := createCertRetryAfter
+ f := testDidRemoveState
+ defer func() {
+ createCertRetryAfter = d
+ testDidRemoveState = f
+ }()
+ createCertRetryAfter = 0
+ done := make(chan struct{})
+ testDidRemoveState = func(domain string) {
+ if domain != example {
+ t.Errorf("testDidRemoveState: domain = %q; want %q", domain, example)
+ }
+ close(done)
+ }
+
+ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+ man := &Manager{
+ Prompt: AcceptTOS,
+ Client: &acme.Client{
+ Key: key,
+ DirectoryURL: ts.URL,
+ },
+ }
+ defer man.stopRenew()
+ hello := &tls.ClientHelloInfo{ServerName: example}
+ if _, err := man.GetCertificate(hello); err == nil {
+ t.Error("GetCertificate: err is nil")
+ }
+ select {
+ case <-time.After(5 * time.Second):
+ t.Errorf("took too long to remove the %q state", example)
+ case <-done:
+ man.stateMu.Lock()
+ defer man.stateMu.Unlock()
+ if v, exist := man.state[example]; exist {
+ t.Errorf("state exists for %q: %+v", example, v)
+ }
+ }
+}
+
+// startACMEServerStub runs an ACME server
// The domain argument is the expected domain name of a certificate request.
-func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
+func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) {
// echo token-02 | shasum -a 256
// then divide result in 2 parts separated by dot
tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid"
@@ -187,7 +288,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
// discovery
case "/":
if err := discoTmpl.Execute(w, ca.URL); err != nil {
- t.Fatalf("discoTmpl: %v", err)
+ t.Errorf("discoTmpl: %v", err)
}
// client key registration
case "/new-reg":
@@ -197,7 +298,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
w.Header().Set("location", ca.URL+"/authz/1")
w.WriteHeader(http.StatusCreated)
if err := authzTmpl.Execute(w, ca.URL); err != nil {
- t.Fatalf("authzTmpl: %v", err)
+ t.Errorf("authzTmpl: %v", err)
}
// accept tls-sni-02 challenge
case "/challenge/2":
@@ -215,14 +316,14 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
b, _ := base64.RawURLEncoding.DecodeString(req.CSR)
csr, err := x509.ParseCertificateRequest(b)
if err != nil {
- t.Fatalf("new-cert: CSR: %v", err)
+ t.Errorf("new-cert: CSR: %v", err)
}
if csr.Subject.CommonName != domain {
t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain)
}
der, err := dummyCert(csr.PublicKey, domain)
if err != nil {
- t.Fatalf("new-cert: dummyCert: %v", err)
+ t.Errorf("new-cert: dummyCert: %v", err)
}
chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL)
w.Header().Set("link", chainUp)
@@ -232,14 +333,51 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
case "/ca-cert":
der, err := dummyCert(nil, "ca")
if err != nil {
- t.Fatalf("ca-cert: dummyCert: %v", err)
+ t.Errorf("ca-cert: dummyCert: %v", err)
}
w.Write(der)
default:
t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
}
}))
- defer ca.Close()
+ finish = func() {
+ ca.Close()
+
+ // make sure token cert was removed
+ cancel := make(chan struct{})
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ tick := time.NewTicker(100 * time.Millisecond)
+ defer tick.Stop()
+ for {
+ hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
+ if _, err := man.GetCertificate(hello); err != nil {
+ return
+ }
+ select {
+ case <-tick.C:
+ case <-cancel:
+ return
+ }
+ }
+ }()
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ close(cancel)
+ t.Error("token cert was not removed")
+ <-done
+ }
+ }
+ return ca.URL, finish
+}
+
+// tests man.GetCertificate flow using the provided hello argument.
+// The domain argument is the expected domain name of a certificate request.
+func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
+ url, finish := startACMEServerStub(t, man, domain)
+ defer finish()
// use EC key to run faster on 386
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@@ -248,7 +386,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
}
man.Client = &acme.Client{
Key: key,
- DirectoryURL: ca.URL,
+ DirectoryURL: url,
}
// simulate tls.Config.GetCertificate
@@ -279,23 +417,6 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl
t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain)
}
- // make sure token cert was removed
- done = make(chan struct{})
- go func() {
- for {
- hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
- if _, err := man.GetCertificate(hello); err != nil {
- break
- }
- time.Sleep(100 * time.Millisecond)
- }
- close(done)
- }()
- select {
- case <-time.After(5 * time.Second):
- t.Error("token cert was not removed")
- case <-done:
- }
}
func TestAccountKeyCache(t *testing.T) {
@@ -335,10 +456,11 @@ func TestCache(t *testing.T) {
man := &Manager{Cache: newMemCache()}
defer man.stopRenew()
- if err := man.cachePut("example.org", tlscert); err != nil {
+ ctx := context.Background()
+ if err := man.cachePut(ctx, "example.org", tlscert); err != nil {
t.Fatalf("man.cachePut: %v", err)
}
- res, err := man.cacheGet("example.org")
+ res, err := man.cacheGet(ctx, "example.org")
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
@@ -438,3 +560,47 @@ func TestValidCert(t *testing.T) {
}
}
}
+
+type cacheGetFunc func(ctx context.Context, key string) ([]byte, error)
+
+func (f cacheGetFunc) Get(ctx context.Context, key string) ([]byte, error) {
+ return f(ctx, key)
+}
+
+func (f cacheGetFunc) Put(ctx context.Context, key string, data []byte) error {
+ return fmt.Errorf("unsupported Put of %q = %q", key, data)
+}
+
+func (f cacheGetFunc) Delete(ctx context.Context, key string) error {
+ return fmt.Errorf("unsupported Delete of %q", key)
+}
+
+func TestManagerGetCertificateBogusSNI(t *testing.T) {
+ m := Manager{
+ Prompt: AcceptTOS,
+ Cache: cacheGetFunc(func(ctx context.Context, key string) ([]byte, error) {
+ return nil, fmt.Errorf("cache.Get of %s", key)
+ }),
+ }
+ tests := []struct {
+ name string
+ wantErr string
+ }{
+ {"foo.com", "cache.Get of foo.com"},
+ {"foo.com.", "cache.Get of foo.com"},
+ {`a\b.com`, "acme/autocert: server name contains invalid character"},
+ {`a/b.com`, "acme/autocert: server name contains invalid character"},
+ {"", "acme/autocert: missing server name"},
+ {"foo", "acme/autocert: server name component count invalid"},
+ {".foo", "acme/autocert: server name component count invalid"},
+ {"foo.", "acme/autocert: server name component count invalid"},
+ {"fo.o", "cache.Get of fo.o"},
+ }
+ for _, tt := range tests {
+ _, err := m.GetCertificate(&tls.ClientHelloInfo{ServerName: tt.name})
+ got := fmt.Sprint(err)
+ if got != tt.wantErr {
+ t.Errorf("GetCertificate(SNI = %q) = %q; want %q", tt.name, got, tt.wantErr)
+ }
+ }
+}
diff --git a/acme/autocert/cache.go b/acme/autocert/cache.go
index 9b184aa..61a5fd2 100644
--- a/acme/autocert/cache.go
+++ b/acme/autocert/cache.go
@@ -5,12 +5,11 @@
package autocert
import (
+ "context"
"errors"
"io/ioutil"
"os"
"path/filepath"
-
- "golang.org/x/net/context"
)
// ErrCacheMiss is returned when a certificate is not found in cache.
@@ -78,12 +77,13 @@ func (d DirCache) Put(ctx context.Context, name string, data []byte) error {
if tmp, err = d.writeTempFile(name, data); err != nil {
return
}
- // prevent overwriting the file if the context was cancelled
- if ctx.Err() != nil {
- return // no need to set err
+ select {
+ case <-ctx.Done():
+ // Don't overwrite the file if the context was canceled.
+ default:
+ newName := filepath.Join(string(d), name)
+ err = os.Rename(tmp, newName)
}
- name = filepath.Join(string(d), name)
- err = os.Rename(tmp, name)
}()
select {
case <-ctx.Done():
diff --git a/acme/autocert/cache_test.go b/acme/autocert/cache_test.go
index ad6d4a4..6e1b88d 100644
--- a/acme/autocert/cache_test.go
+++ b/acme/autocert/cache_test.go
@@ -5,13 +5,12 @@
package autocert
import (
+ "context"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
-
- "golang.org/x/net/context"
)
// make sure DirCache satisfies Cache interface
diff --git a/acme/autocert/example_test.go b/acme/autocert/example_test.go
new file mode 100644
index 0000000..c6267b8
--- /dev/null
+++ b/acme/autocert/example_test.go
@@ -0,0 +1,34 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package autocert_test
+
+import (
+ "crypto/tls"
+ "fmt"
+ "log"
+ "net/http"
+
+ "golang.org/x/crypto/acme/autocert"
+)
+
+func ExampleNewListener() {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "Hello, TLS user! Your config: %+v", r.TLS)
+ })
+ log.Fatal(http.Serve(autocert.NewListener("example.com"), mux))
+}
+
+func ExampleManager() {
+ m := autocert.Manager{
+ Prompt: autocert.AcceptTOS,
+ HostPolicy: autocert.HostWhitelist("example.org"),
+ }
+ s := &http.Server{
+ Addr: ":https",
+ TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
+ }
+ s.ListenAndServeTLS("", "")
+}
diff --git a/acme/autocert/listener.go b/acme/autocert/listener.go
new file mode 100644
index 0000000..d4c93d2
--- /dev/null
+++ b/acme/autocert/listener.go
@@ -0,0 +1,153 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package autocert
+
+import (
+ "crypto/tls"
+ "log"
+ "net"
+ "os"
+ "path/filepath"
+ "runtime"
+ "time"
+)
+
+// NewListener returns a net.Listener that listens on the standard TLS
+// port (443) on all interfaces and returns *tls.Conn connections with
+// LetsEncrypt certificates for the provided domain or domains.
+//
+// It enables one-line HTTPS servers:
+//
+// log.Fatal(http.Serve(autocert.NewListener("example.com"), handler))
+//
+// NewListener is a convenience function for a common configuration.
+// More complex or custom configurations can use the autocert.Manager
+// type instead.
+//
+// Use of this function implies acceptance of the LetsEncrypt Terms of
+// Service. If domains is not empty, the provided domains are passed
+// to HostWhitelist. If domains is empty, the listener will do
+// LetsEncrypt challenges for any requested domain, which is not
+// recommended.
+//
+// Certificates are cached in a "golang-autocert" directory under an
+// operating system-specific cache or temp directory. This may not
+// be suitable for servers spanning multiple machines.
+//
+// The returned Listener also enables TCP keep-alives on the accepted
+// connections. The returned *tls.Conn are returned before their TLS
+// handshake has completed.
+func NewListener(domains ...string) net.Listener {
+ m := &Manager{
+ Prompt: AcceptTOS,
+ }
+ if len(domains) > 0 {
+ m.HostPolicy = HostWhitelist(domains...)
+ }
+ dir := cacheDir()
+ if err := os.MkdirAll(dir, 0700); err != nil {
+ log.Printf("warning: autocert.NewListener not using a cache: %v", err)
+ } else {
+ m.Cache = DirCache(dir)
+ }
+ return m.Listener()
+}
+
+// Listener listens on the standard TLS port (443) on all interfaces
+// and returns a net.Listener returning *tls.Conn connections.
+//
+// The returned Listener also enables TCP keep-alives on the accepted
+// connections. The returned *tls.Conn are returned before their TLS
+// handshake has completed.
+//
+// Unlike NewListener, it is the caller's responsibility to initialize
+// the Manager m's Prompt, Cache, HostPolicy, and other desired options.
+func (m *Manager) Listener() net.Listener {
+ ln := &listener{
+ m: m,
+ conf: &tls.Config{
+ GetCertificate: m.GetCertificate, // bonus: panic on nil m
+ },
+ }
+ ln.tcpListener, ln.tcpListenErr = net.Listen("tcp", ":443")
+ return ln
+}
+
+type listener struct {
+ m *Manager
+ conf *tls.Config
+
+ tcpListener net.Listener
+ tcpListenErr error
+}
+
+func (ln *listener) Accept() (net.Conn, error) {
+ if ln.tcpListenErr != nil {
+ return nil, ln.tcpListenErr
+ }
+ conn, err := ln.tcpListener.Accept()
+ if err != nil {
+ return nil, err
+ }
+ tcpConn := conn.(*net.TCPConn)
+
+ // Because Listener is a convenience function, help out with
+ // this too. This is not possible for the caller to set once
+ // we return a *tcp.Conn wrapping an inaccessible net.Conn.
+ // If callers don't want this, they can do things the manual
+ // way and tweak as needed. But this is what net/http does
+ // itself, so copy that. If net/http changes, we can change
+ // here too.
+ tcpConn.SetKeepAlive(true)
+ tcpConn.SetKeepAlivePeriod(3 * time.Minute)
+
+ return tls.Server(tcpConn, ln.conf), nil
+}
+
+func (ln *listener) Addr() net.Addr {
+ if ln.tcpListener != nil {
+ return ln.tcpListener.Addr()
+ }
+ // net.Listen failed. Return something non-nil in case callers
+ // call Addr before Accept:
+ return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 443}
+}
+
+func (ln *listener) Close() error {
+ if ln.tcpListenErr != nil {
+ return ln.tcpListenErr
+ }
+ return ln.tcpListener.Close()
+}
+
+func homeDir() string {
+ if runtime.GOOS == "windows" {
+ return os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
+ }
+ if h := os.Getenv("HOME"); h != "" {
+ return h
+ }
+ return "/"
+}
+
+func cacheDir() string {
+ const base = "golang-autocert"
+ switch runtime.GOOS {
+ case "darwin":
+ return filepath.Join(homeDir(), "Library", "Caches", base)
+ case "windows":
+ for _, ev := range []string{"APPDATA", "CSIDL_APPDATA", "TEMP", "TMP"} {
+ if v := os.Getenv(ev); v != "" {
+ return filepath.Join(v, base)
+ }
+ }
+ // Worst case:
+ return filepath.Join(homeDir(), base)
+ }
+ if xdg := os.Getenv("XDG_CACHE_HOME"); xdg != "" {
+ return filepath.Join(xdg, base)
+ }
+ return filepath.Join(homeDir(), ".cache", base)
+}
diff --git a/acme/autocert/renewal.go b/acme/autocert/renewal.go
index 1a5018c..6c5da2b 100644
--- a/acme/autocert/renewal.go
+++ b/acme/autocert/renewal.go
@@ -5,15 +5,14 @@
package autocert
import (
+ "context"
"crypto"
"sync"
"time"
-
- "golang.org/x/net/context"
)
-// maxRandRenew is a maximum deviation from Manager.RenewBefore.
-const maxRandRenew = time.Hour
+// renewJitter is the maximum deviation from Manager.RenewBefore.
+const renewJitter = time.Hour
// domainRenewal tracks the state used by the periodic timers
// renewing a single domain's cert.
@@ -65,7 +64,7 @@ func (dr *domainRenewal) renew() {
// TODO: rotate dr.key at some point?
next, err := dr.do(ctx)
if err != nil {
- next = maxRandRenew / 2
+ next = renewJitter / 2
next += time.Duration(pseudoRand.int63n(int64(next)))
}
dr.timer = time.AfterFunc(next, dr.renew)
@@ -83,9 +82,9 @@ func (dr *domainRenewal) renew() {
func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
// a race is likely unavoidable in a distributed environment
// but we try nonetheless
- if tlscert, err := dr.m.cacheGet(dr.domain); err == nil {
+ if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil {
next := dr.next(tlscert.Leaf.NotAfter)
- if next > dr.m.renewBefore()+maxRandRenew {
+ if next > dr.m.renewBefore()+renewJitter {
return next, nil
}
}
@@ -103,7 +102,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
if err != nil {
return 0, err
}
- dr.m.cachePut(dr.domain, tlscert)
+ dr.m.cachePut(ctx, dr.domain, tlscert)
dr.m.stateMu.Lock()
defer dr.m.stateMu.Unlock()
// m.state is guaranteed to be non-nil at this point
@@ -114,7 +113,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
func (dr *domainRenewal) next(expiry time.Time) time.Duration {
d := expiry.Sub(timeNow()) - dr.m.renewBefore()
// add a bit of randomness to renew deadline
- n := pseudoRand.int63n(int64(maxRandRenew))
+ n := pseudoRand.int63n(int64(renewJitter))
d -= time.Duration(n)
if d < 0 {
return 0
diff --git a/acme/autocert/renewal_test.go b/acme/autocert/renewal_test.go
index 10c811a..f232619 100644
--- a/acme/autocert/renewal_test.go
+++ b/acme/autocert/renewal_test.go
@@ -5,6 +5,7 @@
package autocert
import (
+ "context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
@@ -31,7 +32,7 @@ func TestRenewalNext(t *testing.T) {
expiry time.Time
min, max time.Duration
}{
- {now.Add(90 * 24 * time.Hour), 83*24*time.Hour - maxRandRenew, 83 * 24 * time.Hour},
+ {now.Add(90 * 24 * time.Hour), 83*24*time.Hour - renewJitter, 83 * 24 * time.Hour},
{now.Add(time.Hour), 0, 1},
{now, 0, 1},
{now.Add(-time.Hour), 0, 1},
@@ -127,7 +128,7 @@ func TestRenewFromCache(t *testing.T) {
t.Fatal(err)
}
tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}}
- if err := man.cachePut(domain, tlscert); err != nil {
+ if err := man.cachePut(context.Background(), domain, tlscert); err != nil {
t.Fatal(err)
}
@@ -151,7 +152,7 @@ func TestRenewFromCache(t *testing.T) {
// ensure the new cert is cached
after := time.Now().Add(future)
- tlscert, err := man.cacheGet(domain)
+ tlscert, err := man.cacheGet(context.Background(), domain)
if err != nil {
t.Fatalf("man.cacheGet: %v", err)
}
diff --git a/acme/jws.go b/acme/jws.go
index 49ba313..6cbca25 100644
--- a/acme/jws.go
+++ b/acme/jws.go
@@ -134,7 +134,7 @@ func jwsHasher(key crypto.Signer) (string, crypto.Hash) {
return "ES256", crypto.SHA256
case "P-384":
return "ES384", crypto.SHA384
- case "P-512":
+ case "P-521":
return "ES512", crypto.SHA512
}
}
diff --git a/acme/jws_test.go b/acme/jws_test.go
index 1def873..0ff0fb5 100644
--- a/acme/jws_test.go
+++ b/acme/jws_test.go
@@ -12,11 +12,13 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
+ "fmt"
"math/big"
"testing"
)
-const testKeyPEM = `
+const (
+ testKeyPEM = `
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEA4xgZ3eRPkwoRvy7qeRUbmMDe0V+xH9eWLdu0iheeLlrmD2mq
WXfP9IeSKApbn34g8TuAS9g5zhq8ELQ3kmjr+KV86GAMgI6VAcGlq3QrzpTCf/30
@@ -46,10 +48,9 @@ EQeIP6dZtv8IMgtGIb91QX9pXvP0aznzQKwYIA8nZgoENCPfiMTPiEDT9e/0lObO
-----END RSA PRIVATE KEY-----
`
-// This thumbprint is for the testKey defined above.
-const testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ"
+ // This thumbprint is for the testKey defined above.
+ testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ"
-const (
// openssl ecparam -name secp256k1 -genkey -noout
testKeyECPEM = `
-----BEGIN EC PRIVATE KEY-----
@@ -58,39 +59,78 @@ AwEHoUQDQgAE5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HThqIrvawF5
QAaS/RNouybCiRhRjI3EaxLkQwgrCw0gqQ==
-----END EC PRIVATE KEY-----
`
- // 1. opnessl ec -in key.pem -noout -text
+ // openssl ecparam -name secp384r1 -genkey -noout
+ testKeyEC384PEM = `
+-----BEGIN EC PRIVATE KEY-----
+MIGkAgEBBDAQ4lNtXRORWr1bgKR1CGysr9AJ9SyEk4jiVnlUWWUChmSNL+i9SLSD
+Oe/naPqXJ6CgBwYFK4EEACKhZANiAAQzKtj+Ms0vHoTX5dzv3/L5YMXOWuI5UKRj
+JigpahYCqXD2BA1j0E/2xt5vlPf+gm0PL+UHSQsCokGnIGuaHCsJAp3ry0gHQEke
+WYXapUUFdvaK1R2/2hn5O+eiQM8YzCg=
+-----END EC PRIVATE KEY-----
+`
+ // openssl ecparam -name secp521r1 -genkey -noout
+ testKeyEC512PEM = `
+-----BEGIN EC PRIVATE KEY-----
+MIHcAgEBBEIBSNZKFcWzXzB/aJClAb305ibalKgtDA7+70eEkdPt28/3LZMM935Z
+KqYHh/COcxuu3Kt8azRAUz3gyr4zZKhlKUSgBwYFK4EEACOhgYkDgYYABAHUNKbx
+7JwC7H6pa2sV0tERWhHhB3JmW+OP6SUgMWryvIKajlx73eS24dy4QPGrWO9/ABsD
+FqcRSkNVTXnIv6+0mAF25knqIBIg5Q8M9BnOu9GGAchcwt3O7RDHmqewnJJDrbjd
+GGnm6rb+NnWR9DIopM0nKNkToWoF/hzopxu4Ae/GsQ==
+-----END EC PRIVATE KEY-----
+`
+ // 1. openssl ec -in key.pem -noout -text
// 2. remove first byte, 04 (the header); the rest is X and Y
- // 3. covert each with: echo <val> | xxd -r -p | base64 | tr -d '=' | tr '/+' '_-'
- testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ"
- testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk"
+ // 3. convert each with: echo <val> | xxd -r -p | base64 -w 100 | tr -d '=' | tr '/+' '_-'
+ testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ"
+ testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk"
+ testKeyEC384PubX = "MyrY_jLNLx6E1-Xc79_y-WDFzlriOVCkYyYoKWoWAqlw9gQNY9BP9sbeb5T3_oJt"
+ testKeyEC384PubY = "Dy_lB0kLAqJBpyBrmhwrCQKd68tIB0BJHlmF2qVFBXb2itUdv9oZ-TvnokDPGMwo"
+ testKeyEC512PubX = "AdQ0pvHsnALsfqlraxXS0RFaEeEHcmZb44_pJSAxavK8gpqOXHvd5Lbh3LhA8atY738AGwMWpxFKQ1VNeci_r7SY"
+ testKeyEC512PubY = "AXbmSeogEiDlDwz0Gc670YYByFzC3c7tEMeap7CckkOtuN0Yaebqtv42dZH0MiikzSco2ROhagX-HOinG7gB78ax"
+
// echo -n '{"crv":"P-256","kty":"EC","x":"<testKeyECPubX>","y":"<testKeyECPubY>"}' | \
// openssl dgst -binary -sha256 | base64 | tr -d '=' | tr '/+' '_-'
testKeyECThumbprint = "zedj-Bd1Zshp8KLePv2MB-lJ_Hagp7wAwdkA0NUTniU"
)
var (
- testKey *rsa.PrivateKey
- testKeyEC *ecdsa.PrivateKey
+ testKey *rsa.PrivateKey
+ testKeyEC *ecdsa.PrivateKey
+ testKeyEC384 *ecdsa.PrivateKey
+ testKeyEC512 *ecdsa.PrivateKey
)
func init() {
- d, _ := pem.Decode([]byte(testKeyPEM))
+ testKey = parseRSA(testKeyPEM, "testKeyPEM")
+ testKeyEC = parseEC(testKeyECPEM, "testKeyECPEM")
+ testKeyEC384 = parseEC(testKeyEC384PEM, "testKeyEC384PEM")
+ testKeyEC512 = parseEC(testKeyEC512PEM, "testKeyEC512PEM")
+}
+
+func decodePEM(s, name string) []byte {
+ d, _ := pem.Decode([]byte(s))
if d == nil {
- panic("no block found in testKeyPEM")
+ panic("no block found in " + name)
}
- var err error
- testKey, err = x509.ParsePKCS1PrivateKey(d.Bytes)
+ return d.Bytes
+}
+
+func parseRSA(s, name string) *rsa.PrivateKey {
+ b := decodePEM(s, name)
+ k, err := x509.ParsePKCS1PrivateKey(b)
if err != nil {
- panic(err.Error())
+ panic(fmt.Sprintf("%s: %v", name, err))
}
+ return k
+}
- if d, _ = pem.Decode([]byte(testKeyECPEM)); d == nil {
- panic("no block found in testKeyECPEM")
- }
- testKeyEC, err = x509.ParseECPrivateKey(d.Bytes)
+func parseEC(s, name string) *ecdsa.PrivateKey {
+ b := decodePEM(s, name)
+ k, err := x509.ParseECPrivateKey(b)
if err != nil {
- panic(err.Error())
+ panic(fmt.Sprintf("%s: %v", name, err))
}
+ return k
}
func TestJWSEncodeJSON(t *testing.T) {
@@ -141,50 +181,63 @@ func TestJWSEncodeJSON(t *testing.T) {
}
func TestJWSEncodeJSONEC(t *testing.T) {
- claims := struct{ Msg string }{"Hello JWS"}
-
- b, err := jwsEncodeJSON(claims, testKeyEC, "nonce")
- if err != nil {
- t.Fatal(err)
- }
- var jws struct{ Protected, Payload, Signature string }
- if err := json.Unmarshal(b, &jws); err != nil {
- t.Fatal(err)
+ tt := []struct {
+ key *ecdsa.PrivateKey
+ x, y string
+ alg, crv string
+ }{
+ {testKeyEC, testKeyECPubX, testKeyECPubY, "ES256", "P-256"},
+ {testKeyEC384, testKeyEC384PubX, testKeyEC384PubY, "ES384", "P-384"},
+ {testKeyEC512, testKeyEC512PubX, testKeyEC512PubY, "ES512", "P-521"},
}
+ for i, test := range tt {
+ claims := struct{ Msg string }{"Hello JWS"}
+ b, err := jwsEncodeJSON(claims, test.key, "nonce")
+ if err != nil {
+ t.Errorf("%d: %v", i, err)
+ continue
+ }
+ var jws struct{ Protected, Payload, Signature string }
+ if err := json.Unmarshal(b, &jws); err != nil {
+ t.Errorf("%d: %v", i, err)
+ continue
+ }
- if b, err = base64.RawURLEncoding.DecodeString(jws.Protected); err != nil {
- t.Fatalf("jws.Protected: %v", err)
- }
- var head struct {
- Alg string
- Nonce string
- JWK struct {
- Crv string
- Kty string
- X string
- Y string
- } `json:"jwk"`
- }
- if err := json.Unmarshal(b, &head); err != nil {
- t.Fatalf("jws.Protected: %v", err)
- }
- if head.Alg != "ES256" {
- t.Errorf("head.Alg = %q; want ES256", head.Alg)
- }
- if head.Nonce != "nonce" {
- t.Errorf("head.Nonce = %q; want nonce", head.Nonce)
- }
- if head.JWK.Crv != "P-256" {
- t.Errorf("head.JWK.Crv = %q; want P-256", head.JWK.Crv)
- }
- if head.JWK.Kty != "EC" {
- t.Errorf("head.JWK.Kty = %q; want EC", head.JWK.Kty)
- }
- if head.JWK.X != testKeyECPubX {
- t.Errorf("head.JWK.X = %q; want %q", head.JWK.X, testKeyECPubX)
- }
- if head.JWK.Y != testKeyECPubY {
- t.Errorf("head.JWK.Y = %q; want %q", head.JWK.Y, testKeyECPubY)
+ b, err = base64.RawURLEncoding.DecodeString(jws.Protected)
+ if err != nil {
+ t.Errorf("%d: jws.Protected: %v", i, err)
+ }
+ var head struct {
+ Alg string
+ Nonce string
+ JWK struct {
+ Crv string
+ Kty string
+ X string
+ Y string
+ } `json:"jwk"`
+ }
+ if err := json.Unmarshal(b, &head); err != nil {
+ t.Errorf("%d: jws.Protected: %v", i, err)
+ }
+ if head.Alg != test.alg {
+ t.Errorf("%d: head.Alg = %q; want %q", i, head.Alg, test.alg)
+ }
+ if head.Nonce != "nonce" {
+ t.Errorf("%d: head.Nonce = %q; want nonce", i, head.Nonce)
+ }
+ if head.JWK.Crv != test.crv {
+ t.Errorf("%d: head.JWK.Crv = %q; want %q", i, head.JWK.Crv, test.crv)
+ }
+ if head.JWK.Kty != "EC" {
+ t.Errorf("%d: head.JWK.Kty = %q; want EC", i, head.JWK.Kty)
+ }
+ if head.JWK.X != test.x {
+ t.Errorf("%d: head.JWK.X = %q; want %q", i, head.JWK.X, test.x)
+ }
+ if head.JWK.Y != test.y {
+ t.Errorf("%d: head.JWK.Y = %q; want %q", i, head.JWK.Y, test.y)
+ }
}
}
diff --git a/acme/types.go b/acme/types.go
index 0513b2e..ab4de0b 100644
--- a/acme/types.go
+++ b/acme/types.go
@@ -1,9 +1,15 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
package acme
import (
"errors"
"fmt"
"net/http"
+ "strings"
+ "time"
)
// ACME server response statuses used to describe Authorization and Challenge states.
@@ -33,14 +39,8 @@ const (
CRLReasonAACompromise CRLReasonCode = 10
)
-var (
- // ErrAuthorizationFailed indicates that an authorization for an identifier
- // did not succeed.
- ErrAuthorizationFailed = errors.New("acme: identifier authorization failed")
-
- // ErrUnsupportedKey is returned when an unsupported key type is encountered.
- ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
-)
+// ErrUnsupportedKey is returned when an unsupported key type is encountered.
+var ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
// Error is an ACME error, defined in Problem Details for HTTP APIs doc
// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem.
@@ -53,6 +53,7 @@ type Error struct {
// Detail is a human-readable explanation specific to this occurrence of the problem.
Detail string
// Header is the original server error response headers.
+ // It may be nil.
Header http.Header
}
@@ -60,6 +61,50 @@ func (e *Error) Error() string {
return fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail)
}
+// AuthorizationError indicates that an authorization for an identifier
+// did not succeed.
+// It contains all errors from Challenge items of the failed Authorization.
+type AuthorizationError struct {
+ // URI uniquely identifies the failed Authorization.
+ URI string
+
+ // Identifier is an AuthzID.Value of the failed Authorization.
+ Identifier string
+
+ // Errors is a collection of non-nil error values of Challenge items
+ // of the failed Authorization.
+ Errors []error
+}
+
+func (a *AuthorizationError) Error() string {
+ e := make([]string, len(a.Errors))
+ for i, err := range a.Errors {
+ e[i] = err.Error()
+ }
+ return fmt.Sprintf("acme: authorization error for %s: %s", a.Identifier, strings.Join(e, "; "))
+}
+
+// RateLimit reports whether err represents a rate limit error and
+// any Retry-After duration returned by the server.
+//
+// See the following for more details on rate limiting:
+// https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-5.6
+func RateLimit(err error) (time.Duration, bool) {
+ e, ok := err.(*Error)
+ if !ok {
+ return 0, false
+ }
+ // Some CA implementations may return incorrect values.
+ // Use case-insensitive comparison.
+ if !strings.HasSuffix(strings.ToLower(e.ProblemType), ":ratelimited") {
+ return 0, false
+ }
+ if e.Header == nil {
+ return 0, true
+ }
+ return retryAfter(e.Header.Get("Retry-After"), 0), true
+}
+
// Account is a user account. It is associated with a private key.
type Account struct {
// URI is the account unique ID, which is also a URL used to retrieve
@@ -118,6 +163,8 @@ type Directory struct {
}
// Challenge encodes a returned CA challenge.
+// Its Error field may be non-nil if the challenge is part of an Authorization
+// with StatusInvalid.
type Challenge struct {
// Type is the challenge type, e.g. "http-01", "tls-sni-02", "dns-01".
Type string
@@ -130,6 +177,11 @@ type Challenge struct {
// Status identifies the status of this challenge.
Status string
+
+ // Error indicates the reason for an authorization failure
+ // when this challenge was used.
+ // The type of a non-nil value is *Error.
+ Error error
}
// Authorization encodes an authorization response.
@@ -187,12 +239,26 @@ func (z *wireAuthz) authorization(uri string) *Authorization {
return a
}
+func (z *wireAuthz) error(uri string) *AuthorizationError {
+ err := &AuthorizationError{
+ URI: uri,
+ Identifier: z.Identifier.Value,
+ }
+ for _, raw := range z.Challenges {
+ if raw.Error != nil {
+ err.Errors = append(err.Errors, raw.Error.error(nil))
+ }
+ }
+ return err
+}
+
// wireChallenge is ACME JSON challenge representation.
type wireChallenge struct {
URI string `json:"uri"`
Type string
Token string
Status string
+ Error *wireError
}
func (c *wireChallenge) challenge() *Challenge {
@@ -205,5 +271,25 @@ func (c *wireChallenge) challenge() *Challenge {
if v.Status == "" {
v.Status = StatusPending
}
+ if c.Error != nil {
+ v.Error = c.Error.error(nil)
+ }
return v
}
+
+// wireError is a subset of fields of the Problem Details object
+// as described in https://tools.ietf.org/html/rfc7807#section-3.1.
+type wireError struct {
+ Status int
+ Type string
+ Detail string
+}
+
+func (e *wireError) error(h http.Header) *Error {
+ return &Error{
+ StatusCode: e.Status,
+ ProblemType: e.Type,
+ Detail: e.Detail,
+ Header: h,
+ }
+}
diff --git a/acme/types_test.go b/acme/types_test.go
new file mode 100644
index 0000000..a7553e6
--- /dev/null
+++ b/acme/types_test.go
@@ -0,0 +1,63 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package acme
+
+import (
+ "errors"
+ "net/http"
+ "testing"
+ "time"
+)
+
+func TestRateLimit(t *testing.T) {
+ now := time.Date(2017, 04, 27, 10, 0, 0, 0, time.UTC)
+ f := timeNow
+ defer func() { timeNow = f }()
+ timeNow = func() time.Time { return now }
+
+ h120, hTime := http.Header{}, http.Header{}
+ h120.Set("Retry-After", "120")
+ hTime.Set("Retry-After", "Tue Apr 27 11:00:00 2017")
+
+ err1 := &Error{
+ ProblemType: "urn:ietf:params:acme:error:nolimit",
+ Header: h120,
+ }
+ err2 := &Error{
+ ProblemType: "urn:ietf:params:acme:error:rateLimited",
+ Header: h120,
+ }
+ err3 := &Error{
+ ProblemType: "urn:ietf:params:acme:error:rateLimited",
+ Header: nil,
+ }
+ err4 := &Error{
+ ProblemType: "urn:ietf:params:acme:error:rateLimited",
+ Header: hTime,
+ }
+
+ tt := []struct {
+ err error
+ res time.Duration
+ ok bool
+ }{
+ {nil, 0, false},
+ {errors.New("dummy"), 0, false},
+ {err1, 0, false},
+ {err2, 2 * time.Minute, true},
+ {err3, 0, true},
+ {err4, time.Hour, true},
+ }
+ for i, test := range tt {
+ res, ok := RateLimit(test.err)
+ if ok != test.ok {
+ t.Errorf("%d: RateLimit(%+v): ok = %v; want %v", i, test.err, ok, test.ok)
+ continue
+ }
+ if res != test.res {
+ t.Errorf("%d: RateLimit(%+v) = %v; want %v", i, test.err, res, test.res)
+ }
+ }
+}
diff --git a/bcrypt/bcrypt.go b/bcrypt/bcrypt.go
index f8b807f..202fa8a 100644
--- a/bcrypt/bcrypt.go
+++ b/bcrypt/bcrypt.go
@@ -12,9 +12,10 @@ import (
"crypto/subtle"
"errors"
"fmt"
- "golang.org/x/crypto/blowfish"
"io"
"strconv"
+
+ "golang.org/x/crypto/blowfish"
)
const (
@@ -205,7 +206,6 @@ func bcrypt(password []byte, cost int, salt []byte) ([]byte, error) {
}
func expensiveBlowfishSetup(key []byte, cost uint32, salt []byte) (*blowfish.Cipher, error) {
-
csalt, err := base64Decode(salt)
if err != nil {
return nil, err
@@ -213,7 +213,8 @@ func expensiveBlowfishSetup(key []byte, cost uint32, salt []byte) (*blowfish.Cip
// Bug compatibility with C bcrypt implementations. They use the trailing
// NULL in the key string during expansion.
- ckey := append(key, 0)
+ // We copy the key to prevent changing the underlying array.
+ ckey := append(key[:len(key):len(key)], 0)
c, err := blowfish.NewSaltedCipher(ckey, csalt)
if err != nil {
diff --git a/bcrypt/bcrypt_test.go b/bcrypt/bcrypt_test.go
index f08a6f5..aecf759 100644
--- a/bcrypt/bcrypt_test.go
+++ b/bcrypt/bcrypt_test.go
@@ -224,3 +224,20 @@ func BenchmarkGeneration(b *testing.B) {
GenerateFromPassword(passwd, 10)
}
}
+
+// See Issue https://github.com/golang/go/issues/20425.
+func TestNoSideEffectsFromCompare(t *testing.T) {
+ source := []byte("passw0rd123456")
+ password := source[:len(source)-6]
+ token := source[len(source)-6:]
+ want := make([]byte, len(source))
+ copy(want, source)
+
+ wantHash := []byte("$2a$10$LK9XRuhNxHHCvjX3tdkRKei1QiCDUKrJRhZv7WWZPuQGRUM92rOUa")
+ _ = CompareHashAndPassword(wantHash, password)
+
+ got := bytes.Join([][]byte{password, token}, []byte(""))
+ if !bytes.Equal(got, want) {
+ t.Errorf("got=%q want=%q", got, want)
+ }
+}
diff --git a/blake2b/blake2b.go b/blake2b/blake2b.go
index fa9e48e..ce62241 100644
--- a/blake2b/blake2b.go
+++ b/blake2b/blake2b.go
@@ -4,7 +4,7 @@
// Package blake2b implements the BLAKE2b hash algorithm as
// defined in RFC 7693.
-package blake2b
+package blake2b // import "golang.org/x/crypto/blake2b"
import (
"encoding/binary"
diff --git a/blake2b/blake2b_test.go b/blake2b/blake2b_test.go
index a38fceb..7954346 100644
--- a/blake2b/blake2b_test.go
+++ b/blake2b/blake2b_test.go
@@ -22,7 +22,7 @@ func fromHex(s string) []byte {
func TestHashes(t *testing.T) {
defer func(sse4, avx, avx2 bool) {
- useSSE4, useAVX, useAVX2 = sse4, useAVX, avx2
+ useSSE4, useAVX, useAVX2 = sse4, avx, avx2
}(useSSE4, useAVX, useAVX2)
if useAVX2 {
diff --git a/blake2s/blake2s.go b/blake2s/blake2s.go
index 394c121..f2d8221 100644
--- a/blake2s/blake2s.go
+++ b/blake2s/blake2s.go
@@ -4,7 +4,7 @@
// Package blake2s implements the BLAKE2s hash algorithm as
// defined in RFC 7693.
-package blake2s
+package blake2s // import "golang.org/x/crypto/blake2s"
import (
"encoding/binary"
@@ -15,8 +15,12 @@ import (
const (
// The blocksize of BLAKE2s in bytes.
BlockSize = 64
+
// The hash size of BLAKE2s-256 in bytes.
Size = 32
+
+ // The hash size of BLAKE2s-128 in bytes.
+ Size128 = 16
)
var errKeySize = errors.New("blake2s: invalid key size")
@@ -37,6 +41,17 @@ func Sum256(data []byte) [Size]byte {
// key turns the hash into a MAC. The key must between zero and 32 bytes long.
func New256(key []byte) (hash.Hash, error) { return newDigest(Size, key) }
+// New128 returns a new hash.Hash computing the BLAKE2s-128 checksum given a
+// non-empty key. Note that a 128-bit digest is too small to be secure as a
+// cryptographic hash and should only be used as a MAC, thus the key argument
+// is not optional.
+func New128(key []byte) (hash.Hash, error) {
+ if len(key) == 0 {
+ return nil, errors.New("blake2s: a key is required for a 128-bit hash")
+ }
+ return newDigest(Size128, key)
+}
+
func newDigest(hashSize int, key []byte) (*digest, error) {
if len(key) > Size {
return nil, errKeySize
diff --git a/blake2s/blake2s_test.go b/blake2s/blake2s_test.go
index e6f2eeb..ff41670 100644
--- a/blake2s/blake2s_test.go
+++ b/blake2s/blake2s_test.go
@@ -18,21 +18,25 @@ func TestHashes(t *testing.T) {
if useSSE4 {
t.Log("SSE4 version")
testHashes(t)
+ testHashes128(t)
useSSE4 = false
}
if useSSSE3 {
t.Log("SSSE3 version")
testHashes(t)
+ testHashes128(t)
useSSSE3 = false
}
if useSSE2 {
t.Log("SSE2 version")
testHashes(t)
+ testHashes128(t)
useSSE2 = false
}
if useGeneric {
t.Log("generic version")
testHashes(t)
+ testHashes128(t)
}
}
@@ -69,6 +73,39 @@ func testHashes(t *testing.T) {
}
}
+func testHashes128(t *testing.T) {
+ key, _ := hex.DecodeString("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
+
+ input := make([]byte, 255)
+ for i := range input {
+ input[i] = byte(i)
+ }
+
+ for i, expectedHex := range hashes128 {
+ h, err := New128(key)
+ if err != nil {
+ t.Fatalf("#%d: error from New128: %v", i, err)
+ }
+
+ h.Write(input[:i])
+ sum := h.Sum(nil)
+
+ if gotHex := fmt.Sprintf("%x", sum); gotHex != expectedHex {
+ t.Fatalf("#%d (single write): got %s, wanted %s", i, gotHex, expectedHex)
+ }
+
+ h.Reset()
+ for j := 0; j < i; j++ {
+ h.Write(input[j : j+1])
+ }
+
+ sum = h.Sum(sum[:0])
+ if gotHex := fmt.Sprintf("%x", sum); gotHex != expectedHex {
+ t.Fatalf("#%d (byte-by-byte): got %s, wanted %s", i, gotHex, expectedHex)
+ }
+ }
+}
+
// Benchmarks
func benchmarkSum(b *testing.B, size int) {
@@ -355,3 +392,262 @@ var hashes = []string{
"db444c15597b5f1a03d1f9edd16e4a9f43a667cc275175dfa2b704e3bb1a9b83",
"3fb735061abc519dfe979e54c1ee5bfad0a9d858b3315bad34bde999efd724dd",
}
+
+var hashes128 = []string{
+ "9536f9b267655743dee97b8a670f9f53",
+ "13bacfb85b48a1223c595f8c1e7e82cb",
+ "d47a9b1645e2feae501cd5fe44ce6333",
+ "1e2a79436a7796a3e9826bfedf07659f",
+ "7640360ed3c4f3054dba79a21dda66b7",
+ "d1207ac2bf5ac84fc9ef016da5a46a86",
+ "3123987871e59305ece3125abfc0099a",
+ "cf9e072ad522f2cda2d825218086731c",
+ "95d22870392efe2846b12b6e8e84efbb",
+ "7d63c30e2d51333f245601b038c0b93b",
+ "ed608b98e13976bdf4bedc63fa35e443",
+ "ed704b5cd1abf8e0dd67a6ac667a3fa5",
+ "77dc70109827dc74c70fd26cba379ae5",
+ "d2bf34508b07825ee934f33958f4560e",
+ "a340baa7b8a93a6e658adef42e78eeb7",
+ "b85c5ceaecbe9a251eac76f6932ba395",
+ "246519722001f6e8e97a2183f5985e53",
+ "5bce5aa0b7c6cac2ecf6406183cd779a",
+ "13408f1647c02f6efd0047ad8344f695",
+ "a63970f196760aa36cb965ab62f0e0fa",
+ "bc26f48421dd99fd45e15e736d3e7dac",
+ "4c6f70f9e3237cde918afb52d26f1823",
+ "45ed610cfbc37db80c4bf0eef14ae8d6",
+ "87c4c150705ea5078209ec008200539c",
+ "54de21f5e0e6f2afe04daeb822b6931e",
+ "9732a04e505064e19de3d542e7e71631",
+ "d2bd27e95531d6957eef511c4ba64ad4",
+ "7a36c9f70dcc7c3063b547101a5f6c35",
+ "322007d1a44c4257bc7903b183305529",
+ "dbcc9a09f412290ca2e0d53dfd142ddb",
+ "df12ed43b8e53a56db20e0f83764002c",
+ "d114cc11e7d5b33a360c45f18d4c7c6e",
+ "c43b5e836af88620a8a71b1652cb8640",
+ "9491c653e8867ed73c1b4ac6b5a9bb4d",
+ "06d0e988df94ada6c6f9f36f588ab7c5",
+ "561efad2480e93262c8eeaa3677615c4",
+ "ba8ffc702e5adc93503045eca8702312",
+ "5782be6ccdc78c8425285e85de8ccdc6",
+ "aa1c4393e4c07b53ea6e2b5b1e970771",
+ "42a229dc50e52271c51e8666023ebc1e",
+ "53706110e919f84de7f8d6c7f0e7b831",
+ "fc5ac8ee39cc1dd1424391323e2901bd",
+ "bed27b62ff66cac2fbb68193c727106a",
+ "cd5e689b96d0b9ea7e08dac36f7b211e",
+ "0b4c7f604eba058d18e322c6e1baf173",
+ "eb838227fdfad09a27f0f8413120675d",
+ "3149cf9d19a7fd529e6154a8b4c3b3ad",
+ "ca1e20126df930fd5fb7afe4422191e5",
+ "b23398f910599f3c09b6549fa81bcb46",
+ "27fb17c11b34fa5d8b5afe5ee3321ead",
+ "0f665f5f04cf2d46b7fead1a1f328158",
+ "8f068be73b3681f99f3b282e3c02bba5",
+ "ba189bbd13808dcf4e002a4dd21660d5",
+ "2732dcd1b16668ae6ab6a61595d0d62a",
+ "d410ccdd059f0e02b472ec9ec54bdd3c",
+ "b2eaa07b055b3a03a399971327f7e8c2",
+ "2e8a225655e9f99b69c60dc8b4d8e566",
+ "4eb55416c853f2152e67f8a224133cec",
+ "49552403790d8de0505a8e317a443687",
+ "7f2747cd41f56942752e868212c7d5ac",
+ "02a28f10e193b430df7112d2d98cf759",
+ "d4213404a9f1cf759017747cf5958270",
+ "faa34884344f9c65e944882db8476d34",
+ "ece382a8bd5018f1de5da44b72cea75b",
+ "f1efa90d2547036841ecd3627fafbc36",
+ "811ff8686d23a435ecbd0bdafcd27b1b",
+ "b21beea9c7385f657a76558530438721",
+ "9cb969da4f1b4fc5b13bf78fe366f0c4",
+ "8850d16d7b614d3268ccfa009d33c7fc",
+ "aa98a2b6176ea86415b9aff3268c6f6d",
+ "ec3e1efa5ed195eff667e16b1af1e39e",
+ "e40787dca57411d2630db2de699beb08",
+ "554835890735babd06318de23d31e78a",
+ "493957feecddc302ee2bb2086b6ebfd3",
+ "f6069709ad5b0139163717e9ce1114ab",
+ "ba5ed386098da284484b211555505a01",
+ "9244c8dfad8cbb68c118fa51465b3ae4",
+ "51e309a5008eb1f5185e5cc007cfb36f",
+ "6ce9ff712121b4f6087955f4911eafd4",
+ "59b51d8dcda031218ccdd7c760828155",
+ "0012878767a3d4f1c8194458cf1f8832",
+ "82900708afd5b6582dc16f008c655edd",
+ "21302c7e39b5a4cdf1d6f86b4f00c9b4",
+ "e894c7431591eab8d1ce0fe2aa1f01df",
+ "b67e1c40ee9d988226d605621854d955",
+ "6237bdafa34137cbbec6be43ea9bd22c",
+ "4172a8e19b0dcb09b978bb9eff7af52b",
+ "5714abb55bd4448a5a6ad09fbd872fdf",
+ "7ce1700bef423e1f958a94a77a94d44a",
+ "3742ec50cded528527775833453e0b26",
+ "5d41b135724c7c9c689495324b162f18",
+ "85c523333c6442c202e9e6e0f1185f93",
+ "5c71f5222d40ff5d90e7570e71ab2d30",
+ "6e18912e83d012efb4c66250ced6f0d9",
+ "4add4448c2e35e0b138a0bac7b4b1775",
+ "c0376c6bc5e7b8b9d2108ec25d2aab53",
+ "f72261d5ed156765c977751c8a13fcc1",
+ "cff4156c48614b6ceed3dd6b9058f17e",
+ "36bfb513f76c15f514bcb593419835aa",
+ "166bf48c6bffaf8291e6fdf63854bef4",
+ "0b67d33f8b859c3157fbabd9e6e47ed0",
+ "e4da659ca76c88e73a9f9f10f3d51789",
+ "33c1ae2a86b3f51c0642e6ed5b5aa1f1",
+ "27469b56aca2334449c1cf4970dcd969",
+ "b7117b2e363378aa0901b0d6a9f6ddc0",
+ "a9578233b09e5cd5231943fdb12cd90d",
+ "486d7d75253598b716a068243c1c3e89",
+ "66f6b02d682b78ffdc85e9ec86852489",
+ "38a07b9a4b228fbcc305476e4d2e05d2",
+ "aedb61c7970e7d05bf9002dae3c6858c",
+ "c03ef441f7dd30fdb61ad2d4d8e4c7da",
+ "7f45cc1eea9a00cb6aeb2dd748361190",
+ "a59538b358459132e55160899e47bd65",
+ "137010fef72364411820c3fbed15c8df",
+ "d8362b93fc504500dbd33ac74e1b4d70",
+ "a7e49f12c8f47e3b29cf8c0889b0a9c8",
+ "072e94ffbfc684bd8ab2a1b9dade2fd5",
+ "5ab438584bd2229e452052e002631a5f",
+ "f233d14221097baef57d3ec205c9e086",
+ "3a95db000c4a8ff98dc5c89631a7f162",
+ "0544f18c2994ab4ddf1728f66041ff16",
+ "0bc02116c60a3cc331928d6c9d3ba37e",
+ "b189dca6cb5b813c74200834fba97f29",
+ "ac8aaab075b4a5bc24419da239212650",
+ "1e9f19323dc71c29ae99c479dc7e8df9",
+ "12d944c3fa7caa1b3d62adfc492274dd",
+ "b4c68f1fffe8f0030e9b18aad8c9dc96",
+ "25887fab1422700d7fa3edc0b20206e2",
+ "8c09f698d03eaf88abf69f8147865ef6",
+ "5c363ae42a5bec26fbc5e996428d9bd7",
+ "7fdfc2e854fbb3928150d5e3abcf56d6",
+ "f0c944023f714df115f9e4f25bcdb89b",
+ "6d19534b4c332741c8ddd79a9644de2d",
+ "32595eb23764fbfc2ee7822649f74a12",
+ "5a51391aab33c8d575019b6e76ae052a",
+ "98b861ce2c620f10f913af5d704a5afd",
+ "b7fe2fc8b77fb1ce434f8465c7ddf793",
+ "0e8406e0cf8e9cc840668ece2a0fc64e",
+ "b89922db99c58f6a128ccffe19b6ce60",
+ "e1be9af665f0932b77d7f5631a511db7",
+ "74b96f20f58de8dc9ff5e31f91828523",
+ "36a4cfef5a2a7d8548db6710e50b3009",
+ "007e95e8d3b91948a1dedb91f75de76b",
+ "a87a702ce08f5745edf765bfcd5fbe0d",
+ "847e69a388a749a9c507354d0dddfe09",
+ "07176eefbc107a78f058f3d424ca6a54",
+ "ad7e80682333b68296f6cb2b4a8e446d",
+ "53c4aba43896ae422e5de5b9edbd46bf",
+ "33bd6c20ca2a7ab916d6e98003c6c5f8",
+ "060d088ea94aa093f9981a79df1dfcc8",
+ "5617b214b9df08d4f11e58f5e76d9a56",
+ "ca3a60ee85bd971e1daf9f7db059d909",
+ "cd2b7754505d8c884eddf736f1ec613e",
+ "f496163b252f1439e7e113ba2ecabd8e",
+ "5719c7dcf9d9f756d6213354acb7d5cf",
+ "6f7dd40b245c54411e7a9be83ae5701c",
+ "c8994dd9fdeb077a45ea04a30358b637",
+ "4b1184f1e35458c1c747817d527a252f",
+ "fc7df674afeac7a3fd994183f4c67a74",
+ "4f68e05ce4dcc533acf9c7c01d95711e",
+ "d4ebc59e918400720035dfc88e0c486a",
+ "d3105dd6fa123e543b0b3a6e0eeaea9e",
+ "874196128ed443f5bdb2800ca048fcad",
+ "01645f134978dc8f9cf0abc93b53780e",
+ "5b8b64caa257873a0ffd47c981ef6c3f",
+ "4ee208fc50ba0a6e65c5b58cec44c923",
+ "53f409a52427b3b7ffabb057ca088428",
+ "c1d6cd616f5341a93d921e356e5887a9",
+ "e85c20fea67fa7320dc23379181183c8",
+ "7912b6409489df001b7372bc94aebde7",
+ "e559f761ec866a87f1f331767fafc60f",
+ "20a6f5a36bc37043d977ed7708465ef8",
+ "6a72f526965ab120826640dd784c6cc4",
+ "bf486d92ad68e87c613689dd370d001b",
+ "d339fd0eb35edf3abd6419c8d857acaf",
+ "9521cd7f32306d969ddabc4e6a617f52",
+ "a1cd9f3e81520842f3cf6cc301cb0021",
+ "18e879b6f154492d593edd3f4554e237",
+ "66e2329c1f5137589e051592587e521e",
+ "e899566dd6c3e82cbc83958e69feb590",
+ "8a4b41d7c47e4e80659d77b4e4bfc9ae",
+ "f1944f6fcfc17803405a1101998c57dd",
+ "f6bcec07567b4f72851b307139656b18",
+ "22e7bb256918fe9924dce9093e2d8a27",
+ "dd25b925815fe7b50b7079f5f65a3970",
+ "0457f10f299acf0c230dd4007612e58f",
+ "ecb420c19efd93814fae2964d69b54af",
+ "14eb47b06dff685d88751c6e32789db4",
+ "e8f072dbb50d1ab6654aa162604a892d",
+ "69cff9c62092332f03a166c7b0034469",
+ "d3619f98970b798ca32c6c14cd25af91",
+ "2246d423774ee9d51a551e89c0539d9e",
+ "75e5d1a1e374a04a699247dad827b6cf",
+ "6d087dd1d4cd15bf47db07c7a96b1db8",
+ "967e4c055ac51b4b2a3e506cebd5826f",
+ "7417aa79247e473401bfa92a25b62e2a",
+ "24f3f4956da34b5c533d9a551ccd7b16",
+ "0c40382de693a5304e2331eb951cc962",
+ "9436f949d51b347db5c8e6258dafaaac",
+ "d2084297fe84c4ba6e04e4fb73d734fe",
+ "42a6f8ff590af21b512e9e088257aa34",
+ "c484ad06b1cdb3a54f3f6464a7a2a6fd",
+ "1b8ac860f5ceb4365400a201ed2917aa",
+ "c43eadabbe7b7473f3f837fc52650f54",
+ "0e5d3205406126b1f838875deb150d6a",
+ "6bf4946f8ec8a9c417f50cd1e67565be",
+ "42f09a2522314799c95b3fc121a0e3e8",
+ "06b8f1487f691a3f7c3f74e133d55870",
+ "1a70a65fb4f314dcf6a31451a9d2704f",
+ "7d4acdd0823279fd28a1e48b49a04669",
+ "09545cc8822a5dfc93bbab708fd69174",
+ "efc063db625013a83c9a426d39a9bddb",
+ "213bbf89b3f5be0ffdb14854bbcb2588",
+ "b69624d89fe2774df9a6f43695d755d4",
+ "c0f9ff9ded82bd73c512e365a894774d",
+ "d1b68507ed89c17ead6f69012982db71",
+ "14cf16db04648978e35c44850855d1b0",
+ "9f254d4eccab74cd91d694df863650a8",
+ "8f8946e2967baa4a814d36ff01d20813",
+ "6b9dc4d24ecba166cb2915d7a6cba43b",
+ "eb35a80418a0042b850e294db7898d4d",
+ "f55f925d280c637d54055c9df088ef5f",
+ "f48427a04f67e33f3ba0a17f7c9704a7",
+ "4a9f5bfcc0321aea2eced896cee65894",
+ "8723a67d1a1df90f1cef96e6fe81e702",
+ "c166c343ee25998f80bad4067960d3fd",
+ "dab67288d16702e676a040fd42344d73",
+ "c8e9e0d80841eb2c116dd14c180e006c",
+ "92294f546bacf0dea9042c93ecba8b34",
+ "013705b1502b37369ad22fe8237d444e",
+ "9b97f8837d5f2ebab0768fc9a6446b93",
+ "7e7e5236b05ec35f89edf8bf655498e7",
+ "7be8f2362c174c776fb9432fe93bf259",
+ "2422e80420276d2df5702c6470879b01",
+ "df645795db778bcce23bbe819a76ba48",
+ "3f97a4ac87dfc58761cda1782d749074",
+ "50e3f45df21ebfa1b706b9c0a1c245a8",
+ "7879541c7ff612c7ddf17cb8f7260183",
+ "67f6542b903b7ba1945eba1a85ee6b1c",
+ "b34b73d36ab6234b8d3f5494d251138e",
+ "0aea139641fdba59ab1103479a96e05f",
+ "02776815a87b8ba878453666d42afe3c",
+ "5929ab0a90459ebac5a16e2fb37c847e",
+ "c244def5b20ce0468f2b5012d04ac7fd",
+ "12116add6fefce36ed8a0aeccce9b6d3",
+ "3cd743841e9d8b878f34d91b793b4fad",
+ "45e87510cf5705262185f46905fae35f",
+ "276047016b0bfb501b2d4fc748165793",
+ "ddd245df5a799417d350bd7f4e0b0b7e",
+ "d34d917a54a2983f3fdbc4b14caae382",
+ "7730fbc09d0c1fb1939a8fc436f6b995",
+ "eb4899ef257a1711cc9270a19702e5b5",
+ "8a30932014bce35bba620895d374df7a",
+ "1924aabf9c50aa00bee5e1f95b5d9e12",
+ "1758d6f8b982aec9fbe50f20e3082b46",
+ "cd075928ab7e6883e697fe7fd3ac43ee",
+}
diff --git a/chacha20poly1305/chacha20poly1305.go b/chacha20poly1305/chacha20poly1305.go
index eb6739a..3f0dcb9 100644
--- a/chacha20poly1305/chacha20poly1305.go
+++ b/chacha20poly1305/chacha20poly1305.go
@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
// Package chacha20poly1305 implements the ChaCha20-Poly1305 AEAD as specified in RFC 7539.
-package chacha20poly1305
+package chacha20poly1305 // import "golang.org/x/crypto/chacha20poly1305"
import (
"crypto/cipher"
diff --git a/chacha20poly1305/chacha20poly1305_amd64.go b/chacha20poly1305/chacha20poly1305_amd64.go
index 4755033..7cd7ad8 100644
--- a/chacha20poly1305/chacha20poly1305_amd64.go
+++ b/chacha20poly1305/chacha20poly1305_amd64.go
@@ -14,13 +14,60 @@ func chacha20Poly1305Open(dst []byte, key []uint32, src, ad []byte) bool
//go:noescape
func chacha20Poly1305Seal(dst []byte, key []uint32, src, ad []byte)
-//go:noescape
-func haveSSSE3() bool
+// cpuid is implemented in chacha20poly1305_amd64.s.
+func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
+
+// xgetbv with ecx = 0 is implemented in chacha20poly1305_amd64.s.
+func xgetbv() (eax, edx uint32)
-var canUseASM bool
+var (
+ useASM bool
+ useAVX2 bool
+)
func init() {
- canUseASM = haveSSSE3()
+ detectCPUFeatures()
+}
+
+// detectCPUFeatures is used to detect if cpu instructions
+// used by the functions implemented in assembler in
+// chacha20poly1305_amd64.s are supported.
+func detectCPUFeatures() {
+ maxID, _, _, _ := cpuid(0, 0)
+ if maxID < 1 {
+ return
+ }
+
+ _, _, ecx1, _ := cpuid(1, 0)
+
+ haveSSSE3 := isSet(9, ecx1)
+ useASM = haveSSSE3
+
+ haveOSXSAVE := isSet(27, ecx1)
+
+ osSupportsAVX := false
+ // For XGETBV, OSXSAVE bit is required and sufficient.
+ if haveOSXSAVE {
+ eax, _ := xgetbv()
+ // Check if XMM and YMM registers have OS support.
+ osSupportsAVX = isSet(1, eax) && isSet(2, eax)
+ }
+ haveAVX := isSet(28, ecx1) && osSupportsAVX
+
+ if maxID < 7 {
+ return
+ }
+
+ _, ebx7, _, _ := cpuid(7, 0)
+ haveAVX2 := isSet(5, ebx7) && haveAVX
+ haveBMI2 := isSet(8, ebx7)
+
+ useAVX2 = haveAVX2 && haveBMI2
+}
+
+// isSet checks if bit at bitpos is set in value.
+func isSet(bitpos uint, value uint32) bool {
+ return value&(1<<bitpos) != 0
}
// setupState writes a ChaCha20 input matrix to state. See
@@ -47,7 +94,7 @@ func setupState(state *[16]uint32, key *[32]byte, nonce []byte) {
}
func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []byte {
- if !canUseASM {
+ if !useASM {
return c.sealGeneric(dst, nonce, plaintext, additionalData)
}
@@ -60,7 +107,7 @@ func (c *chacha20poly1305) seal(dst, nonce, plaintext, additionalData []byte) []
}
func (c *chacha20poly1305) open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
- if !canUseASM {
+ if !useASM {
return c.openGeneric(dst, nonce, ciphertext, additionalData)
}
diff --git a/chacha20poly1305/chacha20poly1305_amd64.s b/chacha20poly1305/chacha20poly1305_amd64.s
index 39c58b4..1c57e38 100644
--- a/chacha20poly1305/chacha20poly1305_amd64.s
+++ b/chacha20poly1305/chacha20poly1305_amd64.s
@@ -278,15 +278,8 @@ TEXT ·chacha20Poly1305Open(SB), 0, $288-97
MOVQ ad+72(FP), adp
// Check for AVX2 support
- CMPB runtime·support_avx2(SB), $0
- JE noavx2bmi2Open
-
- // Check BMI2 bit for MULXQ.
- // runtime·cpuid_ebx7 is always available here
- // because it passed avx2 check
- TESTL $(1<<8), runtime·cpuid_ebx7(SB)
- JNE chacha20Poly1305Open_AVX2
-noavx2bmi2Open:
+ CMPB ·useAVX2(SB), $1
+ JE chacha20Poly1305Open_AVX2
// Special optimization, for very short buffers
CMPQ inl, $128
@@ -1491,16 +1484,8 @@ TEXT ·chacha20Poly1305Seal(SB), 0, $288-96
MOVQ src_len+56(FP), inl
MOVQ ad+72(FP), adp
- // Check for AVX2 support
- CMPB runtime·support_avx2(SB), $0
- JE noavx2bmi2Seal
-
- // Check BMI2 bit for MULXQ.
- // runtime·cpuid_ebx7 is always available here
- // because it passed avx2 check
- TESTL $(1<<8), runtime·cpuid_ebx7(SB)
- JNE chacha20Poly1305Seal_AVX2
-noavx2bmi2Seal:
+ CMPB ·useAVX2(SB), $1
+ JE chacha20Poly1305Seal_AVX2
// Special optimization, for very short buffers
CMPQ inl, $128
@@ -2709,13 +2694,21 @@ sealAVX2Tail512LoopB:
JMP sealAVX2SealHash
-// func haveSSSE3() bool
-TEXT ·haveSSSE3(SB), NOSPLIT, $0
- XORQ AX, AX
- INCL AX
+// func cpuid(eaxArg, ecxArg uint32) (eax, ebx, ecx, edx uint32)
+TEXT ·cpuid(SB), NOSPLIT, $0-24
+ MOVL eaxArg+0(FP), AX
+ MOVL ecxArg+4(FP), CX
CPUID
- SHRQ $9, CX
- ANDQ $1, CX
- MOVB CX, ret+0(FP)
+ MOVL AX, eax+8(FP)
+ MOVL BX, ebx+12(FP)
+ MOVL CX, ecx+16(FP)
+ MOVL DX, edx+20(FP)
RET
+// func xgetbv() (eax, edx uint32)
+TEXT ·xgetbv(SB),NOSPLIT,$0-8
+ MOVL $0, CX
+ XGETBV
+ MOVL AX, eax+0(FP)
+ MOVL DX, edx+4(FP)
+ RET
diff --git a/cryptobyte/string.go b/cryptobyte/string.go
index b1215b3..6780336 100644
--- a/cryptobyte/string.go
+++ b/cryptobyte/string.go
@@ -5,7 +5,7 @@
// Package cryptobyte implements building and parsing of byte strings for
// DER-encoded ASN.1 and TLS messages. See the examples for the Builder and
// String types to get started.
-package cryptobyte
+package cryptobyte // import "golang.org/x/crypto/cryptobyte"
// String represents a string of bytes. It provides methods for parsing
// fixed-length and length-prefixed values from it.
diff --git a/curve25519/cswap_amd64.s b/curve25519/cswap_amd64.s
index 45484d1..cd793a5 100644
--- a/curve25519/cswap_amd64.s
+++ b/curve25519/cswap_amd64.s
@@ -2,87 +2,64 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// This code was translated into a form compatible with 6a from the public
-// domain sources in SUPERCOP: http://bench.cr.yp.to/supercop.html
-
// +build amd64,!gccgo,!appengine
-// func cswap(inout *[5]uint64, v uint64)
+// func cswap(inout *[4][5]uint64, v uint64)
TEXT ·cswap(SB),7,$0
MOVQ inout+0(FP),DI
MOVQ v+8(FP),SI
- CMPQ SI,$1
- MOVQ 0(DI),SI
- MOVQ 80(DI),DX
- MOVQ 8(DI),CX
- MOVQ 88(DI),R8
- MOVQ SI,R9
- CMOVQEQ DX,SI
- CMOVQEQ R9,DX
- MOVQ CX,R9
- CMOVQEQ R8,CX
- CMOVQEQ R9,R8
- MOVQ SI,0(DI)
- MOVQ DX,80(DI)
- MOVQ CX,8(DI)
- MOVQ R8,88(DI)
- MOVQ 16(DI),SI
- MOVQ 96(DI),DX
- MOVQ 24(DI),CX
- MOVQ 104(DI),R8
- MOVQ SI,R9
- CMOVQEQ DX,SI
- CMOVQEQ R9,DX
- MOVQ CX,R9
- CMOVQEQ R8,CX
- CMOVQEQ R9,R8
- MOVQ SI,16(DI)
- MOVQ DX,96(DI)
- MOVQ CX,24(DI)
- MOVQ R8,104(DI)
- MOVQ 32(DI),SI
- MOVQ 112(DI),DX
- MOVQ 40(DI),CX
- MOVQ 120(DI),R8
- MOVQ SI,R9
- CMOVQEQ DX,SI
- CMOVQEQ R9,DX
- MOVQ CX,R9
- CMOVQEQ R8,CX
- CMOVQEQ R9,R8
- MOVQ SI,32(DI)
- MOVQ DX,112(DI)
- MOVQ CX,40(DI)
- MOVQ R8,120(DI)
- MOVQ 48(DI),SI
- MOVQ 128(DI),DX
- MOVQ 56(DI),CX
- MOVQ 136(DI),R8
- MOVQ SI,R9
- CMOVQEQ DX,SI
- CMOVQEQ R9,DX
- MOVQ CX,R9
- CMOVQEQ R8,CX
- CMOVQEQ R9,R8
- MOVQ SI,48(DI)
- MOVQ DX,128(DI)
- MOVQ CX,56(DI)
- MOVQ R8,136(DI)
- MOVQ 64(DI),SI
- MOVQ 144(DI),DX
- MOVQ 72(DI),CX
- MOVQ 152(DI),R8
- MOVQ SI,R9
- CMOVQEQ DX,SI
- CMOVQEQ R9,DX
- MOVQ CX,R9
- CMOVQEQ R8,CX
- CMOVQEQ R9,R8
- MOVQ SI,64(DI)
- MOVQ DX,144(DI)
- MOVQ CX,72(DI)
- MOVQ R8,152(DI)
- MOVQ DI,AX
- MOVQ SI,DX
+ SUBQ $1, SI
+ NOTQ SI
+ MOVQ SI, X15
+ PSHUFD $0x44, X15, X15
+
+ MOVOU 0(DI), X0
+ MOVOU 16(DI), X2
+ MOVOU 32(DI), X4
+ MOVOU 48(DI), X6
+ MOVOU 64(DI), X8
+ MOVOU 80(DI), X1
+ MOVOU 96(DI), X3
+ MOVOU 112(DI), X5
+ MOVOU 128(DI), X7
+ MOVOU 144(DI), X9
+
+ MOVO X1, X10
+ MOVO X3, X11
+ MOVO X5, X12
+ MOVO X7, X13
+ MOVO X9, X14
+
+ PXOR X0, X10
+ PXOR X2, X11
+ PXOR X4, X12
+ PXOR X6, X13
+ PXOR X8, X14
+ PAND X15, X10
+ PAND X15, X11
+ PAND X15, X12
+ PAND X15, X13
+ PAND X15, X14
+ PXOR X10, X0
+ PXOR X10, X1
+ PXOR X11, X2
+ PXOR X11, X3
+ PXOR X12, X4
+ PXOR X12, X5
+ PXOR X13, X6
+ PXOR X13, X7
+ PXOR X14, X8
+ PXOR X14, X9
+
+ MOVOU X0, 0(DI)
+ MOVOU X2, 16(DI)
+ MOVOU X4, 32(DI)
+ MOVOU X6, 48(DI)
+ MOVOU X8, 64(DI)
+ MOVOU X1, 80(DI)
+ MOVOU X3, 96(DI)
+ MOVOU X5, 112(DI)
+ MOVOU X7, 128(DI)
+ MOVOU X9, 144(DI)
RET
diff --git a/curve25519/curve25519.go b/curve25519/curve25519.go
index 6918c47..2d14c2a 100644
--- a/curve25519/curve25519.go
+++ b/curve25519/curve25519.go
@@ -8,6 +8,10 @@
package curve25519
+import (
+ "encoding/binary"
+)
+
// This code is a port of the public domain, "ref10" implementation of
// curve25519 from SUPERCOP 20130419 by D. J. Bernstein.
@@ -50,17 +54,11 @@ func feCopy(dst, src *fieldElement) {
//
// Preconditions: b in {0,1}.
func feCSwap(f, g *fieldElement, b int32) {
- var x fieldElement
b = -b
- for i := range x {
- x[i] = b & (f[i] ^ g[i])
- }
-
for i := range f {
- f[i] ^= x[i]
- }
- for i := range g {
- g[i] ^= x[i]
+ t := b & (f[i] ^ g[i])
+ f[i] ^= t
+ g[i] ^= t
}
}
@@ -75,12 +73,7 @@ func load3(in []byte) int64 {
// load4 reads a 32-bit, little-endian value from in.
func load4(in []byte) int64 {
- var r int64
- r = int64(in[0])
- r |= int64(in[1]) << 8
- r |= int64(in[2]) << 16
- r |= int64(in[3]) << 24
- return r
+ return int64(binary.LittleEndian.Uint32(in))
}
func feFromBytes(dst *fieldElement, src *[32]byte) {
diff --git a/curve25519/curve25519_test.go b/curve25519/curve25519_test.go
index 14b0ee8..051a830 100644
--- a/curve25519/curve25519_test.go
+++ b/curve25519/curve25519_test.go
@@ -27,3 +27,13 @@ func TestBaseScalarMult(t *testing.T) {
t.Errorf("incorrect result: got %s, want %s", result, expectedHex)
}
}
+
+func BenchmarkScalarBaseMult(b *testing.B) {
+ var in, out [32]byte
+ in[0] = 1
+
+ b.SetBytes(32)
+ for i := 0; i < b.N; i++ {
+ ScalarBaseMult(&out, &in)
+ }
+}
diff --git a/nacl/box/example_test.go b/nacl/box/example_test.go
new file mode 100644
index 0000000..25e42d2
--- /dev/null
+++ b/nacl/box/example_test.go
@@ -0,0 +1,95 @@
+package box_test
+
+import (
+ crypto_rand "crypto/rand" // Custom so it's clear which rand we're using.
+ "fmt"
+ "io"
+
+ "golang.org/x/crypto/nacl/box"
+)
+
+func Example() {
+ senderPublicKey, senderPrivateKey, err := box.GenerateKey(crypto_rand.Reader)
+ if err != nil {
+ panic(err)
+ }
+
+ recipientPublicKey, recipientPrivateKey, err := box.GenerateKey(crypto_rand.Reader)
+ if err != nil {
+ panic(err)
+ }
+
+ // You must use a different nonce for each message you encrypt with the
+ // same key. Since the nonce here is 192 bits long, a random value
+ // provides a sufficiently small probability of repeats.
+ var nonce [24]byte
+ if _, err := io.ReadFull(crypto_rand.Reader, nonce[:]); err != nil {
+ panic(err)
+ }
+
+ msg := []byte("Alas, poor Yorick! I knew him, Horatio")
+ // This encrypts msg and appends the result to the nonce.
+ encrypted := box.Seal(nonce[:], msg, &nonce, recipientPublicKey, senderPrivateKey)
+
+ // The recipient can decrypt the message using their private key and the
+ // sender's public key. When you decrypt, you must use the same nonce you
+ // used to encrypt the message. One way to achieve this is to store the
+ // nonce alongside the encrypted message. Above, we stored the nonce in the
+ // first 24 bytes of the encrypted text.
+ var decryptNonce [24]byte
+ copy(decryptNonce[:], encrypted[:24])
+ decrypted, ok := box.Open(nil, encrypted[24:], &decryptNonce, senderPublicKey, recipientPrivateKey)
+ if !ok {
+ panic("decryption error")
+ }
+ fmt.Println(string(decrypted))
+ // Output: Alas, poor Yorick! I knew him, Horatio
+}
+
+func Example_precompute() {
+ senderPublicKey, senderPrivateKey, err := box.GenerateKey(crypto_rand.Reader)
+ if err != nil {
+ panic(err)
+ }
+
+ recipientPublicKey, recipientPrivateKey, err := box.GenerateKey(crypto_rand.Reader)
+ if err != nil {
+ panic(err)
+ }
+
+ // The shared key can be used to speed up processing when using the same
+ // pair of keys repeatedly.
+ sharedEncryptKey := new([32]byte)
+ box.Precompute(sharedEncryptKey, recipientPublicKey, senderPrivateKey)
+
+ // You must use a different nonce for each message you encrypt with the
+ // same key. Since the nonce here is 192 bits long, a random value
+ // provides a sufficiently small probability of repeats.
+ var nonce [24]byte
+ if _, err := io.ReadFull(crypto_rand.Reader, nonce[:]); err != nil {
+ panic(err)
+ }
+
+ msg := []byte("A fellow of infinite jest, of most excellent fancy")
+ // This encrypts msg and appends the result to the nonce.
+ encrypted := box.SealAfterPrecomputation(nonce[:], msg, &nonce, sharedEncryptKey)
+
+ // The shared key can be used to speed up processing when using the same
+ // pair of keys repeatedly.
+ var sharedDecryptKey [32]byte
+ box.Precompute(&sharedDecryptKey, senderPublicKey, recipientPrivateKey)
+
+ // The recipient can decrypt the message using the shared key. When you
+ // decrypt, you must use the same nonce you used to encrypt the message.
+ // One way to achieve this is to store the nonce alongside the encrypted
+ // message. Above, we stored the nonce in the first 24 bytes of the
+ // encrypted text.
+ var decryptNonce [24]byte
+ copy(decryptNonce[:], encrypted[:24])
+ decrypted, ok := box.OpenAfterPrecomputation(nil, encrypted[24:], &decryptNonce, &sharedDecryptKey)
+ if !ok {
+ panic("decryption error")
+ }
+ fmt.Println(string(decrypted))
+ // Output: A fellow of infinite jest, of most excellent fancy
+}
diff --git a/nacl/secretbox/example_test.go b/nacl/secretbox/example_test.go
index b25e663..789f4ff 100644
--- a/nacl/secretbox/example_test.go
+++ b/nacl/secretbox/example_test.go
@@ -43,7 +43,7 @@ func Example() {
// 24 bytes of the encrypted text.
var decryptNonce [24]byte
copy(decryptNonce[:], encrypted[:24])
- decrypted, ok := secretbox.Open([]byte{}, encrypted[24:], &decryptNonce, &secretKey)
+ decrypted, ok := secretbox.Open(nil, encrypted[24:], &decryptNonce, &secretKey)
if !ok {
panic("decryption error")
}
diff --git a/ocsp/ocsp.go b/ocsp/ocsp.go
index 8ed8796..6bd347e 100644
--- a/ocsp/ocsp.go
+++ b/ocsp/ocsp.go
@@ -450,8 +450,8 @@ func ParseRequest(bytes []byte) (*Request, error) {
// then the signature over the response is checked. If issuer is not nil then
// it will be used to validate the signature or embedded certificate.
//
-// Invalid signatures or parse failures will result in a ParseError. Error
-// responses will result in a ResponseError.
+// Invalid responses and parse failures will result in a ParseError.
+// Error responses will result in a ResponseError.
func ParseResponse(bytes []byte, issuer *x509.Certificate) (*Response, error) {
return ParseResponseForCert(bytes, nil, issuer)
}
@@ -462,8 +462,8 @@ func ParseResponse(bytes []byte, issuer *x509.Certificate) (*Response, error) {
// issuer is not nil then it will be used to validate the signature or embedded
// certificate.
//
-// Invalid signatures or parse failures will result in a ParseError. Error
-// responses will result in a ResponseError.
+// Invalid responses and parse failures will result in a ParseError.
+// Error responses will result in a ResponseError.
func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Response, error) {
var resp responseASN1
rest, err := asn1.Unmarshal(bytes, &resp)
@@ -496,10 +496,32 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon
return nil, ParseError("OCSP response contains bad number of responses")
}
+ var singleResp singleResponse
+ if cert == nil {
+ singleResp = basicResp.TBSResponseData.Responses[0]
+ } else {
+ match := false
+ for _, resp := range basicResp.TBSResponseData.Responses {
+ if cert == nil || cert.SerialNumber.Cmp(resp.CertID.SerialNumber) == 0 {
+ singleResp = resp
+ match = true
+ break
+ }
+ }
+ if !match {
+ return nil, ParseError("no response matching the supplied certificate")
+ }
+ }
+
ret := &Response{
TBSResponseData: basicResp.TBSResponseData.Raw,
Signature: basicResp.Signature.RightAlign(),
SignatureAlgorithm: getSignatureAlgorithmFromOID(basicResp.SignatureAlgorithm.Algorithm),
+ Extensions: singleResp.SingleExtensions,
+ SerialNumber: singleResp.CertID.SerialNumber,
+ ProducedAt: basicResp.TBSResponseData.ProducedAt,
+ ThisUpdate: singleResp.ThisUpdate,
+ NextUpdate: singleResp.NextUpdate,
}
// Handle the ResponderID CHOICE tag. ResponderID can be flattened into
@@ -542,25 +564,14 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon
}
}
- var r singleResponse
- for _, resp := range basicResp.TBSResponseData.Responses {
- if cert == nil || cert.SerialNumber.Cmp(resp.CertID.SerialNumber) == 0 {
- r = resp
- break
- }
- }
-
- for _, ext := range r.SingleExtensions {
+ for _, ext := range singleResp.SingleExtensions {
if ext.Critical {
return nil, ParseError("unsupported critical extension")
}
}
- ret.Extensions = r.SingleExtensions
-
- ret.SerialNumber = r.CertID.SerialNumber
for h, oid := range hashOIDs {
- if r.CertID.HashAlgorithm.Algorithm.Equal(oid) {
+ if singleResp.CertID.HashAlgorithm.Algorithm.Equal(oid) {
ret.IssuerHash = h
break
}
@@ -570,20 +581,16 @@ func ParseResponseForCert(bytes []byte, cert, issuer *x509.Certificate) (*Respon
}
switch {
- case bool(r.Good):
+ case bool(singleResp.Good):
ret.Status = Good
- case bool(r.Unknown):
+ case bool(singleResp.Unknown):
ret.Status = Unknown
default:
ret.Status = Revoked
- ret.RevokedAt = r.Revoked.RevocationTime
- ret.RevocationReason = int(r.Revoked.Reason)
+ ret.RevokedAt = singleResp.Revoked.RevocationTime
+ ret.RevocationReason = int(singleResp.Revoked.Reason)
}
- ret.ProducedAt = basicResp.TBSResponseData.ProducedAt
- ret.ThisUpdate = r.ThisUpdate
- ret.NextUpdate = r.NextUpdate
-
return ret, nil
}
diff --git a/ocsp/ocsp_test.go b/ocsp/ocsp_test.go
index d325d85..df674b3 100644
--- a/ocsp/ocsp_test.go
+++ b/ocsp/ocsp_test.go
@@ -343,6 +343,21 @@ func TestOCSPDecodeMultiResponse(t *testing.T) {
}
}
+func TestOCSPDecodeMultiResponseWithoutMatchingCert(t *testing.T) {
+ wrongCert, _ := hex.DecodeString(startComHex)
+ cert, err := x509.ParseCertificate(wrongCert)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ responseBytes, _ := hex.DecodeString(ocspMultiResponseHex)
+ _, err = ParseResponseForCert(responseBytes, cert, nil)
+ want := ParseError("no response matching the supplied certificate")
+ if err != want {
+ t.Errorf("err: got %q, want %q", err, want)
+ }
+}
+
// This OCSP response was taken from Thawte's public OCSP responder.
// To recreate:
// $ openssl s_client -tls1 -showcerts -servername www.google.com -connect www.google.com:443
diff --git a/pkcs12/pkcs12.go b/pkcs12/pkcs12.go
index ad6341e..eff9ad3 100644
--- a/pkcs12/pkcs12.go
+++ b/pkcs12/pkcs12.go
@@ -109,6 +109,10 @@ func ToPEM(pfxData []byte, password string) ([]*pem.Block, error) {
bags, encodedPassword, err := getSafeContents(pfxData, encodedPassword)
+ if err != nil {
+ return nil, err
+ }
+
blocks := make([]*pem.Block, 0, len(bags))
for _, bag := range bags {
block, err := convertBag(&bag, encodedPassword)
diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go
index a13a650..5fc47e5 100644
--- a/ssh/agent/client_test.go
+++ b/ssh/agent/client_test.go
@@ -180,9 +180,12 @@ func TestCert(t *testing.T) {
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
- listener, err := net.Listen("tcp", ":0")
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
- return nil, nil, err
+ listener, err = net.Listen("tcp", "[::1]:0")
+ if err != nil {
+ return nil, nil, err
+ }
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
@@ -200,6 +203,9 @@ func netPipe() (net.Conn, net.Conn, error) {
}
func TestAuth(t *testing.T) {
+ agent, _, cleanup := startAgent(t)
+ defer cleanup()
+
a, b, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
@@ -208,9 +214,6 @@ func TestAuth(t *testing.T) {
defer a.Close()
defer b.Close()
- agent, _, cleanup := startAgent(t)
- defer cleanup()
-
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil {
t.Errorf("Add: %v", err)
}
@@ -233,7 +236,9 @@ func TestAuth(t *testing.T) {
conn.Close()
}()
- conf := ssh.ClientConfig{}
+ conf := ssh.ClientConfig{
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ }
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
conn, _, _, err := ssh.NewClientConn(b, "", &conf)
if err != nil {
diff --git a/ssh/agent/example_test.go b/ssh/agent/example_test.go
index c1130f7..8556225 100644
--- a/ssh/agent/example_test.go
+++ b/ssh/agent/example_test.go
@@ -6,20 +6,20 @@ package agent_test
import (
"log"
- "os"
"net"
+ "os"
- "golang.org/x/crypto/ssh"
- "golang.org/x/crypto/ssh/agent"
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh/agent"
)
func ExampleClientAgent() {
// ssh-agent has a UNIX socket under $SSH_AUTH_SOCK
socket := os.Getenv("SSH_AUTH_SOCK")
- conn, err := net.Dial("unix", socket)
- if err != nil {
- log.Fatalf("net.Dial: %v", err)
- }
+ conn, err := net.Dial("unix", socket)
+ if err != nil {
+ log.Fatalf("net.Dial: %v", err)
+ }
agentClient := agent.NewClient(conn)
config := &ssh.ClientConfig{
User: "username",
@@ -29,6 +29,7 @@ func ExampleClientAgent() {
// wants it.
ssh.PublicKeysCallback(agentClient.Signers),
},
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
sshc, err := ssh.Dial("tcp", "localhost:22", config)
diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go
index ec9cdee..6b0837d 100644
--- a/ssh/agent/server_test.go
+++ b/ssh/agent/server_test.go
@@ -56,7 +56,9 @@ func TestSetupForwardAgent(t *testing.T) {
incoming <- conn
}()
- conf := ssh.ClientConfig{}
+ conf := ssh.ClientConfig{
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ }
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
diff --git a/ssh/certs.go b/ssh/certs.go
index 6331c94..b1f0220 100644
--- a/ssh/certs.go
+++ b/ssh/certs.go
@@ -251,10 +251,18 @@ type CertChecker struct {
// for user certificates.
SupportedCriticalOptions []string
- // IsAuthority should return true if the key is recognized as
- // an authority. This allows for certificates to be signed by other
- // certificates.
- IsAuthority func(auth PublicKey) bool
+ // IsUserAuthority should return true if the key is recognized as an
+ // authority for the given user certificate. This allows for
+ // certificates to be signed by other certificates. This must be set
+ // if this CertChecker will be checking user certificates.
+ IsUserAuthority func(auth PublicKey) bool
+
+ // IsHostAuthority should report whether the key is recognized as
+ // an authority for this host. This allows for certificates to be
+ // signed by other keys, and for those other keys to only be valid
+ // signers for particular hostnames. This must be set if this
+ // CertChecker will be checking host certificates.
+ IsHostAuthority func(auth PublicKey, address string) bool
// Clock is used for verifying time stamps. If nil, time.Now
// is used.
@@ -268,7 +276,7 @@ type CertChecker struct {
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
// public key that is not a certificate. It must implement host key
// validation or else, if nil, all such keys are rejected.
- HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error
+ HostKeyFallback HostKeyCallback
// IsRevoked is called for each certificate so that revocation checking
// can be implemented. It should return true if the given certificate
@@ -290,8 +298,17 @@ func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey)
if cert.CertType != HostCert {
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
}
+ if !c.IsHostAuthority(cert.SignatureKey, addr) {
+ return fmt.Errorf("ssh: no authorities for hostname: %v", addr)
+ }
+
+ hostname, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return err
+ }
- return c.CheckCert(addr, cert)
+ // Pass hostname only as principal for host certificates (consistent with OpenSSH)
+ return c.CheckCert(hostname, cert)
}
// Authenticate checks a user certificate. Authenticate can be used as
@@ -308,6 +325,9 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
if cert.CertType != UserCert {
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
}
+ if !c.IsUserAuthority(cert.SignatureKey) {
+ return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
+ }
if err := c.CheckCert(conn.User(), cert); err != nil {
return nil, err
@@ -356,10 +376,6 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
}
}
- if !c.IsAuthority(cert.SignatureKey) {
- return fmt.Errorf("ssh: certificate signed by unrecognized authority")
- }
-
clock := c.Clock
if clock == nil {
clock = time.Now
diff --git a/ssh/certs_test.go b/ssh/certs_test.go
index c5f2e53..0200531 100644
--- a/ssh/certs_test.go
+++ b/ssh/certs_test.go
@@ -104,7 +104,7 @@ func TestValidateCert(t *testing.T) {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
checker := CertChecker{}
- checker.IsAuthority = func(k PublicKey) bool {
+ checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
}
@@ -142,7 +142,7 @@ func TestValidateCertTime(t *testing.T) {
checker := CertChecker{
Clock: func() time.Time { return time.Unix(ts, 0) },
}
- checker.IsAuthority = func(k PublicKey) bool {
+ checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(),
testPublicKeys["ecdsa"].Marshal())
}
@@ -160,7 +160,7 @@ func TestValidateCertTime(t *testing.T) {
func TestHostKeyCert(t *testing.T) {
cert := &Certificate{
- ValidPrincipals: []string{"hostname", "hostname.domain"},
+ ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"},
Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity,
CertType: HostCert,
@@ -168,8 +168,8 @@ func TestHostKeyCert(t *testing.T) {
cert.SignCert(rand.Reader, testSigners["ecdsa"])
checker := &CertChecker{
- IsAuthority: func(p PublicKey) bool {
- return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
+ IsHostAuthority: func(p PublicKey, addr string) bool {
+ return addr == "hostname:22" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
},
}
@@ -178,7 +178,14 @@ func TestHostKeyCert(t *testing.T) {
t.Errorf("NewCertSigner: %v", err)
}
- for _, name := range []string{"hostname", "otherhost"} {
+ for _, test := range []struct {
+ addr string
+ succeed bool
+ }{
+ {addr: "hostname:22", succeed: true},
+ {addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22'
+ {addr: "lasthost:22", succeed: false},
+ } {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
@@ -201,16 +208,15 @@ func TestHostKeyCert(t *testing.T) {
User: "user",
HostKeyCallback: checker.CheckHostKey,
}
- _, _, _, err = NewClientConn(c2, name, config)
+ _, _, _, err = NewClientConn(c2, test.addr, config)
- succeed := name == "hostname"
- if (err == nil) != succeed {
- t.Fatalf("NewClientConn(%q): %v", name, err)
+ if (err == nil) != test.succeed {
+ t.Fatalf("NewClientConn(%q): %v", test.addr, err)
}
err = <-errc
- if (err == nil) != succeed {
- t.Fatalf("NewServerConn(%q): %v", name, err)
+ if (err == nil) != test.succeed {
+ t.Fatalf("NewServerConn(%q): %v", test.addr, err)
}
}
}
diff --git a/ssh/client.go b/ssh/client.go
index c97f297..a7e3263 100644
--- a/ssh/client.go
+++ b/ssh/client.go
@@ -5,6 +5,7 @@
package ssh
import (
+ "bytes"
"errors"
"fmt"
"net"
@@ -13,7 +14,7 @@ import (
)
// Client implements a traditional SSH client that supports shells,
-// subprocesses, port forwarding and tunneled dialing.
+// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
type Client struct {
Conn
@@ -59,6 +60,7 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn.forwards.closeAll()
}()
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
+ go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
return conn
}
@@ -68,6 +70,11 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
+ if fullConf.HostKeyCallback == nil {
+ c.Close()
+ return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback")
+ }
+
conn := &connection{
sshConn: sshConn{conn: c},
}
@@ -173,6 +180,13 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
return NewClient(c, chans, reqs), nil
}
+// HostKeyCallback is the function type used for verifying server
+// keys. A HostKeyCallback must return nil if the host key is OK, or
+// an error to reject it. It receives the hostname as passed to Dial
+// or NewClientConn. The remote address is the RemoteAddr of the
+// net.Conn underlying the the SSH connection.
+type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
+
// A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function.
type ClientConfig struct {
@@ -188,10 +202,12 @@ type ClientConfig struct {
// be used during authentication.
Auth []AuthMethod
- // HostKeyCallback, if not nil, is called during the cryptographic
- // handshake to validate the server's host key. A nil HostKeyCallback
- // implies that all host keys are accepted.
- HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
+ // HostKeyCallback is called during the cryptographic
+ // handshake to validate the server's host key. The client
+ // configuration must supply this callback for the connection
+ // to succeed. The functions InsecureIgnoreHostKey or
+ // FixedHostKey can be used for simplistic host key checks.
+ HostKeyCallback HostKeyCallback
// ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used.
@@ -209,3 +225,33 @@ type ClientConfig struct {
// A Timeout of zero means no timeout.
Timeout time.Duration
}
+
+// InsecureIgnoreHostKey returns a function that can be used for
+// ClientConfig.HostKeyCallback to accept any host key. It should
+// not be used for production code.
+func InsecureIgnoreHostKey() HostKeyCallback {
+ return func(hostname string, remote net.Addr, key PublicKey) error {
+ return nil
+ }
+}
+
+type fixedHostKey struct {
+ key PublicKey
+}
+
+func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error {
+ if f.key == nil {
+ return fmt.Errorf("ssh: required host key was nil")
+ }
+ if !bytes.Equal(key.Marshal(), f.key.Marshal()) {
+ return fmt.Errorf("ssh: host key mismatch")
+ }
+ return nil
+}
+
+// FixedHostKey returns a function for use in
+// ClientConfig.HostKeyCallback to accept only a specific host key.
+func FixedHostKey(key PublicKey) HostKeyCallback {
+ hk := &fixedHostKey{key}
+ return hk.check
+}
diff --git a/ssh/client_auth.go b/ssh/client_auth.go
index fd1ec5d..b882da0 100644
--- a/ssh/client_auth.go
+++ b/ssh/client_auth.go
@@ -179,31 +179,26 @@ func (cb publicKeyCallback) method() string {
}
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
- // Authentication is performed in two stages. The first stage sends an
- // enquiry to test if each key is acceptable to the remote. The second
- // stage attempts to authenticate with the valid keys obtained in the
- // first stage.
+ // Authentication is performed by sending an enquiry to test if a key is
+ // acceptable to the remote. If the key is acceptable, the client will
+ // attempt to authenticate with the valid key. If not the client will repeat
+ // the process with the remaining keys.
signers, err := cb()
if err != nil {
return false, nil, err
}
- var validKeys []Signer
+ var methods []string
for _, signer := range signers {
- if ok, err := validateKey(signer.PublicKey(), user, c); ok {
- validKeys = append(validKeys, signer)
- } else {
- if err != nil {
- return false, nil, err
- }
+ ok, err := validateKey(signer.PublicKey(), user, c)
+ if err != nil {
+ return false, nil, err
+ }
+ if !ok {
+ continue
}
- }
- // methods that may continue if this auth is not successful.
- var methods []string
- for _, signer := range validKeys {
pub := signer.PublicKey()
-
pubKey := pub.Marshal()
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
@@ -236,13 +231,29 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
if err != nil {
return false, nil, err
}
- if success {
+
+ // If authentication succeeds or the list of available methods does not
+ // contain the "publickey" method, do not attempt to authenticate with any
+ // other keys. According to RFC 4252 Section 7, the latter can occur when
+ // additional authentication methods are required.
+ if success || !containsMethod(methods, cb.method()) {
return success, methods, err
}
}
+
return false, methods, nil
}
+func containsMethod(methods []string, method string) bool {
+ for _, m := range methods {
+ if m == method {
+ return true
+ }
+ }
+
+ return false
+}
+
// validateKey validates the key provided is acceptable to the server.
func validateKey(key PublicKey, user string, c packetConn) (bool, error) {
pubKey := key.Marshal()
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index e384c79..bd9f8a1 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -38,7 +38,7 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
defer c2.Close()
certChecker := CertChecker{
- IsAuthority: func(k PublicKey) bool {
+ IsUserAuthority: func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
},
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
@@ -76,8 +76,6 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
}
return nil, errors.New("keyboard-interactive failed")
},
- AuthLogCallback: func(conn ConnMetadata, method string, err error) {
- },
}
serverConfig.AddHostKey(testSigners["rsa"])
@@ -92,6 +90,7 @@ func TestClientAuthPublicKey(t *testing.T) {
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
@@ -104,6 +103,7 @@ func TestAuthMethodPassword(t *testing.T) {
Auth: []AuthMethod{
Password(clientPassword),
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
@@ -123,6 +123,7 @@ func TestAuthMethodFallback(t *testing.T) {
return "WRONG", nil
}),
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
@@ -141,6 +142,7 @@ func TestAuthMethodWrongPassword(t *testing.T) {
Password("wrong"),
PublicKeys(testSigners["rsa"]),
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
@@ -158,6 +160,7 @@ func TestAuthMethodKeyboardInteractive(t *testing.T) {
Auth: []AuthMethod{
KeyboardInteractive(answers.Challenge),
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
@@ -203,6 +206,7 @@ func TestAuthMethodRSAandDSA(t *testing.T) {
Auth: []AuthMethod{
PublicKeys(testSigners["dsa"], testSigners["rsa"]),
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with rsa key: %v", err)
@@ -219,6 +223,7 @@ func TestClientHMAC(t *testing.T) {
Config: Config{
MACs: []string{mac},
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
@@ -254,6 +259,7 @@ func TestClientUnsupportedKex(t *testing.T) {
Config: Config{
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") {
t.Errorf("got %v, expected 'common algorithm'", err)
@@ -273,7 +279,8 @@ func TestClientLoginCert(t *testing.T) {
}
clientConfig := &ClientConfig{
- User: "user",
+ User: "user",
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner))
@@ -363,6 +370,7 @@ func testPermissionsPassing(withPermissions bool, t *testing.T) {
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if withPermissions {
clientConfig.User = "permissions"
@@ -409,6 +417,7 @@ func TestRetryableAuth(t *testing.T) {
}), 2),
PublicKeys(testSigners["rsa"]),
},
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
@@ -430,7 +439,8 @@ func ExampleRetryableAuthMethod(t *testing.T) {
}
config := &ClientConfig{
- User: user,
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ User: user,
Auth: []AuthMethod{
RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts),
},
@@ -450,7 +460,8 @@ func TestClientAuthNone(t *testing.T) {
serverConfig.AddHostKey(testSigners["rsa"])
clientConfig := &ClientConfig{
- User: user,
+ User: user,
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
c1, c2, err := netPipe()
@@ -469,3 +480,100 @@ func TestClientAuthNone(t *testing.T) {
t.Fatalf("server: got %q, want %q", serverConn.User(), user)
}
}
+
+// Test if authentication attempts are limited on server when MaxAuthTries is set
+func TestClientAuthMaxAuthTries(t *testing.T) {
+ user := "testuser"
+
+ serverConfig := &ServerConfig{
+ MaxAuthTries: 2,
+ PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
+ if conn.User() == "testuser" && string(pass) == "right" {
+ return nil, nil
+ }
+ return nil, errors.New("password auth failed")
+ },
+ }
+ serverConfig.AddHostKey(testSigners["rsa"])
+
+ expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
+ Reason: 2,
+ Message: "too many authentication failures",
+ })
+
+ for tries := 2; tries < 4; tries++ {
+ n := tries
+ clientConfig := &ClientConfig{
+ User: user,
+ Auth: []AuthMethod{
+ RetryableAuthMethod(PasswordCallback(func() (string, error) {
+ n--
+ if n == 0 {
+ return "right", nil
+ } else {
+ return "wrong", nil
+ }
+ }), tries),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+
+ go newServer(c1, serverConfig)
+ _, _, _, err = NewClientConn(c2, "", clientConfig)
+ if tries > 2 {
+ if err == nil {
+ t.Fatalf("client: got no error, want %s", expectedErr)
+ } else if err.Error() != expectedErr.Error() {
+ t.Fatalf("client: got %s, want %s", err, expectedErr)
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("client: got %s, want no error", err)
+ }
+ }
+ }
+}
+
+// Test if authentication attempts are correctly limited on server
+// when more public keys are provided then MaxAuthTries
+func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) {
+ signers := []Signer{}
+ for i := 0; i < 6; i++ {
+ signers = append(signers, testSigners["dsa"])
+ }
+
+ validConfig := &ClientConfig{
+ User: "testuser",
+ Auth: []AuthMethod{
+ PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+ if err := tryAuth(t, validConfig); err != nil {
+ t.Fatalf("unable to dial remote side: %s", err)
+ }
+
+ expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
+ Reason: 2,
+ Message: "too many authentication failures",
+ })
+ invalidConfig := &ClientConfig{
+ User: "testuser",
+ Auth: []AuthMethod{
+ PublicKeys(append(signers, testSigners["rsa"])...),
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+ if err := tryAuth(t, invalidConfig); err == nil {
+ t.Fatalf("client: got no error, want %s", expectedErr)
+ } else if err.Error() != expectedErr.Error() {
+ t.Fatalf("client: got %s, want %s", err, expectedErr)
+ }
+}
diff --git a/ssh/client_test.go b/ssh/client_test.go
index 1fe790c..ccf5607 100644
--- a/ssh/client_test.go
+++ b/ssh/client_test.go
@@ -6,6 +6,7 @@ package ssh
import (
"net"
+ "strings"
"testing"
)
@@ -13,6 +14,7 @@ func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
receivedVersion := make(chan string, 1)
+ config.HostKeyCallback = InsecureIgnoreHostKey()
go func() {
version, err := readVersion(serverConn)
if err != nil {
@@ -37,3 +39,43 @@ func TestCustomClientVersion(t *testing.T) {
func TestDefaultClientVersion(t *testing.T) {
testClientVersion(t, &ClientConfig{}, packageVersion)
}
+
+func TestHostKeyCheck(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ wantError string
+ key PublicKey
+ }{
+ {"no callback", "must specify HostKeyCallback", nil},
+ {"correct key", "", testSigners["rsa"].PublicKey()},
+ {"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()},
+ } {
+ c1, c2, err := netPipe()
+ if err != nil {
+ t.Fatalf("netPipe: %v", err)
+ }
+ defer c1.Close()
+ defer c2.Close()
+ serverConf := &ServerConfig{
+ NoClientAuth: true,
+ }
+ serverConf.AddHostKey(testSigners["rsa"])
+
+ go NewServerConn(c1, serverConf)
+ clientConf := ClientConfig{
+ User: "user",
+ }
+ if tt.key != nil {
+ clientConf.HostKeyCallback = FixedHostKey(tt.key)
+ }
+
+ _, _, _, err = NewClientConn(c2, "", &clientConf)
+ if err != nil {
+ if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) {
+ t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError)
+ }
+ } else if tt.wantError != "" {
+ t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError)
+ }
+ }
+}
diff --git a/ssh/common.go b/ssh/common.go
index 8656d0f..dc39e4d 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -9,6 +9,7 @@ import (
"crypto/rand"
"fmt"
"io"
+ "math"
"sync"
_ "crypto/sha1"
@@ -40,7 +41,7 @@ var supportedKexAlgos = []string{
kexAlgoDH14SHA1, kexAlgoDH1SHA1,
}
-// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
+// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
// of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
@@ -186,7 +187,7 @@ type Config struct {
// The maximum number of bytes sent or received after which a
// new key is negotiated. It must be at least 256. If
- // unspecified, 1 gigabyte is used.
+ // unspecified, a size suitable for the chosen cipher is used.
RekeyThreshold uint64
// The allowed key exchanges algorithms. If unspecified then a
@@ -230,11 +231,12 @@ func (c *Config) SetDefaults() {
}
if c.RekeyThreshold == 0 {
- // RFC 4253, section 9 suggests rekeying after 1G.
- c.RekeyThreshold = 1 << 30
- }
- if c.RekeyThreshold < minRekeyThreshold {
+ // cipher specific default
+ } else if c.RekeyThreshold < minRekeyThreshold {
c.RekeyThreshold = minRekeyThreshold
+ } else if c.RekeyThreshold >= math.MaxInt64 {
+ // Avoid weirdness if somebody uses -1 as a threshold.
+ c.RekeyThreshold = math.MaxInt64
}
}
diff --git a/ssh/connection.go b/ssh/connection.go
index e786f2f..fd6b068 100644
--- a/ssh/connection.go
+++ b/ssh/connection.go
@@ -25,7 +25,7 @@ type ConnMetadata interface {
// User returns the user ID for this connection.
User() string
- // SessionID returns the sesson hash, also denoted by H.
+ // SessionID returns the session hash, also denoted by H.
SessionID() []byte
// ClientVersion returns the client's version string as hashed
diff --git a/ssh/doc.go b/ssh/doc.go
index d6be894..67b7322 100644
--- a/ssh/doc.go
+++ b/ssh/doc.go
@@ -14,5 +14,8 @@ others.
References:
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
+
+This package does not fall under the stability promise of the Go language itself,
+so its API may be changed when pressing needs arise.
*/
package ssh // import "golang.org/x/crypto/ssh"
diff --git a/ssh/example_test.go b/ssh/example_test.go
index 4d2eabd..618398c 100644
--- a/ssh/example_test.go
+++ b/ssh/example_test.go
@@ -5,12 +5,16 @@
package ssh_test
import (
+ "bufio"
"bytes"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
+ "os"
+ "path/filepath"
+ "strings"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
@@ -91,8 +95,6 @@ func ExampleNewServerConn() {
go ssh.DiscardRequests(reqs)
// Service the incoming Channel channel.
-
- // Service the incoming Channel channel.
for newChannel := range chans {
// Channels have a type, depending on the application level
// protocol intended. In the case of a shell, the type is
@@ -131,16 +133,59 @@ func ExampleNewServerConn() {
}
}
+func ExampleHostKeyCheck() {
+ // Every client must provide a host key check. Here is a
+ // simple-minded parse of OpenSSH's known_hosts file
+ host := "hostname"
+ file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer file.Close()
+
+ scanner := bufio.NewScanner(file)
+ var hostKey ssh.PublicKey
+ for scanner.Scan() {
+ fields := strings.Split(scanner.Text(), " ")
+ if len(fields) != 3 {
+ continue
+ }
+ if strings.Contains(fields[0], host) {
+ var err error
+ hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
+ if err != nil {
+ log.Fatalf("error parsing %q: %v", fields[2], err)
+ }
+ break
+ }
+ }
+
+ if hostKey == nil {
+ log.Fatalf("no hostkey for %s", host)
+ }
+
+ config := ssh.ClientConfig{
+ User: os.Getenv("USER"),
+ HostKeyCallback: ssh.FixedHostKey(hostKey),
+ }
+
+ _, err = ssh.Dial("tcp", host+":22", &config)
+ log.Println(err)
+}
+
func ExampleDial() {
+ var hostKey ssh.PublicKey
// An SSH client is represented with a ClientConn.
//
// To authenticate with the remote server you must pass at least one
- // implementation of AuthMethod via the Auth field in ClientConfig.
+ // implementation of AuthMethod via the Auth field in ClientConfig,
+ // and provide a HostKeyCallback.
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("yourpassword"),
},
+ HostKeyCallback: ssh.FixedHostKey(hostKey),
}
client, err := ssh.Dial("tcp", "yourserver.com:22", config)
if err != nil {
@@ -166,6 +211,7 @@ func ExampleDial() {
}
func ExamplePublicKeys() {
+ var hostKey ssh.PublicKey
// A public key may be used to authenticate against the remote
// server by using an unencrypted PEM-encoded private key file.
//
@@ -188,6 +234,7 @@ func ExamplePublicKeys() {
// Use the PublicKeys method for remote authentication.
ssh.PublicKeys(signer),
},
+ HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Connect to the remote server and perform the SSH handshake.
@@ -199,11 +246,13 @@ func ExamplePublicKeys() {
}
func ExampleClient_Listen() {
+ var hostKey ssh.PublicKey
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("password"),
},
+ HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Dial your ssh server.
conn, err := ssh.Dial("tcp", "localhost:22", config)
@@ -226,12 +275,14 @@ func ExampleClient_Listen() {
}
func ExampleSession_RequestPty() {
+ var hostKey ssh.PublicKey
// Create client config
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("password"),
},
+ HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Connect to ssh server
conn, err := ssh.Dial("tcp", "localhost:22", config)
diff --git a/ssh/handshake.go b/ssh/handshake.go
index 8de6506..932ce83 100644
--- a/ssh/handshake.go
+++ b/ssh/handshake.go
@@ -74,7 +74,7 @@ type handshakeTransport struct {
startKex chan *pendingKex
// data for host key checking
- hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
+ hostKeyCallback HostKeyCallback
dialAddress string
remoteAddr net.Addr
@@ -107,6 +107,8 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
config: config,
}
+ t.resetReadThresholds()
+ t.resetWriteThresholds()
// We always start with a mandatory key exchange.
t.requestKex <- struct{}{}
@@ -237,6 +239,17 @@ func (t *handshakeTransport) requestKeyExchange() {
}
}
+func (t *handshakeTransport) resetWriteThresholds() {
+ t.writePacketsLeft = packetRekeyThreshold
+ if t.config.RekeyThreshold > 0 {
+ t.writeBytesLeft = int64(t.config.RekeyThreshold)
+ } else if t.algorithms != nil {
+ t.writeBytesLeft = t.algorithms.w.rekeyBytes()
+ } else {
+ t.writeBytesLeft = 1 << 30
+ }
+}
+
func (t *handshakeTransport) kexLoop() {
write:
@@ -285,12 +298,8 @@ write:
t.writeError = err
t.sentInitPacket = nil
t.sentInitMsg = nil
- t.writePacketsLeft = packetRekeyThreshold
- if t.config.RekeyThreshold > 0 {
- t.writeBytesLeft = int64(t.config.RekeyThreshold)
- } else if t.algorithms != nil {
- t.writeBytesLeft = t.algorithms.w.rekeyBytes()
- }
+
+ t.resetWriteThresholds()
// we have completed the key exchange. Since the
// reader is still blocked, it is safe to clear out
@@ -344,6 +353,17 @@ write:
// key exchange itself.
const packetRekeyThreshold = (1 << 31)
+func (t *handshakeTransport) resetReadThresholds() {
+ t.readPacketsLeft = packetRekeyThreshold
+ if t.config.RekeyThreshold > 0 {
+ t.readBytesLeft = int64(t.config.RekeyThreshold)
+ } else if t.algorithms != nil {
+ t.readBytesLeft = t.algorithms.r.rekeyBytes()
+ } else {
+ t.readBytesLeft = 1 << 30
+ }
+}
+
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
p, err := t.conn.readPacket()
if err != nil {
@@ -391,12 +411,7 @@ func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
return nil, err
}
- t.readPacketsLeft = packetRekeyThreshold
- if t.config.RekeyThreshold > 0 {
- t.readBytesLeft = int64(t.config.RekeyThreshold)
- } else {
- t.readBytesLeft = t.algorithms.r.rekeyBytes()
- }
+ t.resetReadThresholds()
// By default, a key exchange is hidden from higher layers by
// translating it into msgIgnore.
@@ -574,7 +589,9 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
}
result.SessionID = t.sessionID
- t.conn.prepareKeyChange(t.algorithms, result)
+ if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
+ return err
+ }
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
@@ -614,11 +631,9 @@ func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *
return nil, err
}
- if t.hostKeyCallback != nil {
- err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
- if err != nil {
- return nil, err
- }
+ err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
+ if err != nil {
+ return nil, err
}
return result, nil
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go
index 1b83112..91d4935 100644
--- a/ssh/handshake_test.go
+++ b/ssh/handshake_test.go
@@ -40,9 +40,12 @@ func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
- listener, err := net.Listen("tcp", ":0")
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
- return nil, nil, err
+ listener, err = net.Listen("tcp", "[::1]:0")
+ if err != nil {
+ return nil, nil, err
+ }
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
@@ -436,6 +439,7 @@ func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, couple
clientConf.SetDefaults()
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
+ clientConn.hostKeyCallback = InsecureIgnoreHostKey()
go clientConn.readLoop()
go clientConn.kexLoop()
@@ -525,3 +529,31 @@ func TestDisconnect(t *testing.T) {
t.Errorf("readPacket 3 succeeded")
}
}
+
+func TestHandshakeRekeyDefault(t *testing.T) {
+ clientConf := &ClientConfig{
+ Config: Config{
+ Ciphers: []string{"aes128-ctr"},
+ },
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
+ trC, trS, err := handshakePair(clientConf, "addr", false)
+ if err != nil {
+ t.Fatalf("handshakePair: %v", err)
+ }
+ defer trC.Close()
+ defer trS.Close()
+
+ trC.writePacket([]byte{msgRequestSuccess, 0, 0})
+ trC.Close()
+
+ rgb := (1024 + trC.readBytesLeft) >> 30
+ wgb := (1024 + trC.writeBytesLeft) >> 30
+
+ if rgb != 64 {
+ t.Errorf("got rekey after %dG read, want 64G", rgb)
+ }
+ if wgb != 64 {
+ t.Errorf("got rekey after %dG write, want 64G", wgb)
+ }
+}
diff --git a/ssh/keys.go b/ssh/keys.go
index f38de98..cf68532 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -824,7 +824,7 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
// Implemented based on the documentation at
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
-func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
+func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) {
magic := append([]byte("openssh-key-v1"), 0)
if !bytes.Equal(magic, key[0:len(magic)]) {
return nil, errors.New("ssh: invalid openssh private key format")
@@ -844,14 +844,15 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
return nil, err
}
+ if w.KdfName != "none" || w.CipherName != "none" {
+ return nil, errors.New("ssh: cannot decode encrypted private keys")
+ }
+
pk1 := struct {
Check1 uint32
Check2 uint32
Keytype string
- Pub []byte
- Priv []byte
- Comment string
- Pad []byte `ssh:"rest"`
+ Rest []byte `ssh:"rest"`
}{}
if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil {
@@ -862,24 +863,75 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) {
return nil, errors.New("ssh: checkint mismatch")
}
- // we only handle ed25519 keys currently
- if pk1.Keytype != KeyAlgoED25519 {
- return nil, errors.New("ssh: unhandled key type")
- }
+ // we only handle ed25519 and rsa keys currently
+ switch pk1.Keytype {
+ case KeyAlgoRSA:
+ // https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773
+ key := struct {
+ N *big.Int
+ E *big.Int
+ D *big.Int
+ Iqmp *big.Int
+ P *big.Int
+ Q *big.Int
+ Comment string
+ Pad []byte `ssh:"rest"`
+ }{}
+
+ if err := Unmarshal(pk1.Rest, &key); err != nil {
+ return nil, err
+ }
- for i, b := range pk1.Pad {
- if int(b) != i+1 {
- return nil, errors.New("ssh: padding not as expected")
+ for i, b := range key.Pad {
+ if int(b) != i+1 {
+ return nil, errors.New("ssh: padding not as expected")
+ }
}
- }
- if len(pk1.Priv) != ed25519.PrivateKeySize {
- return nil, errors.New("ssh: private key unexpected length")
- }
+ pk := &rsa.PrivateKey{
+ PublicKey: rsa.PublicKey{
+ N: key.N,
+ E: int(key.E.Int64()),
+ },
+ D: key.D,
+ Primes: []*big.Int{key.P, key.Q},
+ }
- pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
- copy(pk, pk1.Priv)
- return &pk, nil
+ if err := pk.Validate(); err != nil {
+ return nil, err
+ }
+
+ pk.Precompute()
+
+ return pk, nil
+ case KeyAlgoED25519:
+ key := struct {
+ Pub []byte
+ Priv []byte
+ Comment string
+ Pad []byte `ssh:"rest"`
+ }{}
+
+ if err := Unmarshal(pk1.Rest, &key); err != nil {
+ return nil, err
+ }
+
+ if len(key.Priv) != ed25519.PrivateKeySize {
+ return nil, errors.New("ssh: private key unexpected length")
+ }
+
+ for i, b := range key.Pad {
+ if int(b) != i+1 {
+ return nil, errors.New("ssh: padding not as expected")
+ }
+ }
+
+ pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
+ copy(pk, key.Priv)
+ return &pk, nil
+ default:
+ return nil, errors.New("ssh: unhandled key type")
+ }
}
// FingerprintLegacyMD5 returns the user presentation of the key's
diff --git a/ssh/knownhosts/knownhosts.go b/ssh/knownhosts/knownhosts.go
new file mode 100644
index 0000000..ea92b29
--- /dev/null
+++ b/ssh/knownhosts/knownhosts.go
@@ -0,0 +1,546 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package knownhosts implements a parser for the OpenSSH
+// known_hosts host key database.
+package knownhosts
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/sha1"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "strings"
+
+ "golang.org/x/crypto/ssh"
+)
+
+// See the sshd manpage
+// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for
+// background.
+
+type addr struct{ host, port string }
+
+func (a *addr) String() string {
+ h := a.host
+ if strings.Contains(h, ":") {
+ h = "[" + h + "]"
+ }
+ return h + ":" + a.port
+}
+
+type matcher interface {
+ match([]addr) bool
+}
+
+type hostPattern struct {
+ negate bool
+ addr addr
+}
+
+func (p *hostPattern) String() string {
+ n := ""
+ if p.negate {
+ n = "!"
+ }
+
+ return n + p.addr.String()
+}
+
+type hostPatterns []hostPattern
+
+func (ps hostPatterns) match(addrs []addr) bool {
+ matched := false
+ for _, p := range ps {
+ for _, a := range addrs {
+ m := p.match(a)
+ if !m {
+ continue
+ }
+ if p.negate {
+ return false
+ }
+ matched = true
+ }
+ }
+ return matched
+}
+
+// See
+// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
+// The matching of * has no regard for separators, unlike filesystem globs
+func wildcardMatch(pat []byte, str []byte) bool {
+ for {
+ if len(pat) == 0 {
+ return len(str) == 0
+ }
+ if len(str) == 0 {
+ return false
+ }
+
+ if pat[0] == '*' {
+ if len(pat) == 1 {
+ return true
+ }
+
+ for j := range str {
+ if wildcardMatch(pat[1:], str[j:]) {
+ return true
+ }
+ }
+ return false
+ }
+
+ if pat[0] == '?' || pat[0] == str[0] {
+ pat = pat[1:]
+ str = str[1:]
+ } else {
+ return false
+ }
+ }
+}
+
+func (l *hostPattern) match(a addr) bool {
+ return wildcardMatch([]byte(l.addr.host), []byte(a.host)) && l.addr.port == a.port
+}
+
+type keyDBLine struct {
+ cert bool
+ matcher matcher
+ knownKey KnownKey
+}
+
+func serialize(k ssh.PublicKey) string {
+ return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
+}
+
+func (l *keyDBLine) match(addrs []addr) bool {
+ return l.matcher.match(addrs)
+}
+
+type hostKeyDB struct {
+ // Serialized version of revoked keys
+ revoked map[string]*KnownKey
+ lines []keyDBLine
+}
+
+func newHostKeyDB() *hostKeyDB {
+ db := &hostKeyDB{
+ revoked: make(map[string]*KnownKey),
+ }
+
+ return db
+}
+
+func keyEq(a, b ssh.PublicKey) bool {
+ return bytes.Equal(a.Marshal(), b.Marshal())
+}
+
+// IsAuthorityForHost can be used as a callback in ssh.CertChecker
+func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
+ h, p, err := net.SplitHostPort(address)
+ if err != nil {
+ return false
+ }
+ a := addr{host: h, port: p}
+
+ for _, l := range db.lines {
+ if l.cert && keyEq(l.knownKey.Key, remote) && l.match([]addr{a}) {
+ return true
+ }
+ }
+ return false
+}
+
+// IsRevoked can be used as a callback in ssh.CertChecker
+func (db *hostKeyDB) IsRevoked(key *ssh.Certificate) bool {
+ _, ok := db.revoked[string(key.Marshal())]
+ return ok
+}
+
+const markerCert = "@cert-authority"
+const markerRevoked = "@revoked"
+
+func nextWord(line []byte) (string, []byte) {
+ i := bytes.IndexAny(line, "\t ")
+ if i == -1 {
+ return string(line), nil
+ }
+
+ return string(line[:i]), bytes.TrimSpace(line[i:])
+}
+
+func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
+ if w, next := nextWord(line); w == markerCert || w == markerRevoked {
+ marker = w
+ line = next
+ }
+
+ host, line = nextWord(line)
+ if len(line) == 0 {
+ return "", "", nil, errors.New("knownhosts: missing host pattern")
+ }
+
+ // ignore the keytype as it's in the key blob anyway.
+ _, line = nextWord(line)
+ if len(line) == 0 {
+ return "", "", nil, errors.New("knownhosts: missing key type pattern")
+ }
+
+ keyBlob, _ := nextWord(line)
+
+ keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
+ if err != nil {
+ return "", "", nil, err
+ }
+ key, err = ssh.ParsePublicKey(keyBytes)
+ if err != nil {
+ return "", "", nil, err
+ }
+
+ return marker, host, key, nil
+}
+
+func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error {
+ marker, pattern, key, err := parseLine(line)
+ if err != nil {
+ return err
+ }
+
+ if marker == markerRevoked {
+ db.revoked[string(key.Marshal())] = &KnownKey{
+ Key: key,
+ Filename: filename,
+ Line: linenum,
+ }
+
+ return nil
+ }
+
+ entry := keyDBLine{
+ cert: marker == markerCert,
+ knownKey: KnownKey{
+ Filename: filename,
+ Line: linenum,
+ Key: key,
+ },
+ }
+
+ if pattern[0] == '|' {
+ entry.matcher, err = newHashedHost(pattern)
+ } else {
+ entry.matcher, err = newHostnameMatcher(pattern)
+ }
+
+ if err != nil {
+ return err
+ }
+
+ db.lines = append(db.lines, entry)
+ return nil
+}
+
+func newHostnameMatcher(pattern string) (matcher, error) {
+ var hps hostPatterns
+ for _, p := range strings.Split(pattern, ",") {
+ if len(p) == 0 {
+ continue
+ }
+
+ var a addr
+ var negate bool
+ if p[0] == '!' {
+ negate = true
+ p = p[1:]
+ }
+
+ if len(p) == 0 {
+ return nil, errors.New("knownhosts: negation without following hostname")
+ }
+
+ var err error
+ if p[0] == '[' {
+ a.host, a.port, err = net.SplitHostPort(p)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ a.host, a.port, err = net.SplitHostPort(p)
+ if err != nil {
+ a.host = p
+ a.port = "22"
+ }
+ }
+ hps = append(hps, hostPattern{
+ negate: negate,
+ addr: a,
+ })
+ }
+ return hps, nil
+}
+
+// KnownKey represents a key declared in a known_hosts file.
+type KnownKey struct {
+ Key ssh.PublicKey
+ Filename string
+ Line int
+}
+
+func (k *KnownKey) String() string {
+ return fmt.Sprintf("%s:%d: %s", k.Filename, k.Line, serialize(k.Key))
+}
+
+// KeyError is returned if we did not find the key in the host key
+// database, or there was a mismatch. Typically, in batch
+// applications, this should be interpreted as failure. Interactive
+// applications can offer an interactive prompt to the user.
+type KeyError struct {
+ // Want holds the accepted host keys. For each key algorithm,
+ // there can be one hostkey. If Want is empty, the host is
+ // unknown. If Want is non-empty, there was a mismatch, which
+ // can signify a MITM attack.
+ Want []KnownKey
+}
+
+func (u *KeyError) Error() string {
+ if len(u.Want) == 0 {
+ return "knownhosts: key is unknown"
+ }
+ return "knownhosts: key mismatch"
+}
+
+// RevokedError is returned if we found a key that was revoked.
+type RevokedError struct {
+ Revoked KnownKey
+}
+
+func (r *RevokedError) Error() string {
+ return "knownhosts: key is revoked"
+}
+
+// check checks a key against the host database. This should not be
+// used for verifying certificates.
+func (db *hostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error {
+ if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil {
+ return &RevokedError{Revoked: *revoked}
+ }
+
+ host, port, err := net.SplitHostPort(remote.String())
+ if err != nil {
+ return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err)
+ }
+
+ addrs := []addr{
+ {host, port},
+ }
+
+ if address != "" {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err)
+ }
+
+ addrs = append(addrs, addr{host, port})
+ }
+
+ return db.checkAddrs(addrs, remoteKey)
+}
+
+// checkAddrs checks if we can find the given public key for any of
+// the given addresses. If we only find an entry for the IP address,
+// or only the hostname, then this still succeeds.
+func (db *hostKeyDB) checkAddrs(addrs []addr, remoteKey ssh.PublicKey) error {
+ // TODO(hanwen): are these the right semantics? What if there
+ // is just a key for the IP address, but not for the
+ // hostname?
+
+ // Algorithm => key.
+ knownKeys := map[string]KnownKey{}
+ for _, l := range db.lines {
+ if l.match(addrs) {
+ typ := l.knownKey.Key.Type()
+ if _, ok := knownKeys[typ]; !ok {
+ knownKeys[typ] = l.knownKey
+ }
+ }
+ }
+
+ keyErr := &KeyError{}
+ for _, v := range knownKeys {
+ keyErr.Want = append(keyErr.Want, v)
+ }
+
+ // Unknown remote host.
+ if len(knownKeys) == 0 {
+ return keyErr
+ }
+
+ // If the remote host starts using a different, unknown key type, we
+ // also interpret that as a mismatch.
+ if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) {
+ return keyErr
+ }
+
+ return nil
+}
+
+// The Read function parses file contents.
+func (db *hostKeyDB) Read(r io.Reader, filename string) error {
+ scanner := bufio.NewScanner(r)
+
+ lineNum := 0
+ for scanner.Scan() {
+ lineNum++
+ line := scanner.Bytes()
+ line = bytes.TrimSpace(line)
+ if len(line) == 0 || line[0] == '#' {
+ continue
+ }
+
+ if err := db.parseLine(line, filename, lineNum); err != nil {
+ return fmt.Errorf("knownhosts: %s:%d: %v", filename, lineNum, err)
+ }
+ }
+ return scanner.Err()
+}
+
+// New creates a host key callback from the given OpenSSH host key
+// files. The returned callback is for use in
+// ssh.ClientConfig.HostKeyCallback. Hashed hostnames are not supported.
+func New(files ...string) (ssh.HostKeyCallback, error) {
+ db := newHostKeyDB()
+ for _, fn := range files {
+ f, err := os.Open(fn)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ if err := db.Read(f, fn); err != nil {
+ return nil, err
+ }
+ }
+
+ var certChecker ssh.CertChecker
+ certChecker.IsHostAuthority = db.IsHostAuthority
+ certChecker.IsRevoked = db.IsRevoked
+ certChecker.HostKeyFallback = db.check
+
+ return certChecker.CheckHostKey, nil
+}
+
+// Normalize normalizes an address into the form used in known_hosts
+func Normalize(address string) string {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ host = address
+ port = "22"
+ }
+ entry := host
+ if port != "22" {
+ entry = "[" + entry + "]:" + port
+ } else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
+ entry = "[" + entry + "]"
+ }
+ return entry
+}
+
+// Line returns a line to add append to the known_hosts files.
+func Line(addresses []string, key ssh.PublicKey) string {
+ var trimmed []string
+ for _, a := range addresses {
+ trimmed = append(trimmed, Normalize(a))
+ }
+
+ return strings.Join(trimmed, ",") + " " + serialize(key)
+}
+
+// HashHostname hashes the given hostname. The hostname is not
+// normalized before hashing.
+func HashHostname(hostname string) string {
+ // TODO(hanwen): check if we can safely normalize this always.
+ salt := make([]byte, sha1.Size)
+
+ _, err := rand.Read(salt)
+ if err != nil {
+ panic(fmt.Sprintf("crypto/rand failure %v", err))
+ }
+
+ hash := hashHost(hostname, salt)
+ return encodeHash(sha1HashType, salt, hash)
+}
+
+func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
+ if len(encoded) == 0 || encoded[0] != '|' {
+ err = errors.New("knownhosts: hashed host must start with '|'")
+ return
+ }
+ components := strings.Split(encoded, "|")
+ if len(components) != 4 {
+ err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
+ return
+ }
+
+ hashType = components[1]
+ if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
+ return
+ }
+ if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
+ return
+ }
+ return
+}
+
+func encodeHash(typ string, salt []byte, hash []byte) string {
+ return strings.Join([]string{"",
+ typ,
+ base64.StdEncoding.EncodeToString(salt),
+ base64.StdEncoding.EncodeToString(hash),
+ }, "|")
+}
+
+// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
+func hashHost(hostname string, salt []byte) []byte {
+ mac := hmac.New(sha1.New, salt)
+ mac.Write([]byte(hostname))
+ return mac.Sum(nil)
+}
+
+type hashedHost struct {
+ salt []byte
+ hash []byte
+}
+
+const sha1HashType = "1"
+
+func newHashedHost(encoded string) (*hashedHost, error) {
+ typ, salt, hash, err := decodeHash(encoded)
+ if err != nil {
+ return nil, err
+ }
+
+ // The type field seems for future algorithm agility, but it's
+ // actually hardcoded in openssh currently, see
+ // https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
+ if typ != sha1HashType {
+ return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
+ }
+
+ return &hashedHost{salt: salt, hash: hash}, nil
+}
+
+func (h *hashedHost) match(addrs []addr) bool {
+ for _, a := range addrs {
+ if bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/ssh/knownhosts/knownhosts_test.go b/ssh/knownhosts/knownhosts_test.go
new file mode 100644
index 0000000..be7cc0e
--- /dev/null
+++ b/ssh/knownhosts/knownhosts_test.go
@@ -0,0 +1,329 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package knownhosts
+
+import (
+ "bytes"
+ "fmt"
+ "net"
+ "reflect"
+ "testing"
+
+ "golang.org/x/crypto/ssh"
+)
+
+const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX"
+const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx"
+const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs="
+
+var ecKey, alternateEdKey, edKey ssh.PublicKey
+var testAddr = &net.TCPAddr{
+ IP: net.IP{198, 41, 30, 196},
+ Port: 22,
+}
+
+var testAddr6 = &net.TCPAddr{
+ IP: net.IP{198, 41, 30, 196,
+ 1, 2, 3, 4,
+ 1, 2, 3, 4,
+ 1, 2, 3, 4,
+ },
+ Port: 22,
+}
+
+func init() {
+ var err error
+ ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr))
+ if err != nil {
+ panic(err)
+ }
+ edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr))
+ if err != nil {
+ panic(err)
+ }
+ alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr))
+ if err != nil {
+ panic(err)
+ }
+}
+
+func testDB(t *testing.T, s string) *hostKeyDB {
+ db := newHostKeyDB()
+ if err := db.Read(bytes.NewBufferString(s), "testdb"); err != nil {
+ t.Fatalf("Read: %v", err)
+ }
+
+ return db
+}
+
+func TestRevoked(t *testing.T) {
+ db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n")
+ want := &RevokedError{
+ Revoked: KnownKey{
+ Key: edKey,
+ Filename: "testdb",
+ Line: 3,
+ },
+ }
+ if err := db.check("", &net.TCPAddr{
+ Port: 42,
+ }, edKey); err == nil {
+ t.Fatal("no error for revoked key")
+ } else if !reflect.DeepEqual(want, err) {
+ t.Fatalf("got %#v, want %#v", want, err)
+ }
+}
+
+func TestHostAuthority(t *testing.T) {
+ for _, m := range []struct {
+ authorityFor string
+ address string
+
+ good bool
+ }{
+ {authorityFor: "localhost", address: "localhost:22", good: true},
+ {authorityFor: "localhost", address: "localhost", good: false},
+ {authorityFor: "localhost", address: "localhost:1234", good: false},
+ {authorityFor: "[localhost]:1234", address: "localhost:1234", good: true},
+ {authorityFor: "[localhost]:1234", address: "localhost:22", good: false},
+ {authorityFor: "[localhost]:1234", address: "localhost", good: false},
+ } {
+ db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr)
+ if ok := db.IsHostAuthority(db.lines[0].knownKey.Key, m.address); ok != m.good {
+ t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v",
+ m.authorityFor, m.address, m.good, ok)
+ }
+ }
+}
+
+func TestBracket(t *testing.T) {
+ db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr)
+
+ if err := db.check("git.eclipse.org:29418", &net.TCPAddr{
+ IP: net.IP{198, 41, 30, 196},
+ Port: 29418,
+ }, edKey); err != nil {
+ t.Errorf("got error %v, want none", err)
+ }
+
+ if err := db.check("git.eclipse.org:29419", &net.TCPAddr{
+ Port: 42,
+ }, edKey); err == nil {
+ t.Fatalf("no error for unknown address")
+ } else if ke, ok := err.(*KeyError); !ok {
+ t.Fatalf("got type %T, want *KeyError", err)
+ } else if len(ke.Want) > 0 {
+ t.Fatalf("got Want %v, want []", ke.Want)
+ }
+}
+
+func TestNewKeyType(t *testing.T) {
+ str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
+ db := testDB(t, str)
+ if err := db.check("", testAddr, ecKey); err == nil {
+ t.Fatalf("no error for unknown address")
+ } else if ke, ok := err.(*KeyError); !ok {
+ t.Fatalf("got type %T, want *KeyError", err)
+ } else if len(ke.Want) == 0 {
+ t.Fatalf("got empty KeyError.Want")
+ }
+}
+
+func TestSameKeyType(t *testing.T) {
+ str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
+ db := testDB(t, str)
+ if err := db.check("", testAddr, alternateEdKey); err == nil {
+ t.Fatalf("no error for unknown address")
+ } else if ke, ok := err.(*KeyError); !ok {
+ t.Fatalf("got type %T, want *KeyError", err)
+ } else if len(ke.Want) == 0 {
+ t.Fatalf("got empty KeyError.Want")
+ } else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) {
+ t.Fatalf("got key %q, want %q", got, want)
+ }
+}
+
+func TestIPAddress(t *testing.T) {
+ str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
+ db := testDB(t, str)
+ if err := db.check("", testAddr, edKey); err != nil {
+ t.Errorf("got error %q, want none", err)
+ }
+}
+
+func TestIPv6Address(t *testing.T) {
+ str := fmt.Sprintf("%s %s", testAddr6, edKeyStr)
+ db := testDB(t, str)
+
+ if err := db.check("", testAddr6, edKey); err != nil {
+ t.Errorf("got error %q, want none", err)
+ }
+}
+
+func TestBasic(t *testing.T) {
+ str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr)
+ db := testDB(t, str)
+ if err := db.check("server.org:22", testAddr, edKey); err != nil {
+ t.Errorf("got error %q, want none", err)
+ }
+
+ want := KnownKey{
+ Key: edKey,
+ Filename: "testdb",
+ Line: 3,
+ }
+ if err := db.check("server.org:22", testAddr, ecKey); err == nil {
+ t.Errorf("succeeded, want KeyError")
+ } else if ke, ok := err.(*KeyError); !ok {
+ t.Errorf("got %T, want *KeyError", err)
+ } else if len(ke.Want) != 1 {
+ t.Errorf("got %v, want 1 entry", ke)
+ } else if !reflect.DeepEqual(ke.Want[0], want) {
+ t.Errorf("got %v, want %v", ke.Want[0], want)
+ }
+}
+
+func TestNegate(t *testing.T) {
+ str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr)
+ db := testDB(t, str)
+ if err := db.check("server.org:22", testAddr, ecKey); err == nil {
+ t.Errorf("succeeded")
+ } else if ke, ok := err.(*KeyError); !ok {
+ t.Errorf("got error type %T, want *KeyError", err)
+ } else if len(ke.Want) != 0 {
+ t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type())
+ }
+}
+
+func TestWildcard(t *testing.T) {
+ str := fmt.Sprintf("server*.domain %s", edKeyStr)
+ db := testDB(t, str)
+
+ want := &KeyError{
+ Want: []KnownKey{{
+ Filename: "testdb",
+ Line: 1,
+ Key: edKey,
+ }},
+ }
+
+ got := db.check("server.domain:22", &net.TCPAddr{}, ecKey)
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("got %s, want %s", got, want)
+ }
+}
+
+func TestLine(t *testing.T) {
+ for in, want := range map[string]string{
+ "server.org": "server.org " + edKeyStr,
+ "server.org:22": "server.org " + edKeyStr,
+ "server.org:23": "[server.org]:23 " + edKeyStr,
+ "[c629:1ec4:102:304:102:304:102:304]:22": "[c629:1ec4:102:304:102:304:102:304] " + edKeyStr,
+ "[c629:1ec4:102:304:102:304:102:304]:23": "[c629:1ec4:102:304:102:304:102:304]:23 " + edKeyStr,
+ } {
+ if got := Line([]string{in}, edKey); got != want {
+ t.Errorf("Line(%q) = %q, want %q", in, got, want)
+ }
+ }
+}
+
+func TestWildcardMatch(t *testing.T) {
+ for _, c := range []struct {
+ pat, str string
+ want bool
+ }{
+ {"a?b", "abb", true},
+ {"ab", "abc", false},
+ {"abc", "ab", false},
+ {"a*b", "axxxb", true},
+ {"a*b", "axbxb", true},
+ {"a*b", "axbxbc", false},
+ {"a*?", "axbxc", true},
+ {"a*b*", "axxbxxxxxx", true},
+ {"a*b*c", "axxbxxxxxxc", true},
+ {"a*b*?", "axxbxxxxxxc", true},
+ {"a*b*z", "axxbxxbxxxz", true},
+ {"a*b*z", "axxbxxzxxxz", true},
+ {"a*b*z", "axxbxxzxxx", false},
+ } {
+ got := wildcardMatch([]byte(c.pat), []byte(c.str))
+ if got != c.want {
+ t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want)
+ }
+
+ }
+}
+
+// TODO(hanwen): test coverage for certificates.
+
+const testHostname = "hostname"
+
+// generated with keygen -H -f
+const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y="
+
+func TestHostHash(t *testing.T) {
+ testHostHash(t, testHostname, encodedTestHostnameHash)
+}
+
+func TestHashList(t *testing.T) {
+ encoded := HashHostname(testHostname)
+ testHostHash(t, testHostname, encoded)
+}
+
+func testHostHash(t *testing.T, hostname, encoded string) {
+ typ, salt, hash, err := decodeHash(encoded)
+ if err != nil {
+ t.Fatalf("decodeHash: %v", err)
+ }
+
+ if got := encodeHash(typ, salt, hash); got != encoded {
+ t.Errorf("got encoding %s want %s", got, encoded)
+ }
+
+ if typ != sha1HashType {
+ t.Fatalf("got hash type %q, want %q", typ, sha1HashType)
+ }
+
+ got := hashHost(hostname, salt)
+ if !bytes.Equal(got, hash) {
+ t.Errorf("got hash %x want %x", got, hash)
+ }
+}
+
+func TestNormalize(t *testing.T) {
+ for in, want := range map[string]string{
+ "127.0.0.1:22": "127.0.0.1",
+ "[127.0.0.1]:22": "127.0.0.1",
+ "[127.0.0.1]:23": "[127.0.0.1]:23",
+ "127.0.0.1:23": "[127.0.0.1]:23",
+ "[a.b.c]:22": "a.b.c",
+ "[abcd:abcd:abcd:abcd]": "[abcd:abcd:abcd:abcd]",
+ "[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]",
+ "[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23",
+ } {
+ got := Normalize(in)
+ if got != want {
+ t.Errorf("Normalize(%q) = %q, want %q", in, got, want)
+ }
+ }
+}
+
+func TestHashedHostkeyCheck(t *testing.T) {
+ str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr)
+ db := testDB(t, str)
+ if err := db.check(testHostname+":22", testAddr, edKey); err != nil {
+ t.Errorf("check(%s): %v", testHostname, err)
+ }
+ want := &KeyError{
+ Want: []KnownKey{{
+ Filename: "testdb",
+ Line: 1,
+ Key: edKey,
+ }},
+ }
+ if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) {
+ t.Errorf("got error %v, want %v", got, want)
+ }
+}
diff --git a/ssh/server.go b/ssh/server.go
index 77c84d1..23b41d9 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -45,6 +45,12 @@ type ServerConfig struct {
// authenticating.
NoClientAuth bool
+ // MaxAuthTries specifies the maximum number of authentication attempts
+ // permitted per connection. If set to a negative number, the number of
+ // attempts are unlimited. If set to zero, the number of attempts are limited
+ // to 6.
+ MaxAuthTries int
+
// PasswordCallback, if non-nil, is called when a user
// attempts to authenticate using a password.
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
@@ -143,6 +149,10 @@ type ServerConn struct {
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
+ if fullConf.MaxAuthTries == 0 {
+ fullConf.MaxAuthTries = 6
+ }
+
s := &connection{
sshConn: sshConn{conn: c},
}
@@ -267,8 +277,23 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
var cache pubKeyCache
var perms *Permissions
+ authFailures := 0
+
userAuthLoop:
for {
+ if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
+ discMsg := &disconnectMsg{
+ Reason: 2,
+ Message: "too many authentication failures",
+ }
+
+ if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
+ return nil, err
+ }
+
+ return nil, discMsg
+ }
+
var userAuthReq userAuthRequestMsg
if packet, err := s.transport.readPacket(); err != nil {
return nil, err
@@ -289,6 +314,11 @@ userAuthLoop:
if config.NoClientAuth {
authErr = nil
}
+
+ // allow initial attempt of 'none' without penalty
+ if authFailures == 0 {
+ authFailures--
+ }
case "password":
if config.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured")
@@ -360,6 +390,7 @@ userAuthLoop:
if isQuery {
// The client can query if the given public key
// would be okay.
+
if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
@@ -409,6 +440,8 @@ userAuthLoop:
break userAuthLoop
}
+ authFailures++
+
var failureMsg userAuthFailureMsg
if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")
diff --git a/ssh/session_test.go b/ssh/session_test.go
index f35a378..7dce6dd 100644
--- a/ssh/session_test.go
+++ b/ssh/session_test.go
@@ -59,7 +59,8 @@ func dial(handler serverType, t *testing.T) *Client {
}()
config := &ClientConfig{
- User: "testuser",
+ User: "testuser",
+ HostKeyCallback: InsecureIgnoreHostKey(),
}
conn, chans, reqs, err := NewClientConn(c2, "", config)
@@ -641,7 +642,8 @@ func TestSessionID(t *testing.T) {
}
serverConf.AddHostKey(testSigners["ecdsa"])
clientConf := &ClientConfig{
- User: "user",
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ User: "user",
}
go func() {
@@ -747,7 +749,9 @@ func TestHostKeyAlgorithms(t *testing.T) {
// By default, we get the preferred algorithm, which is ECDSA 256.
- clientConf := &ClientConfig{}
+ clientConf := &ClientConfig{
+ HostKeyCallback: InsecureIgnoreHostKey(),
+ }
connect(clientConf, KeyAlgoECDSA256)
// Client asks for RSA explicitly.
diff --git a/ssh/streamlocal.go b/ssh/streamlocal.go
new file mode 100644
index 0000000..a2dccc6
--- /dev/null
+++ b/ssh/streamlocal.go
@@ -0,0 +1,115 @@
+package ssh
+
+import (
+ "errors"
+ "io"
+ "net"
+)
+
+// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
+// with "direct-streamlocal@openssh.com" string.
+//
+// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
+// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
+type streamLocalChannelOpenDirectMsg struct {
+ socketPath string
+ reserved0 string
+ reserved1 uint32
+}
+
+// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
+// with "forwarded-streamlocal@openssh.com" string.
+type forwardedStreamLocalPayload struct {
+ SocketPath string
+ Reserved0 string
+}
+
+// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
+// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
+type streamLocalChannelForwardMsg struct {
+ socketPath string
+}
+
+// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
+func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
+ m := streamLocalChannelForwardMsg{
+ socketPath,
+ }
+ // send message
+ ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
+ if err != nil {
+ return nil, err
+ }
+ if !ok {
+ return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
+ }
+ ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
+
+ return &unixListener{socketPath, c, ch}, nil
+}
+
+func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
+ msg := streamLocalChannelOpenDirectMsg{
+ socketPath: socketPath,
+ }
+ ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
+ if err != nil {
+ return nil, err
+ }
+ go DiscardRequests(in)
+ return ch, err
+}
+
+type unixListener struct {
+ socketPath string
+
+ conn *Client
+ in <-chan forward
+}
+
+// Accept waits for and returns the next connection to the listener.
+func (l *unixListener) Accept() (net.Conn, error) {
+ s, ok := <-l.in
+ if !ok {
+ return nil, io.EOF
+ }
+ ch, incoming, err := s.newCh.Accept()
+ if err != nil {
+ return nil, err
+ }
+ go DiscardRequests(incoming)
+
+ return &chanConn{
+ Channel: ch,
+ laddr: &net.UnixAddr{
+ Name: l.socketPath,
+ Net: "unix",
+ },
+ raddr: &net.UnixAddr{
+ Name: "@",
+ Net: "unix",
+ },
+ }, nil
+}
+
+// Close closes the listener.
+func (l *unixListener) Close() error {
+ // this also closes the listener.
+ l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
+ m := streamLocalChannelForwardMsg{
+ l.socketPath,
+ }
+ ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
+ if err == nil && !ok {
+ err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
+ }
+ return err
+}
+
+// Addr returns the listener's network address.
+func (l *unixListener) Addr() net.Addr {
+ return &net.UnixAddr{
+ Name: l.socketPath,
+ Net: "unix",
+ }
+}
diff --git a/ssh/tcpip.go b/ssh/tcpip.go
index 6151241..acf1717 100644
--- a/ssh/tcpip.go
+++ b/ssh/tcpip.go
@@ -20,12 +20,20 @@ import (
// addr. Incoming connections will be available by calling Accept on
// the returned net.Listener. The listener must be serviced, or the
// SSH connection may hang.
+// N must be "tcp", "tcp4", "tcp6", or "unix".
func (c *Client) Listen(n, addr string) (net.Listener, error) {
- laddr, err := net.ResolveTCPAddr(n, addr)
- if err != nil {
- return nil, err
+ switch n {
+ case "tcp", "tcp4", "tcp6":
+ laddr, err := net.ResolveTCPAddr(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ return c.ListenTCP(laddr)
+ case "unix":
+ return c.ListenUnix(addr)
+ default:
+ return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
}
- return c.ListenTCP(laddr)
}
// Automatic port allocation is broken with OpenSSH before 6.0. See
@@ -116,7 +124,7 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
}
// Register this forward, using the port number we obtained.
- ch := c.forwards.add(*laddr)
+ ch := c.forwards.add(laddr)
return &tcpListener{laddr, c, ch}, nil
}
@@ -131,7 +139,7 @@ type forwardList struct {
// forwardEntry represents an established mapping of a laddr on a
// remote ssh server to a channel connected to a tcpListener.
type forwardEntry struct {
- laddr net.TCPAddr
+ laddr net.Addr
c chan forward
}
@@ -139,16 +147,16 @@ type forwardEntry struct {
// arguments to add/remove/lookup should be address as specified in
// the original forward-request.
type forward struct {
- newCh NewChannel // the ssh client channel underlying this forward
- raddr *net.TCPAddr // the raddr of the incoming connection
+ newCh NewChannel // the ssh client channel underlying this forward
+ raddr net.Addr // the raddr of the incoming connection
}
-func (l *forwardList) add(addr net.TCPAddr) chan forward {
+func (l *forwardList) add(addr net.Addr) chan forward {
l.Lock()
defer l.Unlock()
f := forwardEntry{
- addr,
- make(chan forward, 1),
+ laddr: addr,
+ c: make(chan forward, 1),
}
l.entries = append(l.entries, f)
return f.c
@@ -176,44 +184,69 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
func (l *forwardList) handleChannels(in <-chan NewChannel) {
for ch := range in {
- var payload forwardedTCPPayload
- if err := Unmarshal(ch.ExtraData(), &payload); err != nil {
- ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
- continue
+ var (
+ laddr net.Addr
+ raddr net.Addr
+ err error
+ )
+ switch channelType := ch.ChannelType(); channelType {
+ case "forwarded-tcpip":
+ var payload forwardedTCPPayload
+ if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
+ ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
+ continue
+ }
+
+ // RFC 4254 section 7.2 specifies that incoming
+ // addresses should list the address, in string
+ // format. It is implied that this should be an IP
+ // address, as it would be impossible to connect to it
+ // otherwise.
+ laddr, err = parseTCPAddr(payload.Addr, payload.Port)
+ if err != nil {
+ ch.Reject(ConnectionFailed, err.Error())
+ continue
+ }
+ raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
+ if err != nil {
+ ch.Reject(ConnectionFailed, err.Error())
+ continue
+ }
+
+ case "forwarded-streamlocal@openssh.com":
+ var payload forwardedStreamLocalPayload
+ if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
+ ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
+ continue
+ }
+ laddr = &net.UnixAddr{
+ Name: payload.SocketPath,
+ Net: "unix",
+ }
+ raddr = &net.UnixAddr{
+ Name: "@",
+ Net: "unix",
+ }
+ default:
+ panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
}
-
- // RFC 4254 section 7.2 specifies that incoming
- // addresses should list the address, in string
- // format. It is implied that this should be an IP
- // address, as it would be impossible to connect to it
- // otherwise.
- laddr, err := parseTCPAddr(payload.Addr, payload.Port)
- if err != nil {
- ch.Reject(ConnectionFailed, err.Error())
- continue
- }
- raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort)
- if err != nil {
- ch.Reject(ConnectionFailed, err.Error())
- continue
- }
-
- if ok := l.forward(*laddr, *raddr, ch); !ok {
+ if ok := l.forward(laddr, raddr, ch); !ok {
// Section 7.2, implementations MUST reject spurious incoming
// connections.
ch.Reject(Prohibited, "no forward for address")
continue
}
+
}
}
// remove removes the forward entry, and the channel feeding its
// listener.
-func (l *forwardList) remove(addr net.TCPAddr) {
+func (l *forwardList) remove(addr net.Addr) {
l.Lock()
defer l.Unlock()
for i, f := range l.entries {
- if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
+ if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
l.entries = append(l.entries[:i], l.entries[i+1:]...)
close(f.c)
return
@@ -231,12 +264,12 @@ func (l *forwardList) closeAll() {
l.entries = nil
}
-func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool {
+func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
- if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port {
- f.c <- forward{ch, &raddr}
+ if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
+ f.c <- forward{newCh: ch, raddr: raddr}
return true
}
}
@@ -262,7 +295,7 @@ func (l *tcpListener) Accept() (net.Conn, error) {
}
go DiscardRequests(incoming)
- return &tcpChanConn{
+ return &chanConn{
Channel: ch,
laddr: l.laddr,
raddr: s.raddr,
@@ -277,7 +310,7 @@ func (l *tcpListener) Close() error {
}
// this also closes the listener.
- l.conn.forwards.remove(*l.laddr)
+ l.conn.forwards.remove(l.laddr)
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-tcpip-forward failed")
@@ -293,29 +326,52 @@ func (l *tcpListener) Addr() net.Addr {
// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) {
- // Parse the address into host and numeric port.
- host, portString, err := net.SplitHostPort(addr)
- if err != nil {
- return nil, err
- }
- port, err := strconv.ParseUint(portString, 10, 16)
- if err != nil {
- return nil, err
- }
- // Use a zero address for local and remote address.
- zeroAddr := &net.TCPAddr{
- IP: net.IPv4zero,
- Port: 0,
- }
- ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port))
- if err != nil {
- return nil, err
+ var ch Channel
+ switch n {
+ case "tcp", "tcp4", "tcp6":
+ // Parse the address into host and numeric port.
+ host, portString, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err
+ }
+ port, err := strconv.ParseUint(portString, 10, 16)
+ if err != nil {
+ return nil, err
+ }
+ ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
+ if err != nil {
+ return nil, err
+ }
+ // Use a zero address for local and remote address.
+ zeroAddr := &net.TCPAddr{
+ IP: net.IPv4zero,
+ Port: 0,
+ }
+ return &chanConn{
+ Channel: ch,
+ laddr: zeroAddr,
+ raddr: zeroAddr,
+ }, nil
+ case "unix":
+ var err error
+ ch, err = c.dialStreamLocal(addr)
+ if err != nil {
+ return nil, err
+ }
+ return &chanConn{
+ Channel: ch,
+ laddr: &net.UnixAddr{
+ Name: "@",
+ Net: "unix",
+ },
+ raddr: &net.UnixAddr{
+ Name: addr,
+ Net: "unix",
+ },
+ }, nil
+ default:
+ return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
}
- return &tcpChanConn{
- Channel: ch,
- laddr: zeroAddr,
- raddr: zeroAddr,
- }, nil
}
// DialTCP connects to the remote address raddr on the network net,
@@ -332,7 +388,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)
if err != nil {
return nil, err
}
- return &tcpChanConn{
+ return &chanConn{
Channel: ch,
laddr: laddr,
raddr: raddr,
@@ -366,26 +422,26 @@ type tcpChan struct {
Channel // the backing channel
}
-// tcpChanConn fulfills the net.Conn interface without
+// chanConn fulfills the net.Conn interface without
// the tcpChan having to hold laddr or raddr directly.
-type tcpChanConn struct {
+type chanConn struct {
Channel
laddr, raddr net.Addr
}
// LocalAddr returns the local network address.
-func (t *tcpChanConn) LocalAddr() net.Addr {
+func (t *chanConn) LocalAddr() net.Addr {
return t.laddr
}
// RemoteAddr returns the remote network address.
-func (t *tcpChanConn) RemoteAddr() net.Addr {
+func (t *chanConn) RemoteAddr() net.Addr {
return t.raddr
}
// SetDeadline sets the read and write deadlines associated
// with the connection.
-func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
+func (t *chanConn) SetDeadline(deadline time.Time) error {
if err := t.SetReadDeadline(deadline); err != nil {
return err
}
@@ -396,12 +452,14 @@ func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
// A zero value for t means Read will not time out.
// After the deadline, the error from Read will implement net.Error
// with Timeout() == true.
-func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error {
+func (t *chanConn) SetReadDeadline(deadline time.Time) error {
+ // for compatibility with previous version,
+ // the error message contains "tcpChan"
return errors.New("ssh: tcpChan: deadline not supported")
}
// SetWriteDeadline exists to satisfy the net.Conn interface
// but is not implemented by this type. It always returns an error.
-func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error {
+func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
return errors.New("ssh: tcpChan: deadline not supported")
}
diff --git a/ssh/terminal/util_solaris.go b/ssh/terminal/util_solaris.go
index 07eb5ed..a2e1b57 100644
--- a/ssh/terminal/util_solaris.go
+++ b/ssh/terminal/util_solaris.go
@@ -14,14 +14,12 @@ import (
// State contains the state of a terminal.
type State struct {
- termios syscall.Termios
+ state *unix.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
- // see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c
- var termio unix.Termio
- err := unix.IoctlSetTermio(fd, unix.TCGETA, &termio)
+ _, err := unix.IoctlGetTermio(fd, unix.TCGETA)
return err == nil
}
@@ -71,3 +69,60 @@ func ReadPassword(fd int) ([]byte, error) {
return ret, nil
}
+
+// MakeRaw puts the terminal connected to the given file descriptor into raw
+// mode and returns the previous state of the terminal so that it can be
+// restored.
+// see http://cr.illumos.org/~webrev/andy_js/1060/
+func MakeRaw(fd int) (*State, error) {
+ oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
+ if err != nil {
+ return nil, err
+ }
+ oldTermios := *oldTermiosPtr
+
+ newTermios := oldTermios
+ newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON
+ newTermios.Oflag &^= syscall.OPOST
+ newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN
+ newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB
+ newTermios.Cflag |= syscall.CS8
+ newTermios.Cc[unix.VMIN] = 1
+ newTermios.Cc[unix.VTIME] = 0
+
+ if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil {
+ return nil, err
+ }
+
+ return &State{
+ state: oldTermiosPtr,
+ }, nil
+}
+
+// Restore restores the terminal connected to the given file descriptor to a
+// previous state.
+func Restore(fd int, oldState *State) error {
+ return unix.IoctlSetTermios(fd, unix.TCSETS, oldState.state)
+}
+
+// GetState returns the current state of a terminal which may be useful to
+// restore the terminal after a signal.
+func GetState(fd int) (*State, error) {
+ oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS)
+ if err != nil {
+ return nil, err
+ }
+
+ return &State{
+ state: oldTermiosPtr,
+ }, nil
+}
+
+// GetSize returns the dimensions of the given terminal.
+func GetSize(fd int) (width, height int, err error) {
+ ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
+ if err != nil {
+ return 0, 0, err
+ }
+ return int(ws.Col), int(ws.Row), nil
+}
diff --git a/ssh/test/cert_test.go b/ssh/test/cert_test.go
index 364790f..b231dd8 100644
--- a/ssh/test/cert_test.go
+++ b/ssh/test/cert_test.go
@@ -7,12 +7,14 @@
package test
import (
+ "bytes"
"crypto/rand"
"testing"
"golang.org/x/crypto/ssh"
)
+// Test both logging in with a cert, and also that the certificate presented by an OpenSSH host can be validated correctly
func TestCertLogin(t *testing.T) {
s := newServer(t)
defer s.Shutdown()
@@ -37,11 +39,39 @@ func TestCertLogin(t *testing.T) {
conf := &ssh.ClientConfig{
User: username(),
+ HostKeyCallback: (&ssh.CertChecker{
+ IsHostAuthority: func(pk ssh.PublicKey, addr string) bool {
+ return bytes.Equal(pk.Marshal(), testPublicKeys["ca"].Marshal())
+ },
+ }).CheckHostKey,
}
conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner))
- client, err := s.TryDial(conf)
- if err != nil {
- t.Fatalf("TryDial: %v", err)
+
+ for _, test := range []struct {
+ addr string
+ succeed bool
+ }{
+ {addr: "host.example.com:22", succeed: true},
+ {addr: "host.example.com:10000", succeed: true}, // non-standard port must be OK
+ {addr: "host.example.com", succeed: false}, // port must be specified
+ {addr: "host.ex4mple.com:22", succeed: false}, // wrong host
+ } {
+ client, err := s.TryDialWithAddr(conf, test.addr)
+
+ // Always close client if opened successfully
+ if err == nil {
+ client.Close()
+ }
+
+ // Now evaluate whether the test failed or passed
+ if test.succeed {
+ if err != nil {
+ t.Fatalf("TryDialWithAddr: %v", err)
+ }
+ } else {
+ if err == nil {
+ t.Fatalf("TryDialWithAddr, unexpected success")
+ }
+ }
}
- client.Close()
}
diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go
new file mode 100644
index 0000000..091e48c
--- /dev/null
+++ b/ssh/test/dial_unix_test.go
@@ -0,0 +1,128 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !windows
+
+package test
+
+// direct-tcpip and direct-streamlocal functional tests
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "strings"
+ "testing"
+)
+
+type dialTester interface {
+ TestServerConn(t *testing.T, c net.Conn)
+ TestClientConn(t *testing.T, c net.Conn)
+}
+
+func testDial(t *testing.T, n, listenAddr string, x dialTester) {
+ server := newServer(t)
+ defer server.Shutdown()
+ sshConn := server.Dial(clientConfig())
+ defer sshConn.Close()
+
+ l, err := net.Listen(n, listenAddr)
+ if err != nil {
+ t.Fatalf("Listen: %v", err)
+ }
+ defer l.Close()
+
+ testData := fmt.Sprintf("hello from %s, %s", n, listenAddr)
+ go func() {
+ for {
+ c, err := l.Accept()
+ if err != nil {
+ break
+ }
+ x.TestServerConn(t, c)
+
+ io.WriteString(c, testData)
+ c.Close()
+ }
+ }()
+
+ conn, err := sshConn.Dial(n, l.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ x.TestClientConn(t, conn)
+ defer conn.Close()
+ b, err := ioutil.ReadAll(conn)
+ if err != nil {
+ t.Fatalf("ReadAll: %v", err)
+ }
+ t.Logf("got %q", string(b))
+ if string(b) != testData {
+ t.Fatalf("expected %q, got %q", testData, string(b))
+ }
+}
+
+type tcpDialTester struct {
+ listenAddr string
+}
+
+func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) {
+ host := strings.Split(x.listenAddr, ":")[0]
+ prefix := host + ":"
+ if !strings.HasPrefix(c.LocalAddr().String(), prefix) {
+ t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String())
+ }
+ if !strings.HasPrefix(c.RemoteAddr().String(), prefix) {
+ t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String())
+ }
+}
+
+func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) {
+ // we use zero addresses. see *Client.Dial.
+ if c.LocalAddr().String() != "0.0.0.0:0" {
+ t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String())
+ }
+ if c.RemoteAddr().String() != "0.0.0.0:0" {
+ t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String())
+ }
+}
+
+func TestDialTCP(t *testing.T) {
+ x := &tcpDialTester{
+ listenAddr: "127.0.0.1:0",
+ }
+ testDial(t, "tcp", x.listenAddr, x)
+}
+
+type unixDialTester struct {
+ listenAddr string
+}
+
+func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) {
+ if c.LocalAddr().String() != x.listenAddr {
+ t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String())
+ }
+ if c.RemoteAddr().String() != "@" {
+ t.Fatalf("expected \"@\", got %q", c.RemoteAddr().String())
+ }
+}
+
+func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) {
+ if c.RemoteAddr().String() != x.listenAddr {
+ t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String())
+ }
+ if c.LocalAddr().String() != "@" {
+ t.Fatalf("expected \"@\", got %q", c.LocalAddr().String())
+ }
+}
+
+func TestDialUnix(t *testing.T) {
+ addr, cleanup := newTempSocket(t)
+ defer cleanup()
+ x := &unixDialTester{
+ listenAddr: addr,
+ }
+ testDial(t, "unix", x.listenAddr, x)
+}
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go
index 877a88c..ea81937 100644
--- a/ssh/test/forward_unix_test.go
+++ b/ssh/test/forward_unix_test.go
@@ -16,13 +16,17 @@ import (
"time"
)
-func TestPortForward(t *testing.T) {
+type closeWriter interface {
+ CloseWrite() error
+}
+
+func testPortForward(t *testing.T, n, listenAddr string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
- sshListener, err := conn.Listen("tcp", "localhost:0")
+ sshListener, err := conn.Listen(n, listenAddr)
if err != nil {
t.Fatal(err)
}
@@ -41,14 +45,14 @@ func TestPortForward(t *testing.T) {
}()
forwardedAddr := sshListener.Addr().String()
- tcpConn, err := net.Dial("tcp", forwardedAddr)
+ netConn, err := net.Dial(n, forwardedAddr)
if err != nil {
- t.Fatalf("TCP dial failed: %v", err)
+ t.Fatalf("net dial failed: %v", err)
}
readChan := make(chan []byte)
go func() {
- data, _ := ioutil.ReadAll(tcpConn)
+ data, _ := ioutil.ReadAll(netConn)
readChan <- data
}()
@@ -62,14 +66,14 @@ func TestPortForward(t *testing.T) {
for len(sent) < 1000*1000 {
// Send random sized chunks
m := rand.Intn(len(data))
- n, err := tcpConn.Write(data[:m])
+ n, err := netConn.Write(data[:m])
if err != nil {
break
}
sent = append(sent, data[:n]...)
}
- if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
- t.Errorf("tcpConn.CloseWrite: %v", err)
+ if err := netConn.(closeWriter).CloseWrite(); err != nil {
+ t.Errorf("netConn.CloseWrite: %v", err)
}
read := <-readChan
@@ -86,19 +90,29 @@ func TestPortForward(t *testing.T) {
}
// Check that the forward disappeared.
- tcpConn, err = net.Dial("tcp", forwardedAddr)
+ netConn, err = net.Dial(n, forwardedAddr)
if err == nil {
- tcpConn.Close()
+ netConn.Close()
t.Errorf("still listening to %s after closing", forwardedAddr)
}
}
-func TestAcceptClose(t *testing.T) {
+func TestPortForwardTCP(t *testing.T) {
+ testPortForward(t, "tcp", "localhost:0")
+}
+
+func TestPortForwardUnix(t *testing.T) {
+ addr, cleanup := newTempSocket(t)
+ defer cleanup()
+ testPortForward(t, "unix", addr)
+}
+
+func testAcceptClose(t *testing.T, n, listenAddr string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
- sshListener, err := conn.Listen("tcp", "localhost:0")
+ sshListener, err := conn.Listen(n, listenAddr)
if err != nil {
t.Fatal(err)
}
@@ -124,13 +138,23 @@ func TestAcceptClose(t *testing.T) {
}
}
+func TestAcceptCloseTCP(t *testing.T) {
+ testAcceptClose(t, "tcp", "localhost:0")
+}
+
+func TestAcceptCloseUnix(t *testing.T) {
+ addr, cleanup := newTempSocket(t)
+ defer cleanup()
+ testAcceptClose(t, "unix", addr)
+}
+
// Check that listeners exit if the underlying client transport dies.
-func TestPortForwardConnectionClose(t *testing.T) {
+func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
- sshListener, err := conn.Listen("tcp", "localhost:0")
+ sshListener, err := conn.Listen(n, listenAddr)
if err != nil {
t.Fatal(err)
}
@@ -158,3 +182,13 @@ func TestPortForwardConnectionClose(t *testing.T) {
t.Logf("quit as expected (error %v)", err)
}
}
+
+func TestPortForwardConnectionCloseTCP(t *testing.T) {
+ testPortForwardConnectionClose(t, "tcp", "localhost:0")
+}
+
+func TestPortForwardConnectionCloseUnix(t *testing.T) {
+ addr, cleanup := newTempSocket(t)
+ defer cleanup()
+ testPortForwardConnectionClose(t, "unix", addr)
+}
diff --git a/ssh/test/tcpip_test.go b/ssh/test/tcpip_test.go
deleted file mode 100644
index a2eb935..0000000
--- a/ssh/test/tcpip_test.go
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2012 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build !windows
-
-package test
-
-// direct-tcpip functional tests
-
-import (
- "io"
- "net"
- "testing"
-)
-
-func TestDial(t *testing.T) {
- server := newServer(t)
- defer server.Shutdown()
- sshConn := server.Dial(clientConfig())
- defer sshConn.Close()
-
- l, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- t.Fatalf("Listen: %v", err)
- }
- defer l.Close()
-
- go func() {
- for {
- c, err := l.Accept()
- if err != nil {
- break
- }
-
- io.WriteString(c, c.RemoteAddr().String())
- c.Close()
- }
- }()
-
- conn, err := sshConn.Dial("tcp", l.Addr().String())
- if err != nil {
- t.Fatalf("Dial: %v", err)
- }
- defer conn.Close()
-}
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go
index 3bfd881..e673536 100644
--- a/ssh/test/test_unix_test.go
+++ b/ssh/test/test_unix_test.go
@@ -30,6 +30,7 @@ Protocol 2
HostKey {{.Dir}}/id_rsa
HostKey {{.Dir}}/id_dsa
HostKey {{.Dir}}/id_ecdsa
+HostCertificate {{.Dir}}/id_rsa-cert.pub
Pidfile {{.Dir}}/sshd.pid
#UsePrivilegeSeparation no
KeyRegenerationInterval 3600
@@ -119,6 +120,11 @@ func clientConfig() *ssh.ClientConfig {
ssh.PublicKeys(testSigners["user"]),
},
HostKeyCallback: hostKeyDB().Check,
+ HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker
+ ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
+ ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
+ ssh.KeyAlgoED25519,
+ },
}
return config
}
@@ -154,6 +160,12 @@ func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
}
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
+ return s.TryDialWithAddr(config, "")
+}
+
+// addr is the user specified host:port. While we don't actually dial it,
+// we need to know this for host key matching
+func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
sshd, err := exec.LookPath("sshd")
if err != nil {
s.t.Skipf("skipping test: %v", err)
@@ -179,7 +191,7 @@ func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
s.t.Fatalf("s.cmd.Start: %v", err)
}
s.clientConn = c1
- conn, chans, reqs, err := ssh.NewClientConn(c1, "", config)
+ conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
if err != nil {
return nil, err
}
@@ -250,6 +262,11 @@ func newServer(t *testing.T) *server {
writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
}
+ for k, v := range testdata.SSHCertificates {
+ filename := "id_" + k + "-cert.pub"
+ writeFile(filepath.Join(dir, filename), v)
+ }
+
var authkeys bytes.Buffer
for k, _ := range testdata.PEMBytes {
authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
@@ -266,3 +283,13 @@ func newServer(t *testing.T) *server {
},
}
}
+
+func newTempSocket(t *testing.T) (string, func()) {
+ dir, err := ioutil.TempDir("", "socket")
+ if err != nil {
+ t.Fatal(err)
+ }
+ deferFunc := func() { os.RemoveAll(dir) }
+ addr := filepath.Join(dir, "sock")
+ return addr, deferFunc
+}
diff --git a/ssh/testdata/keys.go b/ssh/testdata/keys.go
index 736dad9..3b3d26c 100644
--- a/ssh/testdata/keys.go
+++ b/ssh/testdata/keys.go
@@ -48,12 +48,69 @@ AAAEAaYmXltfW6nhRo3iWGglRB48lYq0z0Q3I3KyrdutEr6j7d/uFLuDlRbBc4ZVOsx+Gb
HKuOrPtLHFvHsjWPwO+/AAAAE2dhcnRvbm1AZ2FydG9ubS14cHMBAg==
-----END OPENSSH PRIVATE KEY-----
`),
+ "rsa-openssh-format": []byte(`-----BEGIN OPENSSH PRIVATE KEY-----
+b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn
+NhAAAAAwEAAQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5l
+oEuW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lz
+a+yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAIQWL0H31i9B98AAAAH
+c3NoLXJzYQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5loE
+uW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lza+
+yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAADAQABAAAAgCThyTGsT4
+IARDxVMhWl6eiB2ZrgFgWSeJm/NOqtppWgOebsIqPMMg4UVuVFsl422/lE3RkPhVkjGXgE
+pWvZAdCnmLmApK8wK12vF334lZhZT7t3Z9EzJps88PWEHo7kguf285HcnUM7FlFeissJdk
+kXly34y7/3X/a6Tclm+iABAAAAQE0xR/KxZ39slwfMv64Rz7WKk1PPskaryI29aHE3mKHk
+pY2QA+P3QlrKxT/VWUMjHUbNNdYfJm48xu0SGNMRdKMAAABBAORh2NP/06JUV3J9W/2Hju
+X1ViJuqqcQnJPVzpgSL826EC2xwOECTqoY8uvFpUdD7CtpksIxNVqRIhuNOlz0lqEAAABB
+ANkaHTTaPojClO0dKJ/Zjs7pWOCGliebBYprQ/Y4r9QLBkC/XaWMS26gFIrjgC7D2Rv+rZ
+wSD0v0RcmkITP1ZR0AAAAYcHF1ZXJuYUBMdWNreUh5ZHJvLmxvY2FsAQID
+-----END OPENSSH PRIVATE KEY-----`),
"user": []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49
AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD
PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w==
-----END EC PRIVATE KEY-----
`),
+ "ca": []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEAvg9dQ9IRG59lYJb+GESfKWTch4yBpr7Ydw1jkK6vvtrx9jLo
+5hkA8X6+ElRPRqTAZSlN5cBm6YCAcQIOsmXDUn6Oj1lVPQAoOjTBTvsjM3NjGhvv
+52kHTY0nsMsBeY9q5DTtlzmlYkVUq2a6Htgf2mNi01dIw5fJ7uTTo8EbNf7O0i3u
+c9a8P19HaZl5NKiWN4EIZkfB2WdXYRJCVBsGgQj3dE/GrEmH9QINq1A+GkNvK96u
+vZm8H1jjmuqzHplWa7lFeXcx8FTVTbVb/iJrZ2Lc/JvIPitKZWhqbR59yrGjpwEp
+Id7bo4WhO5L3OB0fSIJYvfu+o4WYnt4f3UzecwIDAQABAoIBABRD9yHgKErVuC2Q
+bA+SYZY8VvdtF/X7q4EmQFORDNRA7EPgMc03JU6awRGbQ8i4kHs46EFzPoXvWcKz
+AXYsO6N0Myc900Tp22A5d9NAHATEbPC/wdje7hRq1KyZONMJY9BphFv3nZbY5apR
+Dc90JBFZP5RhXjTc3n9GjvqLAKfFEKVmPRCvqxCOZunw6XR+SgIQLJo36nsIsbhW
+QUXIVaCI6cXMN8bRPm8EITdBNZu06Fpu4ZHm6VaxlXN9smERCDkgBSNXNWHKxmmA
+c3Glo2DByUr2/JFBOrLEe9fkYgr24KNCQkHVcSaFxEcZvTggr7StjKISVHlCNEaB
+7Q+kPoECgYEA3zE9FmvFGoQCU4g4Nl3dpQHs6kaAW8vJlrmq3xsireIuaJoa2HMe
+wYdIvgCnK9DIjyxd5OWnE4jXtAEYPsyGD32B5rSLQrRO96lgb3f4bESCLUb3Bsn/
+sdgeE3p1xZMA0B59htqCrvVgN9k8WxyevBxYl3/gSBm/p8OVH1RTW/ECgYEA2f9Z
+95OLj0KQHQtxQXf+I3VjhCw3LkLW39QZOXVI0QrCJfqqP7uxsJXH9NYX0l0GFTcR
+kRrlyoaSU1EGQosZh+n1MvplGBTkTSV47/bPsTzFpgK2NfEZuFm9RoWgltS+nYeH
+Y2k4mnAN3PhReCMwuprmJz8GRLsO3Cs2s2YylKMCgYEA2UX+uO/q7jgqZ5UJW+ue
+1H5+W0aMuFA3i7JtZEnvRaUVFqFGlwXin/WJ2+WY1++k/rPrJ+Rk9IBXtBUIvEGw
+FC5TIfsKQsJyyWgqx/jbbtJ2g4s8+W/1qfTAuqeRNOg5d2DnRDs90wJuS4//0JaY
+9HkHyVwkQyxFxhSA/AHEMJECgYA2MvyFR1O9bIk0D3I7GsA+xKLXa77Ua53MzIjw
+9i4CezBGDQpjCiFli/fI8am+jY5DnAtsDknvjoG24UAzLy5L0mk6IXMdB6SzYYut
+7ak5oahqW+Y9hxIj+XvLmtGQbphtxhJtLu35x75KoBpxSh6FZpmuTEccs31AVCYn
+eFM/DQKBgQDOPUwbLKqVi6ddFGgrV9MrWw+SWsDa43bPuyvYppMM3oqesvyaX1Dt
+qDvN7owaNxNM4OnfKcZr91z8YPVCFo4RbBif3DXRzjNNBlxEjHBtuMOikwvsmucN
+vIrbeEpjTiUMTEAr6PoTiVHjsfS8WAM6MDlF5M+2PNswDsBpa2yLgA==
+-----END RSA PRIVATE KEY-----
+`),
+}
+
+var SSHCertificates = map[string][]byte{
+ // The following are corresponding certificates for the private keys above, signed by the CA key
+ // Generated by the following commands:
+ //
+ // 1. Assumes "rsa" key above in file named "rsa", write out the public key to "rsa.pub":
+ // ssh-keygen -y -f rsa > rsa.pu
+ //
+ // 2. Assumes "ca" key above in file named "ca", sign a cert for "rsa.pub":
+ // ssh-keygen -s ca -h -n host.example.com -V +500w -I host.example.com-key rsa.pub
+ "rsa": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLjYqmmuTSEmjVhSfLQphBSTJMLwIZhRgmpn8FHKLiEIAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABZHN8UAAAAAGsjIYUAAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABDwAAAAdzc2gtcnNhAAABALeDea+60H6xJGhktAyosHaSY7AYzLocaqd8hJQjEIDifBwzoTlnBmcK9CxGhKuaoJFThdCLdaevCeOSuquh8HTkf+2ebZZc/G5T+2thPvPqmcuEcmMosWo+SIjYhbP3S6KD49aLC1X0kz8IBQeauFvURhkZ5ZjhA1L4aQYt9NjL73nqOl8PplRui+Ov5w8b4ldul4zOvYAFrzfcP6wnnXk3c1Zzwwf5wynD5jakO8GpYKBuhM7Z4crzkKSQjU3hla7xqgfomC5Gz4XbR2TNjcQiRrJQ0UlKtX3X3ObRCEhuvG0Kzjklhv+Ddw6txrhKjMjiSi/Yyius/AE8TmC1p4U= host.example.com
+`),
}
var PEMEncryptedKeys = []struct {
diff --git a/xts/xts.go b/xts/xts.go
index c9a283b..a7643fd 100644
--- a/xts/xts.go
+++ b/xts/xts.go
@@ -23,6 +23,7 @@ package xts // import "golang.org/x/crypto/xts"
import (
"crypto/cipher"
+ "encoding/binary"
"errors"
)
@@ -65,21 +66,20 @@ func (c *Cipher) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) {
}
var tweak [blockSize]byte
- for i := 0; i < 8; i++ {
- tweak[i] = byte(sectorNum)
- sectorNum >>= 8
- }
+ binary.LittleEndian.PutUint64(tweak[:8], sectorNum)
c.k2.Encrypt(tweak[:], tweak[:])
- for i := 0; i < len(plaintext); i += blockSize {
- for j := 0; j < blockSize; j++ {
- ciphertext[i+j] = plaintext[i+j] ^ tweak[j]
+ for len(plaintext) > 0 {
+ for j := range tweak {
+ ciphertext[j] = plaintext[j] ^ tweak[j]
}
- c.k1.Encrypt(ciphertext[i:], ciphertext[i:])
- for j := 0; j < blockSize; j++ {
- ciphertext[i+j] ^= tweak[j]
+ c.k1.Encrypt(ciphertext, ciphertext)
+ for j := range tweak {
+ ciphertext[j] ^= tweak[j]
}
+ plaintext = plaintext[blockSize:]
+ ciphertext = ciphertext[blockSize:]
mul2(&tweak)
}
@@ -97,21 +97,20 @@ func (c *Cipher) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) {
}
var tweak [blockSize]byte
- for i := 0; i < 8; i++ {
- tweak[i] = byte(sectorNum)
- sectorNum >>= 8
- }
+ binary.LittleEndian.PutUint64(tweak[:8], sectorNum)
c.k2.Encrypt(tweak[:], tweak[:])
- for i := 0; i < len(plaintext); i += blockSize {
- for j := 0; j < blockSize; j++ {
- plaintext[i+j] = ciphertext[i+j] ^ tweak[j]
+ for len(ciphertext) > 0 {
+ for j := range tweak {
+ plaintext[j] = ciphertext[j] ^ tweak[j]
}
- c.k1.Decrypt(plaintext[i:], plaintext[i:])
- for j := 0; j < blockSize; j++ {
- plaintext[i+j] ^= tweak[j]
+ c.k1.Decrypt(plaintext, plaintext)
+ for j := range tweak {
+ plaintext[j] ^= tweak[j]
}
+ plaintext = plaintext[blockSize:]
+ ciphertext = ciphertext[blockSize:]
mul2(&tweak)
}
diff --git a/xts/xts_test.go b/xts/xts_test.go
index 7a5e9fa..96d3b6c 100644
--- a/xts/xts_test.go
+++ b/xts/xts_test.go
@@ -83,3 +83,23 @@ func TestXTS(t *testing.T) {
}
}
}
+
+func TestShorterCiphertext(t *testing.T) {
+ // Decrypt used to panic if the input was shorter than the output. See
+ // https://go-review.googlesource.com/c/39954/
+ c, err := NewCipher(aes.NewCipher, make([]byte, 32))
+ if err != nil {
+ t.Fatalf("NewCipher failed: %s", err)
+ }
+
+ plaintext := make([]byte, 32)
+ encrypted := make([]byte, 48)
+ decrypted := make([]byte, 48)
+
+ c.Encrypt(encrypted, plaintext, 0)
+ c.Decrypt(decrypted, encrypted[:len(plaintext)], 0)
+
+ if !bytes.Equal(plaintext, decrypted[:len(plaintext)]) {
+ t.Errorf("En/Decryption is not inverse")
+ }
+}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment