diff --git a/models/policy.go b/models/policy.go index 5f1c46f..25ef899 100644 --- a/models/policy.go +++ b/models/policy.go @@ -1,6 +1,7 @@ package model import ( + "encoding/gob" "encoding/json" "github.com/HFO4/cloudreve/pkg/cache" "github.com/HFO4/cloudreve/pkg/util" @@ -42,11 +43,16 @@ type PolicyOption struct { RangeTransferEnabled bool `json:"range_transfer_enabled"` } +func init() { + // 注册缓存用到的复杂结构 + gob.Register(Policy{}) +} + // GetPolicyByID 用ID获取存储策略 func GetPolicyByID(ID interface{}) (Policy, error) { // 尝试读取缓存 cacheKey := "policy_" + strconv.Itoa(int(ID.(uint))) - if policy, ok := cache.Store.Get(cacheKey); ok { + if policy, ok := cache.Get(cacheKey); ok { return policy.(Policy), nil } @@ -55,7 +61,7 @@ func GetPolicyByID(ID interface{}) (Policy, error) { // 写入缓存 if result.Error == nil { - _ = cache.Store.Set(cacheKey, policy) + _ = cache.Set(cacheKey, policy) } return policy, result.Error diff --git a/models/setting.go b/models/setting.go index 1e6f783..268ceee 100644 --- a/models/setting.go +++ b/models/setting.go @@ -13,14 +13,6 @@ type Setting struct { Value string `gorm:"size:‎65535"` } -// settingCache 设置项缓存 -var settingCache = make(map[string]string) - -// ClearCache 清空设置缓存 -func ClearCache() { - settingCache = make(map[string]string) -} - // IsTrueVal 返回设置的值是否为真 func IsTrueVal(val string) bool { return val == "1" || val == "true" @@ -32,13 +24,13 @@ func GetSettingByName(name string) string { // 优先从缓存中查找 cacheKey := "setting_" + name - if optionValue, ok := cache.Store.Get(cacheKey); ok { + if optionValue, ok := cache.Get(cacheKey); ok { return optionValue.(string) } // 尝试数据库中查找 result := DB.Where("name = ?", name).First(&setting) if result.Error == nil { - _ = cache.Store.Set(cacheKey, setting.Value) + _ = cache.Set(cacheKey, setting.Value) return setting.Value } return "" @@ -48,13 +40,14 @@ func GetSettingByName(name string) string { // TODO 其他设置获取也使用缓存 func GetSettingByNames(names []string) map[string]string { var queryRes []Setting - res := make(map[string]string) + res, miss := cache.GetsSettingByName(names) - DB.Where("name IN (?)", names).Find(&queryRes) + DB.Where("name IN (?)", miss).Find(&queryRes) for _, setting := range queryRes { res[setting.Name] = setting.Value } + _ = cache.SetSettings(res) return res } diff --git a/models/setting_test.go b/models/setting_test.go index ae3fd8a..6c5afb5 100644 --- a/models/setting_test.go +++ b/models/setting_test.go @@ -73,11 +73,11 @@ func TestGetSettingByNames(t *testing.T) { //找到其中一个设置时 rows = sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName", "Cloudreve", "basic") + AddRow("siteName2", "Cloudreve", "basic") mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings = GetSettingByNames([]string{"siteName", "siteDes2333"}) + settings = GetSettingByNames([]string{"siteName2", "siteDes2333"}) asserts.Equal(map[string]string{ - "siteName": "Cloudreve", + "siteName2": "Cloudreve", }, settings) asserts.NoError(mock.ExpectationsWereMet()) @@ -87,6 +87,17 @@ func TestGetSettingByNames(t *testing.T) { settings = GetSettingByNames([]string{"siteName2333", "siteDes2333"}) asserts.Equal(map[string]string{}, settings) asserts.NoError(mock.ExpectationsWereMet()) + + // 一个设置命中缓存 + mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WithArgs("siteDes2").WillReturnRows(sqlmock.NewRows([]string{"name", "value", "type"}). + AddRow("siteDes2", "Cloudreve2", "basic")) + settings = GetSettingByNames([]string{"siteName", "siteDes2"}) + asserts.Equal(map[string]string{ + "siteName": "Cloudreve", + "siteDes2": "Cloudreve2", + }, settings) + asserts.NoError(mock.ExpectationsWereMet()) + } // TestGetSettingByName 测试GetSettingByName diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go index eb3bd8b..0f0ff64 100644 --- a/pkg/cache/driver.go +++ b/pkg/cache/driver.go @@ -1,10 +1,63 @@ package cache +import ( + "github.com/HFO4/cloudreve/pkg/conf" + "github.com/gin-gonic/gin" +) + // Store 缓存存储器 -var Store Driver = NewMemoStore() +var Store Driver + +func init() { + Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0") + return + + if conf.RedisConfig.Server == "" || gin.Mode() == gin.TestMode { + Store = NewMemoStore() + } else { + Store = NewRedisStore(10, "tcp", conf.RedisConfig.Server, conf.RedisConfig.Password, conf.RedisConfig.DB) + } +} // Driver 键值缓存存储容器 type Driver interface { + // 设置值 Set(key string, value interface{}) error + // 取值 Get(key string) (interface{}, bool) + // 批量取值,返回成功取值的map即不存在的值 + Gets(keys []string, prefix string) (map[string]interface{}, []string) + // 批量设置值 + Sets(values map[string]interface{}, prefix string) error +} + +// Set 设置缓存值 +func Set(key string, value interface{}) error { + return Store.Set(key, value) +} + +// Get 获取缓存值 +func Get(key string) (interface{}, bool) { + return Store.Get(key) +} + +// GetsSettingByName 根据名称批量获取设置项缓存 +func GetsSettingByName(keys []string) (map[string]string, []string) { + raw, miss := Store.Gets(keys, "setting_") + + res := make(map[string]string, len(raw)) + for k, v := range raw { + res[k] = v.(string) + } + + return res, miss +} + +// SetSettings 批量设置站点设置缓存 +func SetSettings(values map[string]string) error { + var toBeSet = make(map[string]interface{}, len(values)) + for key, value := range values { + toBeSet[key] = interface{}(value) + } + return Store.Sets(toBeSet, "setting_") } diff --git a/pkg/cache/memo.go b/pkg/cache/memo.go index fca7a20..858665e 100644 --- a/pkg/cache/memo.go +++ b/pkg/cache/memo.go @@ -24,3 +24,27 @@ func (store *MemoStore) Set(key string, value interface{}) error { func (store *MemoStore) Get(key string) (interface{}, bool) { return store.Store.Load(key) } + +// Gets 批量取值 +func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) { + var res = make(map[string]interface{}) + var notFound = make([]string, 0, len(keys)) + + for _, key := range keys { + if value, ok := store.Store.Load(prefix + key); ok { + res[key] = value + } else { + notFound = append(notFound, key) + } + } + + return res, notFound +} + +// Sets 批量设置值 +func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error { + for key, value := range values { + store.Store.Store(prefix+key, value) + } + return nil +} diff --git a/pkg/cache/memo_test.go b/pkg/cache/memo_test.go index 4a59b04..2f0be0b 100644 --- a/pkg/cache/memo_test.go +++ b/pkg/cache/memo_test.go @@ -59,3 +59,50 @@ func TestMemoStore_Get(t *testing.T) { } } + +func TestMemoStore_Gets(t *testing.T) { + asserts := assert.New(t) + store := NewMemoStore() + + err := store.Set("1", "1,val") + err = store.Set("2", "2,val") + err = store.Set("3", "3,val") + err = store.Set("4", "4,val") + asserts.NoError(err) + + // 全部命中 + { + values, miss := store.Gets([]string{"1", "2", "3", "4"}, "") + asserts.Len(values, 4) + asserts.Len(miss, 0) + } + + // 命中一半 + { + values, miss := store.Gets([]string{"1", "2", "9", "10"}, "") + asserts.Len(values, 2) + asserts.Equal([]string{"9", "10"}, miss) + } +} + +func TestMemoStore_Sets(t *testing.T) { + asserts := assert.New(t) + store := NewMemoStore() + + err := store.Sets(map[string]interface{}{ + "1": "1.val", + "2": "2.val", + "3": "3.val", + "4": "4.val", + }, "test_") + asserts.NoError(err) + + vals, miss := store.Gets([]string{"1", "2", "3", "4"}, "test_") + asserts.Len(miss, 0) + asserts.Equal(map[string]interface{}{ + "1": "1.val", + "2": "2.val", + "3": "3.val", + "4": "4.val", + }, vals) +} diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go new file mode 100644 index 0000000..be30744 --- /dev/null +++ b/pkg/cache/redis.go @@ -0,0 +1,163 @@ +package cache + +import ( + "bytes" + "encoding/gob" + "github.com/HFO4/cloudreve/pkg/util" + "github.com/garyburd/redigo/redis" + "strconv" + "time" +) + +// RedisStore redis存储驱动 +type RedisStore struct { + pool *redis.Pool +} + +type item struct { + Value interface{} +} + +// NewRedisStore 创建新的redis存储 +func NewRedisStore(size int, network, address, password, database string) *RedisStore { + return &RedisStore{ + pool: &redis.Pool{ + MaxIdle: size, + IdleTimeout: 240 * time.Second, + TestOnBorrow: func(c redis.Conn, t time.Time) error { + _, err := c.Do("PING") + return err + }, + Dial: func() (redis.Conn, error) { + db, err := strconv.Atoi(database) + if err != nil { + return nil, err + } + + c, err := redis.Dial( + network, + address, + redis.DialDatabase(db), + redis.DialPassword(password), + ) + if err != nil { + util.Log().Warning("无法创建Redis连接:%s", err) + return nil, err + } + return c, nil + }, + }, + } +} + +// Set 存储值 +func (store *RedisStore) Set(key string, value interface{}) error { + rc := store.pool.Get() + defer rc.Close() + + var buffer bytes.Buffer + enc := gob.NewEncoder(&buffer) + storeValue := item{ + Value: value, + } + err := enc.Encode(storeValue) + if err != nil { + return err + } + + if rc.Err() == nil { + _, err := rc.Do("SET", key, buffer.Bytes()) + if err != nil { + return err + } + return nil + } + + return rc.Err() +} + +// Get 取值 +func (store *RedisStore) Get(key string) (interface{}, bool) { + rc := store.pool.Get() + defer rc.Close() + + v, err := redis.Bytes(rc.Do("GET", key)) + if err != nil { + return nil, false + } + + var res item + buffer := bytes.NewReader(v) + dec := gob.NewDecoder(buffer) + err = dec.Decode(&res) + if err != nil { + return nil, false + } + + return res.Value, true + +} + +// Gets 批量取值 +func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) { + rc := store.pool.Get() + defer rc.Close() + + var queryKeys = make([]string, len(keys)) + for key, value := range keys { + queryKeys[key] = prefix + value + } + + v, err := redis.ByteSlices(rc.Do("MGET", redis.Args{}.AddFlat(queryKeys)...)) + if err != nil { + return nil, keys + } + + var res = make(map[string]interface{}) + var missed = make([]string, 0, len(keys)) + + for key, value := range v { + var decoded item + buffer := bytes.NewReader(value) + dec := gob.NewDecoder(buffer) + err = dec.Decode(&decoded) + if err != nil || decoded.Value == nil { + missed = append(missed, keys[key]) + } else { + res[keys[key]] = decoded.Value + } + } + // 解码所得值 + return res, missed +} + +// Sets 批量设置值 +func (store *RedisStore) Sets(values map[string]interface{}, prefix string) error { + rc := store.pool.Get() + defer rc.Close() + var setValues = make(map[string]interface{}) + + // 编码待设置值 + for key, value := range values { + var buffer bytes.Buffer + enc := gob.NewEncoder(&buffer) + storeValue := item{ + Value: value, + } + err := enc.Encode(storeValue) + if err != nil { + return err + } + setValues[prefix+key] = buffer.Bytes() + } + + if rc.Err() == nil { + _, err := rc.Do("MSET", redis.Args{}.AddFlat(setValues)...) + if err != nil { + return err + } + return nil + } + + return rc.Err() +}