Skip to content

Instantly share code, notes, and snippets.

@hiyosi
Created August 30, 2021 07:18
Show Gist options
  • Save hiyosi/82bd11f64a9634bf6742124df35bf4a5 to your computer and use it in GitHub Desktop.
Save hiyosi/82bd11f64a9634bf6742124df35bf4a5 to your computer and use it in GitHub Desktop.
package main
import (
"context"
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"log"
"os"
"strings"
"time"
"github.com/spf13/pflag"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
entryv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/entry/v1"
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
"github.com/spiffe/spire/pkg/agent/client"
)
var (
fs = pflag.NewFlagSet("", pflag.ExitOnError)
spireAddr = fs.String("spire-addr", "127.0.0.1", "'DNS name or IP address of the SPIRE server")
spirePort = fs.Int("spire-port", 8081, "Port number of the SPIRE server")
svidPath = fs.String("svid-path", "", "Path to the SVID file")
keyPath = fs.String("key-path", "", "Path to the private key file")
bundlePath = fs.String("bundle-path", "", "Path to the bundle file")
trustDomain = fs.String("trust-domain", "example.org", "Name of TrustDomain")
parentID = fs.String("parent-id", "", "ParentID value for RegistrationEntry")
spiffeID = fs.String("spiffe-id", "", "SPIFFEID value for RegistrationEntry")
selectors = fs.StringArray("selector", nil, "Selector value for RegistrationEntry")
)
func main() {
fs.AddGoFlagSet(flag.CommandLine)
fs.Parse(os.Args)
ctx := context.Background()
td, err := spiffeid.TrustDomainFromString(*trustDomain)
if err != nil {
log.Fatalf("Unable to parse trust domain: %v", err)
}
svid, err := x509svid.Load(*svidPath, *keyPath)
if err != nil {
log.Fatalf("Unable to read svid or key file: %v", err)
}
trustBundle, err := x509bundle.Load(td, *bundlePath)
if err != nil {
log.Fatalf("Unable to read bundle file: %v", err)
}
entryClient, err := newEntryClient(ctx, td, trustBundle, svid)
if err != nil {
log.Fatal(err)
}
resp, err := createEntry(ctx, entryClient)
if err != nil {
log.Fatalf("Error from Entry API: %s", err)
}
fmt.Printf("Entry ID: %v\n", resp.Results[0].Entry.Id)
fmt.Printf("SPIFFE ID: %v\n", resp.Results[0].Entry.SpiffeId)
fmt.Printf("Parent ID: %v\n", resp.Results[0].Entry.ParentId)
}
func newEntryClient(ctx context.Context, td spiffeid.TrustDomain, trustBundle *x509bundle.Bundle, svid *x509svid.SVID) (entryv1.EntryClient, error) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
conn, err := client.DialServer(ctx, client.DialServerConfig{
Address: fmt.Sprintf("%v:%v", *spireAddr, *spirePort),
TrustDomain: td,
GetBundle: func() []*x509.Certificate {
return trustBundle.X509Authorities()
},
GetAgentCertificate: func() *tls.Certificate {
c := &tls.Certificate{
PrivateKey: svid.PrivateKey,
}
for _, cert := range svid.Certificates {
c.Certificate = append(c.Certificate, cert.Raw)
}
return c
},
})
if err != nil {
return nil, fmt.Errorf("unable to create grpc connection: %v", err)
}
return entryv1.NewEntryClient(conn), nil
}
func createEntry(ctx context.Context, client entryv1.EntryClient) (*entryv1.BatchCreateEntryResponse, error) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
pID, err := spiffeid.FromString(*parentID)
if err != nil {
log.Fatalf("Unable to parse Parent ID: %s", err)
}
sID, err := spiffeid.FromString(*spiffeID)
if err != nil {
log.Fatalf("Unable to parse SPIFFE ID: %s", err)
}
var ss []*types.Selector
for _, selector := range *selectors {
s, err := parseSelector(selector)
if err != nil {
log.Fatalf("Unable to parse selectors: %s", err)
}
ss = append(ss, s)
}
req := &entryv1.BatchCreateEntryRequest{
Entries: []*types.Entry{
{
ParentId: &types.SPIFFEID{
TrustDomain: pID.TrustDomain().String(),
Path: pID.Path(),
},
SpiffeId: &types.SPIFFEID{
TrustDomain: sID.TrustDomain().String(),
Path: sID.Path(),
},
Selectors: ss,
},
},
}
return client.BatchCreateEntry(ctx, req)
}
// https://github.com/spiffe/spire/blob/ea7204453b053c8011ae5827d91334c553a419e2/cmd/spire-server/cli/entry/util.go#L20
func parseSelector(str string) (*types.Selector, error) {
parts := strings.SplitAfterN(str, ":", 2)
if len(parts) < 2 {
return nil, fmt.Errorf("selector \"%s\" must be formatted as type:value", str)
}
s := &types.Selector{
// Strip the trailing delimiter
Type: strings.TrimSuffix(parts[0], ":"),
Value: parts[1],
}
return s, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment