Skip to content

Instantly share code, notes, and snippets.

@blacknon
Last active February 12, 2023 09:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save blacknon/201de46166213df77437d02891b98d82 to your computer and use it in GitHub Desktop.
Save blacknon/201de46166213df77437d02891b98d82 to your computer and use it in GitHub Desktop.
goで`github.com/miekg/pkcs11/p11`を使って、Yubikey内からsshのCryptoSignerを取得するサンプルコード
// Copyright (c) 2020 Blacknon. All rights reserved.
// Use of this source code is governed by an MIT license
// that can be found in the LICENSE file.
// `github.com/miekg/pkcs11/p11`を使って、Yubikey内からsshのCryptoSignerを取得するサンプルコード
package main
import (
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"fmt"
"log"
"os"
"github.com/ThalesIgnite/crypto11"
"github.com/blacknon/pkcs11/p11"
"github.com/miekg/pkcs11"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
var (
provider = "/usr/local/lib/opensc-pkcs11.so"
)
// C11 struct
type C11 struct {
Provider string
PIN string
Label string
// Context is crypto11 Context
Ctx *crypto11.Context
}
// OpenCtx is
func (c *C11) OpenCtx() (err error) {
config := &crypto11.Config{
Path: c.Provider,
TokenLabel: c.Label,
Pin: c.PIN,
}
c.Ctx, err = crypto11.Configure(config)
return
}
type Slotdata struct {
keyid map[int][]byte
}
// getPassphrase gets the passphrase from virtual terminal input and returns the result. Works only on UNIX-based OS.
func getPassphrase(msg string) (input string, err error) {
fmt.Fprintf(os.Stderr, msg)
// Open /dev/tty
tty, err := os.Open("/dev/tty")
if err != nil {
log.Fatal(err)
}
defer tty.Close()
// get input
result, err := terminal.ReadPassword(int(tty.Fd()))
if len(result) == 0 {
err = fmt.Errorf("err: input is empty")
return
}
input = string(result)
fmt.Println()
return
}
func getPIN(label string) (pin string, err error) {
pin, err = getPassphrase(label + "'s PIN:")
return
}
func main() {
// Create p11.module
module, err := p11.OpenModule(provider)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
// get slots
slots, err := module.Slots()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
slotids := map[string]*Slotdata{}
for _, slot := range slots {
// var slotid int
tokenInfo, err := slot.TokenInfo()
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
fmt.Println(tokenInfo.Label)
session, err := slot.OpenSession()
if err != nil {
fmt.Fprintln(os.Stderr, err)
session.Close()
continue
}
// get public key
pub := []*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_PUBLIC_KEY),
}
keyids := map[int][]byte{}
fmt.Println("----- Get Public Key -----")
obj, _ := session.FindObjects(pub)
for n, o := range obj {
// get label
l, err := o.Label()
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
b, _ := o.Attribute(pkcs11.CKA_ID)
fmt.Println(b)
keyids[n] = b
v, err := o.Value()
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
rsaPubKey, err := x509.ParsePKIXPublicKey(v)
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
sshKey, ok := rsaPubKey.(*rsa.PublicKey)
if !ok {
fmt.Fprintln(os.Stderr, "invalid PEM passed in from user")
continue
}
pubkey, err := ssh.NewPublicKey(sshKey)
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
p := base64.StdEncoding.EncodeToString(pubkey.Marshal())
fmt.Println(l, ":", p)
}
// set slot id
sd := new(Slotdata)
sd.keyid = keyids
slotids[tokenInfo.Label] = sd
}
fmt.Println("=====")
module.Destroy()
fmt.Println(slotids)
fmt.Println("=====")
for s, v := range slotids {
pin, err := getPIN("test")
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
fmt.Println(s)
c11 := new(C11)
c11.Provider = provider
c11.Label = "uesugi"
c11.PIN = pin
err = c11.OpenCtx()
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
for _, kid := range v.keyid {
set, err := crypto11.NewAttributeSetWithID(kid)
fmt.Println(set)
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
signer, err := c11.Ctx.FindKeyPairWithAttributes(set)
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
fmt.Println("Print signer")
fmt.Println(signer)
fmt.Println("++++++++++")
signers, err := c11.Ctx.FindAllKeyPairs()
if err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
fmt.Println("Print signers")
fmt.Println(signers)
}
fmt.Println(123)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment