Skip to content

Instantly share code, notes, and snippets.

@guotie
Created September 11, 2016 00:38
Show Gist options
  • Save guotie/c2a8000ec8bd706bc50706ca8e1e5e43 to your computer and use it in GitHub Desktop.
Save guotie/c2a8000ec8bd706bc50706ca8e1e5e43 to your computer and use it in GitHub Desktop.
package user
import (
"fmt"
"time"
"dxmall/utils"
"github.com/satori/go.uuid"
"github.com/smtc/glog"
"gopkg.in/redis.v3"
)
// TokenType is default type of generated tokens.
const (
TokenType = "Bearer"
AccessTokenPrefix = "accesstoken-"
RefreshTokenPrefix = "refreshtoken-"
)
// AccessToken ...
type AccessToken struct {
Client string
UserId int64
Token string
ExpiresIn int
ExpiresAt time.Time
Scope string
TokenType string
}
// RefreshToken ...
type RefreshToken struct {
Client string
UserId int64
Token string
ExpiresIn int
ExpiresAt time.Time
Scope string
}
func accessTokenName(uid int64, client string) string {
return fmt.Sprintf("%s%d-%s", AccessTokenPrefix, uid, client)
}
func refreshTokenName(uid int64, client string) string {
return fmt.Sprintf("%s%d-%s", RefreshTokenPrefix, uid, client)
}
// GrantAccessToken deletes old tokens and grants a new access token
func GrantAccessToken(rc *redis.Client, client string, uid int64, expiresIn int, scope string) (*AccessToken, error) {
// Create a new access token
accessToken := &AccessToken{
Client: client,
UserId: uid,
Token: uuid.NewV4().String(),
ExpiresIn: expiresIn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second),
Scope: scope,
TokenType: "Bearer",
}
// token与accessToken的对应关系
// user和token的对应关系
timo := time.Duration(expiresIn) * time.Second
err := utils.SetCache(rc, accessToken.Token, accessToken, expiresIn)
if err != nil {
return nil, err
}
err = rc.Set(accessTokenName(uid, client), accessToken.Token, timo).Err()
if err != nil {
return nil, err
}
return accessToken, nil
}
func NewRefreshToken(rc *redis.Client, client string, uid int64, expiresIn int, scope string) (*RefreshToken, error) {
refToken := &RefreshToken{
Client: client,
UserId: uid,
Token: uuid.NewV4().String(),
ExpiresIn: expiresIn,
ExpiresAt: time.Now().Add(time.Duration(expiresIn) * time.Second),
Scope: scope,
}
// 如果超时时间为0, 则不保存在redis中
if expiresIn == 0 {
return refToken, nil
}
// token与refreshToken的对应关系
// user和token的对应关系
err := utils.SetCache(rc, refToken.Token, refToken, expiresIn)
if err != nil {
return nil, err
}
err = utils.SetCache(rc, refreshTokenName(uid, client), refToken.Token, expiresIn)
if err != nil {
return nil, err
}
return refToken, nil
}
// refresh token
// 如果在redis中存在, 直接取出来用; 否则, 重新生成refresh token
func GrantRefreshToken(rc *redis.Client, client string, uid int64, expiresIn int, scope string) (*RefreshToken, error) {
var rt *RefreshToken
key := refreshTokenName(uid, client)
// 如果为空, 返回错误
err := utils.GetCache(rc, key, &rt)
if err != nil {
rt, err = NewRefreshToken(rc, client, uid, expiresIn, scope)
return rt, err
}
glog.Info("GrantRefreshToken: uid=%d refreshToken=%s\n", uid, rt.Token)
return rt, nil
}
//
// 根据refresh token刷新access token
// seconds是 access token的超时时间
func RefreshAccessToken(rc *redis.Client, clientName, refToken string, scope string, seconds int) (*AccessToken, error) {
var rt RefreshToken
err := utils.GetCache(rc, refToken, &rt)
if err != nil {
return nil, err
}
// 安全隐患: 写入日志
// 名字应该一致
if rt.Client != clientName {
glog.Error("RefreshAccessToken: refresh client name %s NOT equal with client name %s\n",
rt.Client, clientName)
return nil, fmt.Errorf("refresh client name %s NOT equal with client name %s", rt.Client, clientName)
}
if scope == "" {
scope = rt.Scope
}
token, err := GrantAccessToken(rc, rt.Client, rt.UserId, seconds, scope)
if err != nil {
return nil, err
}
glog.Info("RefreshAccessToken: uid=%d accessToken=%s", token.UserId, token.Token)
return token, nil
}
// 删除 access token
// 删除 refresh token
func RemoveToken(rc *redis.Client, token string) {
rc.Del(token)
}
func GetAccessToken(rc *redis.Client, token string) (*AccessToken, error) {
var at AccessToken
err := utils.GetCache(rc, token, &at)
return &at, err
}
func GetRefreshTokenByAccessToken(rc *redis.Client, at *AccessToken) (string, error) {
name := refreshTokenName(at.UserId, at.Client)
// 这里得到的是refresh token的name
rt, err := rc.Get(name).Result()
return rt, err
}
//
// token为access token
// 删除 access token 和 refresh token
func RemoveTokens(rc *redis.Client, token string) error {
accessToken, err := GetAccessToken(rc, token)
if err != nil {
return err
}
glog.Info("RemoveTokens: uid=%d accessToken=%s\n",
accessToken.UserId, token)
// 直接删除 access token
RemoveToken(rc, token)
rc.Del(accessTokenName(accessToken.UserId, accessToken.Client))
// 查找refresh token
rt, err := GetRefreshTokenByAccessToken(rc, accessToken)
if err != nil {
glog.Error("GetRefreshTokenByAccessToken failed: %v\n", err)
}
glog.Info("RemoveTokens: refreshToken=%s\n", rt)
RemoveToken(rc, rt)
// delete refresh token
rc.Del(refreshTokenName(accessToken.UserId, accessToken.Client))
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment