Skip to content

Instantly share code, notes, and snippets.

@weedge
Last active December 28, 2022 16:21
Show Gist options
  • Save weedge/c2d830dceef4163acc6dd749a05493db to your computer and use it in GitHub Desktop.
Save weedge/c2d830dceef4163acc6dd749a05493db to your computer and use it in GitHub Desktop.
// ¥¥¥ show me the money ¥¥¥
// author: weedge
// desc: redis change asset tx demo
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"golang.org/x/sync/singleflight"
"log"
"sync"
"time"
"github.com/go-redis/redis/v8"
"github.com/go-redsync/redsync/v4"
redsyncredis "github.com/go-redsync/redsync/v4/redis"
"github.com/go-redsync/redsync/v4/redis/goredis/v8"
//"github.com/jaevor/go-nanoid"
"github.com/segmentio/ksuid"
)
var (
ErrorNoEnoughAsset = errors.New("asset not enough")
ErrorSetAssetFromCBTimeOut = errors.New("set asset from db exec timeout")
ErrorSetAssetDoneOut = errors.New("set asset from db exec done out")
ErrorWatchAssetCasLoopMaxRetry = errors.New("watch user asset CAS loop reached maximum number of retries")
)
const (
RedisClientTypeReplica = 1
RedisClientTypeCluster = 2
)
const (
RedisKeyPrefixUserAssetInfo = "I.asset."
RedisKeyPrefixGiftInfo = "I.gift."
RedisKeyPrefixUserInfo = "I.user."
RedisKeyPrefixRoomInfo = "I.room."
RedisKeyPrefixUserAssetInfoLock = "L.asset."
RedisKeyPrefixAssetEventMsg = "M.asset."
)
const (
RedisKeyExpireUserAssetInfo = 86400 * time.Second
RedisKeyExpireGiftInfo = 7 * 86400 * time.Second
RedisKeyExpireUserInfo = 86400 * time.Second
RedisKeyExpireRoomInfo = 86400 * time.Second
RedisKeyExpireUserAssetInfoLock = 60 * time.Second
RedisKeyExpireAssetEventMsg = 86400 * time.Second
)
const (
DisLockerBlockRetryCn = 100
// MaxRetries Redis transactions use optimistic locking.
MaxRetries = 1000
)
const (
TimeOutSetAssetFromCB = 3 * time.Second
)
func getUserAssetInfoKey(userId int64, assetType int) string {
return fmt.Sprintf("%s{%d}.{%d}", RedisKeyPrefixUserAssetInfo, userId, assetType)
}
func getGiftInfoKey(giftId int64) string {
return fmt.Sprintf("%s{%d}", RedisKeyPrefixGiftInfo, giftId)
}
func getUserInfoKey(userId int64) string {
return fmt.Sprintf("%s{%d}", RedisKeyPrefixUserInfo, userId)
}
func getRoomInfoKey(roomId int64) string {
return fmt.Sprintf("%s{%d}", RedisKeyPrefixRoomInfo, roomId)
}
func getUserAssetInfoLockKey(userId int64, tag string) string {
return fmt.Sprintf("%s{%d}.%s", RedisKeyPrefixUserAssetInfoLock, userId, tag)
}
func getUserAssetEventMsgKey(opUserId int64, eventId string) string {
return fmt.Sprintf("%s{%d}.%s", RedisKeyPrefixAssetEventMsg, opUserId, eventId)
}
//var once sync.Once
//var instance redis.UniversalClient
//todo: initRedisClusterClient once init instance by config
// initRedisClusterDefaultClient init default instance
func initRedisClusterDefaultClient() redis.UniversalClient {
return redis.NewClusterClient(&redis.ClusterOptions{
//Addrs: []string{":26379", ":26380", ":26381", ":26382", ":26383", ":26384"},
Addrs: []string{":26379"},
Password: "",
Username: "",
MaxRetries: 3,
MinRetryBackoff: 3 * time.Second,
MaxRetryBackoff: 5 * time.Second,
DialTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
// connect pool
PoolSize: 100,
MinIdleConns: 10,
MaxConnAge: 60 * time.Second,
PoolTimeout: 5 * time.Second,
IdleTimeout: 30 * time.Second,
IdleCheckFrequency: 3 * time.Second,
// To route commands by latency or randomly, enable one of the following.
//RouteByLatency: true,
RouteRandomly: true,
})
}
func initRedisDefaultClient() redis.UniversalClient {
return redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // no password set
DB: 0, // use default DB
})
}
// initRedsync Create an instance of redisync to be used to obtain a mutual exclusion
func initRedsync(clients ...redis.UniversalClient) *redsync.Redsync {
var pools []redsyncredis.Pool
for _, client := range clients {
pools = append(pools, goredis.NewPool(client))
}
return redsync.New(pools...)
}
type IAssetCallBack interface {
getAsset(ctx context.Context) (assetDto *UserAssetDto, err error)
}
type UserAssetDto struct {
AssetCn int `json:"assetCn"`
AssetType int `json:"assetType"`
UserId int64 `json:"userId"`
}
func (m *UserAssetDto) getAsset(ctx context.Context) (assetDto *UserAssetDto, err error) {
// mock
return &UserAssetDto{AssetCn: m.AssetCn, AssetType: m.AssetType, UserId: m.UserId}, nil
}
type assetIncrHandler func(ctx context.Context) (incrAssetCn int)
type UserAssetCache struct {
redisClient redis.UniversalClient
getAssetLocker *redsync.Redsync
sfGroup *singleflight.Group
}
// NewUserAssetCache new asset cache obj with option
func NewUserAssetCache(redisType int, locker *redsync.Redsync) (obj *UserAssetCache) {
obj = &UserAssetCache{getAssetLocker: locker, sfGroup: &singleflight.Group{}}
switch redisType {
case RedisClientTypeReplica:
obj.redisClient = initRedisDefaultClient()
case RedisClientTypeCluster:
obj.redisClient = initRedisClusterDefaultClient()
}
return
}
func (m *UserAssetCache) getAsset(ctx context.Context, key string) (assetObj *UserAssetDto, err error) {
assetObj = &UserAssetDto{}
resCmd := m.redisClient.Get(ctx, key)
res, err := resCmd.Bytes()
if err != nil {
return
}
err = json.Unmarshal(res, assetObj)
return
}
func (m *UserAssetCache) setAsset(ctx context.Context, key string, val *UserAssetDto) (err error) {
res, err := json.Marshal(*val)
if err != nil {
return
}
resCmd := m.redisClient.Set(ctx, key, res, RedisKeyExpireUserAssetInfo)
err = resCmd.Err()
return
}
func (m *UserAssetCache) setAssetFromCallBack(ctx context.Context, key, lockerKey string, cb IAssetCallBack) (err error) {
_, err = m.getAsset(ctx, key)
if err == nil {
return
}
if err != nil && err != redis.Nil {
return
}
// _, err, _ = m.sfGroup.Do(key, func() (data interface{}, err error) {
resCh := m.sfGroup.DoChan(key, func() (data interface{}, err error) {
// Create a new lock client.
// Obtain a new mutex by using the same name for all instances wanting the
// same lock.
mutex := m.getAssetLocker.NewMutex(lockerKey, redsync.WithTries(DisLockerBlockRetryCn), redsync.WithExpiry(RedisKeyExpireUserAssetInfoLock))
// Obtain a lock for our given mutex. After this is successful, no one else
// can obtain the same lock (the same mutex name) until we unlock it.
if err = mutex.LockContext(ctx); err != nil {
return
}
defer func() {
// Release the lock so other processes or threads can obtain a lock.
if ok, err := mutex.UnlockContext(ctx); !ok || err != nil {
log.Printf("[error] key %s unlock error:%+v", lockerKey, err)
}
}()
data, err = m.getAsset(ctx, key)
if err == nil {
return
}
if err != nil && err != redis.Nil {
return
}
data, err = cb.getAsset(ctx)
if err != nil {
return
}
err = m.setAsset(ctx, key, data.(*UserAssetDto))
if err != nil {
return
}
log.Printf("set asset key %s val %+v is ok\n", key, data)
return
})
select {
case <-ctx.Done():
return ErrorSetAssetDoneOut
case <-time.After(TimeOutSetAssetFromCB):
return ErrorSetAssetFromCBTimeOut
case res := <-resCh:
return res.Err
}
//return
}
func (m *UserAssetCache) watchUserAssetChangeTx(ctx context.Context, key, lockerKey, eventMsgKey string, cb IAssetCallBack, handle assetIncrHandler) error {
err := m.setAssetFromCallBack(ctx, key, lockerKey, cb)
if err != nil {
return err
}
// Transactional function.
txf := func(tx *redis.Tx) error {
// Get the current value or zero.
bytes, err := tx.Get(ctx, key).Bytes()
if err != nil && err != redis.Nil {
return err
}
assetObj := &UserAssetDto{}
err = json.Unmarshal(bytes, assetObj)
if err != nil {
return err
}
// Actual operation (local in optimistic lock).
assetObj.AssetCn += handle(ctx)
if assetObj.AssetCn < 0 {
return ErrorNoEnoughAsset
}
bytes, err = json.Marshal(assetObj)
if err != nil {
return err
}
// Operation is commited only if the watched keys remain unchanged.
_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, key, bytes, RedisKeyExpireUserAssetInfo)
pipe.Set(ctx, eventMsgKey, 1, RedisKeyExpireUserAssetInfo)
return nil
})
return err
}
// Retry if the key has been changed.
for i := 0; i < MaxRetries; i++ {
err := m.redisClient.Watch(ctx, txf, key)
if err == nil {
// Success.
return nil
}
if err == redis.TxFailedErr {
//println(key, err.Error())
// Optimistic lock lost. Retry.
continue
}
// Return any other error.
return err
}
return ErrorWatchAssetCasLoopMaxRetry
}
var (
assetStringChangeLua = `
if redis.call("exists", KEYS[1]) ~= 1 then
return -1
end
if redis.call("type", KEYS[1]).ok=="string" then
local assetStr = redis.call('get',KEYS[1]);
local assetInfo = cjson.decode(assetStr);
if assetInfo.assetCn == nil then
assetInfo.assetCn = 0
end
local incr = 0
if ARGV[1] ~= nil then
incr = tonumber(ARGV[1]);
end
if assetInfo.assetCn+incr < 0 then
return -2;
end
assetInfo.assetCn = assetInfo.assetCn+incr;
assetStr = cjson.encode(assetInfo);
if redis.call('set',KEYS[1],assetStr,'ex',tonumber(ARGV[2]))
and redis.call('set',KEYS[2],1,'ex',tonumber(ARGV[3])) then
return 1;
end
return 0;
end
if redis.call("type", KEYS[1]).ok=="hash" then
local assetCn = redis.call('hincrby',KEYS[1],'assetCn',0);
local incr = 0
if ARGV[1] ~= nil then
incr = tonumber(ARGV[1]);
end
if assetCn+incr < 0 then
return -2;
end
assetCn = assetCn+incr;
if redis.call('hincrby',KEYS[1],'assetCn',incr)
and redis.call('expire',KEYS[1],tonumber(ARGV[2]))
and redis.call('set',KEYS[2],1,'ex',tonumber(ARGV[3])) then
return 1;
end
return 0;
end
`
)
const (
redisLuaAssetChangeResCodeSuccess = 1
redisLuaAssetChangeResCodeNoDone = 0
redisLuaAssetChangeResCodeNoExists = -1
redisLuaAssetChangeResCodeNoEnough = -2
)
func (m *UserAssetCache) userAssetChangeLuaAtomicTx(ctx context.Context, key string, lockerKey, eventMsgKey string, cb IAssetCallBack, handle assetIncrHandler) error {
val, err := m.redisClient.Eval(ctx, assetStringChangeLua, []string{key, eventMsgKey},
handle(ctx), RedisKeyExpireUserAssetInfo.Seconds(), RedisKeyExpireAssetEventMsg.Seconds()).Result()
if err != nil {
return err
}
if val == redisLuaAssetChangeResCodeSuccess {
return nil
}
if val == redisLuaAssetChangeResCodeNoEnough {
return ErrorNoEnoughAsset
}
if val == redisLuaAssetChangeResCodeNoExists {
err = m.setAssetFromCallBack(ctx, key, lockerKey, cb)
if err != nil {
return err
}
val, err = m.redisClient.Eval(ctx, assetStringChangeLua, []string{key, eventMsgKey},
handle(ctx), RedisKeyExpireUserAssetInfo.Seconds(), RedisKeyExpireAssetEventMsg.Seconds()).Result()
if err != nil {
return err
}
}
return nil
}
func main() {
client := initRedisClusterDefaultClient()
UserAssetCacheObj := NewUserAssetCache(RedisClientTypeCluster, initRedsync(client))
ctx := context.Background()
userId := int64(100)
assetType := 1
key := getUserAssetInfoKey(userId, assetType)
lockerKey := getUserAssetInfoLockKey(userId, "gift")
dto := &UserAssetDto{AssetCn: 10000, AssetType: assetType, UserId: userId}
obj, err := UserAssetCacheObj.getAsset(ctx, key)
if err != nil && err != redis.Nil {
log.Fatal(key, err)
}
log.Printf("start: %s,%+v\n", key, obj)
concurrencyTest(10000, "concurrencyTestSetAssetFromCallBack", func() {
err := UserAssetCacheObj.setAssetFromCallBack(ctx, key, lockerKey, dto)
if err != nil {
log.Fatal(key, lockerKey, err)
}
})
obj, err = UserAssetCacheObj.getAsset(ctx, key)
if err != nil && err != redis.Nil {
log.Fatal(key, err)
}
log.Printf("before: %s,%+v\n", key, obj)
//// for short url case, faster and shorter than uuid
//decenaryID, err := nanoid.CustomASCII("0123456789", 12)
//if err != nil {
// log.Fatal(err)
//}
concurrencyTest(100, "concurrencyTestWatchUserAssetChangeTx", func() {
ctx := context.Background()
//eventId := decenaryID()
// use ksuid generate id with generate time
eventId := ksuid.New().String()
eventMsgKey := getUserAssetEventMsgKey(userId, eventId)
log.Println("eventKey", eventMsgKey)
err := UserAssetCacheObj.watchUserAssetChangeTx(ctx, key, lockerKey, eventMsgKey, dto, func(ctx context.Context) (incrAssetCn int) {
return 500
//return 1
})
if err != nil {
log.Printf("[error] lockerKey:%s eventMsgKey:%s %s\n", lockerKey, eventMsgKey, err.Error())
}
})
obj, err = UserAssetCacheObj.getAsset(ctx, key)
if err != nil {
log.Fatal(err)
}
log.Printf("after: %s,%+v\n", key, obj)
concurrencyTest(100, "concurrencyTestUserAssetChangeLuaAtomicTx", func() {
ctx := context.Background()
//eventId := decenaryID()
// use ksuid generate id with generate time
eventId := ksuid.New().String()
eventMsgKey := getUserAssetEventMsgKey(userId, eventId)
log.Println("eventKey", eventMsgKey)
err := UserAssetCacheObj.userAssetChangeLuaAtomicTx(ctx, key, lockerKey, eventMsgKey, dto, func(ctx context.Context) (incrAssetCn int) {
return -500
//return 1
})
if err != nil {
log.Printf("[error] lockerKey:%s eventMsgKey:%s %s\n", lockerKey, eventMsgKey, err.Error())
}
})
obj, err = UserAssetCacheObj.getAsset(ctx, key)
if err != nil {
log.Fatal(err)
}
log.Printf("after: %s,%+v\n", key, obj)
if err = client.Close(); err != nil {
log.Fatal(err)
}
}
func concurrencyTest(n int, name string, handle func()) {
startTime := time.Now()
wg := &sync.WaitGroup{}
wg.Add(n)
for i := 1; i <= n; i++ {
go func(i int) {
defer wg.Done()
handle()
}(i)
}
wg.Wait()
println(name, "concurrency", n, "cost", time.Now().Sub(startTime).Microseconds(), "microsecond")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment