Cloudreve/pkg/cache/redis.go

203 lines
3.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cache
import (
"bytes"
"encoding/gob"
"github.com/HFO4/cloudreve/pkg/util"
"github.com/gomodule/redigo/redis"
"strconv"
"time"
)
// RedisStore redis存储驱动
type RedisStore struct {
pool *redis.Pool
}
type item struct {
Value interface{}
}
func serializer(value interface{}) ([]byte, error) {
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
storeValue := item{
Value: value,
}
err := enc.Encode(storeValue)
if err != nil {
return nil, err
}
return buffer.Bytes(), nil
}
func deserializer(value []byte) (interface{}, error) {
var res item
buffer := bytes.NewReader(value)
dec := gob.NewDecoder(buffer)
err := dec.Decode(&res)
if err != nil {
return nil, err
}
return res.Value, nil
}
// NewRedisStore 创建新的redis存储
func NewRedisStore(size int, network, address, password, database string) *RedisStore {
return &RedisStore{
pool: &redis.Pool{
MaxIdle: size,
IdleTimeout: 240 * time.Second,
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
Dial: func() (redis.Conn, error) {
db, err := strconv.Atoi(database)
if err != nil {
return nil, err
}
c, err := redis.Dial(
network,
address,
redis.DialDatabase(db),
redis.DialPassword(password),
)
if err != nil {
util.Log().Warning("无法创建Redis连接%s", err)
return nil, err
}
return c, nil
},
},
}
}
// Set 存储值
func (store *RedisStore) Set(key string, value interface{}, ttl int) error {
rc := store.pool.Get()
defer rc.Close()
serialized, err := serializer(value)
if err != nil {
return err
}
if rc.Err() != nil {
return rc.Err()
}
if ttl > 0 {
_, err = rc.Do("SETEX", key, ttl, serialized)
} else {
_, err = rc.Do("SET", key, serialized)
}
if err != nil {
return err
}
return nil
}
// Get 取值
func (store *RedisStore) Get(key string) (interface{}, bool) {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return nil, false
}
v, err := redis.Bytes(rc.Do("GET", key))
if err != nil || v == nil {
return nil, false
}
finalValue, err := deserializer(v)
if err != nil {
return nil, false
}
return finalValue, true
}
// Gets 批量取值
func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return nil, keys
}
var queryKeys = make([]string, len(keys))
for key, value := range keys {
queryKeys[key] = prefix + value
}
v, err := redis.ByteSlices(rc.Do("MGET", redis.Args{}.AddFlat(queryKeys)...))
if err != nil {
return nil, keys
}
var res = make(map[string]interface{})
var missed = make([]string, 0, len(keys))
for key, value := range v {
decoded, err := deserializer(value)
if err != nil || decoded == nil {
missed = append(missed, keys[key])
} else {
res[keys[key]] = decoded
}
}
// 解码所得值
return res, missed
}
// Sets 批量设置值
func (store *RedisStore) Sets(values map[string]interface{}, prefix string) error {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return rc.Err()
}
var setValues = make(map[string]interface{})
// 编码待设置值
for key, value := range values {
serialized, err := serializer(value)
if err != nil {
return err
}
setValues[prefix+key] = serialized
}
_, err := rc.Do("MSET", redis.Args{}.AddFlat(setValues)...)
if err != nil {
return err
}
return nil
}
// Delete 批量删除给定的键
func (store *RedisStore) Delete(keys []string, prefix string) error {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return rc.Err()
}
// 处理前缀
for i := 0; i < len(keys); i++ {
keys[i] = prefix + keys[i]
}
_, err := rc.Do("DEL", redis.Args{}.AddFlat(keys)...)
if err != nil {
return err
}
return nil
}