Created
June 28, 2021 12:54
-
-
Save yutakahashi114/c2205e6feb06cad3bee470b665a0dd33 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"crypto/rsa" | |
"encoding/base64" | |
"encoding/binary" | |
"encoding/json" | |
"fmt" | |
"io/ioutil" | |
"log" | |
"math/rand" | |
"net/http" | |
"os" | |
"strings" | |
"sync" | |
"time" | |
"github.com/aws/aws-sdk-go/service/cognitoidentityprovider" | |
"github.com/dgrijalva/jwt-go" | |
"github.com/go-chi/chi/v5" | |
"github.com/go-chi/chi/v5/middleware" | |
"github.com/go-chi/render" | |
"github.com/gofrs/uuid" | |
) | |
var userPool UserPool | |
var clientID ClientID | |
var signKey *rsa.PrivateKey | |
const jwkKeyID = "hoge" | |
const poolFileName = "pool.json" | |
func main() { | |
data, err := ioutil.ReadFile(poolFileName) | |
if err != nil { | |
panic(err) | |
} | |
poolMap := make(map[UserPoolID]UserMap) | |
err = json.Unmarshal(data, &poolMap) | |
if err != nil { | |
panic(err) | |
} | |
userPool = UserPool{ | |
poolMap: poolMap, | |
mutex: &sync.Mutex{}, | |
} | |
clientID = ClientID(os.Getenv("CLIENT_ID")) | |
signKey, err = getPrivateKey() | |
if err != nil { | |
panic(err) | |
} | |
pubKeyString, err := getPublicKey() | |
if err != nil { | |
panic(err) | |
} | |
mux := chi.NewRouter() | |
mux.Use(middleware.Logger) | |
mux.Get("/{userPoolID}/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { | |
w.Write([]byte(fmt.Sprintf(`{"keys":[%s]}`, pubKeyString))) | |
}) | |
mux.Post("/", func(w http.ResponseWriter, r *http.Request) { | |
b, err := ioutil.ReadAll(r.Body) | |
if err != nil { | |
render.JSON(w, r, http.StatusInternalServerError) | |
return | |
} | |
defer r.Body.Close() | |
out, err := route(r.Header.Get("X-Amz-Target"), b) | |
if err != nil { | |
log.Println(err) | |
render.JSON(w, r, http.StatusInternalServerError) | |
return | |
} | |
w.WriteHeader(http.StatusOK) | |
w.Write(out) | |
}) | |
log.Println(http.ListenAndServe(":80", mux)) | |
} | |
type User struct { | |
UUID string `json:"uuid"` | |
Password string `json:"password"` | |
Username Username `json:"username"` | |
Email string `json:"email"` | |
EmailVerified bool `json:"email_verified"` | |
} | |
func (u User) ToToken(userPoolID UserPoolID) *jwt.Token { | |
token := jwt.New(jwt.SigningMethodRS256) | |
token.Header["kid"] = jwkKeyID | |
claims := token.Claims.(jwt.MapClaims) | |
claims["iat"] = time.Now().Unix() | |
claims["exp"] = time.Now().Add(time.Hour * 24 * 365 * 10).Unix() | |
claims["sub"] = u.UUID | |
claims["email"] = u.Email | |
return token | |
} | |
type ClientID string | |
type Username string | |
type UserMap map[Username]*User | |
type UserPoolID string | |
type UserPool struct { | |
poolMap map[UserPoolID]UserMap | |
mutex *sync.Mutex | |
} | |
func (pool UserPool) GetUser(userPoolID UserPoolID, username Username) (*User, bool) { | |
pool.mutex.Lock() | |
defer pool.mutex.Unlock() | |
return pool.getUser(userPoolID, username) | |
} | |
func (pool UserPool) getUser(userPoolID UserPoolID, username Username) (*User, bool) { | |
uMap, ok := pool.poolMap[userPoolID] | |
if !ok { | |
return nil, false | |
} | |
u, ok := uMap[username] | |
if !ok { | |
return nil, false | |
} | |
return u, true | |
} | |
func (pool UserPool) CreateUser(userPoolID UserPoolID, user User) error { | |
pool.mutex.Lock() | |
defer pool.mutex.Unlock() | |
uMap, ok := pool.poolMap[userPoolID] | |
if !ok { | |
uMap = make(UserMap) | |
pool.poolMap[userPoolID] = uMap | |
} | |
if _, ok := uMap[user.Username]; ok { | |
return fmt.Errorf("already exist") | |
} | |
uMap[user.Username] = &user | |
return pool.updateFile() | |
} | |
func (pool UserPool) updateFile() error { | |
file, err := os.Create(poolFileName) | |
if err != nil { | |
return err | |
} | |
defer file.Close() | |
content, err := json.Marshal(pool.poolMap) | |
if err != nil { | |
return err | |
} | |
_, err = file.Write(content) | |
return err | |
} | |
func (pool UserPool) DeleteUser(userPoolID UserPoolID, username Username) error { | |
pool.mutex.Lock() | |
defer pool.mutex.Unlock() | |
if _, exist := pool.getUser(userPoolID, username); exist { | |
delete(pool.poolMap[userPoolID], username) | |
} | |
return pool.updateFile() | |
} | |
type jwkKey struct { | |
Alg string `json:"alg"` | |
E string `json:"e"` | |
Kid string `json:"kid"` | |
Kty string `json:"kty"` | |
N string `json:"n"` | |
Use string `json:"use"` | |
} | |
func getPublicKey() (string, error) { | |
verifyBytes, err := ioutil.ReadFile("./key.pem.pub.pkcs8") | |
if err != nil { | |
return "", err | |
} | |
verifyKey, err := jwt.ParseRSAPublicKeyFromPEM(verifyBytes) | |
if err != nil { | |
return "", err | |
} | |
encodedN := base64.RawURLEncoding.EncodeToString(verifyKey.N.Bytes()) | |
bytesE := make([]byte, 4) | |
binary.BigEndian.PutUint32(bytesE, uint32(verifyKey.E)) | |
encodedE := base64.RawURLEncoding.EncodeToString(bytesE) | |
pubKey, err := json.Marshal(jwkKey{ | |
Alg: jwt.SigningMethodRS256.Alg(), | |
E: encodedE, | |
Kid: jwkKeyID, | |
Kty: "RSA", | |
N: encodedN, | |
Use: "sig", | |
}) | |
if err != nil { | |
return "", err | |
} | |
return string(pubKey), nil | |
} | |
func getPrivateKey() (*rsa.PrivateKey, error) { | |
signBytes, err := ioutil.ReadFile("./key.pem") | |
if err != nil { | |
return nil, err | |
} | |
return jwt.ParseRSAPrivateKeyFromPEM(signBytes) | |
} | |
func route(xAmzTarget string, body []byte) ([]byte, error) { | |
log.Println(xAmzTarget) | |
targets := strings.Split(xAmzTarget, ".") | |
if len(targets) < 2 { | |
return nil, fmt.Errorf("invalid header") | |
} | |
switch targets[1] { | |
case "AdminInitiateAuth": | |
return adminInitiateAuth(body) | |
case "AdminCreateUser": | |
return adminCreateUser(body) | |
} | |
return nil, fmt.Errorf("invalid operation name") | |
} | |
func adminInitiateAuth(body []byte) ([]byte, error) { | |
in := cognitoidentityprovider.AdminInitiateAuthInput{} | |
err := json.Unmarshal(body, &in) | |
if err != nil { | |
return nil, err | |
} | |
// TODO: まだ ADMIN_NO_SRP_AUTH だけ | |
if in.AuthFlow == nil || *in.AuthFlow != cognitoidentityprovider.AuthFlowTypeAdminNoSrpAuth { | |
return nil, fmt.Errorf("invalid auth flow") | |
} | |
var cID ClientID | |
if in.ClientId != nil { | |
cID = ClientID(*in.ClientId) | |
} | |
if cID != clientID { | |
return nil, fmt.Errorf("invalid client id") | |
} | |
var userPoolID UserPoolID | |
if in.UserPoolId != nil { | |
userPoolID = UserPoolID(*in.UserPoolId) | |
} | |
var username Username | |
if u, ok := in.AuthParameters["USERNAME"]; ok && u != nil { | |
username = Username(*u) | |
} | |
u, ok := userPool.GetUser( | |
userPoolID, | |
username, | |
) | |
if !ok { | |
return nil, fmt.Errorf("user not found") | |
} | |
if !u.EmailVerified { | |
return nil, fmt.Errorf("email not verified") | |
} | |
var password string | |
if p, ok := in.AuthParameters["PASSWORD"]; ok && p != nil { | |
password = *p | |
} | |
if u.Password != password { | |
return nil, fmt.Errorf("password not match") | |
} | |
tokenString, err := u.ToToken(userPoolID).SignedString(signKey) | |
if err != nil { | |
return nil, err | |
} | |
return json.Marshal(cognitoidentityprovider.AdminInitiateAuthOutput{ | |
AuthenticationResult: &cognitoidentityprovider.AuthenticationResultType{ | |
AccessToken: &[]string{tokenString}[0], | |
}, | |
}) | |
} | |
func adminCreateUser(body []byte) ([]byte, error) { | |
in := cognitoidentityprovider.AdminCreateUserInput{} | |
err := json.Unmarshal(body, &in) | |
if err != nil { | |
return nil, err | |
} | |
var userPoolID UserPoolID | |
if in.UserPoolId != nil { | |
userPoolID = UserPoolID(*in.UserPoolId) | |
} | |
var username Username | |
if in.Username != nil { | |
username = Username(*in.Username) | |
} | |
if in.MessageAction != nil && *in.MessageAction == "RESEND" { | |
if _, exist := userPool.GetUser(userPoolID, username); !exist { | |
return nil, fmt.Errorf("user not found") | |
} | |
// TODO: パスワード変更して通知メール再送信 | |
return json.Marshal(cognitoidentityprovider.AdminCreateUserOutput{}) | |
} | |
var email string | |
for _, attr := range in.UserAttributes { | |
if attr.Name == nil || attr.Value == nil { | |
continue | |
} | |
if *attr.Name == "email" { | |
email = *attr.Value | |
} | |
} | |
// TODO: email_verified が true ならパスワードメール送信, false なら検証メール送信 | |
// 常にtrueとして扱っている | |
if email == "" { | |
return nil, fmt.Errorf("invalid email") | |
} | |
id, err := uuid.NewV4() | |
if err != nil { | |
return nil, err | |
} | |
pass, err := makeRandomStr(8) | |
if err != nil { | |
return nil, err | |
} | |
idString := id.String() | |
user := User{ | |
UUID: idString, | |
Password: pass, | |
Username: username, | |
Email: email, | |
EmailVerified: true, | |
} | |
err = userPool.CreateUser(userPoolID, user) | |
if err != nil { | |
return nil, err | |
} | |
return json.Marshal(cognitoidentityprovider.AdminCreateUserOutput{ | |
User: &cognitoidentityprovider.UserType{ | |
Attributes: []*cognitoidentityprovider.AttributeType{ | |
{ | |
Name: &[]string{"sub"}[0], | |
Value: &[]string{idString}[0], | |
}, | |
}, | |
}, | |
}) | |
} | |
func makeRandomStr(digit uint32) (string, error) { | |
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789@&%/:;,." | |
b := make([]byte, digit) | |
if _, err := rand.Read(b); err != nil { | |
return "", err | |
} | |
var result string | |
for _, v := range b { | |
result += string(letters[int(v)%len(letters)]) | |
} | |
return result, nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment