Skip to content

Instantly share code, notes, and snippets.

@yutakahashi114
Created June 28, 2021 12:54
Show Gist options
  • Save yutakahashi114/c2205e6feb06cad3bee470b665a0dd33 to your computer and use it in GitHub Desktop.
Save yutakahashi114/c2205e6feb06cad3bee470b665a0dd33 to your computer and use it in GitHub Desktop.
package main
import (
"crypto/rsa"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"math/rand"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go/service/cognitoidentityprovider"
"github.com/dgrijalva/jwt-go"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/render"
"github.com/gofrs/uuid"
)
var userPool UserPool
var clientID ClientID
var signKey *rsa.PrivateKey
const jwkKeyID = "hoge"
const poolFileName = "pool.json"
func main() {
data, err := ioutil.ReadFile(poolFileName)
if err != nil {
panic(err)
}
poolMap := make(map[UserPoolID]UserMap)
err = json.Unmarshal(data, &poolMap)
if err != nil {
panic(err)
}
userPool = UserPool{
poolMap: poolMap,
mutex: &sync.Mutex{},
}
clientID = ClientID(os.Getenv("CLIENT_ID"))
signKey, err = getPrivateKey()
if err != nil {
panic(err)
}
pubKeyString, err := getPublicKey()
if err != nil {
panic(err)
}
mux := chi.NewRouter()
mux.Use(middleware.Logger)
mux.Get("/{userPoolID}/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf(`{"keys":[%s]}`, pubKeyString)))
})
mux.Post("/", func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
if err != nil {
render.JSON(w, r, http.StatusInternalServerError)
return
}
defer r.Body.Close()
out, err := route(r.Header.Get("X-Amz-Target"), b)
if err != nil {
log.Println(err)
render.JSON(w, r, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
w.Write(out)
})
log.Println(http.ListenAndServe(":80", mux))
}
type User struct {
UUID string `json:"uuid"`
Password string `json:"password"`
Username Username `json:"username"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
func (u User) ToToken(userPoolID UserPoolID) *jwt.Token {
token := jwt.New(jwt.SigningMethodRS256)
token.Header["kid"] = jwkKeyID
claims := token.Claims.(jwt.MapClaims)
claims["iat"] = time.Now().Unix()
claims["exp"] = time.Now().Add(time.Hour * 24 * 365 * 10).Unix()
claims["sub"] = u.UUID
claims["email"] = u.Email
return token
}
type ClientID string
type Username string
type UserMap map[Username]*User
type UserPoolID string
type UserPool struct {
poolMap map[UserPoolID]UserMap
mutex *sync.Mutex
}
func (pool UserPool) GetUser(userPoolID UserPoolID, username Username) (*User, bool) {
pool.mutex.Lock()
defer pool.mutex.Unlock()
return pool.getUser(userPoolID, username)
}
func (pool UserPool) getUser(userPoolID UserPoolID, username Username) (*User, bool) {
uMap, ok := pool.poolMap[userPoolID]
if !ok {
return nil, false
}
u, ok := uMap[username]
if !ok {
return nil, false
}
return u, true
}
func (pool UserPool) CreateUser(userPoolID UserPoolID, user User) error {
pool.mutex.Lock()
defer pool.mutex.Unlock()
uMap, ok := pool.poolMap[userPoolID]
if !ok {
uMap = make(UserMap)
pool.poolMap[userPoolID] = uMap
}
if _, ok := uMap[user.Username]; ok {
return fmt.Errorf("already exist")
}
uMap[user.Username] = &user
return pool.updateFile()
}
func (pool UserPool) updateFile() error {
file, err := os.Create(poolFileName)
if err != nil {
return err
}
defer file.Close()
content, err := json.Marshal(pool.poolMap)
if err != nil {
return err
}
_, err = file.Write(content)
return err
}
func (pool UserPool) DeleteUser(userPoolID UserPoolID, username Username) error {
pool.mutex.Lock()
defer pool.mutex.Unlock()
if _, exist := pool.getUser(userPoolID, username); exist {
delete(pool.poolMap[userPoolID], username)
}
return pool.updateFile()
}
type jwkKey struct {
Alg string `json:"alg"`
E string `json:"e"`
Kid string `json:"kid"`
Kty string `json:"kty"`
N string `json:"n"`
Use string `json:"use"`
}
func getPublicKey() (string, error) {
verifyBytes, err := ioutil.ReadFile("./key.pem.pub.pkcs8")
if err != nil {
return "", err
}
verifyKey, err := jwt.ParseRSAPublicKeyFromPEM(verifyBytes)
if err != nil {
return "", err
}
encodedN := base64.RawURLEncoding.EncodeToString(verifyKey.N.Bytes())
bytesE := make([]byte, 4)
binary.BigEndian.PutUint32(bytesE, uint32(verifyKey.E))
encodedE := base64.RawURLEncoding.EncodeToString(bytesE)
pubKey, err := json.Marshal(jwkKey{
Alg: jwt.SigningMethodRS256.Alg(),
E: encodedE,
Kid: jwkKeyID,
Kty: "RSA",
N: encodedN,
Use: "sig",
})
if err != nil {
return "", err
}
return string(pubKey), nil
}
func getPrivateKey() (*rsa.PrivateKey, error) {
signBytes, err := ioutil.ReadFile("./key.pem")
if err != nil {
return nil, err
}
return jwt.ParseRSAPrivateKeyFromPEM(signBytes)
}
func route(xAmzTarget string, body []byte) ([]byte, error) {
log.Println(xAmzTarget)
targets := strings.Split(xAmzTarget, ".")
if len(targets) < 2 {
return nil, fmt.Errorf("invalid header")
}
switch targets[1] {
case "AdminInitiateAuth":
return adminInitiateAuth(body)
case "AdminCreateUser":
return adminCreateUser(body)
}
return nil, fmt.Errorf("invalid operation name")
}
func adminInitiateAuth(body []byte) ([]byte, error) {
in := cognitoidentityprovider.AdminInitiateAuthInput{}
err := json.Unmarshal(body, &in)
if err != nil {
return nil, err
}
// TODO: まだ ADMIN_NO_SRP_AUTH だけ
if in.AuthFlow == nil || *in.AuthFlow != cognitoidentityprovider.AuthFlowTypeAdminNoSrpAuth {
return nil, fmt.Errorf("invalid auth flow")
}
var cID ClientID
if in.ClientId != nil {
cID = ClientID(*in.ClientId)
}
if cID != clientID {
return nil, fmt.Errorf("invalid client id")
}
var userPoolID UserPoolID
if in.UserPoolId != nil {
userPoolID = UserPoolID(*in.UserPoolId)
}
var username Username
if u, ok := in.AuthParameters["USERNAME"]; ok && u != nil {
username = Username(*u)
}
u, ok := userPool.GetUser(
userPoolID,
username,
)
if !ok {
return nil, fmt.Errorf("user not found")
}
if !u.EmailVerified {
return nil, fmt.Errorf("email not verified")
}
var password string
if p, ok := in.AuthParameters["PASSWORD"]; ok && p != nil {
password = *p
}
if u.Password != password {
return nil, fmt.Errorf("password not match")
}
tokenString, err := u.ToToken(userPoolID).SignedString(signKey)
if err != nil {
return nil, err
}
return json.Marshal(cognitoidentityprovider.AdminInitiateAuthOutput{
AuthenticationResult: &cognitoidentityprovider.AuthenticationResultType{
AccessToken: &[]string{tokenString}[0],
},
})
}
func adminCreateUser(body []byte) ([]byte, error) {
in := cognitoidentityprovider.AdminCreateUserInput{}
err := json.Unmarshal(body, &in)
if err != nil {
return nil, err
}
var userPoolID UserPoolID
if in.UserPoolId != nil {
userPoolID = UserPoolID(*in.UserPoolId)
}
var username Username
if in.Username != nil {
username = Username(*in.Username)
}
if in.MessageAction != nil && *in.MessageAction == "RESEND" {
if _, exist := userPool.GetUser(userPoolID, username); !exist {
return nil, fmt.Errorf("user not found")
}
// TODO: パスワード変更して通知メール再送信
return json.Marshal(cognitoidentityprovider.AdminCreateUserOutput{})
}
var email string
for _, attr := range in.UserAttributes {
if attr.Name == nil || attr.Value == nil {
continue
}
if *attr.Name == "email" {
email = *attr.Value
}
}
// TODO: email_verified が true ならパスワードメール送信, false なら検証メール送信
// 常にtrueとして扱っている
if email == "" {
return nil, fmt.Errorf("invalid email")
}
id, err := uuid.NewV4()
if err != nil {
return nil, err
}
pass, err := makeRandomStr(8)
if err != nil {
return nil, err
}
idString := id.String()
user := User{
UUID: idString,
Password: pass,
Username: username,
Email: email,
EmailVerified: true,
}
err = userPool.CreateUser(userPoolID, user)
if err != nil {
return nil, err
}
return json.Marshal(cognitoidentityprovider.AdminCreateUserOutput{
User: &cognitoidentityprovider.UserType{
Attributes: []*cognitoidentityprovider.AttributeType{
{
Name: &[]string{"sub"}[0],
Value: &[]string{idString}[0],
},
},
},
})
}
func makeRandomStr(digit uint32) (string, error) {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789@&%/:;,."
b := make([]byte, digit)
if _, err := rand.Read(b); err != nil {
return "", err
}
var result string
for _, v := range b {
result += string(letters[int(v)%len(letters)])
}
return result, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment