Skip to content

Instantly share code, notes, and snippets.

@zaharidichev
Created February 10, 2020 12:03
Show Gist options
  • Save zaharidichev/0ec8d4a02fac5becf555e7718893d4a8 to your computer and use it in GitHub Desktop.
Save zaharidichev/0ec8d4a02fac5becf555e7718893d4a8 to your computer and use it in GitHub Desktop.
diff --git a/cli/cmd/upgrade.go b/cli/cmd/upgrade.go
index 1ae754be..19470eb2 100644
--- a/cli/cmd/upgrade.go
+++ b/cli/cmd/upgrade.go
@@ -7,6 +7,7 @@ import (
"io/ioutil"
"os"
"strings"
+ "time"
"github.com/linkerd/linkerd2/pkg/config"
"github.com/linkerd/linkerd2/pkg/issuercerts"
@@ -399,7 +400,7 @@ func ensureIssuerCertWorksWithAllProxies(k kubernetes.Interface, cred *tls.Cred)
roots, err := tls.DecodePEMCertPool(pod.Anchors)
if roots != nil {
- err = cred.Verify(roots, "")
+ err = cred.Verify(roots, "", time.Time{})
}
if err != nil {
@@ -546,7 +547,7 @@ func verifyWebhookTLS(value *charts.TLS, webhook string) error {
return err
}
roots := crt.CertPool()
- if err := crt.Verify(roots, webhookCommonName(webhook)); err != nil {
+ if err := crt.Verify(roots, webhookCommonName(webhook), time.Time{}); err != nil {
return err
}
diff --git a/pkg/healthcheck/healthcheck.go b/pkg/healthcheck/healthcheck.go
index 2bba7aa4..1ba2a3f2 100644
--- a/pkg/healthcheck/healthcheck.go
+++ b/pkg/healthcheck/healthcheck.go
@@ -963,7 +963,7 @@ func (hc *HealthChecker) allCategories() []category {
description: "issuer cert is issued by the trust root",
hintAnchor: "l5d-identity-issuer-cert-issued-by-trust-root",
check: func(ctx context.Context) error {
- return hc.issuerCert.Verify(tls.CertificatesToPool(hc.roots), hc.issuerIdentity())
+ return hc.issuerCert.Verify(tls.CertificatesToPool(hc.roots), hc.issuerIdentity(), time.Time{})
},
},
},
diff --git a/pkg/identity/service.go b/pkg/identity/service.go
index ff2c3e77..208eefef 100644
--- a/pkg/identity/service.go
+++ b/pkg/identity/service.go
@@ -114,7 +114,7 @@ func (svc *Service) loadCredentials() (tls.Issuer, error) {
return nil, fmt.Errorf("failed to read CA from disk: %s", err)
}
- if err := creds.Crt.Verify(svc.trustAnchors, svc.expectedName); err != nil {
+ if err := creds.Crt.Verify(svc.trustAnchors, svc.expectedName, time.Time{}); err != nil {
return nil, fmt.Errorf("failed to verify issuer credentials for '%s' with trust anchors: %s", svc.expectedName, err)
}
@@ -149,7 +149,7 @@ func (svc *Service) ensureIssuerStillValid() error {
issuer := *svc.issuer
switch is := issuer.(type) {
case *tls.CA:
- return is.Cred.Verify(svc.trustAnchors, svc.expectedName)
+ return is.Cred.Verify(svc.trustAnchors, svc.expectedName, time.Time{})
default:
return fmt.Errorf("unsupported issuer type. Expected *tls.CA, got %v", is)
}
diff --git a/pkg/issuercerts/issuercerts.go b/pkg/issuercerts/issuercerts.go
index ad1b853f..50262d7d 100644
--- a/pkg/issuercerts/issuercerts.go
+++ b/pkg/issuercerts/issuercerts.go
@@ -169,7 +169,7 @@ func (ic *IssuerCertData) VerifyAndBuildCreds(dnsName string) (*tls.Cred, error)
return nil, err
}
- if err := creds.Verify(roots, dnsName); err != nil {
+ if err := creds.Verify(roots, dnsName, time.Time{}); err != nil {
return nil, err
}
diff --git a/pkg/tls/cred.go b/pkg/tls/cred.go
index 79e3da71..a079a9bf 100644
--- a/pkg/tls/cred.go
+++ b/pkg/tls/cred.go
@@ -83,8 +83,13 @@ func (crt *Crt) CertPool() *x509.CertPool {
return p
}
-// Verify the validity of the provided certificate
-func (crt *Crt) Verify(roots *x509.CertPool, name string) error {
+// Verify the validity of the provided certificate. If current time is Zero
+// then the actual current time from time.Now() shall be used
+func (crt *Crt) Verify(roots *x509.CertPool, name string, currentTime time.Time) error {
+ if currentTime.IsZero() {
+ currentTime = time.Now()
+ }
+
i := x509.NewCertPool()
for _, c := range crt.TrustChain {
i.AddCert(c)
@@ -92,8 +97,8 @@ func (crt *Crt) Verify(roots *x509.CertPool, name string) error {
vo := x509.VerifyOptions{Roots: roots, Intermediates: i, DNSName: name}
_, err := crt.Certificate.Verify(vo)
- if _, ok := crtExpiryError(err, *crt.Certificate); ok {
- return fmt.Errorf("%s - Current Time : %s - Invalid before %s - Invalid After %s", err, time.Now(), crt.Certificate.NotBefore, crt.Certificate.NotAfter)
+ if ok := crtExpiryError(err); ok {
+ return fmt.Errorf("%s - Current Time : %s - Invalid before %s - Invalid After %s", err, currentTime, crt.Certificate.NotBefore, crt.Certificate.NotAfter)
}
return err
}
@@ -227,25 +232,11 @@ func DecodePEMCrt(txt string) (*Crt, error) {
return &crt, nil
}
-func crtExpiryError(err error, crt x509.Certificate) (timeFmtError, bool) {
- timeFmtErrObj := timeFmtError{now: time.Now(), crt: crt}
+func crtExpiryError(err error) bool {
switch v := err.(type) {
- case x509.CertificateInvalidError:
- return timeFmtErrObj, v.Reason == x509.Expired
+ case x509.CertificateInvalidError :
+ return v.Reason == x509.Expired
default:
- return timeFmtErrObj, false
+ return false
}
}
-
-type timeFmtError struct {
- now time.Time
- crt x509.Certificate
-}
-
-//Return error message with lesser precision on time; for testing only
-func (t timeFmtError) Error() string {
- now := t.now
- fmtTime := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), 0, 0, now.Location())
- return fmt.Sprintf("%s - Current Time : %s - Invalid before %s - Invalid After %s", x509.CertificateInvalidError{Reason: x509.Expired}.Error(), fmtTime.String(), t.crt.NotBefore, t.crt.NotAfter)
-
-}
diff --git a/pkg/tls/cred_test.go b/pkg/tls/cred_test.go
index 0fc21c79..8739808c 100644
--- a/pkg/tls/cred_test.go
+++ b/pkg/tls/cred_test.go
@@ -29,7 +29,7 @@ func TestCrtRoundtrip(t *testing.T) {
t.Fatalf("Failed to decode PEM Crt: %s", err)
}
- if err := crt.Verify(rootTrust, "endentity.test"); err != nil {
+ if err := crt.Verify(rootTrust, "endentity.test", time.Time{}); err != nil {
t.Fatal("Failed to verify round-tripped certificate")
}
}
@@ -67,9 +67,7 @@ func TestCrtExpiry(t *testing.T) {
fakeExpiryError := x509.CertificateInvalidError{Reason: x509.Expired}
- //need to remove seconds and nanoseconds for testing returned error
now := time.Now()
- curTimeDate := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), 0, 0, now.Location())
testCases := []struct {
notBefore time.Time
@@ -78,15 +76,15 @@ func TestCrtExpiry(t *testing.T) {
}{
//cert not valid yet
{
- notAfter: curTimeDate.AddDate(0, 0, 20),
- notBefore: curTimeDate.AddDate(0, 0, 10),
- expected: fmt.Sprintf("%s - Current Time : %s - Invalid before %s - Invalid After %s", fakeExpiryError.Error(), curTimeDate, curTimeDate.AddDate(0, 0, 10), curTimeDate.AddDate(0, 0, 20)),
+ notAfter: now.AddDate(0, 0, 20),
+ notBefore: now.AddDate(0, 0, 10),
+ expected: fmt.Sprintf("%s - Current Time : %s - Invalid before %s - Invalid After %s", fakeExpiryError.Error(), now, now.AddDate(0, 0, 10), now.AddDate(0, 0, 20)),
},
//cert has expired
{
- notAfter: curTimeDate.AddDate(0, 0, -10),
- notBefore: curTimeDate.AddDate(0, 0, -20),
- expected: fmt.Sprintf("%s - Current Time : %s - Invalid before %s - Invalid After %s", fakeExpiryError.Error(), curTimeDate, curTimeDate.AddDate(0, 0, -20), curTimeDate.AddDate(0, 0, -10)),
+ notAfter: now.AddDate(0, 0, -10),
+ notBefore: now.AddDate(0, 0, -20),
+ expected: fmt.Sprintf("%s - Current Time : %s - Invalid before %s - Invalid After %s", fakeExpiryError.Error(), now, now.AddDate(0, 0, -20), now.AddDate(0, 0, -10)),
},
// cert is valid
{
@@ -103,16 +101,20 @@ func TestCrtExpiry(t *testing.T) {
crt.Certificate.NotBefore = tc.notBefore
crt.Certificate.NotAfter = tc.notAfter
- if err := crt.Verify(rootTrust, "expired.test"); err != nil {
- if timeFmtErr, ok := crtExpiryError(err, *crt.Certificate); ok {
- if tc.expected != timeFmtErr.Error() {
- t.Logf("Returned error : %s\n", timeFmtErr.Error())
- t.Logf("Expected error : %s\n", tc.expected)
- t.Fatal("test case failed")
+ err := crt.Verify(rootTrust, "expired.test", now)
+ if tc.expected != "" {
+ if err != nil {
+ // we need to get the same error as we were expecting
+ if tc.expected != err.Error() {
+ t.Fatalf("got error: %s but was expecting: %s", err.Error(), tc.expected)
}
+ } else {
+ t.Fatalf("expected error : %s, but got no error", tc.expected)
}
} else {
- t.Log("no error on verification")
+ if err != nil {
+ t.Fatalf("was not expecting error but got: %s", err.Error())
+ }
}
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment