This commit is contained in:
Cloudreamr 2024-02-23 18:12:35 +00:00 committed by GitHub
commit 5734482228
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
202 changed files with 4197 additions and 18341 deletions

2
.gitmodules vendored
View file

@ -1,3 +1,3 @@
[submodule "assets"]
path = assets
url = https://github.com/cloudreve/frontend.git
url = https://github.com/Cloudreamr/frontend.git

Binary file not shown.

View file

@ -1,15 +1,27 @@
package bootstrap
import (
"encoding/json"
// "bytes"
// "crypto/aes"
// "crypto/cipher"
// "encoding/gob"
// "encoding/json"
"fmt"
// "io/ioutil"
// "os"
// "strconv"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/hashicorp/go-version"
"github.com/cloudreve/Cloudreve/v3/pkg/vol"
// "github.com/cloudreve/Cloudreve/v3/pkg/request"
// "github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
)
var matrix []byte
var APPID string
// InitApplication 初始化应用常量
func InitApplication() {
fmt.Print(`
@ -19,40 +31,95 @@ func InitApplication() {
/ /___| | (_) | |_| | (_| | | | __/\ V / __/
\____/|_|\___/ \__,_|\__,_|_| \___| \_/ \___|
V` + conf.BackendVersion + ` Commit #` + conf.LastCommit + ` Pro=` + conf.IsPro + `
V` + conf.BackendVersion + ` Commit #` + conf.LastCommit + ` Plus=` + conf.IsPlus + `
================================================
`)
go CheckUpdate()
// data, err := ioutil.ReadFile(util.RelativePath(string([]byte{107, 101, 121, 46, 98, 105, 110})))
// if err != nil {
// util.Log().Panic("%s", err)
// }
//table := deSign(data)
//constant.HashIDTable = table["table"].([]int)
//APPID = table["id"].(string)
//matrix = table["pic"].([]byte)
APPID = `1145141919810`
matrix = []byte{1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0}
vol.ClientSecret = APPID
}
type GitHubRelease struct {
URL string `json:"html_url"`
Name string `json:"name"`
Tag string `json:"tag_name"`
// InitCustomRoute 初始化自定义路由
func InitCustomRoute(group *gin.RouterGroup) {
group.GET(string([]byte{98, 103}), func(c *gin.Context) {
c.Header("content-type", "image/png")
c.Writer.Write(matrix)
})
group.GET("id", func(c *gin.Context) {
c.String(200, APPID)
})
}
// CheckUpdate 检查更新
func CheckUpdate() {
client := request.NewClient()
res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse()
if err != nil {
util.Log().Warning("更新检查失败, %s", err)
return
}
// func deSign(data []byte) map[string]interface{} {
// res := decode(data, seed())
// dec := gob.NewDecoder(bytes.NewReader(res))
// obj := map[string]interface{}{}
// err := dec.Decode(&obj)
// if err != nil {
// util.Log().Panic("您仍在使用旧版的授权文件,请前往 https://pro.cloudreve.org/ 登录下载最新的授权文件")
// os.Exit(-1)
// }
// return checkKeyUpdate(obj)
// }
var list []GitHubRelease
if err := json.Unmarshal([]byte(res), &list); err != nil {
util.Log().Warning("更新检查失败, %s", err)
return
}
// func checkKeyUpdate(table map[string]interface{}) map[string]interface{} {
// if table["version"].(string) != conf.KeyVersion {
// util.Log().Info("正在自动更新授权文件...")
// reqBody := map[string]string{
// "secret": table["secret"].(string),
// "id": table["id"].(string),
// }
// reqBodyString, _ := json.Marshal(reqBody)
// client := request.NewClient()
// resp := client.Request("POST", "https://pro.cloudreve.org/Api/UpdateKey",
// bytes.NewReader(reqBodyString)).CheckHTTPResponse(200)
// if resp.Err != nil {
// util.Log().Panic("授权文件更新失败, %s", resp.Err)
// }
// keyContent, _ := ioutil.ReadAll(resp.Response.Body)
// ioutil.WriteFile(util.RelativePath(string([]byte{107, 101, 121, 46, 98, 105, 110})), keyContent, os.ModePerm)
if len(list) > 0 {
present, err1 := version.NewVersion(conf.BackendVersion)
latest, err2 := version.NewVersion(list[0].Tag)
if err1 == nil && err2 == nil && latest.GreaterThan(present) {
util.Log().Info("有新的版本 [%s] 可用,下载:%s", list[0].Name, list[0].URL)
}
}
// return deSign(keyContent)
// }
}
// return table
// }
// func seed() []byte {
// res := []int{8}
// s := "20210323"
// m := 1 << 20
// a := 9
// b := 7
// for i := 1; i < 23; i++ {
// res = append(res, (a*res[i-1]+b)%m)
// s += strconv.Itoa(res[i])
// }
// return []byte(s)
// }
// func decode(cryted []byte, key []byte) []byte {
// block, _ := aes.NewCipher(key[:32])
// blockSize := block.BlockSize()
// blockMode := cipher.NewCBCDecrypter(block, key[:blockSize])
// orig := make([]byte, len(cryted))
// blockMode.CryptBlocks(orig, cryted)
// orig = pKCS7UnPadding(orig)
// return orig
// }
// func pKCS7UnPadding(origData []byte) []byte {
// length := len(origData)
// unpadding := int(origData[length-1])
// return origData[:(length - unpadding)]
// }

3
bootstrap/constant/constant.go Executable file
View file

@ -0,0 +1,3 @@
package constant
// var HashIDTable = []int{0, 1, 2, 3, 4, 5}

View file

@ -1,6 +1,9 @@
package bootstrap
import (
"io/fs"
"path/filepath"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/models/scripts"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
@ -14,8 +17,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/gin-gonic/gin"
"io/fs"
"path/filepath"
)
// Init 初始化启动

View file

@ -79,8 +79,8 @@ func InitStatic(statics fs.FS) {
}
staticName := "cloudreve-frontend"
if conf.IsPro == "true" {
staticName += "-pro"
if conf.IsPlus == "true" {
staticName += "-plus"
}
if v.Name != staticName {
@ -102,8 +102,8 @@ func Eject(statics fs.FS) {
util.Log().Panic("Failed to initialize static resources: %s", err)
}
var walk func(relPath string, d fs.DirEntry, err error) error
walk = func(relPath string, d fs.DirEntry, err error) error {
// var walk func(relPath string, d fs.DirEntry, err error) error
walk := func(relPath string, d fs.DirEntry, err error) error {
if err != nil {
return errors.Errorf("Failed to read info of %q: %s, skipping...", relPath, err)
}
@ -111,11 +111,11 @@ func Eject(statics fs.FS) {
if !d.IsDir() {
// 写入文件
out, err := util.CreatNestedFile(filepath.Join(util.RelativePath(""), StaticFolder, relPath))
defer out.Close()
if err != nil {
return errors.Errorf("Failed to create file %q: %s, skipping...", relPath, err)
}
defer out.Close()
util.Log().Info("Ejecting %q...", relPath)
obj, _ := embedFS.Open(relPath)

51
go.mod
View file

@ -3,11 +3,10 @@ module github.com/cloudreve/Cloudreve/v3
go 1.18
require (
github.com/DATA-DOG/go-sqlmock v1.3.3
github.com/HFO4/aliyun-oss-go-sdk v2.2.3+incompatible
github.com/aws/aws-sdk-go v1.31.5
github.com/duo-labs/webauthn v0.0.0-20220330035159-03696f3d4499
github.com/fatih/color v1.9.0
github.com/fatih/color v1.16.0
github.com/gin-contrib/cors v1.3.0
github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8
github.com/gin-contrib/sessions v0.0.5
@ -20,22 +19,25 @@ require (
github.com/gofrs/uuid v4.0.0+incompatible
github.com/gomodule/redigo v2.0.0+incompatible
github.com/google/go-querystring v1.0.0
github.com/google/uuid v1.3.0
github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.2.1
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/go-version v1.3.0
github.com/iGoogle-ink/gopay v1.5.36
github.com/jinzhu/gorm v1.9.11
github.com/juju/ratelimit v1.0.1
github.com/mholt/archiver/v4 v4.0.0-alpha.6
github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.2.0
github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371
github.com/qiniu/go-sdk/v7 v7.11.1
github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1
github.com/robfig/cron/v3 v3.0.1
github.com/samber/lo v1.38.1
github.com/smartwalle/alipay/v3 v3.2.20
github.com/speps/go-hashids v2.0.0+incompatible
github.com/stretchr/testify v1.7.2
github.com/stretchr/testify v1.8.3
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.393
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf v1.0.393
@ -70,10 +72,10 @@ require (
github.com/fullstorydev/grpcurl v1.8.1 // indirect
github.com/fxamacker/cbor/v2 v2.4.0 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.9.8 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.1.0 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
@ -84,7 +86,6 @@ require (
github.com/google/btree v1.0.1 // indirect
github.com/google/certificate-transparency-go v1.1.2-0.20210511102531-373a877eec92 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/googleapis/gax-go/v2 v2.0.5 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
@ -98,10 +99,10 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.15.1 // indirect
github.com/klauspost/pgzip v1.2.5 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/lib/pq v1.10.3 // indirect
github.com/mattn/go-colorable v0.1.4 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.12 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect
@ -110,7 +111,7 @@ require (
github.com/mozillazg/go-httpheader v0.2.1 // indirect
github.com/nwaples/rardecode/v2 v2.0.0-beta.2 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/pelletier/go-toml/v2 v2.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.10.0 // indirect
@ -122,13 +123,16 @@ require (
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/satori/go.uuid v1.2.0 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/smartwalle/ncrypto v1.0.4 // indirect
github.com/smartwalle/ngx v1.0.9 // indirect
github.com/smartwalle/nsign v1.0.9 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spf13/cobra v1.1.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/therootcompany/xz v1.0.1 // indirect
github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect
github.com/ugorji/go/codec v1.2.7 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/ulikunitz/xz v0.5.10 // indirect
github.com/urfave/cli v1.22.5 // indirect
github.com/x448/float16 v0.8.4 // indirect
@ -147,20 +151,19 @@ require (
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.7.0 // indirect
go.uber.org/zap v1.16.0 // indirect
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect
golang.org/x/net v0.0.0-20220630215102-69896b714898 // indirect
golang.org/x/mod v0.8.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.4.0 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.9.0 // indirect
golang.org/x/tools v0.6.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20210510173355-fb37daa5cd7a // indirect
google.golang.org/grpc v1.37.0 // indirect
google.golang.org/protobuf v1.28.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/cheggaaa/pb.v1 v1.0.28 // indirect
gopkg.in/mail.v2 v2.3.1 // indirect

102
go.sum
View file

@ -62,8 +62,6 @@ github.com/Azure/go-autorest v12.0.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSW
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/DATA-DOG/go-sqlmock v1.3.3 h1:CWUqKXe0s8A2z6qCgkP4Kru7wC11YoAnoupUKFDnH08=
github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0=
github.com/GeertJohan/go.rice v1.0.2/go.mod h1:af5vUNlDNkCjOZeSGFgIJxDje9qdjsO6hshx0gTmZt4=
github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20191009163259-e802c2cb94ae/go.mod h1:mjwGPas4yKduTyubHvD1Atl9r1rUq8DfVy+gkVvZ+oo=
@ -236,6 +234,8 @@ github.com/etcd-io/gofail v0.0.0-20190801230047-ad7f989257ca/go.mod h1:49H/RkXP8
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM=
github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c=
@ -289,12 +289,14 @@ github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBY
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho=
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.8.0/go.mod h1:9JhgTzTaE31GZDpH/HSvHiRJrJ3iKAgqqH0Bl/Ocjdk=
github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2BOGlCyvTqsp/xIw=
github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU=
@ -305,8 +307,8 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/goccy/go-json v0.9.8 h1:DxXB6MLd6yyel7CLph8EwNIonUtVZd3Ue5iRcL4DQCE=
github.com/goccy/go-json v0.9.8/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
@ -488,6 +490,8 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO
github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo=
github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4=
github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg=
github.com/iGoogle-ink/gopay v1.5.36 h1:RctuoiEdTbiXOmzQ9i1388opwAOjheUDIFoHl1EeNr8=
github.com/iGoogle-ink/gopay v1.5.36/go.mod h1:JADVzrfz9kzGMCgV7OzJ954pqwMU7PotYMAjP84YKIE=
github.com/iancoleman/strcase v0.0.0-20180726023541-3605ed457bf7/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
@ -566,14 +570,15 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/go-gypsy v1.0.0/go.mod h1:chkXM0zjdpXOiqkCW1XcCHDfjfk14PH2KKkQWxfJUcU=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/letsencrypt/pkcs11key/v4 v4.0.0/go.mod h1:EFUvBDay26dErnNb70Nd0/VW3tJiIbETBPTl9ATXQag=
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg=
github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/lyft/protoc-gen-star v0.5.1/go.mod h1:9toiA3cC7z5uVbODF7kEQ91Xn7XNFkVUl+SrEe+ZORU=
@ -585,6 +590,8 @@ github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcncea
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-ieproxy v0.0.0-20190610004146-91bb50d98149/go.mod h1:31jz6HNzdxOmlERGGEc4v/dMssOfmp2p5bT/okiKFFc=
github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
@ -593,8 +600,11 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng=
github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
@ -686,8 +696,8 @@ github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FI
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw=
github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac=
github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
@ -744,12 +754,12 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/pseudomuto/protoc-gen-doc v1.4.1/go.mod h1:exDTOVwqpp30eV/EDPFLZy3Pwr2sn6hBC1WIYH/UbIg=
github.com/pseudomuto/protokit v0.2.0/go.mod h1:2PdH30hxVHsup8KpBTOXTBeMVhJZVio3Q8ViKSAXT0Q=
github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371 h1:8VWtyY2IwjEQZSNT4Kyyct9zv9hoegD5GQhFr+TMdCI=
github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371/go.mod h1:9UFrQveqNm3ELF6HSvMtDR3KYpJ7Ib9s0WVmYhaUBlU=
github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdkk=
github.com/qiniu/go-sdk/v7 v7.11.1 h1:/LZ9rvFS4p6SnszhGv11FNB1+n4OZvBCwFg7opH5Ovs=
github.com/qiniu/go-sdk/v7 v7.11.1/go.mod h1:btsaOc8CA3hdVloULfFdDgDc+g4f3TDZEFsDY0BLE+w=
github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs=
github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 h1:leEwA4MD1ew0lNgzz6Q4G76G3AEfeci+TMggN6WuFRs=
github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1/go.mod h1:JaY6n2sDr+z2WTsXkOmNRUfDy6FN0L6Nk7x06ndm4tY=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQi3hZP0fzvnsLmzXZdQGc4bEcgu24cp+d4M=
@ -790,6 +800,14 @@ github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrf
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/smartwalle/alipay/v3 v3.2.20 h1:IjpG3YYgUgzCfS0z/EHlUbbr0OlrmOBHUst/3FzToYE=
github.com/smartwalle/alipay/v3 v3.2.20/go.mod h1:KWg91KsY+eIOf26ZfZeH7bed1bWulGpGrL1ErHF3jWo=
github.com/smartwalle/ncrypto v1.0.4 h1:P2rqQxDepJwgeO5ShoC+wGcK2wNJDmcdBOWAksuIgx8=
github.com/smartwalle/ncrypto v1.0.4/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk=
github.com/smartwalle/ngx v1.0.9 h1:pUXDvWRZJIHVrCKA1uZ15YwNti+5P4GuJGbpJ4WvpMw=
github.com/smartwalle/ngx v1.0.9/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0=
github.com/smartwalle/nsign v1.0.9 h1:8poAgG7zBd8HkZy9RQDwasC6XZvJpDGQWSjzL2FZL6E=
github.com/smartwalle/nsign v1.0.9/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/assertions v1.0.0 h1:UVQPSSmc3qtTi+zPPkCXvZX9VvW/xT/NsRvKfwY81a8=
github.com/smartystreets/assertions v1.0.0/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM=
@ -829,8 +847,10 @@ github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3
github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v0.0.0-20170130113145-4d4bfba8f1d1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@ -838,8 +858,11 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s=
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393 h1:hfhmMk7j4uDMRkfrrIOneMVXPBEhy3HSYiWX0gWoyhc=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393/go.mod h1:482ndbWuXqgStZNCqE88UoZeDveIt0juS7MY71Vangg=
@ -863,11 +886,10 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1
github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoik09Xen7gje4m9ERNah1d1PPsVq1VEx9vE4=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8=
github.com/ulikunitz/xz v0.5.7/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8=
@ -970,12 +992,13 @@ golang.org/x/crypto v0.0.0-20191117063200-497ca9f6d64f/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201124201722-c8d3bf9c5392/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -994,6 +1017,8 @@ golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMx
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@ -1018,8 +1043,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 h1:LQmS1nU0twXLA96Kt7U9qtHJEbBk3z6Q0V4UXjZkpr4=
golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -1071,8 +1096,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw=
golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -1100,8 +1125,9 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -1170,8 +1196,11 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211020174200-9d6173849985/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -1182,8 +1211,10 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@ -1255,12 +1286,11 @@ golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 h1:0c3L82FDQ5rt1bjTBlchS8t6RQ6299/+5bWMnRLh+uI=
golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
@ -1393,8 +1423,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.25.1-0.20200805231151-a709e31e5d12/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=

48
main.go
View file

@ -69,10 +69,11 @@ func main() {
// 收到信号后关闭服务器
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go shutdown(sigChan, server)
wait := shutdown(sigChan, server)
defer func() {
<-sigChan
sigChan <- syscall.SIGTERM
<-wait
}()
// 如果启用了SSL
@ -104,7 +105,7 @@ func main() {
util.Log().Info("Listening to %q", conf.SystemConfig.Listen)
server.Addr = conf.SystemConfig.Listen
if err := server.ListenAndServe(); err != nil {
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
util.Log().Error("Failed to listen to %q: %s", conf.SystemConfig.Listen, err)
}
}
@ -133,26 +134,29 @@ func RunUnix(server *http.Server) error {
return server.Serve(listener)
}
func shutdown(sigChan chan os.Signal, server *http.Server) {
sig := <-sigChan
util.Log().Info("Signal %s received, shutting down server...", sig)
ctx := context.Background()
if conf.SystemConfig.GracePeriod != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.SystemConfig.GracePeriod)*time.Second)
func shutdown(sigChan chan os.Signal, server *http.Server) chan struct{} {
wait := make(chan struct{})
go func() {
sig := <-sigChan
util.Log().Info("Signal %s received, shutting down server...", sig)
if conf.SystemConfig.GracePeriod == 0 {
conf.SystemConfig.GracePeriod = 10
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(conf.SystemConfig.GracePeriod)*time.Second)
defer cancel()
}
// Shutdown http server
err := server.Shutdown(ctx)
if err != nil {
util.Log().Error("Failed to shutdown server: %s", err)
}
// Shutdown http server
err := server.Shutdown(ctx)
if err != nil {
util.Log().Error("Failed to shutdown server: %s", err)
}
// Persist in-memory cache
if err := cache.Store.Persist(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)); err != nil {
util.Log().Warning("Failed to persist cache: %s", err)
}
// Persist in-memory cache
if err := cache.Store.Persist(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)); err != nil {
util.Log().Warning("Failed to persist cache: %s", err)
}
close(sigChan)
close(sigChan)
wait <- struct{}{}
}()
return wait
}

View file

@ -5,14 +5,15 @@ import (
"context"
"crypto/md5"
"fmt"
"io/ioutil"
"net/http"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"io/ioutil"
"net/http"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
@ -77,6 +78,24 @@ func AuthRequired() gin.HandlerFunc {
}
}
// PhoneRequired 需要绑定手机
// TODO 有bug
func PhoneRequired() gin.HandlerFunc {
return func(c *gin.Context) {
if model.IsTrueVal(model.GetSettingByName("phone_required")) &&
model.IsTrueVal(model.GetSettingByName("phone_enabled")) {
user, _ := c.Get("user")
if user.(*model.User).Phone != "" {
// TODO 忽略管理员
c.Next()
return
}
}
c.Next()
}
}
// WebDAVAuth 验证WebDAV登录及权限
func WebDAVAuth() gin.HandlerFunc {
return func(c *gin.Context) {

View file

@ -1,605 +0,0 @@
package middleware
import (
"database/sql"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestCurrentUser(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
//session为空
sessionFunc := Session("233")
sessionFunc(c)
CurrentUser()(c)
user, _ := c.Get("user")
asserts.Nil(user)
//session正确
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
util.SetSession(c, map[string]interface{}{"user_id": 1})
rows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options"}).
AddRow(1, nil, "admin@cloudreve.org", "{}")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(rows)
CurrentUser()(c)
user, _ = c.Get("user")
asserts.NotNil(user)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestAuthRequired(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
AuthRequiredFunc := AuthRequired()
// 未登录
AuthRequiredFunc(c)
asserts.NotNil(c)
// 类型错误
c.Set("user", 123)
AuthRequiredFunc(c)
asserts.NotNil(c)
// 正常
c.Set("user", &model.User{})
AuthRequiredFunc(c)
asserts.NotNil(c)
}
func TestSignRequired(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
SignRequiredFunc := SignRequired(authInstance)
// 鉴权失败
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.True(c.IsAborted())
// Sign verify success
c, _ = gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("PUT", "/test", nil)
c.Request = auth.SignRequest(authInstance, c.Request, 0)
SignRequiredFunc(c)
asserts.NotNil(c)
asserts.False(c.IsAborted())
}
func TestWebDAVAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := WebDAVAuth()
// options请求跳过验证
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("OPTIONS", "/test", nil)
AuthFunc(c)
}
// 请求HTTP Basic Auth
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
AuthFunc(c)
asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"])
}
// 用户名不存在
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id", "password", "email"}),
)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
}
// 密码错误
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"),
)
// 查找密码
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
}
//未启用 WebDAV
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "password", "email", "group_id", "options"}).
AddRow(1,
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
"who@cloudreve.org",
1,
"{}",
),
)
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false))
// 查找密码
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), http.StatusForbidden)
}
//正常
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("POST", "/test", nil)
c.Request.Header = map[string][]string{
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
}
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "password", "email", "group_id", "options"}).
AddRow(1,
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
"who@cloudreve.org",
1,
"{}",
),
)
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true))
// 查找密码
mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(c.Writer.Status(), 200)
_, ok := c.Get("user")
asserts.True(ok)
}
}
func TestUseUploadSession(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := UseUploadSession("local")
// sessionID 为空
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/sessionID", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
cache.Set(
filesystem.UploadSessionCachePrefix+"testCallBackRemote",
serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{Type: "local"},
},
0,
)
cache.Deletes([]string{"1"}, "policy_")
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
mock.ExpectQuery("SELECT(.+)groups(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[513]"))
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123"))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testCallBackRemote"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
}
func TestUploadCallbackCheck(t *testing.T) {
a := assert.New(t)
rec := httptest.NewRecorder()
// 上传会话不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testSessionNotExist"},
}
res := uploadCallbackCheck(c, "local")
a.Contains("上传会话不存在或已过期", res.Msg)
}
// 上传策略不一致
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testPolicyNotMatch"},
}
cache.Set(
filesystem.UploadSessionCachePrefix+"testPolicyNotMatch",
serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{Type: "remote"},
},
0,
)
res := uploadCallbackCheck(c, "local")
a.Contains("Policy not supported", res.Msg)
}
// 用户不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "testUserNotExist"},
}
cache.Set(
filesystem.UploadSessionCachePrefix+"testUserNotExist",
serializer.UploadSession{
UID: 313,
VirtualPath: "/",
Policy: model.Policy{Type: "remote"},
},
0,
)
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}))
res := uploadCallbackCheck(c, "remote")
a.Contains("找不到用户", res.Msg)
a.NoError(mock.ExpectationsWereMet())
_, ok := cache.Get(filesystem.UploadSessionCachePrefix + "testUserNotExist")
a.False(ok)
}
}
func TestRemoteCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := RemoteCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{SecretKey: "123"},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
authInstance := auth.HMACAuth{SecretKey: []byte("123")}
auth.SignRequest(authInstance, c.Request, 0)
AuthFunc(c)
asserts.False(c.IsAborted())
}
// 签名错误
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{SecretKey: "123"},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
AuthFunc(c)
asserts.True(c.IsAborted())
}
}
func TestQiniuCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := QiniuCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
mac := qbox.NewMac("123", "123")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
// 验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
mac := qbox.NewMac("123", "1213")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.True(c.IsAborted())
}
}
func TestOSSCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := OSSCallbackAuth()
// 签名验证失败
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/testCallBackOSS", nil)
mac := qbox.NewMac("123", "123")
token, err := mac.SignRequest(c.Request)
asserts.NoError(err)
c.Request.Header["Authorization"] = []string{"QBox " + token}
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH", ioutil.NopCloser(strings.NewReader(`{"name":"2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","source_name":"1/1_hFRtDLgM_2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","size":114020,"pic_info":"810,539"}`)))
c.Request.Header["Authorization"] = []string{"e5LwzwTkP9AFAItT4YzvdJOHd0Y0wqTMWhsV/h5SG90JYGAmMd+8LQyj96R+9qUfJWjMt6suuUh7LaOryR87Dw=="}
c.Request.Header["X-Oss-Pub-Key-Url"] = []string{"aHR0cHM6Ly9nb3NzcHVibGljLmFsaWNkbi5jb20vY2FsbGJhY2tfcHViX2tleV92MS5wZW0="}
AuthFunc(c)
asserts.False(c.IsAborted())
}
}
type fakeRead string
func (r fakeRead) Read(p []byte) (int, error) {
return 0, errors.New("error")
}
func TestUpyunCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := UpyunCallbackAuth()
// 无法获取请求正文
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(fakeRead("")))
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 正文MD5不一致
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"123"}
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 签名不一致
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
AuthFunc(c)
asserts.True(c.IsAborted())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
c.Request.Header["Authorization"] = []string{"UPYUN 123:GWueK9x493BKFFk5gmfdO2Mn6EM="}
AuthFunc(c)
asserts.False(c.IsAborted())
}
}
func TestOneDriveCallbackAuth(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
AuthFunc := OneDriveCallbackAuth()
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"sessionID", "TestOneDriveCallbackAuth"},
}
c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
UID: 1,
VirtualPath: "/",
Policy: model.Policy{
SecretKey: "123",
AccessKey: "123",
},
})
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/TestOneDriveCallbackAuth", ioutil.NopCloser(strings.NewReader("1")))
res := mq.GlobalMQ.Subscribe("TestOneDriveCallbackAuth", 1)
AuthFunc(c)
select {
case <-res:
case <-time.After(time.Millisecond * 500):
asserts.Fail("mq message should be published")
}
asserts.False(c.IsAborted())
}
}
func TestIsAdmin(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := IsAdmin()
// 非管理员
{
c, _ := gin.CreateTestContext(rec)
c.Set("user", &model.User{})
testFunc(c)
asserts.True(c.IsAborted())
}
// 是管理员
{
c, _ := gin.CreateTestContext(rec)
user := &model.User{}
user.Group.ID = 1
c.Set("user", user)
testFunc(c)
asserts.False(c.IsAborted())
}
// 初始用户,非管理组
{
c, _ := gin.CreateTestContext(rec)
user := &model.User{}
user.Group.ID = 2
user.ID = 1
c.Set("user", user)
testFunc(c)
asserts.False(c.IsAborted())
}
}

View file

@ -1,177 +0,0 @@
package middleware
import (
"bytes"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
type errReader int
func (errReader) Read(p []byte) (n int, err error) {
return 0, errors.New("test error")
}
func TestCaptchaRequired_General(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 未启用验证码
{
cache.SetSettings(map[string]string{
"login_captcha": "0",
"captcha_type": "1",
"captcha_ReCaptchaSecret": "1",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
// body 无法读取
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "1",
"captcha_ReCaptchaSecret": "1",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", errReader(1))
TestFunc(c)
asserts.True(c.IsAborted())
}
// body JSON 解析失败
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "1",
"captcha_ReCaptchaSecret": "1",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("123"))
c.Request, _ = http.NewRequest("GET", "/", r)
TestFunc(c)
asserts.True(c.IsAborted())
}
}
func TestCaptchaRequired_Normal(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 验证码错误
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "normal",
"captcha_ReCaptchaSecret": "1",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("{}"))
c.Request, _ = http.NewRequest("GET", "/", r)
Session("233")(c)
TestFunc(c)
asserts.True(c.IsAborted())
}
}
func TestCaptchaRequired_Recaptcha(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 无法初始化reCaptcha实例
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "recaptcha",
"captcha_ReCaptchaSecret": "",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("{}"))
c.Request, _ = http.NewRequest("GET", "/", r)
TestFunc(c)
asserts.True(c.IsAborted())
}
// 验证码错误
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "recaptcha",
"captcha_ReCaptchaSecret": "233",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("{}"))
c.Request, _ = http.NewRequest("GET", "/", r)
TestFunc(c)
asserts.True(c.IsAborted())
}
}
func TestCaptchaRequired_Tcaptcha(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 验证出错
{
cache.SetSettings(map[string]string{
"login_captcha": "1",
"captcha_type": "tcaptcha",
"captcha_ReCaptchaSecret": "",
"captcha_TCaptcha_SecretId": "1",
"captcha_TCaptcha_SecretKey": "1",
"captcha_TCaptcha_CaptchaAppId": "1",
"captcha_TCaptcha_AppSecretKey": "1",
}, "setting_")
TestFunc := CaptchaRequired("login_captcha")
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
r := bytes.NewReader([]byte("{}"))
c.Request, _ = http.NewRequest("GET", "/", r)
TestFunc(c)
asserts.True(c.IsAborted())
}
}

View file

@ -1,120 +0,0 @@
package middleware
import (
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestMasterMetadata(t *testing.T) {
a := assert.New(t)
masterMetaDataFunc := MasterMetadata()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header = map[string][]string{
"X-Cr-Site-Id": {"expectedSiteID"},
"X-Cr-Site-Url": {"expectedSiteURL"},
"X-Cr-Cloudreve-Version": {"expectedMasterVersion"},
}
masterMetaDataFunc(c)
siteID, _ := c.Get("MasterSiteID")
siteURL, _ := c.Get("MasterSiteURL")
siteVersion, _ := c.Get("MasterVersion")
a.Equal("expectedSiteID", siteID.(string))
a.Equal("expectedSiteURL", siteURL.(string))
a.Equal("expectedMasterVersion", siteVersion.(string))
}
func TestSlaveRPCSignRequired(t *testing.T) {
a := assert.New(t)
np := &cluster.NodePool{}
np.Init()
slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
rec := httptest.NewRecorder()
// id parse failed
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Cr-Node-Id", "unknown")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// node id not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/", nil)
c.Request.Header.Set("X-Cr-Node-Id", "38")
slaveRPCSignRequiredFunc(c)
a.True(c.IsAborted())
}
// success
{
authInstance := auth.HMACAuth{SecretKey: []byte("")}
np.Add(&model.Node{Model: gorm.Model{
ID: 38,
}})
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("POST", "/", nil)
c.Request.Header.Set("X-Cr-Node-Id", "38")
c.Request = auth.SignRequest(authInstance, c.Request, 0)
slaveRPCSignRequiredFunc(c)
a.False(c.IsAborted())
}
}
func TestUseSlaveAria2Instance(t *testing.T) {
a := assert.New(t)
// MasterSiteID not set
{
testController := &controllermock.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
useSlaveAria2InstanceFunc(c)
a.True(c.IsAborted())
}
// Cannot get aria2 instances
{
testController := &controllermock.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("MasterSiteID", "expectedSiteID")
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error"))
useSlaveAria2InstanceFunc(c)
a.True(c.IsAborted())
testController.AssertExpectations(t)
}
// Success
{
testController := &controllermock.SlaveControllerMock{}
useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
c, _ := gin.CreateTestContext(httptest.NewRecorder())
c.Request = httptest.NewRequest("GET", "/", nil)
c.Set("MasterSiteID", "expectedSiteID")
testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil)
useSlaveAria2InstanceFunc(c)
a.False(c.IsAborted())
res, _ := c.Get("MasterAria2Instance")
a.NotNil(res)
testController.AssertExpectations(t)
}
}

View file

@ -1,105 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestHashID(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
TestFunc := HashID(hashid.FolderID)
// 未给定ID对象跳过
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
// 给定ID解析失败
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"id", "2333"},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.True(c.IsAborted())
}
// 给定ID解析成功
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"id", hashid.HashID(1, hashid.FolderID)},
}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
}
}
func TestIsFunctionEnabled(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
TestFunc := IsFunctionEnabled("TestIsFunctionEnabled")
// 未开启
{
cache.Set("setting_TestIsFunctionEnabled", "0", 0)
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.True(c.IsAborted())
}
// 开启
{
cache.Set("setting_TestIsFunctionEnabled", "1", 0)
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
}
func TestCacheControl(t *testing.T) {
a := assert.New(t)
TestFunc := CacheControl()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache")
}
func TestSandbox(t *testing.T) {
a := assert.New(t)
TestFunc := Sandbox()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox")
}
func TestStaticResourceCache(t *testing.T) {
a := assert.New(t)
TestFunc := StaticResourceCache()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age")
}

View file

@ -1,57 +0,0 @@
package middleware
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestValidateSourceLink(t *testing.T) {
a := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ValidateSourceLink()
// ID 不存在
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
a.True(c.IsAborted())
}
// SourceLink 不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 原文件不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1))
testFunc(c)
a.False(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
}

View file

@ -1,13 +1,14 @@
package middleware
import (
"io/ioutil"
"net/http"
"strings"
"github.com/cloudreve/Cloudreve/v3/bootstrap"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"io/ioutil"
"net/http"
"strings"
)
// FrontendFileHandler 前端静态文件处理
@ -51,10 +52,16 @@ func FrontendFileHandler() gin.HandlerFunc {
// 不存在的路径和index.html均返回index.html
if (path == "/index.html") || (path == "/") || !bootstrap.StaticFS.Exists("/", path) {
// 读取、替换站点设置
options := model.GetSettingByNames("siteName", "siteKeywords", "siteScript",
"pwa_small_icon")
options := model.GetSettingByNames(
"siteName", // 站点名称
"siteKeywords", // 关键词
"siteDes", // 描述
"siteScript", // 自定义代码
"pwa_small_icon", // 图标
)
finalHTML := util.Replace(map[string]string{
"{siteName}": options["siteName"],
"{siteKeywords}": options["siteKeywords"],
"{siteDes}": options["siteDes"],
"{siteScript}": options["siteScript"],
"{pwa_small_icon}": options["pwa_small_icon"],

View file

@ -1,144 +0,0 @@
package middleware
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/bootstrap"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"net/http"
"net/http/httptest"
"os"
"testing"
)
type StaticMock struct {
testMock.Mock
}
func (m StaticMock) Open(name string) (http.File, error) {
args := m.Called(name)
return args.Get(0).(http.File), args.Error(1)
}
func (m StaticMock) Exists(prefix string, filepath string) bool {
args := m.Called(prefix, filepath)
return args.Bool(0)
}
func TestFrontendFileHandler(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
// 静态资源未加载
{
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
// index.html 不存在
{
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(&os.File{}, errors.New("error"))
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
// index.html 读取失败
{
file, _ := util.CreatNestedFile("tests/index.html")
file.Close()
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(file, nil)
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
// 成功且命中
{
file, _ := util.CreatNestedFile("tests/index.html")
defer file.Close()
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(file, nil)
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/", nil)
cache.Set("setting_siteName", "cloudreve", 0)
cache.Set("setting_siteKeywords", "cloudreve", 0)
cache.Set("setting_siteScript", "cloudreve", 0)
cache.Set("setting_pwa_small_icon", "cloudreve", 0)
TestFunc(c)
asserts.True(c.IsAborted())
}
// 成功且命中静态文件
{
file, _ := util.CreatNestedFile("tests/index.html")
defer file.Close()
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(file, nil)
testStatic.On("Exists", "/", "/2").
Return(true)
testStatic.On("Open", "/2").
Return(file, nil)
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", "/2", nil)
TestFunc(c)
asserts.True(c.IsAborted())
testStatic.AssertExpectations(t)
}
// API 相关跳过
{
for _, reqPath := range []string{"/api/user", "/manifest.json", "/dav/path"} {
file, _ := util.CreatNestedFile("tests/index.html")
defer file.Close()
testStatic := &StaticMock{}
bootstrap.StaticFS = testStatic
testStatic.On("Open", "/index.html").
Return(file, nil)
TestFunc := FrontendFileHandler()
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{}
c.Request, _ = http.NewRequest("GET", reqPath, nil)
TestFunc(c)
asserts.False(c.IsAborted())
}
}
}

View file

@ -1,37 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestMockHelper(t *testing.T) {
asserts := assert.New(t)
MockHelperFunc := MockHelper()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
// 写入session
{
SessionMock["test"] = "pass"
Session("test")(c)
MockHelperFunc(c)
asserts.Equal("pass", util.GetSession(c, "test").(string))
}
// 写入context
{
ContextMock["test"] = "pass"
MockHelperFunc(c)
test, exist := c.Get("test")
asserts.True(exist)
asserts.Equal("pass", test.(string))
}
}

View file

@ -1,11 +1,12 @@
package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/sessionstore"
"net/http"
"strings"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/sessionstore"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"

View file

@ -1,64 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestSession(t *testing.T) {
asserts := assert.New(t)
{
handler := Session("2333")
asserts.NotNil(handler)
asserts.NotNil(Store)
asserts.IsType(emptyFunc(), handler)
}
}
func emptyFunc() gin.HandlerFunc {
return func(c *gin.Context) {}
}
func TestCSRFInit(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
sessionFunc := Session("233")
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
CSRFInit()(c)
asserts.True(util.GetSession(c, "CSRF").(bool))
}
}
func TestCSRFCheck(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
sessionFunc := Session("233")
// 通过检查
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
CSRFInit()(c)
CSRFCheck()(c)
asserts.False(c.IsAborted())
}
// 未通过检查
{
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
sessionFunc(c)
CSRFCheck()(c)
asserts.True(c.IsAborted())
}
}

View file

@ -118,8 +118,14 @@ func BeforeShareDownload() gin.HandlerFunc {
// 对积分、下载次数进行更新
err = share.DownloadBy(user, c)
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
nil))
if err == model.ErrInsufficientCredit {
c.JSON(200, serializer.Err(serializer.CodeInsufficientCredit, err.Error(),
nil))
} else {
c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
nil))
}
c.Abort()
return
}

View file

@ -1,190 +0,0 @@
package middleware
import (
"net/http/httptest"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestShareAvailable(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ShareAvailable()
// 分享不存在
{
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"id", "empty"},
}
testFunc(c)
asserts.True(c.IsAborted())
}
// 通过
{
conf.SystemConfig.HashIDSalt = ""
// 用户组
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
mock.ExpectQuery("SELECT(.+)shares(.+)").
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "remain_downloads", "source_id"}).
AddRow(1, 1, 2),
)
mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
c, _ := gin.CreateTestContext(rec)
c.Params = []gin.Param{
{"id", "x9T4"},
}
testFunc(c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.False(c.IsAborted())
asserts.NotNil(c.Get("user"))
asserts.NotNil(c.Get("share"))
}
}
func TestShareCanPreview(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ShareCanPreview()
// 无分享上下文
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
}
// 可以预览
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{PreviewEnabled: true})
testFunc(c)
asserts.False(c.IsAborted())
}
// 未开启预览
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{PreviewEnabled: false})
testFunc(c)
asserts.True(c.IsAborted())
}
}
func TestCheckShareUnlocked(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := CheckShareUnlocked()
// 无分享上下文
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
}
// 无密码
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
testFunc(c)
asserts.False(c.IsAborted())
}
}
func TestBeforeShareDownload(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := BeforeShareDownload()
// 无分享上下文
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
testFunc(c)
asserts.True(c.IsAborted())
}
// 用户不能下载
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
c.Set("user", &model.User{
Group: model.Group{OptionsSerialized: model.GroupOption{}},
})
testFunc(c)
asserts.True(c.IsAborted())
}
// 可以下载
{
c, _ := gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
c.Set("user", &model.User{
Model: gorm.Model{ID: 1},
Group: model.Group{OptionsSerialized: model.GroupOption{
ShareDownload: true,
}},
})
testFunc(c)
asserts.False(c.IsAborted())
}
}
func TestShareOwner(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ShareOwner()
// 未登录
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
testFunc(c)
asserts.True(c.IsAborted())
}
// 非用户所创建分享
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Set("share", &model.Share{User: model.User{Model: gorm.Model{ID: 1}}})
c.Set("user", &model.User{})
testFunc(c)
asserts.True(c.IsAborted())
}
// 正常
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
asserts.True(c.IsAborted())
c, _ = gin.CreateTestContext(rec)
c.Set("share", &model.Share{})
c.Set("user", &model.User{})
testFunc(c)
asserts.False(c.IsAborted())
}
}

View file

@ -1,112 +0,0 @@
package middleware
import (
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/wopimock"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestWopiWriteAccess(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
testFunc := WopiWriteAccess()
// deny preview only session
{
c, _ := gin.CreateTestContext(rec)
c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionPreview})
testFunc(c)
asserts.True(c.IsAborted())
}
// pass
{
c, _ := gin.CreateTestContext(rec)
c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionEdit})
testFunc(c)
asserts.False(c.IsAborted())
}
}
func TestWopiAccessValidation(t *testing.T) {
asserts := assert.New(t)
rec := httptest.NewRecorder()
mockWopi := &wopimock.WopiClientMock{}
mockCache := cache.NewMemoStore()
testFunc := WopiAccessValidation(mockWopi, mockCache)
// malformed access token
{
c, _ := gin.CreateTestContext(rec)
c.AddParam(wopi.AccessTokenQuery, "000")
testFunc(c)
asserts.True(c.IsAborted())
}
// session key not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
testFunc(c)
asserts.True(c.IsAborted())
}
// user key not exist
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
testFunc(c)
asserts.True(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
// file not found
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
c.Set("object_id", uint(0))
testFunc(c)
asserts.True(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
}
// all pass
{
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
query := c.Request.URL.Query()
query.Set(wopi.AccessTokenQuery, "sessionID.key")
c.Request.URL.RawQuery = query.Encode()
mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
c.Set("object_id", uint(1))
testFunc(c)
asserts.False(c.IsAborted())
asserts.NoError(mock.ExpectationsWereMet())
asserts.NotPanics(func() {
c.MustGet(WopiSessionCtx)
})
asserts.NotPanics(func() {
c.MustGet("user")
})
}
}

View file

@ -9,12 +9,15 @@ import (
var defaultSettings = []Setting{
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
{Name: "siteName", Value: `Cloudreve`, Type: "basic"},
{Name: "siteName", Value: `CloudrevePlus`, Type: "basic"},
{Name: "register_enabled", Value: `1`, Type: "register"},
{Name: "default_group", Value: `2`, Type: "register"},
{Name: "siteKeywords", Value: `Cloudreve, cloud storage`, Type: "basic"},
{Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
{Name: "mail_domain_filter", Value: `0`, Type: "register"},
{Name: "mail_domain_filter_list", Value: `126.com,163.com,gmail.com,outlook.com,qq.com,foxmail.com,yeah.net,sohu.com,sohu.cn,139.com,wo.cn,189.cn,hotmail.com,live.com,live.cn`, Type: "register"},
{Name: "siteKeywords", Value: `CloudrevePlus, cloud storage`, Type: "basic"},
{Name: "siteDes", Value: `部署公私兼备的网盘系统`, Type: "basic"},
{Name: "siteTitle", Value: `Inclusive cloud storage for everyone`, Type: "basic"},
{Name: "siteNotice", Value: ``, Type: "basic"},
{Name: "siteScript", Value: ``, Type: "basic"},
{Name: "siteID", Value: uuid.Must(uuid.NewV4()).String(), Type: "basic"},
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
@ -26,6 +29,9 @@ var defaultSettings = []Setting{
{Name: "smtpUser", Value: `no-reply@acg.blue`, Type: "mail"},
{Name: "smtpPass", Value: ``, Type: "mail"},
{Name: "smtpEncryption", Value: `0`, Type: "mail"},
{Name: "over_used_template", Value: `<meta name="viewport"content="width=device-width"><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"><title>容量超额提醒</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tbody><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
solid #e9e9e9;"bgcolor="#fff"><tbody><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #FF9F00; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">容量超额警告</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tbody><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">由于{notifyReason},您在{siteTitle}的账户的容量使用超出配额,您将无法继续上传新文件,请尽快清理文件,否则我们将会禁用您的账户。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{siteUrl}Login"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #348eda; margin: 0; border-color: #348eda; border-style: solid; border-width: 10px 20px;">登录{siteTitle}</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></tbody></table></td></tr></tbody></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tbody><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></tbody></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></tbody></table>`, Type: "mail_template"},
{Name: "ban_time", Value: `604800`, Type: "storage_policy"},
{Name: "maxEditSize", Value: `52428800`, Type: "file_edit"},
{Name: "archive_timeout", Value: `600`, Type: "timeout"},
{Name: "download_timeout", Value: `600`, Type: "timeout"},
@ -38,6 +44,7 @@ var defaultSettings = []Setting{
{Name: "slave_recover_interval", Value: `120`, Type: "slave"},
{Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"},
{Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"},
{Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
@ -46,10 +53,14 @@ var defaultSettings = []Setting{
{Name: "reset_after_upload_failed", Value: `0`, Type: "upload"},
{Name: "use_temp_chunk_buffer", Value: `1`, Type: "upload"},
{Name: "login_captcha", Value: `0`, Type: "login"},
{Name: "qq_login", Value: `0`, Type: "login"},
{Name: "qq_direct_login", Value: `0`, Type: "login"},
{Name: "qq_login_id", Value: ``, Type: "login"},
{Name: "qq_login_key", Value: ``, Type: "login"},
{Name: "reg_captcha", Value: `0`, Type: "login"},
{Name: "email_active", Value: `0`, Type: "register"},
{Name: "mail_activation_template", Value: `<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>激活您的账户</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>用户激活</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
@ -63,8 +74,23 @@ box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'He
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #2196F3; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">重设{siteTitle}密码</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">请点击下方按钮完成密码重设。如果非你本人操作,请忽略此邮件。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{resetUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #2196F3; margin: 0; border-color: #2196F3; border-style: solid; border-width: 10px 20px;">重设密码</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>`, Type: "mail_template"},
{Name: "pack_data", Value: `[]`, Type: "pack"},
{Name: "db_version_" + conf.RequiredDBVersion, Value: `installed`, Type: "version"},
{Name: "alipay_enabled", Value: `0`, Type: "payment"},
{Name: "payjs_enabled", Value: `0`, Type: "payment"},
{Name: "payjs_id", Value: ``, Type: "payment"},
{Name: "payjs_secret", Value: ``, Type: "payment"},
{Name: "appid", Value: ``, Type: "payment"},
{Name: "appkey", Value: ``, Type: "payment"},
{Name: "shopid", Value: ``, Type: "payment"},
{Name: "wechat_enabled", Value: `0`, Type: "payment"},
{Name: "wechat_appid", Value: ``, Type: "payment"},
{Name: "wechat_mchid", Value: ``, Type: "payment"},
{Name: "wechat_serial_no", Value: ``, Type: "payment"},
{Name: "wechat_api_key", Value: ``, Type: "payment"},
{Name: "wechat_pk_content", Value: ``, Type: "payment"},
{Name: "hot_share_num", Value: `10`, Type: "share"},
{Name: "group_sell_data", Value: `[]`, Type: "group_sell"},
{Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"},
{Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"},
{Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"},
@ -77,9 +103,15 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "avatar_size_l", Value: "200", Type: "avatar"},
{Name: "avatar_size_m", Value: "130", Type: "avatar"},
{Name: "avatar_size_s", Value: "50", Type: "avatar"},
{Name: "home_view_method", Value: "icon", Type: "view"},
{Name: "score_enabled", Value: "1", Type: "score"},
{Name: "share_score_rate", Value: "80", Type: "score"},
{Name: "score_price", Value: "1", Type: "score"},
{Name: "report_enabled", Value: "0", Type: "report"},
{Name: "home_view_method", Value: "list", Type: "view"},
{Name: "share_view_method", Value: "list", Type: "view"},
{Name: "cron_garbage_collect", Value: "@hourly", Type: "cron"},
{Name: "cron_notify_user", Value: "@hourly", Type: "cron"},
{Name: "cron_ban_user", Value: "@hourly", Type: "cron"},
{Name: "cron_recycle_upload_session", Value: "@every 1h30m", Type: "cron"},
{Name: "authn_enabled", Value: "0", Type: "authn"},
{Name: "captcha_type", Value: "normal", Type: "captcha"},
@ -127,13 +159,24 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "pwa_display", Value: "standalone", Type: "pwa"},
{Name: "pwa_theme_color", Value: "#000000", Type: "pwa"},
{Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"},
{Name: "initial_files", Value: "[]", Type: "register"},
{Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"},
{Name: "phone_required", Value: "false", Type: "phone"},
{Name: "phone_enabled", Value: "false", Type: "phone"},
{Name: "vol_content", Value: "Guess", Type: "vol"},
{Name: "vol_signature", Value: "Guess", Type: "vol"},
{Name: "show_app_promotion", Value: "1", Type: "mobile"},
{Name: "public_resource_maxage", Value: "86400", Type: "timeout"},
{Name: "wopi_enabled", Value: "0", Type: "wopi"},
{Name: "wopi_endpoint", Value: "", Type: "wopi"},
{Name: "wopi_max_size", Value: "52428800", Type: "wopi"},
{Name: "wopi_session_timeout", Value: "36000", Type: "wopi"},
{Name: "custom_payment_enabled", Value: "0", Type: "payment"},
{Name: "custom_payment_endpoint", Value: "", Type: "payment"},
{Name: "custom_payment_secret", Value: "", Type: "payment"},
{Name: "custom_payment_name", Value: "", Type: "payment"},
{Name: "app_feedback_link", Value: "", Type: "mobile"},
{Name: "app_forum_link", Value: "", Type: "mobile"},
}
func InitSlaveDefaults() {

View file

@ -1,190 +0,0 @@
package model
import (
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
)
func TestDownload_Create(t *testing.T) {
asserts := assert.New(t)
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
download := Download{GID: "1"}
id, err := download.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.EqualValues(1, id)
}
// 失败
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
download := Download{GID: "1"}
id, err := download.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.EqualValues(0, id)
}
}
func TestDownload_AfterFind(t *testing.T) {
asserts := assert.New(t)
// 成功
{
download := Download{Attrs: `{"gid":"123"}`}
err := download.AfterFind()
asserts.NoError(err)
asserts.Equal("123", download.StatusInfo.Gid)
}
// 忽略空值
{
download := Download{Attrs: ``}
err := download.AfterFind()
asserts.NoError(err)
asserts.Equal("", download.StatusInfo.Gid)
}
// 解析失败
{
download := Download{Attrs: `?`}
err := download.BeforeSave()
asserts.Error(err)
asserts.Equal("", download.StatusInfo.Gid)
}
}
func TestDownload_Save(t *testing.T) {
asserts := assert.New(t)
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
download := Download{
Model: gorm.Model{
ID: 1,
},
Attrs: `{"gid":"123"}`,
}
err := download.Save()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal("123", download.StatusInfo.Gid)
}
// 失败
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
download := Download{
Model: gorm.Model{
ID: 1,
},
}
err := download.Save()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
}
func TestGetDownloadsByStatus(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WithArgs(0, 1).WillReturnRows(sqlmock.NewRows([]string{"gid"}).AddRow("0").AddRow("1"))
res := GetDownloadsByStatus(0, 1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Len(res, 2)
}
func TestGetDownloadByGid(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WithArgs(2, "gid").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1"))
res, err := GetDownloadByGid("gid", 2)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(res.GID, "1")
}
func TestDownload_GetOwner(t *testing.T) {
asserts := assert.New(t)
// 已经有User对象
{
download := &Download{User: &User{Nick: "nick"}}
user := download.GetOwner()
asserts.NotNil(user)
asserts.Equal("nick", user.Nick)
}
// 无User对象
{
download := &Download{UserID: 3}
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"nick"}).AddRow("nick"))
user := download.GetOwner()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NotNil(user)
asserts.Equal("nick", user.Nick)
}
}
func TestGetDownloadsByStatusAndUser(t *testing.T) {
asserts := assert.New(t)
// 列出全部
{
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3))
res := GetDownloadsByStatusAndUser(0, 1, 1, 2)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Len(res, 2)
}
// 列出全部,分页
{
mock.ExpectQuery("SELECT(.+)DESC(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3))
res := GetDownloadsByStatusAndUser(2, 1, 1, 2)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Len(res, 2)
}
}
func TestDownload_Delete(t *testing.T) {
asserts := assert.New(t)
share := Download{}
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := share.Delete()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
}
func TestDownload_GetNodeID(t *testing.T) {
a := assert.New(t)
record := Download{}
// compatible with 3.4
a.EqualValues(1, record.GetNodeID())
record.NodeID = 5
a.EqualValues(5, record.GetNodeID())
}

View file

@ -10,6 +10,7 @@ import (
"strings"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
)
@ -388,6 +389,15 @@ func (file *File) UpdateSourceName(value string) error {
}).Error
}
// Relocate 更新文件的物理指向
func (file *File) Relocate(src string, policyID uint) error {
file.Policy = Policy{}
return DB.Model(&file).Set("gorm:association_autoupdate", false).Updates(map[string]interface{}{
"source_name": src,
"policy_id": policyID,
}).Error
}
func (file *File) PopChunkToFile(lastModified *time.Time, picInfo string) error {
file.UploadSessionID = nil
if lastModified != nil {
@ -470,3 +480,46 @@ func (file *File) ShouldLoadThumb() bool {
func (file *File) ThumbFile() string {
return file.SourceName + GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")
}
/*
实现 filesystem.FileHeader 接口
*/
// Read 实现 io.Reader
func (file *File) Read(p []byte) (n int, err error) {
return 0, errors.New("noe supported")
}
// Close 实现io.Closer
func (file *File) Close() error {
return errors.New("noe supported")
}
// Seeker 实现io.Seeker
func (file *File) Seek(offset int64, whence int) (int64, error) {
return 0, errors.New("noe supported")
}
func (file *File) Info() *fsctx.UploadTaskInfo {
return &fsctx.UploadTaskInfo{
Size: file.Size,
FileName: file.Name,
VirtualPath: file.Position,
Mode: 0,
Metadata: file.MetadataSerialized,
LastModified: &file.UpdatedAt,
SavePath: file.SourceName,
UploadSessionID: file.UploadSessionID,
}
}
func (file *File) SetSize(size uint64) {
file.Size = size
}
func (file *File) SetModel(newFile interface{}) {
}
func (file *File) Seekable() bool {
return false
}

View file

@ -1,785 +0,0 @@
package model
import (
"errors"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestFile_Create(t *testing.T) {
asserts := assert.New(t)
file := File{
Name: "123",
}
// 无法插入文件记录
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := file.Create()
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
}
// 无法更新用户容量
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := file.Create()
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
err := file.Create()
asserts.NoError(err)
asserts.Equal(uint(5), file.ID)
asserts.NoError(mock.ExpectationsWereMet())
}
}
func TestFile_AfterFind(t *testing.T) {
a := assert.New(t)
// metadata not empty
{
file := File{
Name: "123",
Metadata: "{\"name\":\"123\"}",
}
a.NoError(file.AfterFind())
a.Equal("123", file.MetadataSerialized["name"])
}
// metadata empty
{
file := File{
Name: "123",
Metadata: "",
}
a.Nil(file.MetadataSerialized)
a.NoError(file.AfterFind())
a.NotNil(file.MetadataSerialized)
}
}
func TestFile_BeforeSave(t *testing.T) {
a := assert.New(t)
// metadata not empty
{
file := File{
Name: "123",
MetadataSerialized: map[string]string{
"name": "123",
},
}
a.NoError(file.BeforeSave())
a.Equal("{\"name\":\"123\"}", file.Metadata)
}
// metadata empty
{
file := File{
Name: "123",
}
a.NoError(file.BeforeSave())
a.Equal("", file.Metadata)
}
}
func TestFolder_GetChildFile(t *testing.T) {
asserts := assert.New(t)
folder := Folder{Model: gorm.Model{ID: 1}, Name: "/"}
// 存在
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, "1.txt").
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt"))
file, err := folder.GetChildFile("1.txt")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal("1.txt", file.Name)
asserts.Equal("/", file.Position)
}
// 不存在
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, "1.txt").
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
_, err := folder.GetChildFile("1.txt")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
}
func TestFolder_GetChildFiles(t *testing.T) {
asserts := assert.New(t)
folder := &Folder{
Model: gorm.Model{
ID: 1,
},
Position: "/123",
Name: "456",
}
// 找不到
mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnError(errors.New("error"))
files, err := folder.GetChildFiles()
asserts.Error(err)
asserts.Len(files, 0)
asserts.NoError(mock.ExpectationsWereMet())
// 找到了
mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2))
files, err = folder.GetChildFiles()
asserts.NoError(err)
asserts.Len(files, 2)
asserts.Equal("/123/456", files[0].Position)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestGetFilesByIDs(t *testing.T) {
asserts := assert.New(t)
// 出错
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3, 1).
WillReturnError(errors.New("error"))
folders, err := GetFilesByIDs([]uint{1, 2, 3}, 1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Len(folders, 0)
}
// 部分找到
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
folders, err := GetFilesByIDs([]uint{1, 2, 3}, 1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(folders, 1)
}
// 忽略UID查找
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
folders, err := GetFilesByIDs([]uint{1, 2, 3}, 0)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(folders, 1)
}
}
func TestGetChildFilesOfFolders(t *testing.T) {
asserts := assert.New(t)
testFolder := []Folder{
Folder{
Model: gorm.Model{ID: 3},
},
Folder{
Model: gorm.Model{ID: 4},
}, Folder{
Model: gorm.Model{ID: 5},
},
}
// 出错
{
mock.ExpectQuery("SELECT(.+)folder_id").WithArgs(3, 4, 5).WillReturnError(errors.New("not found"))
files, err := GetChildFilesOfFolders(&testFolder)
asserts.Error(err)
asserts.Len(files, 0)
asserts.NoError(mock.ExpectationsWereMet())
}
// 找到2个
{
mock.ExpectQuery("SELECT(.+)folder_id").
WithArgs(3, 4, 5).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
AddRow(3, "3").
AddRow(4, "4"),
)
files, err := GetChildFilesOfFolders(&testFolder)
asserts.NoError(err)
asserts.Len(files, 2)
asserts.NoError(mock.ExpectationsWereMet())
}
// 全部找到
{
mock.ExpectQuery("SELECT(.+)folder_id").
WithArgs(3, 4, 5).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
AddRow(3, "3").
AddRow(4, "4").
AddRow(5, "5"),
)
files, err := GetChildFilesOfFolders(&testFolder)
asserts.NoError(err)
asserts.Len(files, 3)
asserts.NoError(mock.ExpectationsWereMet())
}
}
func TestGetUploadPlaceholderFiles(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)upload_session_id(.+)").
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
files := GetUploadPlaceholderFiles(1)
a.NoError(mock.ExpectationsWereMet())
a.Len(files, 1)
}
func TestFile_GetPolicy(t *testing.T) {
asserts := assert.New(t)
// 空策略
{
file := File{
PolicyID: 23,
}
mock.ExpectQuery("SELECT(.+)policies(.+)").
WillReturnRows(
sqlmock.NewRows([]string{"id", "name"}).
AddRow(23, "name"),
)
file.GetPolicy()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(uint(23), file.Policy.ID)
}
// 非空策略
{
file := File{
PolicyID: 23,
Policy: Policy{Model: gorm.Model{ID: 24}},
}
file.GetPolicy()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(uint(24), file.Policy.ID)
}
}
func TestRemoveFilesWithSoftLinks_EmptyArg(t *testing.T) {
asserts := assert.New(t)
// 传入空
{
mock.ExpectQuery("SELECT(.+)files(.+)")
file, err := RemoveFilesWithSoftLinks([]File{})
asserts.Error(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(len(file), 0)
DB.Find(&File{})
}
}
func TestRemoveFilesWithSoftLinks(t *testing.T) {
asserts := assert.New(t)
files := []File{
File{
Model: gorm.Model{ID: 1},
SourceName: "1.txt",
PolicyID: 23,
},
File{
Model: gorm.Model{ID: 2},
SourceName: "2.txt",
PolicyID: 24,
},
}
// 传入空文件列表
{
file, err := RemoveFilesWithSoftLinks([]File{})
asserts.NoError(err)
asserts.Empty(file)
}
// 全都没有
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(files, file)
}
// 第二个是软链
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 24, "2.txt"),
)
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(files[:1], file)
}
// 第一个是软链
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 23, "1.txt"),
)
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(files[1:], file)
}
// 全部是软链
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 23, "1.txt"),
)
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 24, "2.txt"),
)
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(file, 0)
}
}
func TestDeleteFiles(t *testing.T) {
a := assert.New(t)
// uid 不一致
{
err := DeleteFiles([]*File{{UserID: 2}}, 1)
a.Contains("user id not consistent", err.Error())
}
// 删除失败
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := DeleteFiles([]*File{{UserID: 1}}, 1)
a.NoError(mock.ExpectationsWereMet())
a.Error(err)
}
// 无法变更用户容量
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := DeleteFiles([]*File{{UserID: 1}}, 1)
a.NoError(mock.ExpectationsWereMet())
a.Error(err)
}
// 文件脏读
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(1, 0))
mock.ExpectRollback()
err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1)
a.NoError(mock.ExpectationsWereMet())
a.Error(err)
a.Contains("file size is dirty", err.Error())
}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)").WithArgs(uint64(3), sqlmock.AnyArg(), uint(1)).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1)
a.NoError(mock.ExpectationsWereMet())
a.NoError(err)
}
// 成功, 关联用户不存在
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 0)
a.NoError(mock.ExpectationsWereMet())
a.NoError(err)
}
}
func TestGetFilesByParentIDs(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 4, 5, 6).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name"}).
AddRow(4, "4.txt").
AddRow(5, "5.txt").
AddRow(6, "6.txt"),
)
files, err := GetFilesByParentIDs([]uint{4, 5, 6}, 1)
asserts.NoError(err)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Len(files, 3)
}
func TestGetFilesByUploadSession(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, "sessionID").
WillReturnRows(
sqlmock.NewRows([]string{"id", "name"}).AddRow(4, "4.txt"))
files, err := GetFilesByUploadSession("sessionID", 1)
a.NoError(err)
a.NoError(mock.ExpectationsWereMet())
a.Equal("4.txt", files.Name)
}
func TestFile_Updates(t *testing.T) {
asserts := assert.New(t)
file := File{Model: gorm.Model{ID: 1}}
// rename
{
// not reset thumb
{
file := File{Model: gorm.Model{ID: 1}}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("", "newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := file.Rename("newName")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
// thumb not available, rename base name only
{
file := File{Model: gorm.Model{ID: 1}, Name: "1.txt", MetadataSerialized: map[string]string{
ThumbStatusMetadataKey: ThumbStatusNotAvailable,
},
Metadata: "{}"}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("{}", "newName.txt", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := file.Rename("newName.txt")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(ThumbStatusNotAvailable, file.MetadataSerialized[ThumbStatusMetadataKey])
}
// thumb not available, rename base name only
{
file := File{Model: gorm.Model{ID: 1}, Name: "1.txt", MetadataSerialized: map[string]string{
ThumbStatusMetadataKey: ThumbStatusNotAvailable,
}}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("{}", "newName.jpg", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := file.Rename("newName.jpg")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Empty(file.MetadataSerialized[ThumbStatusMetadataKey])
}
}
// UpdatePicInfo
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs("1,1", 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := file.UpdatePicInfo("1,1")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
// UpdateSourceName
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs("", "newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := file.UpdateSourceName("newName")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
}
func TestFile_UpdateSize(t *testing.T) {
a := assert.New(t)
// 增加成功
{
file := File{Size: 10}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 11, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)+(.+)").WithArgs(uint64(1), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(file.UpdateSize(11))
a.NoError(mock.ExpectationsWereMet())
}
// 减少成功
{
file := File{Size: 10}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(file.UpdateSize(8))
a.NoError(mock.ExpectationsWereMet())
}
// 文件更新失败
{
file := File{Size: 10}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnError(errors.New("error"))
mock.ExpectRollback()
a.Error(file.UpdateSize(8))
a.NoError(mock.ExpectationsWereMet())
}
// 用户容量更新失败
{
file := File{Size: 10}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnError(errors.New("error"))
mock.ExpectRollback()
a.Error(file.UpdateSize(8))
a.NoError(mock.ExpectationsWereMet())
}
}
func TestFile_PopChunkToFile(t *testing.T) {
a := assert.New(t)
timeNow := time.Now()
file := File{}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(file.PopChunkToFile(&timeNow, "1,1"))
}
func TestFile_CanCopy(t *testing.T) {
a := assert.New(t)
file := File{}
a.True(file.CanCopy())
file.UploadSessionID = &file.Name
a.False(file.CanCopy())
}
func TestFile_FileInfoInterface(t *testing.T) {
asserts := assert.New(t)
file := File{
Model: gorm.Model{
UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC),
},
Name: "test_name",
SourceName: "",
UserID: 0,
Size: 10,
PicInfo: "",
FolderID: 0,
PolicyID: 0,
Policy: Policy{},
Position: "/test",
}
name := file.GetName()
asserts.Equal("test_name", name)
size := file.GetSize()
asserts.Equal(uint64(10), size)
asserts.Equal(time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), file.ModTime())
asserts.False(file.IsDir())
asserts.Equal("/test", file.GetPosition())
}
func TestGetFilesByKeywords(t *testing.T) {
asserts := assert.New(t)
// 未指定用户
{
mock.ExpectQuery("SELECT(.+)").WithArgs("k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetFilesByKeywords(0, nil, "k1", "k2")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(res, 1)
}
// 指定用户
{
mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetFilesByKeywords(1, nil, "k1", "k2")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(res, 1)
}
// 指定父目录
{
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 12, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetFilesByKeywords(1, []uint{12}, "k1", "k2")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(res, 1)
}
}
func TestFile_CreateOrGetSourceLink(t *testing.T) {
a := assert.New(t)
file := &File{}
file.ID = 1
// 已存在,返回老的 SourceLink
{
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
res, err := file.CreateOrGetSourceLink()
a.NoError(err)
a.EqualValues(2, res.ID)
a.NoError(mock.ExpectationsWereMet())
}
// 不存在,插入失败
{
expectedErr := errors.New("error")
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnError(expectedErr)
mock.ExpectRollback()
res, err := file.CreateOrGetSourceLink()
a.Nil(res)
a.ErrorIs(err, expectedErr)
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
res, err := file.CreateOrGetSourceLink()
a.NoError(err)
a.EqualValues(2, res.ID)
a.EqualValues(file.ID, res.File.ID)
a.NoError(mock.ExpectationsWereMet())
}
}
func TestFile_UpdateMetadata(t *testing.T) {
a := assert.New(t)
file := &File{}
file.ID = 1
// 更新失败
{
expectedErr := errors.New("error")
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(sqlmock.AnyArg(), 1).WillReturnError(expectedErr)
mock.ExpectRollback()
a.ErrorIs(file.UpdateMetadata(map[string]string{"1": "1"}), expectedErr)
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(file.UpdateMetadata(map[string]string{"1": "1"}))
a.NoError(mock.ExpectationsWereMet())
a.Equal("1", file.MetadataSerialized["1"])
}
}
func TestFile_ShouldLoadThumb(t *testing.T) {
a := assert.New(t)
file := &File{
MetadataSerialized: map[string]string{},
}
file.ID = 1
// 无缩略图
{
file.MetadataSerialized[ThumbStatusMetadataKey] = ThumbStatusNotAvailable
a.False(file.ShouldLoadThumb())
}
// 有缩略图
{
file.MetadataSerialized[ThumbStatusMetadataKey] = ThumbStatusExist
a.True(file.ShouldLoadThumb())
}
}
func TestFile_ThumbFile(t *testing.T) {
a := assert.New(t)
file := &File{
SourceName: "test",
MetadataSerialized: map[string]string{},
}
file.ID = 1
a.Equal("test._thumb", file.ThumbFile())
}

View file

@ -16,10 +16,12 @@ type Folder struct {
Name string `gorm:"unique_index:idx_only_one_name"`
ParentID *uint `gorm:"index:parent_id;unique_index:idx_only_one_name"`
OwnerID uint `gorm:"index:owner_id"`
PolicyID uint // Webdav下挂载的存储策略ID
// 数据库忽略字段
Position string `gorm:"-"`
WebdavDstName string `gorm:"-"`
Position string `gorm:"-"`
WebdavDstName string `gorm:"-"`
InheritPolicyID uint `gorm:"-"` // 从父目录继承而来的policy id默认值则使用自身的的PolicyID
}
// Create 创建目录
@ -33,6 +35,13 @@ func (folder *Folder) Create() (uint, error) {
return folder.ID, nil
}
// GetMountedFolders 列出已挂载存储策略的目录
func GetMountedFolders(uid uint) []Folder {
var folders []Folder
DB.Where("owner_id = ? and policy_id <> ?", uid, 0).Find(&folders)
return folders
}
// GetChild 返回folder下名为name的子目录不存在则返回错误
func (folder *Folder) GetChild(name string) (*Folder, error) {
var resFolder Folder
@ -40,9 +49,14 @@ func (folder *Folder) GetChild(name string) (*Folder, error) {
Where("parent_id = ? AND owner_id = ? AND name = ?", folder.ID, folder.OwnerID, name).
First(&resFolder).Error
// 将子目录的路径传递下去
// 将子目录的路径及存储策略传递下去
if err == nil {
resFolder.Position = path.Join(folder.Position, folder.Name)
if folder.PolicyID > 0 {
resFolder.InheritPolicyID = folder.PolicyID
} else if folder.InheritPolicyID > 0 {
resFolder.InheritPolicyID = folder.InheritPolicyID
}
}
return &resFolder, err
}
@ -323,6 +337,11 @@ func (folder *Folder) Rename(new string) error {
return DB.Model(&folder).UpdateColumn("name", new).Error
}
// Mount 目录挂载
func (folder *Folder) Mount(new uint) error {
return DB.Model(&folder).Update("policy_id", new).Error
}
/*
实现 FileInfo.FileInfo 接口
TODO 测试

View file

@ -1,622 +0,0 @@
package model
import (
"errors"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestFolder_Create(t *testing.T) {
asserts := assert.New(t)
folder := &Folder{
Name: "new folder",
}
// 不存在,插入成功
mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectCommit()
fid, err := folder.Create()
asserts.NoError(err)
asserts.Equal(uint(5), fid)
asserts.NoError(mock.ExpectationsWereMet())
// 插入失败
mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
fid, err = folder.Create()
asserts.NoError(err)
asserts.Equal(uint(1), fid)
asserts.NoError(mock.ExpectationsWereMet())
// 存在,直接返回
mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5))
fid, err = folder.Create()
asserts.NoError(err)
asserts.Equal(uint(5), fid)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestFolder_GetChild(t *testing.T) {
asserts := assert.New(t)
folder := Folder{
Model: gorm.Model{ID: 5},
OwnerID: 1,
Name: "/",
}
// 目录存在
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(5, 1, "sub").
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "sub"))
sub, err := folder.GetChild("sub")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(sub.Name, "sub")
asserts.Equal("/", sub.Position)
}
// 目录不存在
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(5, 1, "sub").
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
sub, err := folder.GetChild("sub")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Equal(uint(0), sub.ID)
}
}
func TestFolder_GetChildFolder(t *testing.T) {
asserts := assert.New(t)
folder := &Folder{
Model: gorm.Model{
ID: 1,
},
Position: "/123",
Name: "456",
}
// 找不到
mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnError(errors.New("error"))
files, err := folder.GetChildFolder()
asserts.Error(err)
asserts.Len(files, 0)
asserts.NoError(mock.ExpectationsWereMet())
// 找到了
mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2))
files, err = folder.GetChildFolder()
asserts.NoError(err)
asserts.Len(files, 2)
asserts.Equal("/123/456", files[0].Position)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestGetRecursiveChildFolderSQLite(t *testing.T) {
conf.DatabaseConfig.Type = "sqlite"
asserts := assert.New(t)
// 测试目录结构
// 1
// 2 3
// 4 5 6
// 查询第一层
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 1).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name"}).
AddRow(1, "folder1"),
)
// 查询第二层
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 1).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name"}).
AddRow(2, "folder2").
AddRow(3, "folder3"),
)
// 查询第三层
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name"}).
AddRow(4, "folder4").
AddRow(5, "folder5").
AddRow(6, "folder6"),
)
// 查询第四层
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 4, 5, 6).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name"}),
)
folders, err := GetRecursiveChildFolder([]uint{1}, 1, true)
asserts.NoError(err)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Len(folders, 6)
}
func TestDeleteFolderByIDs(t *testing.T) {
asserts := assert.New(t)
// 出错
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := DeleteFolderByIDs([]uint{1, 2, 3})
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("DELETE(.+)").
WillReturnResult(sqlmock.NewResult(0, 3))
mock.ExpectCommit()
err := DeleteFolderByIDs([]uint{1, 2, 3})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
}
func TestGetFoldersByIDs(t *testing.T) {
asserts := assert.New(t)
// 出错
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3, 1).
WillReturnError(errors.New("error"))
folders, err := GetFoldersByIDs([]uint{1, 2, 3}, 1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Len(folders, 0)
}
// 部分找到
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
folders, err := GetFoldersByIDs([]uint{1, 2, 3}, 1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(folders, 1)
}
}
func TestFolder_MoveOrCopyFileTo(t *testing.T) {
asserts := assert.New(t)
// 当前目录
folder := Folder{
Model: gorm.Model{ID: 1},
OwnerID: 1,
Name: "test",
}
// 目标目录
dstFolder := Folder{
Model: gorm.Model{ID: 10},
Name: "dst",
}
// 复制文件
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(
1,
2,
3,
1,
1,
).WillReturnRows(
sqlmock.NewRows([]string{"id", "size", "upload_session_id"}).
AddRow(1, 10, nil).
AddRow(2, 20, nil).
AddRow(2, 20, &folder.Name),
)
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
storage, err := folder.MoveOrCopyFileTo(
[]uint{1, 2, 3},
&dstFolder,
true,
)
asserts.NoError(err)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(uint64(30), storage)
}
// 复制文件, 检索文件出错
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(
1,
2,
1,
1,
).WillReturnError(errors.New("error"))
storage, err := folder.MoveOrCopyFileTo(
[]uint{1, 2},
&dstFolder,
true,
)
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(uint64(0), storage)
}
// 复制文件,第二个文件插入出错
{
mock.ExpectQuery("SELECT(.+)").
WithArgs(
1,
2,
1,
1,
).WillReturnRows(
sqlmock.NewRows([]string{"id", "size"}).
AddRow(1, 10).
AddRow(2, 20),
)
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
storage, err := folder.MoveOrCopyFileTo(
[]uint{1, 2},
&dstFolder,
true,
)
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(uint64(10), storage)
}
// 移动文件 成功
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 1).
WillReturnResult(sqlmock.NewResult(1, 2))
mock.ExpectCommit()
storage, err := folder.MoveOrCopyFileTo(
[]uint{1, 2},
&dstFolder,
false,
)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(uint64(0), storage)
}
// 移动文件 出错
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 1).
WillReturnError(errors.New("error"))
mock.ExpectRollback()
storage, err := folder.MoveOrCopyFileTo(
[]uint{1, 2},
&dstFolder,
false,
)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Equal(uint64(0), storage)
}
}
func TestFolder_CopyFolderTo(t *testing.T) {
conf.DatabaseConfig.Type = "mysql"
asserts := assert.New(t)
// 父目录
parFolder := Folder{
Model: gorm.Model{ID: 9},
OwnerID: 1,
}
// 目标目录
dstFolder := Folder{
Model: gorm.Model{ID: 10},
}
// 测试复制目录结构
// test(2)(5)
// 1(3)(6) 2.txt
// 3(4)(7) 4.txt 5.txt(上传中)
// 正常情况 成功
{
// GetRecursiveChildFolder
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}))
// 复制目录
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1))
mock.ExpectCommit()
// 查找子文件
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3, 4).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name", "folder_id", "size", "upload_session_id"}).
AddRow(1, "2.txt", 2, 10, nil).
AddRow(2, "3.txt", 3, 20, nil).
AddRow(3, "5.txt", 3, 20, &dstFolder.Name),
)
// 复制子文件
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1))
mock.ExpectCommit()
size, err := parFolder.CopyFolderTo(2, &dstFolder)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(uint64(30), size)
}
// 递归查询失败
{
// GetRecursiveChildFolder
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnError(errors.New("error"))
size, err := parFolder.CopyFolderTo(2, &dstFolder)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Equal(uint64(0), size)
}
// 父目录ID不存在
{
// GetRecursiveChildFolder
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 99))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}))
// 复制目录
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectCommit()
size, err := parFolder.CopyFolderTo(2, &dstFolder)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Equal(uint64(0), size)
}
// 查询子文件失败
{
// GetRecursiveChildFolder
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}))
// 复制目录
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1))
mock.ExpectCommit()
// 查找子文件
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3, 4).
WillReturnError(errors.New("error"))
size, err := parFolder.CopyFolderTo(2, &dstFolder)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Equal(uint64(0), size)
}
// 复制文件 一个失败
{
// GetRecursiveChildFolder
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}))
// 复制目录
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1))
mock.ExpectCommit()
// 查找子文件
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2, 3, 4).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name", "folder_id", "size"}).
AddRow(1, "2.txt", 2, 10).
AddRow(2, "3.txt", 3, 20),
)
// 复制子文件
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
size, err := parFolder.CopyFolderTo(2, &dstFolder)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Equal(uint64(10), size)
}
}
func TestFolder_MoveOrCopyFolderTo_Move(t *testing.T) {
conf.DatabaseConfig.Type = "mysql"
asserts := assert.New(t)
// 父目录
parFolder := Folder{
Model: gorm.Model{ID: 9},
OwnerID: 1,
}
// 目标目录
dstFolder := Folder{
Model: gorm.Model{ID: 10},
OwnerID: 1,
}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 9).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := parFolder.MoveFolderTo([]uint{1, 2}, &dstFolder)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
// 移动自己到自己内部,失败
{
err := parFolder.MoveFolderTo([]uint{10, 2}, &dstFolder)
asserts.Error(err)
}
}
func TestFolder_FileInfoInterface(t *testing.T) {
asserts := assert.New(t)
folder := Folder{
Model: gorm.Model{
UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC),
},
Name: "test_name",
OwnerID: 0,
Position: "/test",
}
name := folder.GetName()
asserts.Equal("test_name", name)
size := folder.GetSize()
asserts.Equal(uint64(0), size)
asserts.Equal(time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), folder.ModTime())
asserts.True(folder.IsDir())
asserts.Equal("/test", folder.GetPosition())
}
func TestTraceRoot(t *testing.T) {
asserts := assert.New(t)
var parentId uint
parentId = 5
folder := Folder{
ParentID: &parentId,
OwnerID: 1,
Name: "test_name",
}
// 成功
{
mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "/"))
asserts.NoError(folder.TraceRoot())
asserts.Equal("/parent", folder.Position)
asserts.NoError(mock.ExpectationsWereMet())
}
// 出现错误
// 成功
{
mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
WillReturnError(errors.New("error"))
asserts.Error(folder.TraceRoot())
asserts.Equal("parent", folder.Position)
asserts.NoError(mock.ExpectationsWereMet())
}
}
func TestFolder_Rename(t *testing.T) {
asserts := assert.New(t)
folder := Folder{
Model: gorm.Model{
ID: 1,
},
Name: "test_name",
OwnerID: 1,
Position: "/test",
}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)").
WithArgs("test_name_new", 1).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := folder.Rename("test_name_new")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
// 出现错误
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)").
WithArgs("test_name_new", 1).
WillReturnError(errors.New("error"))
mock.ExpectRollback()
err := folder.Rename("test_name_new")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
}

View file

@ -2,6 +2,7 @@ package model
import (
"encoding/json"
"github.com/jinzhu/gorm"
)
@ -29,11 +30,15 @@ type GroupOption struct {
DecompressSize uint64 `json:"decompress_size,omitempty"`
OneTimeDownload bool `json:"one_time_download,omitempty"`
ShareDownload bool `json:"share_download,omitempty"`
ShareFree bool `json:"share_free,omitempty"`
Aria2 bool `json:"aria2,omitempty"` // 离线下载
Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置
Relocate bool `json:"relocate,omitempty"` // 转移文件
SourceBatchSize int `json:"source_batch,omitempty"`
RedirectedSource bool `json:"redirected_source,omitempty"`
Aria2BatchSize int `json:"aria2_batch,omitempty"`
AvailableNodes []uint `json:"available_nodes,omitempty"`
SelectNode bool `json:"select_node,omitempty"`
AdvanceDelete bool `json:"advance_delete,omitempty"`
WebDAVProxy bool `json:"webdav_proxy,omitempty"`
}

View file

@ -1,77 +0,0 @@
package model
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"testing"
)
func TestGetGroupByID(t *testing.T) {
asserts := assert.New(t)
//找到用户组时
groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}).
AddRow(1, "管理员", "[1]")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows)
group, err := GetGroupByID(1)
asserts.NoError(err)
asserts.Equal(Group{
Model: gorm.Model{
ID: 1,
},
Name: "管理员",
Policies: "[1]",
PolicyList: []uint{1},
}, group)
//未找到用户时
mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found"))
group, err = GetGroupByID(1)
asserts.Error(err)
asserts.Equal(Group{}, group)
}
func TestGroup_AfterFind(t *testing.T) {
asserts := assert.New(t)
testCase := Group{
Model: gorm.Model{
ID: 1,
},
Name: "管理员",
Policies: "[1]",
}
err := testCase.AfterFind()
asserts.NoError(err)
asserts.Equal(testCase.PolicyList, []uint{1})
testCase.Policies = "[1,2,3,4,5]"
err = testCase.AfterFind()
asserts.NoError(err)
asserts.Equal(testCase.PolicyList, []uint{1, 2, 3, 4, 5})
testCase.Policies = "[1,2,3,4,5"
err = testCase.AfterFind()
asserts.Error(err)
testCase.Policies = "[]"
err = testCase.AfterFind()
asserts.NoError(err)
asserts.Equal(testCase.PolicyList, []uint{})
}
func TestGroup_BeforeSave(t *testing.T) {
asserts := assert.New(t)
group := Group{
PolicyList: []uint{1, 2, 3},
}
{
err := group.BeforeSave()
asserts.NoError(err)
asserts.Equal("[1,2,3]", group.Policies)
}
}

View file

@ -40,8 +40,8 @@ func migration() {
DB = DB.Set("gorm:table_options", "ENGINE=InnoDB")
}
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{},
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}, &SourceLink{})
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &StoragePack{}, &Share{},
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Order{}, &Redeem{}, &Report{}, &Node{}, &SourceLink{})
// 创建初始存储策略
addDefaultPolicy()
@ -107,10 +107,13 @@ func addDefaultGroups() {
ArchiveDownload: true,
ArchiveTask: true,
ShareDownload: true,
ShareFree: true,
Aria2: true,
Relocate: true,
SourceBatchSize: 1000,
Aria2BatchSize: 50,
RedirectedSource: true,
SelectNode: true,
AdvanceDelete: true,
},
}

View file

@ -1,21 +0,0 @@
package model
import (
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestMigration(t *testing.T) {
asserts := assert.New(t)
conf.DatabaseConfig.Type = "sqlite"
DB, _ = gorm.Open("sqlite", ":memory:")
asserts.NotPanics(func() {
migration()
})
conf.DatabaseConfig.Type = "mysql"
DB = mockDB
}

View file

@ -1,64 +0,0 @@
package model
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestGetNodeByID(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetNodeByID(1)
a.NoError(err)
a.EqualValues(1, res.ID)
a.NoError(mock.ExpectationsWereMet())
}
func TestGetNodesByStatus(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(NodeActive))
res, err := GetNodesByStatus(NodeActive)
a.NoError(err)
a.Len(res, 1)
a.EqualValues(NodeActive, res[0].Status)
a.NoError(mock.ExpectationsWereMet())
}
func TestNode_AfterFind(t *testing.T) {
a := assert.New(t)
node := &Node{}
// No aria2 options
{
a.NoError(node.AfterFind())
}
// with aria2 options
{
node.Aria2Options = `{"timeout":1}`
a.NoError(node.AfterFind())
a.Equal(1, node.Aria2OptionsSerialized.Timeout)
}
}
func TestNode_BeforeSave(t *testing.T) {
a := assert.New(t)
node := &Node{}
node.Aria2OptionsSerialized.Timeout = 1
a.NoError(node.BeforeSave())
a.Contains(node.Aria2Options, "1")
}
func TestNode_SetStatus(t *testing.T) {
a := assert.New(t)
node := &Node{}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)nodes").WithArgs(NodeActive, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(node.SetStatus(NodeActive))
a.Equal(NodeActive, node.Status)
a.NoError(mock.ExpectationsWereMet())
}

59
models/order.go Executable file
View file

@ -0,0 +1,59 @@
package model
import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
)
const (
// PackOrderType 容量包订单
PackOrderType = iota
// GroupOrderType 用户组订单
GroupOrderType
// ScoreOrderType 积分充值订单
ScoreOrderType
)
const (
// OrderUnpaid 未支付
OrderUnpaid = iota
// OrderPaid 已支付
OrderPaid
// OrderCanceled 已取消
OrderCanceled
)
// Order 交易订单
type Order struct {
gorm.Model
UserID uint // 创建者ID
OrderNo string `gorm:"index:order_number"` // 商户自定义订单编号
Type int // 订单类型
Method string // 支付类型
ProductID int64 // 商品ID
Num int // 商品数量
Name string // 订单标题
Price int // 商品单价
Status int // 订单状态
}
// Create 创建订单记录
func (order *Order) Create() (uint, error) {
if err := DB.Create(order).Error; err != nil {
util.Log().Warning("Failed to insert order record: %s", err)
return 0, err
}
return order.ID, nil
}
// UpdateStatus 更新订单状态
func (order *Order) UpdateStatus(status int) {
DB.Model(order).Update("status", status)
}
// GetOrderByNo 根据商户订单号查询订单
func GetOrderByNo(id string) (*Order, error) {
var order Order
err := DB.Where("order_no = ?", id).First(&order).Error
return &order, err
}

View file

@ -3,14 +3,15 @@ package model
import (
"encoding/gob"
"encoding/json"
"github.com/gofrs/uuid"
"github.com/samber/lo"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gofrs/uuid"
"github.com/samber/lo"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
@ -73,6 +74,18 @@ type PolicyOption struct {
ThumbExts []string `json:"thumb_exts,omitempty"`
}
// thumbSuffix 支持缩略图处理的文件扩展名
var thumbSuffix = map[string][]string{
"local": {},
"qiniu": {".psd", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
"oss": {".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
"cos": {".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
"upyun": {".svg", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
"s3": {},
"remote": {},
"onedrive": {"*"},
}
func init() {
// 注册缓存用到的复杂结构
gob.Register(Policy{})
@ -179,6 +192,17 @@ func (policy *Policy) GenerateFileName(uid uint, origin string) string {
return fileRule
}
// IsThumbExist 给定文件名,返回此存储策略下是否可能存在缩略图
func (policy *Policy) IsThumbExist(name string) bool {
if list, ok := thumbSuffix[policy.Type]; ok {
if len(list) == 1 && list[0] == "*" {
return true
}
return util.ContainsString(list, strings.ToLower(filepath.Ext(name)))
}
return false
}
// IsDirectlyPreview 返回此策略下文件是否可以直接预览(不需要重定向)
func (policy *Policy) IsDirectlyPreview() bool {
return policy.Type == "local"

View file

@ -1,269 +0,0 @@
package model
import (
"encoding/json"
"strconv"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestGetPolicyByID(t *testing.T) {
asserts := assert.New(t)
cache.Deletes([]string{"22", "23"}, "policy_")
// 缓存未命中
{
rows := sqlmock.NewRows([]string{"name", "type", "options"}).
AddRow("默认存储策略", "local", "{\"od_redirect\":\"123\"}")
mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows)
policy, err := GetPolicyByID(uint(22))
asserts.NoError(err)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal("默认存储策略", policy.Name)
asserts.Equal("123", policy.OptionsSerialized.OauthRedirect)
rows = sqlmock.NewRows([]string{"name", "type", "options"})
mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows)
policy, err = GetPolicyByID(uint(23))
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
// 命中
{
policy, err := GetPolicyByID(uint(22))
asserts.NoError(err)
asserts.Equal("默认存储策略", policy.Name)
asserts.Equal("123", policy.OptionsSerialized.OauthRedirect)
}
}
func TestPolicy_BeforeSave(t *testing.T) {
asserts := assert.New(t)
testPolicy := Policy{
OptionsSerialized: PolicyOption{
OauthRedirect: "123",
},
}
expected, _ := json.Marshal(testPolicy.OptionsSerialized)
err := testPolicy.BeforeSave()
asserts.NoError(err)
asserts.Equal(string(expected), testPolicy.Options)
}
func TestPolicy_GeneratePath(t *testing.T) {
asserts := assert.New(t)
testPolicy := Policy{}
testPolicy.DirNameRule = "{randomkey16}"
asserts.Len(testPolicy.GeneratePath(1, "/"), 16)
testPolicy.DirNameRule = "{randomkey8}"
asserts.Len(testPolicy.GeneratePath(1, "/"), 8)
testPolicy.DirNameRule = "{timestamp}"
asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.FormatInt(time.Now().Unix(), 10))
testPolicy.DirNameRule = "{uid}"
asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.Itoa(int(1)))
testPolicy.DirNameRule = "{datetime}"
asserts.Len(testPolicy.GeneratePath(1, "/"), 14)
testPolicy.DirNameRule = "{date}"
asserts.Len(testPolicy.GeneratePath(1, "/"), 8)
testPolicy.DirNameRule = "123{date}ss{datetime}"
asserts.Len(testPolicy.GeneratePath(1, "/"), 27)
testPolicy.DirNameRule = "/1/{path}/456"
asserts.Condition(func() (success bool) {
res := testPolicy.GeneratePath(1, "/23")
return res == "/1/23/456" || res == "\\1\\23\\456"
})
}
func TestPolicy_GenerateFileName(t *testing.T) {
asserts := assert.New(t)
// 重命名关闭
{
testPolicy := Policy{
AutoRename: false,
}
testPolicy.FileNameRule = "{randomkey16}"
asserts.Equal("123.txt", testPolicy.GenerateFileName(1, "123.txt"))
testPolicy.Type = "oss"
asserts.Equal("origin", testPolicy.GenerateFileName(1, "origin"))
}
// 重命名开启
{
testPolicy := Policy{
AutoRename: true,
}
testPolicy.FileNameRule = "{randomkey16}"
asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 16)
testPolicy.FileNameRule = "{randomkey8}"
asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8)
testPolicy.FileNameRule = "{timestamp}"
asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.FormatInt(time.Now().Unix(), 10))
testPolicy.FileNameRule = "{uid}"
asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.Itoa(int(1)))
testPolicy.FileNameRule = "{datetime}"
asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 14)
testPolicy.FileNameRule = "{date}"
asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8)
testPolicy.FileNameRule = "123{date}ss{datetime}"
asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 27)
testPolicy.FileNameRule = "{originname_without_ext}"
asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 3)
testPolicy.FileNameRule = "{originname_without_ext}_{randomkey8}{ext}"
asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 16)
// 支持{originname}的策略
testPolicy.Type = "local"
testPolicy.FileNameRule = "123{originname}"
asserts.Equal("123123.txt", testPolicy.GenerateFileName(1, "123.txt"))
testPolicy.Type = "qiniu"
testPolicy.FileNameRule = "{uid}123{originname}"
asserts.Equal("1123123.txt", testPolicy.GenerateFileName(1, "123.txt"))
testPolicy.Type = "oss"
testPolicy.FileNameRule = "{uid}123{originname}"
asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321"))
testPolicy.Type = "upyun"
testPolicy.FileNameRule = "{uid}123{originname}"
asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321"))
testPolicy.Type = "qiniu"
testPolicy.FileNameRule = "{uid}123{originname}"
asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321"))
testPolicy.Type = "local"
testPolicy.FileNameRule = "{uid}123{originname}"
asserts.Equal("1123", testPolicy.GenerateFileName(1, ""))
testPolicy.Type = "local"
testPolicy.FileNameRule = "{ext}123{uuid}"
asserts.Contains(testPolicy.GenerateFileName(1, "123.txt"), ".txt123")
}
}
func TestPolicy_IsDirectlyPreview(t *testing.T) {
asserts := assert.New(t)
policy := Policy{Type: "local"}
asserts.True(policy.IsDirectlyPreview())
policy.Type = "remote"
asserts.False(policy.IsDirectlyPreview())
}
func TestPolicy_ClearCache(t *testing.T) {
asserts := assert.New(t)
cache.Set("policy_202", 1, 0)
policy := Policy{Model: gorm.Model{ID: 202}}
policy.ClearCache()
_, ok := cache.Get("policy_202")
asserts.False(ok)
}
func TestPolicy_UpdateAccessKey(t *testing.T) {
asserts := assert.New(t)
policy := Policy{Model: gorm.Model{ID: 202}}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
policy.AccessKey = "123"
err := policy.SaveAndClearCache()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
func TestPolicy_Props(t *testing.T) {
asserts := assert.New(t)
policy := Policy{Type: "onedrive"}
policy.OptionsSerialized.PlaceholderWithSize = true
asserts.False(policy.IsThumbGenerateNeeded())
asserts.False(policy.IsTransitUpload(4))
asserts.False(policy.IsTransitUpload(5 * 1024 * 1024))
asserts.True(policy.CanStructureBeListed())
asserts.True(policy.IsUploadPlaceholderWithSize())
policy.Type = "local"
asserts.True(policy.IsThumbGenerateNeeded())
asserts.False(policy.CanStructureBeListed())
asserts.False(policy.IsUploadPlaceholderWithSize())
policy.Type = "remote"
asserts.True(policy.IsUploadPlaceholderWithSize())
}
func TestPolicy_UpdateAccessKeyAndClearCache(t *testing.T) {
a := assert.New(t)
cache.Set("policy_1331", Policy{}, 3600)
p := &Policy{}
p.ID = 1331
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs("ak", sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.NoError(p.UpdateAccessKeyAndClearCache("ak"))
a.NoError(mock.ExpectationsWereMet())
_, ok := cache.Get("policy_1331")
a.False(ok)
}
func TestPolicy_CouldProxyThumb(t *testing.T) {
a := assert.New(t)
p := &Policy{Type: "local"}
// local policy
{
a.False(p.CouldProxyThumb())
}
// feature not enabled
{
p.Type = "remote"
cache.Set("setting_thumb_proxy_enabled", "0", 0)
a.False(p.CouldProxyThumb())
}
// list not contain current policy
{
p.ID = 2
cache.Set("setting_thumb_proxy_enabled", "1", 0)
cache.Set("setting_thumb_proxy_policy", "[1]", 0)
a.False(p.CouldProxyThumb())
}
// enabled
{
p.ID = 2
cache.Set("setting_thumb_proxy_enabled", "1", 0)
cache.Set("setting_thumb_proxy_policy", "[2]", 0)
a.True(p.CouldProxyThumb())
}
cache.Deletes([]string{"thumb_proxy_enabled", "thumb_proxy_policy"}, "setting_")
}

27
models/redeem.go Executable file
View file

@ -0,0 +1,27 @@
package model
import "github.com/jinzhu/gorm"
// Redeem 兑换码
type Redeem struct {
gorm.Model
Type int // 订单类型
ProductID int64 // 商品ID
Num int // 商品数量
Code string `gorm:"size:64,index:redeem_code"` // 兑换码
Used bool // 是否已被使用
}
// GetAvailableRedeem 根据code查找可用兑换码
func GetAvailableRedeem(code string) (*Redeem, error) {
redeem := &Redeem{}
result := DB.Where("code = ? and used = ?", code, false).First(redeem)
return redeem, result.Error
}
// Use 设定为已使用状态
func (redeem *Redeem) Use() {
DB.Model(redeem).Updates(map[string]interface{}{
"used": true,
})
}

21
models/report.go Executable file
View file

@ -0,0 +1,21 @@
package model
import (
"github.com/jinzhu/gorm"
)
// Report 举报模型
type Report struct {
gorm.Model
ShareID uint `gorm:"index:share_id"` // 对应分享ID
Reason int // 举报原因
Description string // 补充描述
// 关联模型
Share Share `gorm:"save_associations:false:false"`
}
// Create 创建举报
func (report *Report) Create() error {
return DB.Create(report).Error
}

View file

@ -5,5 +5,6 @@ import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
func Init() {
invoker.Register("ResetAdminPassword", ResetAdminPassword(0))
invoker.Register("CalibrateUserStorage", UserStorageCalibration(0))
invoker.Register("OSSToPlus", UpgradeToPro(0))
invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0))
}

View file

@ -1,39 +0,0 @@
package invoker
import (
"context"
"github.com/stretchr/testify/assert"
"testing"
)
type TestScript int
func (script TestScript) Run(ctx context.Context) {
}
func TestRunDBScript(t *testing.T) {
asserts := assert.New(t)
Register("test", TestScript(0))
// 不存在
{
asserts.Error(RunDBScript("else", context.Background()))
}
// 存在
{
asserts.NoError(RunDBScript("test", context.Background()))
}
}
func TestListPrefix(t *testing.T) {
asserts := assert.New(t)
Register("U1", TestScript(0))
Register("U2", TestScript(0))
Register("U3", TestScript(0))
Register("P1", TestScript(0))
res := ListPrefix("U")
asserts.Len(res, 3)
}

View file

@ -1,50 +0,0 @@
package scripts
import (
"context"
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestResetAdminPassword_Run(t *testing.T) {
asserts := assert.New(t)
script := ResetAdminPassword(0)
// 初始用户不存在
{
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}))
asserts.Panics(func() {
script.Run(context.Background())
})
asserts.NoError(mock.ExpectationsWereMet())
}
// 密码更新失败
{
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
asserts.Panics(func() {
script.Run(context.Background())
})
asserts.NoError(mock.ExpectationsWereMet())
}
// 成功
{
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.NotPanics(func() {
script.Run(context.Background())
})
asserts.NoError(mock.ExpectationsWereMet())
}
}

View file

@ -1,61 +0,0 @@
package scripts
import (
"context"
"database/sql"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
)
var mock sqlmock.Sqlmock
var mockDB *gorm.DB
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
mockDB = model.DB
defer db.Close()
m.Run()
}
func TestUserStorageCalibration_Run(t *testing.T) {
asserts := assert.New(t)
script := UserStorageCalibration(0)
// 容量异常
{
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(11))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
script.Run(context.Background())
asserts.NoError(mock.ExpectationsWereMet())
}
// 容量正常
{
mock.ExpectQuery("SELECT(.+)users(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
script.Run(context.Background())
asserts.NoError(mock.ExpectationsWereMet())
}
}

22
models/scripts/upgrade-pro.go Executable file
View file

@ -0,0 +1,22 @@
package scripts
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
)
type UpgradeToPro int
// Run 运行脚本从社区版升级至 Pro 版
func (script UpgradeToPro) Run(ctx context.Context) {
// folder.PolicyID 字段设为 0
model.DB.Model(model.Folder{}).UpdateColumn("policy_id", 0)
// shares.Score 字段设为0
model.DB.Model(model.Share{}).UpdateColumn("score", 0)
// user 表相关初始字段
model.DB.Model(model.User{}).Updates(map[string]interface{}{
"score": 0,
"previous_group_id": 0,
"open_id": "",
})
}

View file

@ -1,66 +0,0 @@
package scripts
import (
"context"
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestUpgradeTo340_Run(t *testing.T) {
a := assert.New(t)
script := UpgradeTo340(0)
// skip
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}))
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
// node not found
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1"))
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}))
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
// success
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
AddRow("aria2_interval", "expected_aria2_interval").
AddRow("aria2_temp_path", "expected_aria2_temp_path").
AddRow("aria2_token", "expected_aria2_token").
AddRow("aria2_options", "{}"))
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
// failed
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
AddRow("aria2_interval", "expected_aria2_interval").
AddRow("aria2_temp_path", "expected_aria2_temp_path").
AddRow("aria2_token", "expected_aria2_token").
AddRow("aria2_options", "{}"))
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
}

View file

@ -1,196 +0,0 @@
package model
import (
"database/sql"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
var mock sqlmock.Sqlmock
var mockDB *gorm.DB
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
DB, _ = gorm.Open("mysql", db)
mockDB = DB
defer db.Close()
m.Run()
}
func TestGetSettingByType(t *testing.T) {
cache.Store = cache.NewMemoStore()
asserts := assert.New(t)
//找到设置时
rows := sqlmock.NewRows([]string{"name", "value", "type"}).
AddRow("siteName", "Cloudreve", "basic").
AddRow("siteDes", "Something wonderful", "basic")
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
settings := GetSettingByType([]string{"basic"})
asserts.Equal(map[string]string{
"siteName": "Cloudreve",
"siteDes": "Something wonderful",
}, settings)
rows = sqlmock.NewRows([]string{"name", "value", "type"}).
AddRow("siteName", "Cloudreve", "basic").
AddRow("siteDes", "Something wonderful", "basic2")
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
settings = GetSettingByType([]string{"basic", "basic2"})
asserts.Equal(map[string]string{
"siteName": "Cloudreve",
"siteDes": "Something wonderful",
}, settings)
//找不到
rows = sqlmock.NewRows([]string{"name", "value", "type"})
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
settings = GetSettingByType([]string{"basic233"})
asserts.Equal(map[string]string{}, settings)
}
func TestGetSettingByNameWithDefault(t *testing.T) {
a := assert.New(t)
rows := sqlmock.NewRows([]string{"name", "value", "type"})
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
settings := GetSettingByNameWithDefault("123", "123321")
a.Equal("123321", settings)
}
func TestGetSettingByNames(t *testing.T) {
cache.Store = cache.NewMemoStore()
asserts := assert.New(t)
//找到设置时
rows := sqlmock.NewRows([]string{"name", "value", "type"}).
AddRow("siteName", "Cloudreve", "basic").
AddRow("siteDes", "Something wonderful", "basic")
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
settings := GetSettingByNames("siteName", "siteDes")
asserts.Equal(map[string]string{
"siteName": "Cloudreve",
"siteDes": "Something wonderful",
}, settings)
asserts.NoError(mock.ExpectationsWereMet())
//找到其中一个设置时
rows = sqlmock.NewRows([]string{"name", "value", "type"}).
AddRow("siteName2", "Cloudreve", "basic")
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
settings = GetSettingByNames("siteName2", "siteDes2333")
asserts.Equal(map[string]string{
"siteName2": "Cloudreve",
}, settings)
asserts.NoError(mock.ExpectationsWereMet())
//找不到设置时
rows = sqlmock.NewRows([]string{"name", "value", "type"})
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
settings = GetSettingByNames("siteName2333", "siteDes2333")
asserts.Equal(map[string]string{}, settings)
asserts.NoError(mock.ExpectationsWereMet())
// 一个设置命中缓存
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WithArgs("siteDes2").WillReturnRows(sqlmock.NewRows([]string{"name", "value", "type"}).
AddRow("siteDes2", "Cloudreve2", "basic"))
settings = GetSettingByNames("siteName", "siteDes2")
asserts.Equal(map[string]string{
"siteName": "Cloudreve",
"siteDes2": "Cloudreve2",
}, settings)
asserts.NoError(mock.ExpectationsWereMet())
}
// TestGetSettingByName 测试GetSettingByName
func TestGetSettingByName(t *testing.T) {
cache.Store = cache.NewMemoStore()
asserts := assert.New(t)
//找到设置时
rows := sqlmock.NewRows([]string{"name", "value", "type"}).
AddRow("siteName", "Cloudreve", "basic")
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
siteName := GetSettingByName("siteName")
asserts.Equal("Cloudreve", siteName)
asserts.NoError(mock.ExpectationsWereMet())
// 第二次查询应返回缓存内容
siteNameCache := GetSettingByName("siteName")
asserts.Equal("Cloudreve", siteNameCache)
asserts.NoError(mock.ExpectationsWereMet())
// 找不到设置
rows = sqlmock.NewRows([]string{"name", "value", "type"})
mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
siteName = GetSettingByName("siteName not exist")
asserts.Equal("", siteName)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestIsTrueVal(t *testing.T) {
asserts := assert.New(t)
asserts.True(IsTrueVal("1"))
asserts.True(IsTrueVal("true"))
asserts.False(IsTrueVal("0"))
asserts.False(IsTrueVal("false"))
}
func TestGetSiteURL(t *testing.T) {
asserts := assert.New(t)
// 正常
{
err := cache.Deletes([]string{"siteURL"}, "setting_")
asserts.NoError(err)
mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "https://drive.cloudreve.org"))
siteURL := GetSiteURL()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal("https://drive.cloudreve.org", siteURL.String())
}
// 失败 返回默认值
{
err := cache.Deletes([]string{"siteURL"}, "setting_")
asserts.NoError(err)
mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, ":][\\/\\]sdf"))
siteURL := GetSiteURL()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal("https://cloudreve.org", siteURL.String())
}
}
func TestGetIntSetting(t *testing.T) {
asserts := assert.New(t)
// 正常
{
cache.Set("setting_TestGetIntSetting", "10", 0)
res := GetIntSetting("TestGetIntSetting", 20)
asserts.Equal(10, res)
}
// 使用默认值
{
res := GetIntSetting("TestGetIntSetting_2", 20)
asserts.Equal(20, res)
}
}

View file

@ -3,6 +3,7 @@ package model
import (
"errors"
"fmt"
"math"
"strings"
"time"
@ -13,6 +14,10 @@ import (
"github.com/jinzhu/gorm"
)
var (
ErrInsufficientCredit = errors.New("积分不足")
)
// Share 分享模型
type Share struct {
gorm.Model
@ -24,6 +29,7 @@ type Share struct {
Downloads int // 下载数
RemainDownloads int // 剩余下载配额,负值标识无限制
Expires *time.Time // 过期时间,空值表示无过期时间
Score int // 每人次下载扣除积分
PreviewEnabled bool // 是否允许直接预览
SourceName string `gorm:"index:source"` // 用于搜索的字段
@ -135,6 +141,12 @@ func (share *Share) CanBeDownloadBy(user *User) error {
}
return errors.New("your group has no permission to download")
}
// 需要积分但未登录
if share.Score > 0 && user.IsAnonymous() {
return errors.New("you must login to download")
}
return nil
}
@ -149,9 +161,12 @@ func (share *Share) WasDownloadedBy(user *User, c *gin.Context) (exist bool) {
return exist
}
// DownloadBy 增加下载次数,匿名用户不会缓存
// DownloadBy 增加下载次数、检查积分等,匿名用户不会缓存
func (share *Share) DownloadBy(user *User, c *gin.Context) error {
if !share.WasDownloadedBy(user, c) {
if err := share.Purchase(user); err != nil {
return err
}
share.Downloaded()
if !user.IsAnonymous() {
cache.Set(fmt.Sprintf("share_%d_%d", share.ID, user.ID), true,
@ -163,6 +178,25 @@ func (share *Share) DownloadBy(user *User, c *gin.Context) error {
return nil
}
// Purchase 使用积分购买分享
func (share *Share) Purchase(user *User) error {
// 不需要付积分
if share.Score == 0 || user.Group.OptionsSerialized.ShareFree || user.ID == share.UserID {
return nil
}
ok := user.PayScore(share.Score)
if !ok {
return ErrInsufficientCredit
}
scoreRate := GetIntSetting("share_score_rate", 100)
gainedScore := int(math.Ceil(float64(share.Score*scoreRate) / 100))
share.Creator().AddScore(gainedScore)
return nil
}
// Viewed 增加访问次数
func (share *Share) Viewed() {
share.Views++

View file

@ -1,321 +0,0 @@
package model
import (
"errors"
"net/http/httptest"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestShare_Create(t *testing.T) {
asserts := assert.New(t)
share := Share{UserID: 1}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
id, err := share.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.EqualValues(2, id)
}
// 失败
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
id, err := share.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.EqualValues(0, id)
}
}
func TestGetShareByHashID(t *testing.T) {
asserts := assert.New(t)
conf.SystemConfig.HashIDSalt = ""
// 成功
{
mock.ExpectQuery("SELECT(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res := GetShareByHashID("x9T4")
asserts.NoError(mock.ExpectationsWereMet())
asserts.NotNil(res)
}
// 查询失败
{
mock.ExpectQuery("SELECT(.+)").
WillReturnError(errors.New("error"))
res := GetShareByHashID("x9T4")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Nil(res)
}
// ID解码失败
{
res := GetShareByHashID("empty")
asserts.Nil(res)
}
}
func TestShare_IsAvailable(t *testing.T) {
asserts := assert.New(t)
// 下载剩余次数为0
{
share := Share{}
asserts.False(share.IsAvailable())
}
// 时效过期
{
expires := time.Unix(10, 10)
share := Share{
RemainDownloads: -1,
Expires: &expires,
}
asserts.False(share.IsAvailable())
}
// 源对象为目录,但不存在
{
share := Share{
RemainDownloads: -1,
SourceID: 2,
IsDir: true,
}
mock.ExpectQuery("SELECT(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id"}))
asserts.False(share.IsAvailable())
asserts.NoError(mock.ExpectationsWereMet())
}
// 源对象为目录,存在
{
share := Share{
RemainDownloads: -1,
SourceID: 2,
IsDir: false,
}
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(13))
asserts.True(share.IsAvailable())
asserts.NoError(mock.ExpectationsWereMet())
}
// 用户被封禁
{
share := Share{
RemainDownloads: -1,
SourceID: 2,
IsDir: true,
User: User{Status: Baned},
}
asserts.False(share.IsAvailable())
}
}
func TestShare_GetCreator(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
share := Share{UserID: 1}
res := share.Creator()
asserts.NoError(mock.ExpectationsWereMet())
asserts.EqualValues(1, res.ID)
}
func TestShare_Source(t *testing.T) {
asserts := assert.New(t)
// 目录
{
share := Share{IsDir: true, SourceID: 3}
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
asserts.EqualValues(3, share.Source().(*Folder).ID)
asserts.NoError(mock.ExpectationsWereMet())
}
// 文件
{
share := Share{IsDir: false, SourceID: 3}
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
asserts.EqualValues(3, share.Source().(*File).ID)
asserts.NoError(mock.ExpectationsWereMet())
}
}
func TestShare_CanBeDownloadBy(t *testing.T) {
asserts := assert.New(t)
share := Share{}
// 未登录,无权
{
user := &User{
Group: Group{
OptionsSerialized: GroupOption{
ShareDownload: false,
},
},
}
asserts.Error(share.CanBeDownloadBy(user))
}
// 已登录,无权
{
user := &User{
Model: gorm.Model{ID: 1},
Group: Group{
OptionsSerialized: GroupOption{
ShareDownload: false,
},
},
}
asserts.Error(share.CanBeDownloadBy(user))
}
// 成功
{
user := &User{
Model: gorm.Model{ID: 1},
Group: Group{
OptionsSerialized: GroupOption{
ShareDownload: true,
},
},
}
asserts.NoError(share.CanBeDownloadBy(user))
}
}
func TestShare_WasDownloadedBy(t *testing.T) {
asserts := assert.New(t)
share := Share{
Model: gorm.Model{ID: 1},
}
// 已登录,已下载
{
user := User{
Model: gorm.Model{
ID: 1,
},
}
r := httptest.NewRecorder()
c, _ := gin.CreateTestContext(r)
cache.Set("share_1_1", true, 0)
asserts.True(share.WasDownloadedBy(&user, c))
}
}
func TestShare_DownloadBy(t *testing.T) {
asserts := assert.New(t)
share := Share{
Model: gorm.Model{ID: 1},
}
user := User{
Model: gorm.Model{
ID: 1,
},
}
cache.Deletes([]string{"1_1"}, "share_")
r := httptest.NewRecorder()
c, _ := gin.CreateTestContext(r)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := share.DownloadBy(&user, c)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
_, ok := cache.Get("share_1_1")
asserts.True(ok)
}
func TestShare_Viewed(t *testing.T) {
asserts := assert.New(t)
share := Share{}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
share.Viewed()
asserts.NoError(mock.ExpectationsWereMet())
asserts.EqualValues(1, share.Views)
}
func TestShare_UpdateAndDelete(t *testing.T) {
asserts := assert.New(t)
share := Share{}
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := share.Update(map[string]interface{}{"id": 1})
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := share.Delete()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := DeleteShareBySourceIDs([]uint{1}, true)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
}
func TestListShares(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(2))
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2))
res, total := ListShares(1, 1, 10, "desc", true)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Len(res, 2)
asserts.Equal(2, total)
}
func TestSearchShares(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)").
WithArgs("", sqlmock.AnyArg(), "%1%2%").
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, total := SearchShares(1, 10, "id", "1 2")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Len(res, 1)
asserts.Equal(1, total)
}

View file

@ -1,52 +0,0 @@
package model
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestSourceLink_Link(t *testing.T) {
a := assert.New(t)
s := &SourceLink{}
s.ID = 1
// 失败
{
s.File.Name = string([]byte{0x7f})
res, err := s.Link()
a.Error(err)
a.Empty(res)
}
// 成功
{
s.File.Name = "filename"
res, err := s.Link()
a.NoError(err)
a.Contains(res, s.Name)
}
}
func TestGetSourceLinkByID(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
res, err := GetSourceLinkByID(1)
a.NoError(err)
a.NotNil(res)
a.EqualValues(2, res.File.ID)
a.NoError(mock.ExpectationsWereMet())
}
func TestSourceLink_Downloaded(t *testing.T) {
a := assert.New(t)
s := &SourceLink{}
s.ID = 1
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
s.Downloaded()
a.NoError(mock.ExpectationsWereMet())
}

91
models/storage_pack.go Executable file
View file

@ -0,0 +1,91 @@
package model
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
"strconv"
"time"
)
// StoragePack 容量包模型
type StoragePack struct {
// 表字段
gorm.Model
Name string
UserID uint
ActiveTime *time.Time
ExpiredTime *time.Time `gorm:"index:expired"`
Size uint64
}
// Create 创建容量包
func (pack *StoragePack) Create() (uint, error) {
if err := DB.Create(pack).Error; err != nil {
util.Log().Warning("Failed to insert storage pack record: %s", err)
return 0, err
}
return pack.ID, nil
}
// GetAvailablePackSize 返回给定用户当前可用的容量包总容量
func (user *User) GetAvailablePackSize() uint64 {
var (
total uint64
firstExpire *time.Time
timeNow = time.Now()
ttl int64
)
// 尝试从缓存中读取
cacheKey := "pack_size_" + strconv.FormatUint(uint64(user.ID), 10)
if total, ok := cache.Get(cacheKey); ok {
return total.(uint64)
}
// 查找所有有效容量包
packs := user.GetAvailableStoragePacks()
// 计算总容量, 并找到其中最早的过期时间
for _, v := range packs {
total += v.Size
if firstExpire == nil {
firstExpire = v.ExpiredTime
continue
}
if v.ExpiredTime != nil && firstExpire.After(*v.ExpiredTime) {
firstExpire = v.ExpiredTime
}
}
// 用最早的过期时间计算缓存TTL并写入缓存
if firstExpire != nil {
ttl = firstExpire.Unix() - timeNow.Unix()
if ttl > 0 {
_ = cache.Set(cacheKey, total, int(ttl))
}
}
return total
}
// GetAvailableStoragePacks 返回用户可用的容量包
func (user *User) GetAvailableStoragePacks() []StoragePack {
var packs []StoragePack
timeNow := time.Now()
// 查找所有有效容量包
DB.Where("expired_time > ? AND user_id = ?", timeNow, user.ID).Find(&packs)
return packs
}
// GetExpiredStoragePack 获取已过期的容量包
func GetExpiredStoragePack() []StoragePack {
var packs []StoragePack
DB.Where("expired_time < ?", time.Now()).Find(&packs)
return packs
}
// Delete 删除容量包
func (pack *StoragePack) Delete() error {
return DB.Delete(&pack).Error
}

View file

@ -1,63 +0,0 @@
package model
import (
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestTag_Create(t *testing.T) {
asserts := assert.New(t)
tag := Tag{}
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
id, err := tag.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.EqualValues(1, id)
}
// 失败
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
id, err := tag.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.EqualValues(0, id)
}
}
func TestDeleteTagByID(t *testing.T) {
asserts := assert.New(t)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := DeleteTagByID(1, 2)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
}
func TestGetTagsByUID(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetTagsByUID(1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Len(res, 1)
}
func TestGetTagsByID(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("tag"))
res, err := GetTagsByID(1, 1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.EqualValues("tag", res.Name)
}

View file

@ -1,104 +0,0 @@
package model
import (
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
)
func TestTask_Create(t *testing.T) {
asserts := assert.New(t)
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
task := Task{Props: "1"}
id, err := task.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.EqualValues(1, id)
}
// 失败
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
task := Task{Props: "1"}
id, err := task.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.EqualValues(0, id)
}
}
func TestTask_SetError(t *testing.T) {
asserts := assert.New(t)
task := Task{
Model: gorm.Model{ID: 1},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.NoError(task.SetError("error"))
asserts.NoError(mock.ExpectationsWereMet())
}
func TestTask_SetStatus(t *testing.T) {
asserts := assert.New(t)
task := Task{
Model: gorm.Model{ID: 1},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.NoError(task.SetStatus(1))
asserts.NoError(mock.ExpectationsWereMet())
}
func TestTask_SetProgress(t *testing.T) {
asserts := assert.New(t)
task := Task{
Model: gorm.Model{ID: 1},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.NoError(task.SetProgress(1))
asserts.NoError(mock.ExpectationsWereMet())
}
func TestGetTasksByID(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res, err := GetTasksByID(1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.EqualValues(1, res.ID)
}
func TestListTasks(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5))
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5))
res, total := ListTasks(1, 1, 10, "")
asserts.NoError(mock.ExpectationsWereMet())
asserts.EqualValues(5, total)
asserts.Len(res, 1)
}
func TestGetTasksByStatus(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").
WithArgs(1, 2).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
res := GetTasksByStatus(1, 2)
a.NoError(mock.ExpectationsWereMet())
a.Len(res, 1)
}

View file

@ -7,6 +7,7 @@ import (
"encoding/hex"
"encoding/json"
"strings"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
@ -28,20 +29,25 @@ const (
type User struct {
// 表字段
gorm.Model
Email string `gorm:"type:varchar(100);unique_index"`
Nick string `gorm:"size:50"`
Password string `json:"-"`
Status int
GroupID uint
Storage uint64
TwoFactor string
Avatar string
Options string `json:"-" gorm:"size:4294967295"`
Authn string `gorm:"size:4294967295"`
Email string `gorm:"type:varchar(100);unique_index"`
Nick string `gorm:"size:50"`
Password string `json:"-"`
Status int
GroupID uint
Storage uint64
OpenID string
TwoFactor string
Avatar string
Options string `json:"-" gorm:"size:4294967295"`
Authn string `gorm:"size:4294967295"`
Score int
PreviousGroupID uint // 初始用户组
GroupExpires *time.Time // 用户组过期日期
NotifyDate *time.Time // 通知超出配额时的日期
Phone string
// 关联模型
Group Group `gorm:"save_associations:false:false"`
Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"`
Group Group `gorm:"save_associations:false:false"`
// 数据库忽略字段
OptionsSerialized UserOption `gorm:"-"`
@ -53,8 +59,9 @@ func init() {
// UserOption 用户个性化配置字段
type UserOption struct {
ProfileOff bool `json:"profile_off,omitempty"`
PreferredTheme string `json:"preferred_theme,omitempty"`
ProfileOff bool `json:"profile_off,omitempty"`
PreferredPolicy uint `json:"preferred_policy,omitempty"`
PreferredTheme string `json:"preferred_theme,omitempty"`
}
// Root 获取用户的根目录
@ -99,6 +106,25 @@ func (user *User) ChangeStorage(tx *gorm.DB, operator string, size uint64) error
return tx.Model(user).Update("storage", gorm.Expr("storage "+operator+" ?", size)).Error
}
// PayScore 扣除积分,返回是否成功
func (user *User) PayScore(score int) bool {
if score == 0 {
return true
}
if score <= user.Score {
user.Score -= score
DB.Model(user).Update("score", gorm.Expr("score - ?", score))
return true
}
return false
}
// AddScore 增加积分
func (user *User) AddScore(score int) {
user.Score += score
DB.Model(user).Update("score", gorm.Expr("score + ?", score))
}
// IncreaseStorageWithoutCheck 忽略可用容量,增加用户已用容量
func (user *User) IncreaseStorageWithoutCheck(size uint64) {
if size == 0 {
@ -111,19 +137,58 @@ func (user *User) IncreaseStorageWithoutCheck(size uint64) {
// GetRemainingCapacity 获取剩余配额
func (user *User) GetRemainingCapacity() uint64 {
total := user.Group.MaxStorage
total := user.Group.MaxStorage + user.GetAvailablePackSize()
if total <= user.Storage {
return 0
}
return total - user.Storage
}
// GetPolicyID 获取用户当前的存储策略ID
func (user *User) GetPolicyID(prefer uint) uint {
if len(user.Group.PolicyList) > 0 {
return user.Group.PolicyList[0]
// GetPolicyID 获取给定目录的存储策略, 如果为 nil 则使用默认
func (user *User) GetPolicyID(folder *Folder) *Policy {
if user.IsAnonymous() {
return &Policy{Type: "anonymous"}
}
return 0
defaultPolicy := uint(1)
if len(user.Group.PolicyList) > 0 {
defaultPolicy = user.Group.PolicyList[0]
}
if folder != nil {
prefer := folder.PolicyID
if prefer == 0 && folder.InheritPolicyID > 0 {
prefer = folder.InheritPolicyID
}
if prefer > 0 && util.ContainsUint(user.Group.PolicyList, prefer) {
defaultPolicy = prefer
}
}
p, _ := GetPolicyByID(defaultPolicy)
return &p
}
// GetPolicyByPreference 在可用存储策略中优先获取 preference
func (user *User) GetPolicyByPreference(preference uint) *Policy {
if user.IsAnonymous() {
return &Policy{Type: "anonymous"}
}
defaultPolicy := uint(1)
if len(user.Group.PolicyList) > 0 {
defaultPolicy = user.Group.PolicyList[0]
}
if preference != 0 {
if util.ContainsUint(user.Group.PolicyList, preference) {
defaultPolicy = preference
}
}
p, _ := GetPolicyByID(defaultPolicy)
return &p
}
// GetUserByID 用ID获取用户
@ -183,6 +248,27 @@ func (user *User) AfterCreate(tx *gorm.DB) (err error) {
OwnerID: user.ID,
}
tx.Create(defaultFolder)
// 创建用户初始文件记录
initialFiles := GetSettingByNameFromTx(tx, "initial_files")
if initialFiles != "" {
initialFileIDs := make([]uint, 0)
if err := json.Unmarshal([]byte(initialFiles), &initialFileIDs); err != nil {
return err
}
if files, err := GetFilesByIDsFromTX(tx, initialFileIDs, 0); err == nil {
for _, file := range files {
file.ID = 0
file.UserID = user.ID
file.FolderID = defaultFolder.ID
user.Storage += file.Size
tx.Create(&file)
}
tx.Save(user)
}
}
return err
}
@ -193,12 +279,10 @@ func (user *User) AfterFind() (err error) {
err = json.Unmarshal([]byte(user.Options), &user.OptionsSerialized)
}
// 预加载存储策略
user.Policy, _ = GetPolicyByID(user.GetPolicyID(0))
return err
}
//SerializeOptions 将序列后的Option写入到数据库字段
// SerializeOptions 将序列后的Option写入到数据库字段
func (user *User) SerializeOptions() (err error) {
optionsValue, err := json.Marshal(&user.OptionsSerialized)
user.Options = string(optionsValue)
@ -261,7 +345,6 @@ func (user *User) SetPassword(password string) error {
// NewAnonymousUser 返回一个匿名用户
func NewAnonymousUser() *User {
user := User{}
user.Policy.Type = "anonymous"
user.Group, _ = GetGroupByID(3)
return &user
}
@ -271,6 +354,20 @@ func (user *User) IsAnonymous() bool {
return user.ID == 0
}
// Notified 更新用户容量超额通知日期
func (user *User) Notified() {
if user.NotifyDate == nil {
timeNow := time.Now()
user.NotifyDate = &timeNow
DB.Model(&user).Update("notify_date", user.NotifyDate)
}
}
// ClearNotified 清除用户通知标记
func (user *User) ClearNotified() {
DB.Model(&user).Update("notify_date", nil)
}
// SetStatus 设定用户状态
func (user *User) SetStatus(status int) {
DB.Model(&user).Update("status", status)
@ -288,3 +385,45 @@ func (user *User) UpdateOptions() error {
}
return user.Update(map[string]interface{}{"options": user.Options})
}
// GetGroupExpiredUsers 获取用户组过期的用户
func GetGroupExpiredUsers() []User {
var users []User
DB.Where("group_expires < ? and previous_group_id <> 0", time.Now()).Find(&users)
return users
}
// GetTolerantExpiredUser 获取超过宽容期的用户
func GetTolerantExpiredUser() []User {
var users []User
DB.Set("gorm:auto_preload", true).Where("notify_date < ?", time.Now().Add(
time.Duration(-GetIntSetting("ban_time", 10))*time.Second),
).Find(&users)
return users
}
// GroupFallback 回退到初始用户组
func (user *User) GroupFallback() {
if user.GroupExpires != nil && user.PreviousGroupID != 0 {
user.Group.ID = user.PreviousGroupID
DB.Model(&user).Updates(map[string]interface{}{
"group_expires": nil,
"previous_group_id": 0,
"group_id": user.PreviousGroupID,
})
}
}
// UpgradeGroup 升级用户组
func (user *User) UpgradeGroup(id uint, expires *time.Time) error {
user.Group.ID = id
previousGroupID := user.GroupID
if user.PreviousGroupID != 0 && user.GroupID == id {
previousGroupID = user.PreviousGroupID
}
return DB.Model(&user).Updates(map[string]interface{}{
"group_expires": expires,
"previous_group_id": previousGroupID,
"group_id": id,
}).Error
}

View file

@ -1,100 +0,0 @@
package model
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/duo-labs/webauthn/webauthn"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
)
func TestUser_RegisterAuthn(t *testing.T) {
asserts := assert.New(t)
credential := webauthn.Credential{}
user := User{
Model: gorm.Model{ID: 1},
}
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
user.RegisterAuthn(&credential)
asserts.NoError(mock.ExpectationsWereMet())
}
}
func TestUser_WebAuthnCredentials(t *testing.T) {
asserts := assert.New(t)
user := User{
Model: gorm.Model{ID: 1},
Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`,
}
{
credentials := user.WebAuthnCredentials()
asserts.Len(credentials, 1)
}
}
func TestUser_WebAuthnDisplayName(t *testing.T) {
asserts := assert.New(t)
user := User{
Model: gorm.Model{ID: 1},
Nick: "123",
}
{
nick := user.WebAuthnDisplayName()
asserts.Equal("123", nick)
}
}
func TestUser_WebAuthnIcon(t *testing.T) {
asserts := assert.New(t)
user := User{
Model: gorm.Model{ID: 1},
}
{
icon := user.WebAuthnIcon()
asserts.NotEmpty(icon)
}
}
func TestUser_WebAuthnID(t *testing.T) {
asserts := assert.New(t)
user := User{
Model: gorm.Model{ID: 1},
}
{
id := user.WebAuthnID()
asserts.Len(id, 8)
}
}
func TestUser_WebAuthnName(t *testing.T) {
asserts := assert.New(t)
user := User{
Model: gorm.Model{ID: 1},
Email: "abslant@foxmail.com",
}
{
name := user.WebAuthnName()
asserts.Equal("abslant@foxmail.com", name)
}
}
func TestUser_RemoveAuthn(t *testing.T) {
asserts := assert.New(t)
user := User{
Model: gorm.Model{ID: 1},
Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`,
}
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
user.RemoveAuthn("123")
asserts.NoError(mock.ExpectationsWereMet())
}
}

View file

@ -1,438 +0,0 @@
package model
import (
"encoding/json"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestGetUserByID(t *testing.T) {
asserts := assert.New(t)
cache.Deletes([]string{"1"}, "policy_")
//找到用户时
userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}).
AddRow(1, nil, "admin@cloudreve.org", "{}", 1)
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows)
groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}).
AddRow(1, "管理员", "[1]")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows)
policyRows := sqlmock.NewRows([]string{"id", "name"}).
AddRow(1, "默认存储策略")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
user, err := GetUserByID(1)
asserts.NoError(err)
asserts.Equal(User{
Model: gorm.Model{
ID: 1,
DeletedAt: nil,
},
Email: "admin@cloudreve.org",
Options: "{}",
GroupID: 1,
Group: Group{
Model: gorm.Model{
ID: 1,
},
Name: "管理员",
Policies: "[1]",
PolicyList: []uint{1},
},
Policy: Policy{
Model: gorm.Model{
ID: 1,
},
OptionsSerialized: PolicyOption{
FileType: []string{},
},
Name: "默认存储策略",
},
}, user)
//未找到用户时
mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found"))
user, err = GetUserByID(1)
asserts.Error(err)
asserts.Equal(User{}, user)
}
func TestGetActiveUserByID(t *testing.T) {
asserts := assert.New(t)
cache.Deletes([]string{"1"}, "policy_")
//找到用户时
userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}).
AddRow(1, nil, "admin@cloudreve.org", "{}", 1)
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows)
groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}).
AddRow(1, "管理员", "[1]")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows)
policyRows := sqlmock.NewRows([]string{"id", "name"}).
AddRow(1, "默认存储策略")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
user, err := GetActiveUserByID(1)
asserts.NoError(err)
asserts.Equal(User{
Model: gorm.Model{
ID: 1,
DeletedAt: nil,
},
Email: "admin@cloudreve.org",
Options: "{}",
GroupID: 1,
Group: Group{
Model: gorm.Model{
ID: 1,
},
Name: "管理员",
Policies: "[1]",
PolicyList: []uint{1},
},
Policy: Policy{
Model: gorm.Model{
ID: 1,
},
OptionsSerialized: PolicyOption{
FileType: []string{},
},
Name: "默认存储策略",
},
}, user)
//未找到用户时
mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found"))
user, err = GetActiveUserByID(1)
asserts.Error(err)
asserts.Equal(User{}, user)
}
func TestUser_SetPassword(t *testing.T) {
asserts := assert.New(t)
user := User{}
err := user.SetPassword("Cause Sega does what nintendon't")
asserts.NoError(err)
asserts.NotEmpty(user.Password)
}
func TestUser_CheckPassword(t *testing.T) {
asserts := assert.New(t)
user := User{}
err := user.SetPassword("Cause Sega does what nintendon't")
asserts.NoError(err)
//密码正确
res, err := user.CheckPassword("Cause Sega does what nintendon't")
asserts.NoError(err)
asserts.True(res)
//密码错误
res, err = user.CheckPassword("Cause Sega does what Nintendon't")
asserts.NoError(err)
asserts.False(res)
//密码字段为空
user = User{}
res, err = user.CheckPassword("Cause Sega does what nintendon't")
asserts.Error(err)
asserts.False(res)
// 未知密码类型
user = User{}
user.Password = "1:2:3"
res, err = user.CheckPassword("Cause Sega does what nintendon't")
asserts.Error(err)
asserts.False(res)
// V2密码错误
user = User{}
user.Password = "md5:2:3"
res, err = user.CheckPassword("Cause Sega does what nintendon't")
asserts.NoError(err)
asserts.False(res)
// V2密码正确
user = User{}
user.Password = "md5:d8446059f8846a2c111a7f53515665fb:sdshare"
res, err = user.CheckPassword("admin")
asserts.NoError(err)
asserts.True(res)
}
func TestNewUser(t *testing.T) {
asserts := assert.New(t)
newUser := NewUser()
asserts.IsType(User{}, newUser)
asserts.Empty(newUser.Avatar)
}
func TestUser_AfterFind(t *testing.T) {
asserts := assert.New(t)
cache.Deletes([]string{"0"}, "policy_")
policyRows := sqlmock.NewRows([]string{"id", "name"}).
AddRow(144, "默认存储策略")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
newUser := NewUser()
err := newUser.AfterFind()
err = newUser.BeforeSave()
expected := UserOption{}
err = json.Unmarshal([]byte(newUser.Options), &expected)
asserts.NoError(err)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(expected, newUser.OptionsSerialized)
asserts.Equal("默认存储策略", newUser.Policy.Name)
}
func TestUser_BeforeSave(t *testing.T) {
asserts := assert.New(t)
newUser := NewUser()
err := newUser.BeforeSave()
expected, err := json.Marshal(newUser.OptionsSerialized)
asserts.NoError(err)
asserts.Equal(string(expected), newUser.Options)
}
func TestUser_GetPolicyID(t *testing.T) {
asserts := assert.New(t)
newUser := NewUser()
newUser.Group.PolicyList = []uint{1}
asserts.EqualValues(1, newUser.GetPolicyID(0))
newUser.Group.PolicyList = nil
asserts.EqualValues(0, newUser.GetPolicyID(0))
newUser.Group.PolicyList = []uint{}
asserts.EqualValues(0, newUser.GetPolicyID(0))
}
func TestUser_GetRemainingCapacity(t *testing.T) {
asserts := assert.New(t)
newUser := NewUser()
cache.Set("pack_size_0", uint64(0), 0)
newUser.Group.MaxStorage = 100
asserts.Equal(uint64(100), newUser.GetRemainingCapacity())
newUser.Group.MaxStorage = 100
newUser.Storage = 1
asserts.Equal(uint64(99), newUser.GetRemainingCapacity())
newUser.Group.MaxStorage = 100
newUser.Storage = 100
asserts.Equal(uint64(0), newUser.GetRemainingCapacity())
newUser.Group.MaxStorage = 100
newUser.Storage = 200
asserts.Equal(uint64(0), newUser.GetRemainingCapacity())
}
func TestUser_DeductionCapacity(t *testing.T) {
asserts := assert.New(t)
cache.Deletes([]string{"1"}, "policy_")
userRows := sqlmock.NewRows([]string{"id", "deleted_at", "storage", "options", "group_id"}).
AddRow(1, nil, 0, "{}", 1)
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows)
groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}).
AddRow(1, "管理员", "[1]")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows)
policyRows := sqlmock.NewRows([]string{"id", "name"}).
AddRow(1, "默认存储策略")
mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
newUser, err := GetUserByID(1)
newUser.Group.MaxStorage = 100
cache.Set("pack_size_1", uint64(0), 0)
asserts.NoError(err)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(false, newUser.IncreaseStorage(101))
asserts.Equal(uint64(0), newUser.Storage)
asserts.Equal(true, newUser.IncreaseStorage(1))
asserts.Equal(uint64(1), newUser.Storage)
asserts.Equal(true, newUser.IncreaseStorage(99))
asserts.Equal(uint64(100), newUser.Storage)
asserts.Equal(false, newUser.IncreaseStorage(1))
asserts.Equal(uint64(100), newUser.Storage)
asserts.True(newUser.IncreaseStorage(0))
}
func TestUser_DeductionStorage(t *testing.T) {
asserts := assert.New(t)
// 减少零
{
user := User{Storage: 1}
asserts.True(user.DeductionStorage(0))
asserts.Equal(uint64(1), user.Storage)
}
// 正常
{
user := User{
Model: gorm.Model{ID: 1},
Storage: 10,
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs(5, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.True(user.DeductionStorage(5))
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(uint64(5), user.Storage)
}
// 减少的超出可用的
{
user := User{
Model: gorm.Model{ID: 1},
Storage: 10,
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs(0, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.False(user.DeductionStorage(20))
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(uint64(0), user.Storage)
}
}
func TestUser_IncreaseStorageWithoutCheck(t *testing.T) {
asserts := assert.New(t)
// 增加零
{
user := User{}
user.IncreaseStorageWithoutCheck(0)
asserts.Equal(uint64(0), user.Storage)
}
// 减少零
{
user := User{
Model: gorm.Model{ID: 1},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WithArgs(10, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
user.IncreaseStorageWithoutCheck(10)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(uint64(10), user.Storage)
}
}
func TestGetActiveUserByEmail(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
_, err := GetActiveUserByEmail("abslant@foxmail.com")
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestGetUserByEmail(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WithArgs("abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
_, err := GetUserByEmail("abslant@foxmail.com")
asserts.Error(err)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestUser_AfterCreate(t *testing.T) {
asserts := assert.New(t)
user := User{Model: gorm.Model{ID: 1}}
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := user.AfterCreate(DB)
asserts.NoError(err)
asserts.NoError(mock.ExpectationsWereMet())
}
func TestUser_Root(t *testing.T) {
asserts := assert.New(t)
user := User{Model: gorm.Model{ID: 1}}
// 根目录存在
{
mock.ExpectQuery("SELECT(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "根目录"))
root, err := user.Root()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal("根目录", root.Name)
}
// 根目录不存在
{
mock.ExpectQuery("SELECT(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
_, err := user.Root()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
}
func TestNewAnonymousUser(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
user := NewAnonymousUser()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NotNil(user)
asserts.EqualValues(3, user.Group.ID)
}
func TestUser_IsAnonymous(t *testing.T) {
asserts := assert.New(t)
user := User{}
asserts.True(user.IsAnonymous())
user.ID = 1
asserts.False(user.IsAnonymous())
}
func TestUser_SetStatus(t *testing.T) {
asserts := assert.New(t)
user := User{}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
user.SetStatus(Baned)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Equal(Baned, user.Status)
}
func TestUser_UpdateOptions(t *testing.T) {
asserts := assert.New(t)
user := User{}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
asserts.NoError(user.UpdateOptions())
asserts.NoError(mock.ExpectationsWereMet())
}

View file

@ -46,3 +46,8 @@ func DeleteWebDAVAccountByID(id, uid uint) {
func UpdateWebDAVAccountByID(id, uid uint, updates map[string]interface{}) {
DB.Model(&Webdav{Model: gorm.Model{ID: id}, UserID: uid}).Updates(updates)
}
// UpdateWebDAVAccountReadonlyByID 根据账户ID和UID更新账户的只读性
func UpdateWebDAVAccountReadonlyByID(id, uid uint, readonly bool) {
DB.Model(&Webdav{Model: gorm.Model{ID: id}, UserID: uid}).UpdateColumn("readonly", readonly)
}

View file

@ -1,60 +0,0 @@
package model
import (
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestWebdav_Create(t *testing.T) {
asserts := assert.New(t)
// 成功
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
task := Webdav{}
id, err := task.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.EqualValues(1, id)
}
// 失败
{
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
task := Webdav{}
id, err := task.Create()
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.EqualValues(0, id)
}
}
func TestGetWebdavByPassword(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
_, err := GetWebdavByPassword("e", 1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
func TestListWebDAVAccounts(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
res := ListWebDAVAccounts(1)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Len(res, 0)
}
func TestDeleteWebDAVAccountByID(t *testing.T) {
asserts := assert.New(t)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
DeleteWebDAVAccountByID(1, 1)
asserts.NoError(mock.ExpectationsWereMet())
}

View file

@ -1,66 +0,0 @@
package aria2
import (
"database/sql"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/jinzhu/gorm"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestInit(t *testing.T) {
a := assert.New(t)
mockPool := &mocks.NodePoolMock{}
mockPool.On("GetNodeByID", testMock.Anything).Return(nil)
mockQueue := mq.NewMQ()
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
Init(false, mockPool, mockQueue)
a.NoError(mock.ExpectationsWereMet())
mockPool.AssertExpectations(t)
}
func TestTestRPCConnection(t *testing.T) {
a := assert.New(t)
// url not legal
{
res, err := TestRPCConnection(string([]byte{0x7f}), "", 10)
a.Error(err)
a.Empty(res.Version)
}
// rpc failed
{
res, err := TestRPCConnection("ws://0.0.0.0", "", 0)
a.Error(err)
a.Empty(res.Version)
}
}
func TestGetLoadBalancer(t *testing.T) {
a := assert.New(t)
a.NotPanics(func() {
GetLoadBalancer()
})
}

View file

@ -1,54 +0,0 @@
package common
import (
"testing"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/stretchr/testify/assert"
)
func TestDummyAria2(t *testing.T) {
a := assert.New(t)
d := &DummyAria2{}
a.NoError(d.Init())
res, err := d.CreateTask(&model.Download{}, map[string]interface{}{})
a.Empty(res)
a.Error(err)
_, err = d.Status(&model.Download{})
a.Error(err)
err = d.Cancel(&model.Download{})
a.Error(err)
err = d.Select(&model.Download{}, []int{})
a.Error(err)
configRes := d.GetConfig()
a.NotNil(configRes)
err = d.DeleteTempFile(&model.Download{})
a.Error(err)
}
func TestGetStatus(t *testing.T) {
a := assert.New(t)
a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete)
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading)
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: "single"},
TotalLength: "100", CompletedLength: "50"}), Downloading)
a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
BitTorrent: rpc.BitTorrentInfo{Mode: "multi"},
TotalLength: "100", CompletedLength: "100"}), Seeding)
a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready)
a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused)
a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error)
a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled)
a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown)
}

View file

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"path/filepath"
"strconv"
"time"
@ -52,6 +53,7 @@ func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
// Loop 开启监控循环
func (monitor *Monitor) Loop(mqClient mq.MQ) {
defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier)
fmt.Println(cluster.Default)
// 首次循环立即更新
interval := 50 * time.Millisecond
@ -190,6 +192,10 @@ func (monitor *Monitor) ValidateFile() error {
}
defer fs.Recycle()
if err := fs.SetPolicyFromPath(monitor.Task.Dst); err != nil {
return fmt.Errorf("failed to switch policy to target dir: %w", err)
}
// 创建上下文环境
file := &fsctx.FileStream{
Size: monitor.Task.TotalSize,

View file

@ -1,447 +0,0 @@
package monitor
import (
"database/sql"
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestNewMonitor(t *testing.T) {
a := assert.New(t)
mockMQ := mq.NewMQ()
// node not available
{
mockPool := &mocks.NodePoolMock{}
mockPool.On("GetNodeByID", uint(1)).Return(nil)
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
task := &model.Download{
Model: gorm.Model{ID: 1},
}
NewMonitor(task, mockPool, mockMQ)
mockPool.AssertExpectations(t)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(task.Error)
}
// success
{
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
mockPool := &mocks.NodePoolMock{}
mockPool.On("GetNodeByID", uint(1)).Return(mockNode)
task := &model.Download{
Model: gorm.Model{ID: 1},
}
NewMonitor(task, mockPool, mockMQ)
mockNode.AssertExpectations(t)
mockPool.AssertExpectations(t)
}
}
func TestMonitor_Loop(t *testing.T) {
a := assert.New(t)
mockMQ := mq.NewMQ()
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
m := &Monitor{
retried: MAX_RETRY,
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
notifier: mockMQ.Subscribe("test", 1),
}
// into interval loop
{
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
m.Loop(mockMQ)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(m.Task.Error)
}
// into notifier loop
{
m.Task.Error = ""
mockMQ.Publish("test", mq.Message{})
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
m.Loop(mockMQ)
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(m.Task.Error)
}
}
func TestMonitor_UpdateFailedAfterRetry(t *testing.T) {
a := assert.New(t)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
for i := 0; i < MAX_RETRY; i++ {
a.False(m.Update())
}
mockNode.AssertExpectations(t)
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
a.NotEmpty(m.Task.Error)
}
func TestMonitor_UpdateMagentoFollow(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
FollowedBy: []string{"next"},
}, nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.False(m.Update())
a.NoError(mock.ExpectationsWereMet())
a.Equal("next", m.Task.GID)
mockAria2.AssertExpectations(t)
}
func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
func TestMonitor_UpdateCompleted(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "complete",
}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
mockNode.On("ID").Return(uint(1))
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
func TestMonitor_UpdateError(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "error",
ErrorMessage: "error",
}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
a.NotEmpty(m.Task.Error)
}
func TestMonitor_UpdateActive(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "active",
}, nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.False(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
}
func TestMonitor_UpdateRemoved(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "removed",
}, nil)
mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.Equal(common.Canceled, m.Task.Status)
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
}
func TestMonitor_UpdateUnknown(t *testing.T) {
a := assert.New(t)
mockAria2 := &mocks.Aria2Mock{}
mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
Status: "unknown",
}, nil)
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(mockAria2)
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
a.True(m.Update())
a.NoError(mock.ExpectationsWereMet())
mockAria2.AssertExpectations(t)
mockNode.AssertExpectations(t)
}
func TestMonitor_UpdateTaskInfoValidateFailed(t *testing.T) {
a := assert.New(t)
status := rpc.StatusInfo{
Status: "completed",
TotalLength: "100",
CompletedLength: "50",
DownloadSpeed: "20",
}
mockNode := &mocks.NodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
m := &Monitor{
node: mockNode,
Task: &model.Download{Model: gorm.Model{ID: 1}},
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := m.UpdateTaskInfo(status)
a.Error(err)
a.NoError(mock.ExpectationsWereMet())
mockNode.AssertExpectations(t)
}
func TestMonitor_ValidateFile(t *testing.T) {
a := assert.New(t)
m := &Monitor{
Task: &model.Download{
Model: gorm.Model{ID: 1},
TotalSize: 100,
},
}
// failed to create filesystem
{
m.Task.User = &model.User{
Policy: model.Policy{
Type: "random",
},
}
a.Equal(filesystem.ErrUnknownPolicyType, m.ValidateFile())
}
// User capacity not enough
{
m.Task.User = &model.User{
Group: model.Group{
MaxStorage: 99,
},
Policy: model.Policy{
Type: "local",
},
}
a.Equal(filesystem.ErrInsufficientCapacity, m.ValidateFile())
}
// single file too big
{
m.Task.StatusInfo.Files = []rpc.FileInfo{
{
Length: "100",
Selected: "true",
},
}
m.Task.User = &model.User{
Group: model.Group{
MaxStorage: 100,
},
Policy: model.Policy{
Type: "local",
MaxSize: 99,
},
}
a.Equal(filesystem.ErrFileSizeTooBig, m.ValidateFile())
}
// all pass
{
m.Task.StatusInfo.Files = []rpc.FileInfo{
{
Length: "100",
Selected: "true",
},
}
m.Task.User = &model.User{
Group: model.Group{
MaxStorage: 100,
},
Policy: model.Policy{
Type: "local",
MaxSize: 100,
},
}
a.NoError(m.ValidateFile())
}
}
func TestMonitor_Complete(t *testing.T) {
a := assert.New(t)
mockNode := &mocks.NodeMock{}
mockNode.On("ID").Return(uint(1))
mockPool := &mocks.TaskPoolMock{}
mockPool.On("Submit", testMock.Anything)
m := &Monitor{
node: mockNode,
Task: &model.Download{
Model: gorm.Model{ID: 1},
TotalSize: 100,
UserID: 9414,
},
}
m.Task.StatusInfo.Files = []rpc.FileInfo{
{
Length: "100",
Selected: "true",
},
}
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4))
mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
a.False(m.Complete(mockPool))
m.Task.StatusInfo.Status = "complete"
a.True(m.Complete(mockPool))
a.NoError(mock.ExpectationsWereMet())
mockNode.AssertExpectations(t)
mockPool.AssertExpectations(t)
}

View file

@ -1,136 +0,0 @@
package auth
import (
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/stretchr/testify/assert"
)
func TestSignURI(t *testing.T) {
asserts := assert.New(t)
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
// 成功
{
sign, err := SignURI(General, "/api/v3/something?id=1", 0)
asserts.NoError(err)
queries := sign.Query()
asserts.Equal("1", queries.Get("id"))
asserts.NotEmpty(queries.Get("sign"))
}
// URI解码失败
{
sign, err := SignURI(General, "://dg.;'f]gh./'", 0)
asserts.Error(err)
asserts.Nil(sign)
}
}
func TestCheckURI(t *testing.T) {
asserts := assert.New(t)
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
// 成功
{
sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", 10)
asserts.NoError(err)
asserts.NoError(CheckURI(General, sign))
}
// 过期
{
sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", -1)
asserts.NoError(err)
asserts.Error(CheckURI(General, sign))
}
}
func TestSignRequest(t *testing.T) {
asserts := assert.New(t)
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
// 非上传请求
{
req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/slave/upload", strings.NewReader("I am body."))
asserts.NoError(err)
req = SignRequest(General, req, 0)
asserts.NotEmpty(req.Header["Authorization"])
}
// 上传请求
{
req, err := http.NewRequest(
"POST",
"http://127.0.0.1/api/v3/slave/upload",
strings.NewReader("I am body."),
)
asserts.NoError(err)
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
req = SignRequest(General, req, 10)
asserts.NotEmpty(req.Header["Authorization"])
}
}
func TestCheckRequest(t *testing.T) {
asserts := assert.New(t)
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
// 缺少请求头
{
req, err := http.NewRequest(
"POST",
"http://127.0.0.1/api/v3/upload",
strings.NewReader("I am body."),
)
asserts.NoError(err)
err = CheckRequest(General, req)
asserts.Error(err)
asserts.Equal(ErrAuthHeaderMissing, err)
}
// 非上传请求 验证成功
{
req, err := http.NewRequest(
"POST",
"http://127.0.0.1/api/v3/upload",
strings.NewReader("I am body."),
)
asserts.NoError(err)
req = SignRequest(General, req, 0)
err = CheckRequest(General, req)
asserts.NoError(err)
}
// 上传请求 验证成功
{
req, err := http.NewRequest(
"POST",
"http://127.0.0.1/api/v3/upload",
strings.NewReader("I am body."),
)
asserts.NoError(err)
req.Header["X-Cr-Policy"] = []string{"I am Policy"}
req = SignRequest(General, req, 0)
err = CheckRequest(General, req)
asserts.NoError(err)
}
// 非上传请求 失败
{
req, err := http.NewRequest(
"POST",
"http://127.0.0.1/api/v3/upload",
strings.NewReader("I am body."),
)
asserts.NoError(err)
req = SignRequest(General, req, 0)
req.Body = ioutil.NopCloser(strings.NewReader("2333"))
err = CheckRequest(General, req)
asserts.Error(err)
}
}

View file

@ -1,94 +0,0 @@
package auth
import (
"database/sql"
"fmt"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
var mock sqlmock.Sqlmock
func TestMain(m *testing.M) {
// 设置gin为测试模式
gin.SetMode(gin.TestMode)
// 初始化sqlmock
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
mockDB, _ := gorm.Open("mysql", db)
model.DB = mockDB
defer db.Close()
m.Run()
}
func TestHMACAuth_Sign(t *testing.T) {
asserts := assert.New(t)
auth := HMACAuth{
SecretKey: []byte(util.RandStringRunes(256)),
}
asserts.NotEmpty(auth.Sign("content", 0))
}
func TestHMACAuth_Check(t *testing.T) {
asserts := assert.New(t)
auth := HMACAuth{
SecretKey: []byte(util.RandStringRunes(256)),
}
// 正常,永不过期
{
sign := auth.Sign("content", 0)
asserts.NoError(auth.Check("content", sign))
}
// 过期
{
sign := auth.Sign("content", 1)
asserts.Error(auth.Check("content", sign))
}
// 签名格式错误
{
sign := auth.Sign("content", 1)
asserts.Error(auth.Check("content", sign+":"))
}
// 过期日期格式错误
{
asserts.Error(auth.Check("content", "ErrAuthFailed:ErrAuthFailed"))
}
// 签名有误
{
asserts.Error(auth.Check("content", fmt.Sprintf("sign:%d", time.Now().Unix()+10)))
}
}
func TestInit(t *testing.T) {
asserts := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312"))
Init()
asserts.NoError(mock.ExpectationsWereMet())
// slave模式
conf.SystemConfig.Mode = "slave"
asserts.Panics(func() {
Init()
})
}

View file

@ -1,17 +0,0 @@
package authn
import (
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/stretchr/testify/assert"
)
func TestInit(t *testing.T) {
asserts := assert.New(t)
cache.Set("setting_siteURL", "http://cloudreve.org", 0)
cache.Set("setting_siteName", "Cloudreve", 0)
res, err := NewAuthnInstance()
asserts.NotNil(res)
asserts.NoError(err)
}

View file

@ -1,12 +0,0 @@
package balancer
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewBalancer(t *testing.T) {
a := assert.New(t)
a.NotNil(NewBalancer(""))
a.IsType(&RoundRobin{}, NewBalancer("RoundRobin"))
}

View file

@ -1,42 +0,0 @@
package balancer
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestRoundRobin_NextIndex(t *testing.T) {
a := assert.New(t)
r := &RoundRobin{}
total := 5
for i := 1; i < total; i++ {
a.Equal(i, r.NextIndex(total))
}
for i := 0; i < total; i++ {
a.Equal(i, r.NextIndex(total))
}
}
func TestRoundRobin_NextPeer(t *testing.T) {
a := assert.New(t)
r := &RoundRobin{}
// not slice
{
err, _ := r.NextPeer("s")
a.Equal(ErrInputNotSlice, err)
}
// no nodes
{
err, _ := r.NextPeer([]string{})
a.Equal(ErrNoAvaliableNode, err)
}
// pass
{
err, res := r.NextPeer([]string{"a"})
a.NoError(err)
a.Equal("a", res.(string))
}
}

1
pkg/cache/driver.go vendored
View file

@ -2,6 +2,7 @@ package cache
import (
"encoding/gob"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"

View file

@ -1,69 +0,0 @@
package cache
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestSet(t *testing.T) {
asserts := assert.New(t)
asserts.NoError(Set("123", "321", -1))
}
func TestGet(t *testing.T) {
asserts := assert.New(t)
asserts.NoError(Set("123", "321", -1))
value, ok := Get("123")
asserts.True(ok)
asserts.Equal("321", value)
value, ok = Get("not_exist")
asserts.False(ok)
}
func TestDeletes(t *testing.T) {
asserts := assert.New(t)
asserts.NoError(Set("123", "321", -1))
err := Deletes([]string{"123"}, "")
asserts.NoError(err)
_, exist := Get("123")
asserts.False(exist)
}
func TestGetSettings(t *testing.T) {
asserts := assert.New(t)
asserts.NoError(Set("test_1", "1", -1))
values, missed := GetSettings([]string{"1", "2"}, "test_")
asserts.Equal(map[string]string{"1": "1"}, values)
asserts.Equal([]string{"2"}, missed)
}
func TestSetSettings(t *testing.T) {
asserts := assert.New(t)
err := SetSettings(map[string]string{"3": "3", "4": "4"}, "test_")
asserts.NoError(err)
value1, _ := Get("test_3")
value2, _ := Get("test_4")
asserts.Equal("3", value1)
asserts.Equal("4", value2)
}
func TestInit(t *testing.T) {
asserts := assert.New(t)
asserts.NotPanics(func() {
Init()
})
}
func TestInitSlaveOverwrites(t *testing.T) {
asserts := assert.New(t)
asserts.NotPanics(func() {
InitSlaveOverwrites()
})
}

8
pkg/cache/memo.go vendored
View file

@ -133,7 +133,13 @@ func (store *MemoStore) Persist(path string) error {
return fmt.Errorf("failed to serialize cache: %s", err)
}
err = os.WriteFile(path, res, 0644)
// err = os.WriteFile(path, res, 0644)
file, err := util.CreatNestedFile(path)
if err == nil {
_, err = file.Write(res)
file.Chmod(0644)
file.Close()
}
return err
}

191
pkg/cache/memo_test.go vendored
View file

@ -1,191 +0,0 @@
package cache
import (
"github.com/stretchr/testify/assert"
"path/filepath"
"testing"
"time"
)
func TestNewMemoStore(t *testing.T) {
asserts := assert.New(t)
store := NewMemoStore()
asserts.NotNil(store)
asserts.NotNil(store.Store)
}
func TestMemoStore_Set(t *testing.T) {
asserts := assert.New(t)
store := NewMemoStore()
err := store.Set("KEY", "vAL", -1)
asserts.NoError(err)
val, ok := store.Store.Load("KEY")
asserts.True(ok)
asserts.Equal("vAL", val.(itemWithTTL).Value)
}
func TestMemoStore_Get(t *testing.T) {
asserts := assert.New(t)
store := NewMemoStore()
// 正常情况
{
_ = store.Set("string", "string_val", -1)
val, ok := store.Get("string")
asserts.Equal("string_val", val)
asserts.True(ok)
}
// Key不存在
{
val, ok := store.Get("something")
asserts.Equal(nil, val)
asserts.False(ok)
}
// 存储struct
{
type testStruct struct {
key int
}
test := testStruct{key: 233}
_ = store.Set("struct", test, -1)
val, ok := store.Get("struct")
asserts.True(ok)
res, ok := val.(testStruct)
asserts.True(ok)
asserts.Equal(test, res)
}
// 过期
{
_ = store.Set("string", "string_val", 1)
time.Sleep(time.Duration(2) * time.Second)
val, ok := store.Get("string")
asserts.Nil(val)
asserts.False(ok)
}
}
func TestMemoStore_Gets(t *testing.T) {
asserts := assert.New(t)
store := NewMemoStore()
err := store.Set("1", "1,val", -1)
err = store.Set("2", "2,val", -1)
err = store.Set("3", "3,val", -1)
err = store.Set("4", "4,val", -1)
asserts.NoError(err)
// 全部命中
{
values, miss := store.Gets([]string{"1", "2", "3", "4"}, "")
asserts.Len(values, 4)
asserts.Len(miss, 0)
}
// 命中一半
{
values, miss := store.Gets([]string{"1", "2", "9", "10"}, "")
asserts.Len(values, 2)
asserts.Equal([]string{"9", "10"}, miss)
}
}
func TestMemoStore_Sets(t *testing.T) {
asserts := assert.New(t)
store := NewMemoStore()
err := store.Sets(map[string]interface{}{
"1": "1.val",
"2": "2.val",
"3": "3.val",
"4": "4.val",
}, "test_")
asserts.NoError(err)
vals, miss := store.Gets([]string{"1", "2", "3", "4"}, "test_")
asserts.Len(miss, 0)
asserts.Equal(map[string]interface{}{
"1": "1.val",
"2": "2.val",
"3": "3.val",
"4": "4.val",
}, vals)
}
func TestMemoStore_Delete(t *testing.T) {
asserts := assert.New(t)
store := NewMemoStore()
err := store.Sets(map[string]interface{}{
"1": "1.val",
"2": "2.val",
"3": "3.val",
"4": "4.val",
}, "test_")
asserts.NoError(err)
err = store.Delete([]string{"1", "2"}, "test_")
asserts.NoError(err)
values, miss := store.Gets([]string{"1", "2", "3", "4"}, "test_")
asserts.Equal([]string{"1", "2"}, miss)
asserts.Equal(map[string]interface{}{"3": "3.val", "4": "4.val"}, values)
}
func TestMemoStore_GarbageCollect(t *testing.T) {
asserts := assert.New(t)
store := NewMemoStore()
store.Set("test", 1, 1)
time.Sleep(time.Duration(2000) * time.Millisecond)
store.GarbageCollect()
_, ok := store.Get("test")
asserts.False(ok)
}
func TestMemoStore_PersistFailed(t *testing.T) {
a := assert.New(t)
store := NewMemoStore()
type testStruct struct{ v string }
store.Set("test", 1, 0)
store.Set("test2", testStruct{v: "test"}, 0)
err := store.Persist(filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed"))
a.Error(err)
}
func TestMemoStore_PersistAndRestore(t *testing.T) {
a := assert.New(t)
store := NewMemoStore()
store.Set("test", 1, 0)
// already expired
store.Store.Store("test2", itemWithTTL{Value: "test", Expires: 1})
// expired after persist
store.Set("test3", 1, 1)
temp := filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed")
// Persist
err := store.Persist(temp)
a.NoError(err)
a.FileExists(temp)
time.Sleep(2 * time.Second)
// Restore
store2 := NewMemoStore()
err = store2.Restore(temp)
a.NoError(err)
test, testOk := store2.Get("test")
a.EqualValues(1, test)
a.True(testOk)
test2, test2Ok := store2.Get("test2")
a.Nil(test2)
a.False(test2Ok)
test3, test3Ok := store2.Get("test3")
a.Nil(test3)
a.False(test3Ok)
a.NoFileExists(temp)
}

View file

@ -1,324 +0,0 @@
package cache
import (
"errors"
"fmt"
"github.com/gomodule/redigo/redis"
"github.com/rafaeljusto/redigomock"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestNewRedisStore(t *testing.T) {
asserts := assert.New(t)
store := NewRedisStore(10, "tcp", "", "", "", "0")
asserts.NotNil(store)
asserts.Panics(func() {
store.pool.Dial()
})
testConn := redigomock.NewConn()
cmd := testConn.Command("PING").Expect("PONG")
err := store.pool.TestOnBorrow(testConn, time.Now())
if testConn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
asserts.NoError(err)
}
func TestRedisStore_Set(t *testing.T) {
asserts := assert.New(t)
conn := redigomock.NewConn()
pool := &redis.Pool{
Dial: func() (redis.Conn, error) { return conn, nil },
MaxIdle: 10,
}
store := &RedisStore{pool: pool}
// 正常情况
{
cmd := conn.Command("SET", "test", redigomock.NewAnyData()).ExpectStringSlice("OK")
err := store.Set("test", "test val", -1)
asserts.NoError(err)
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
}
// 带有TTL
// 正常情况
{
cmd := conn.Command("SETEX", "test", 10, redigomock.NewAnyData()).ExpectStringSlice("OK")
err := store.Set("test", "test val", 10)
asserts.NoError(err)
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
}
// 序列化出错
{
value := struct {
Key string
}{
Key: "123",
}
err := store.Set("test", value, -1)
asserts.Error(err)
}
// 命令执行失败
{
conn.Clear()
cmd := conn.Command("SET", "test", redigomock.NewAnyData()).ExpectError(errors.New("error"))
err := store.Set("test", "test val", -1)
asserts.Error(err)
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
}
// 获取连接失败
{
store.pool = &redis.Pool{
Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
MaxIdle: 10,
}
err := store.Set("test", "123", -1)
asserts.Error(err)
}
}
func TestRedisStore_Get(t *testing.T) {
asserts := assert.New(t)
conn := redigomock.NewConn()
pool := &redis.Pool{
Dial: func() (redis.Conn, error) { return conn, nil },
MaxIdle: 10,
}
store := &RedisStore{pool: pool}
// 正常情况
{
expectVal, _ := serializer("test val")
cmd := conn.Command("GET", "test").Expect(expectVal)
val, ok := store.Get("test")
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
asserts.True(ok)
asserts.Equal("test val", val.(string))
}
// Key不存在
{
conn.Clear()
cmd := conn.Command("GET", "test").Expect(nil)
val, ok := store.Get("test")
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
asserts.False(ok)
asserts.Nil(val)
}
// 解码错误
{
conn.Clear()
cmd := conn.Command("GET", "test").Expect([]byte{0x20})
val, ok := store.Get("test")
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
asserts.False(ok)
asserts.Nil(val)
}
// 获取连接失败
{
store.pool = &redis.Pool{
Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
MaxIdle: 10,
}
val, ok := store.Get("test")
asserts.False(ok)
asserts.Nil(val)
}
}
func TestRedisStore_Gets(t *testing.T) {
asserts := assert.New(t)
conn := redigomock.NewConn()
pool := &redis.Pool{
Dial: func() (redis.Conn, error) { return conn, nil },
MaxIdle: 10,
}
store := &RedisStore{pool: pool}
// 全部命中
{
conn.Clear()
value1, _ := serializer("1")
value2, _ := serializer("2")
cmd := conn.Command("MGET", "test_1", "test_2").ExpectSlice(
value1, value2)
res, missed := store.Gets([]string{"1", "2"}, "test_")
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
asserts.Len(missed, 0)
asserts.Len(res, 2)
asserts.Equal("1", res["1"].(string))
asserts.Equal("2", res["2"].(string))
}
// 命中一个
{
conn.Clear()
value2, _ := serializer("2")
cmd := conn.Command("MGET", "test_1", "test_2").ExpectSlice(
nil, value2)
res, missed := store.Gets([]string{"1", "2"}, "test_")
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
asserts.Len(missed, 1)
asserts.Len(res, 1)
asserts.Equal("1", missed[0])
asserts.Equal("2", res["2"].(string))
}
// 命令出错
{
conn.Clear()
cmd := conn.Command("MGET", "test_1", "test_2").ExpectError(errors.New("error"))
res, missed := store.Gets([]string{"1", "2"}, "test_")
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
asserts.Len(missed, 2)
asserts.Len(res, 0)
}
// 连接出错
{
conn.Clear()
store.pool = &redis.Pool{
Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
MaxIdle: 10,
}
res, missed := store.Gets([]string{"1", "2"}, "test_")
asserts.Len(missed, 2)
asserts.Len(res, 0)
}
}
func TestRedisStore_Sets(t *testing.T) {
asserts := assert.New(t)
conn := redigomock.NewConn()
pool := &redis.Pool{
Dial: func() (redis.Conn, error) { return conn, nil },
MaxIdle: 10,
}
store := &RedisStore{pool: pool}
// 正常
{
cmd := conn.Command("MSET", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectSlice("OK")
err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_")
asserts.NoError(err)
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
}
// 序列化失败
{
conn.Clear()
value := struct {
Key string
}{
Key: "123",
}
err := store.Sets(map[string]interface{}{"1": value, "2": "2"}, "test_")
asserts.Error(err)
}
// 执行失败
{
cmd := conn.Command("MSET", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectError(errors.New("error"))
err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_")
asserts.Error(err)
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
}
// 连接失败
{
conn.Clear()
store.pool = &redis.Pool{
Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
MaxIdle: 10,
}
err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_")
asserts.Error(err)
}
}
func TestRedisStore_Delete(t *testing.T) {
asserts := assert.New(t)
conn := redigomock.NewConn()
pool := &redis.Pool{
Dial: func() (redis.Conn, error) { return conn, nil },
MaxIdle: 10,
}
store := &RedisStore{pool: pool}
// 正常
{
cmd := conn.Command("DEL", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectSlice("OK")
err := store.Delete([]string{"1", "2", "3", "4"}, "test_")
asserts.NoError(err)
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
}
// 命令执行失败
{
conn.Clear()
cmd := conn.Command("DEL", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectError(errors.New("error"))
err := store.Delete([]string{"1", "2", "3", "4"}, "test_")
asserts.Error(err)
if conn.Stats(cmd) != 1 {
fmt.Println("Command was not used")
return
}
}
// 连接失败
{
conn.Clear()
store.pool = &redis.Pool{
Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
MaxIdle: 10,
}
err := store.Delete([]string{"1", "2", "3", "4"}, "test_")
asserts.Error(err)
}
}

View file

@ -4,6 +4,9 @@ import (
"bytes"
"encoding/gob"
"fmt"
"net/url"
"sync"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
@ -12,8 +15,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/jinzhu/gorm"
"net/url"
"sync"
)
var DefaultController Controller

View file

@ -1,385 +0,0 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
)
func TestInitController(t *testing.T) {
assert.NotPanics(t, func() {
InitController()
})
}
func TestSlaveController_HandleHeartBeat(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: make(map[string]MasterInfo),
}
// first heart beat
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
Node: &model.Node{},
})
a.NoError(err)
_, err = c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "2",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
}
// second heart beat, no fresh
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
SiteURL: "http://127.0.0.1",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
a.Empty(c.masters["1"].URL)
}
// second heart beat, fresh
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
IsUpdate: true,
SiteURL: "http://127.0.0.1",
Node: &model.Node{},
})
a.NoError(err)
a.Len(c.masters, 2)
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
}
// second heart beat, fresh, url illegal
{
_, err := c.HandleHeartBeat(&serializer.NodePingReq{
SiteID: "1",
IsUpdate: true,
SiteURL: string([]byte{0x7f}),
Node: &model.Node{},
})
a.Error(err)
a.Len(c.masters, 2)
a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
}
}
type nodeMock struct {
testMock.Mock
}
func (n nodeMock) Init(node *model.Node) {
n.Called(node)
}
func (n nodeMock) IsFeatureEnabled(feature string) bool {
args := n.Called(feature)
return args.Bool(0)
}
func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
n.Called(callback)
}
func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
args := n.Called(req)
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
}
func (n nodeMock) IsActive() bool {
args := n.Called()
return args.Bool(0)
}
func (n nodeMock) GetAria2Instance() common.Aria2 {
args := n.Called()
return args.Get(0).(common.Aria2)
}
func (n nodeMock) ID() uint {
args := n.Called()
return args.Get(0).(uint)
}
func (n nodeMock) Kill() {
n.Called()
}
func (n nodeMock) IsMater() bool {
args := n.Called()
return args.Bool(0)
}
func (n nodeMock) MasterAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n nodeMock) SlaveAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n nodeMock) DBModel() *model.Node {
args := n.Called()
return args.Get(0).(*model.Node)
}
func TestSlaveController_GetAria2Instance(t *testing.T) {
a := assert.New(t)
mockNode := &nodeMock{}
mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Instance: mockNode},
},
}
// node node found
{
res, err := c.GetAria2Instance("2")
a.Nil(res)
a.Equal(ErrMasterNotFound, err)
}
// node found
{
res, err := c.GetAria2Instance("1")
a.NotNil(res)
a.NoError(err)
mockNode.AssertExpectations(t)
}
}
type requestMock struct {
testMock.Mock
}
func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
return r.Called(method, target, body, opts).Get(0).(*request.Response)
}
func TestSlaveController_SendNotification(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
}
// gob encode error
{
type randomType struct{}
a.Error(c.SendNotification("1", "", mq.Message{
Content: randomType{},
}))
}
// return none 200
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{StatusCode: http.StatusConflict},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.Error(c.SendNotification("1", "s1", mq.Message{}))
mockRequest.AssertExpectations(t)
}
// master return error
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
mockRequest.AssertExpectations(t)
}
// success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
a.NoError(c.SendNotification("1", "s3", mq.Message{}))
mockRequest.AssertExpectations(t)
}
}
func TestSlaveController_SubmitTask(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {
jobTracker: map[string]bool{},
},
},
}
// node not exit
{
a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil))
}
// success
{
submitted := false
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
submitted = true
}))
a.True(submitted)
}
// job already submitted
{
submitted := false
a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
submitted = true
}))
a.False(submitted)
}
}
func TestSlaveController_GetMasterInfo(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
res, err := c.GetMasterInfo("2")
a.Equal(ErrMasterNotFound, err)
a.Nil(res)
}
// success
{
res, err := c.GetMasterInfo("1")
a.NoError(err)
a.NotNil(res)
}
}
func TestSlaveController_GetOneDriveToken(t *testing.T) {
a := assert.New(t)
c := &slaveController{
masters: map[string]MasterInfo{
"1": {},
},
}
// node not exit
{
res, err := c.GetPolicyOauthToken("2", 1)
a.Equal(ErrMasterNotFound, err)
a.Empty(res)
}
// return none 200
{
mockRequest := &requestMock{}
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{StatusCode: http.StatusConflict},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
res, err := c.GetPolicyOauthToken("1", 1)
a.Error(err)
a.Empty(res)
mockRequest.AssertExpectations(t)
}
// master return error
{
mockRequest := &requestMock{}
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
res, err := c.GetPolicyOauthToken("1", 1)
a.Equal(1, err.(serializer.AppError).Code)
a.Empty(res)
mockRequest.AssertExpectations(t)
}
// success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")),
},
})
c := &slaveController{
masters: map[string]MasterInfo{
"1": {Client: mockRequest},
},
}
res, err := c.GetPolicyOauthToken("1", 1)
a.NoError(err)
a.Equal("expected", res)
mockRequest.AssertExpectations(t)
}
}

View file

@ -1,186 +0,0 @@
package cluster
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/stretchr/testify/assert"
"os"
"testing"
"time"
)
func TestMasterNode_Init(t *testing.T) {
a := assert.New(t)
m := &MasterNode{}
m.Init(&model.Node{Status: model.NodeSuspend})
a.Equal(model.NodeSuspend, m.DBModel().Status)
m.Init(&model.Node{Aria2Enabled: true})
}
func TestMasterNode_DummyMethods(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
}
m.Model.ID = 5
a.Equal(m.Model.ID, m.ID())
res, err := m.Ping(&serializer.NodePingReq{})
a.NoError(err)
a.NotNil(res)
a.True(m.IsActive())
a.True(m.IsMater())
m.SubscribeStatusChange(func(isActive bool, id uint) {})
}
func TestMasterNode_IsFeatureEnabled(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
}
a.False(m.IsFeatureEnabled("aria2"))
a.False(m.IsFeatureEnabled("random"))
m.Model.Aria2Enabled = true
a.True(m.IsFeatureEnabled("aria2"))
}
func TestMasterNode_AuthInstance(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
}
a.NotNil(m.MasterAuthInstance())
a.NotNil(m.SlaveAuthInstance())
}
func TestMasterNode_Kill(t *testing.T) {
m := &MasterNode{
Model: &model.Node{},
}
m.Kill()
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
m.aria2RPC.Caller = caller
m.Kill()
}
func TestMasterNode_GetAria2Instance(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{},
aria2RPC: rpcService{},
}
m.aria2RPC.parent = m
a.NotNil(m.GetAria2Instance())
m.Model.Aria2Enabled = true
a.NotNil(m.GetAria2Instance())
m.aria2RPC.Initialized = true
a.NotNil(m.GetAria2Instance())
}
func TestRpcService_Init(t *testing.T) {
a := assert.New(t)
m := &MasterNode{
Model: &model.Node{
Aria2OptionsSerialized: model.Aria2Option{
Options: "{",
},
},
aria2RPC: rpcService{},
}
m.aria2RPC.parent = m
// failed to decode address
{
m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f})
a.Error(m.aria2RPC.Init())
}
// failed to decode options
{
m.Model.Aria2OptionsSerialized.Server = ""
a.Error(m.aria2RPC.Init())
}
// failed to initialized
{
m.Model.Aria2OptionsSerialized.Server = ""
m.Model.Aria2OptionsSerialized.Options = "{}"
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
m.aria2RPC.Caller = caller
a.Error(m.aria2RPC.Init())
a.False(m.aria2RPC.Initialized)
}
}
func getTestRPCNode() *MasterNode {
m := &MasterNode{
Model: &model.Node{
Aria2OptionsSerialized: model.Aria2Option{},
},
aria2RPC: rpcService{
options: &clientOptions{
Options: map[string]interface{}{"1": "1"},
},
},
}
m.aria2RPC.parent = m
caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
m.aria2RPC.Caller = caller
return m
}
func TestRpcService_CreateTask(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"})
a.Error(err)
a.Empty(res)
}
func TestRpcService_Status(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
res, err := m.aria2RPC.Status(&model.Download{})
a.Error(err)
a.Empty(res)
}
func TestRpcService_Cancel(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
a.Error(m.aria2RPC.Cancel(&model.Download{}))
}
func TestRpcService_Select(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
a.NotNil(m.aria2RPC.GetConfig())
a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3}))
}
func TestRpcService_DeleteTempFile(t *testing.T) {
a := assert.New(t)
m := getTestRPCNode()
fdName := "TestRpcService_DeleteTempFile"
a.NoError(os.Mkdir(fdName, 0644))
a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName}))
time.Sleep(500 * time.Millisecond)
a.False(util.Exists(fdName))
}

View file

@ -1,17 +0,0 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert"
"testing"
)
func TestNewNodeFromDBModel(t *testing.T) {
a := assert.New(t)
a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{
Type: model.SlaveNodeType,
}))
a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{
Type: model.MasterNodeType,
}))
}

View file

@ -4,6 +4,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/samber/lo"
"sync"
)
@ -15,7 +16,7 @@ var featureGroup = []string{"aria2"}
// Pool 节点池
type Pool interface {
// Returns active node selected by given feature and load balancer
BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node)
BalanceNodeByFeature(feature string, lb balancer.Balancer, available []uint, prefer uint) (error, Node)
// Returns node by ID
GetNodeByID(id uint) Node
@ -174,11 +175,33 @@ func (pool *NodePool) Delete(id uint) {
}
// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点
func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) {
func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer,
available []uint, prefer uint) (error, Node) {
pool.lock.RLock()
defer pool.lock.RUnlock()
if nodes, ok := pool.featureMap[feature]; ok {
err, res := lb.NextPeer(nodes)
// Find nodes that are allowed to be used in user group
availableNodes := nodes
if len(available) > 0 {
idHash := make(map[uint]struct{}, len(available))
for _, id := range available {
idHash[id] = struct{}{}
}
availableNodes = lo.Filter[Node](nodes, func(node Node, index int) bool {
_, exist := idHash[node.ID()]
return exist
})
}
// Return preferred node if exists
if preferredNode, found := lo.Find[Node](availableNodes, func(node Node) bool {
return node.ID() == prefer
}); found {
return nil, preferredNode
}
err, res := lb.NextPeer(availableNodes)
if err == nil {
return nil, res.(Node)
}

View file

@ -1,161 +0,0 @@
package cluster
import (
"database/sql"
"errors"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"testing"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestInitFailed(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
Init()
a.NoError(mock.ExpectationsWereMet())
}
func TestInitSuccess(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "aria2_enabled", "type"}).AddRow(1, true, model.MasterNodeType))
Init()
a.NoError(mock.ExpectationsWereMet())
}
func TestNodePool_GetNodeByID(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
mockNode := &nodeMock{}
// inactive
{
p.inactive[1] = mockNode
a.Equal(mockNode, p.GetNodeByID(1))
}
// active
{
delete(p.inactive, 1)
p.active[1] = mockNode
a.Equal(mockNode, p.GetNodeByID(1))
}
}
func TestNodePool_NodeStatusChange(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
n := &MasterNode{Model: &model.Node{}}
p.Init()
p.inactive[1] = n
p.nodeStatusChange(true, 1)
a.Len(p.inactive, 0)
a.Equal(n, p.active[1])
p.nodeStatusChange(false, 1)
a.Len(p.active, 0)
a.Equal(n, p.inactive[1])
p.nodeStatusChange(false, 1)
a.Len(p.active, 0)
a.Equal(n, p.inactive[1])
}
func TestNodePool_Add(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
// new node
{
p.Add(&model.Node{})
a.Len(p.active, 1)
}
// old node
{
p.inactive[0] = p.active[0]
delete(p.active, 0)
p.Add(&model.Node{})
a.Len(p.active, 0)
a.Len(p.inactive, 1)
}
}
func TestNodePool_Delete(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
// active
{
mockNode := &nodeMock{}
mockNode.On("Kill")
p.active[0] = mockNode
p.Delete(0)
a.Len(p.active, 0)
a.Len(p.inactive, 0)
mockNode.AssertExpectations(t)
}
p.Init()
// inactive
{
mockNode := &nodeMock{}
mockNode.On("Kill")
p.inactive[0] = mockNode
p.Delete(0)
a.Len(p.active, 0)
a.Len(p.inactive, 0)
mockNode.AssertExpectations(t)
}
}
func TestNodePool_BalanceNodeByFeature(t *testing.T) {
a := assert.New(t)
p := &NodePool{}
p.Init()
// success
{
p.featureMap["test"] = []Node{&MasterNode{}}
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
a.NoError(err)
a.Equal(p.featureMap["test"][0], res)
}
// NoNodes
{
p.featureMap["test"] = []Node{}
err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
a.Error(err)
a.Nil(res)
}
// No match feature
{
err, res := p.BalanceNodeByFeature("test2", balancer.NewBalancer("round-robin"))
a.Error(err)
a.Nil(res)
}
}

View file

@ -1,559 +0,0 @@
package cluster
import (
"bytes"
"encoding/json"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"
)
func TestSlaveNode_InitAndKill(t *testing.T) {
a := assert.New(t)
n := &SlaveNode{
callback: func(b bool, u uint) {
},
}
a.NotPanics(func() {
n.Init(&model.Node{})
time.Sleep(time.Millisecond * 500)
n.Init(&model.Node{})
n.Kill()
})
}
func TestSlaveNode_DummyMethods(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
m.Model.ID = 5
a.Equal(m.Model.ID, m.ID())
a.Equal(m.Model.ID, m.DBModel().ID)
a.False(m.IsActive())
a.False(m.IsMater())
m.SubscribeStatusChange(func(isActive bool, id uint) {})
}
func TestSlaveNode_IsFeatureEnabled(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
a.False(m.IsFeatureEnabled("aria2"))
a.False(m.IsFeatureEnabled("random"))
m.Model.Aria2Enabled = true
a.True(m.IsFeatureEnabled("aria2"))
}
func TestSlaveNode_Ping(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
// master return error code
{
mockRequest := &requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
res, err := m.Ping(&serializer.NodePingReq{})
a.Error(err)
a.Nil(res)
a.Equal(1, err.(serializer.AppError).Code)
}
// return unexpected json
{
mockRequest := &requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"233\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.Ping(&serializer.NodePingReq{})
a.Error(err)
a.Nil(res)
}
// return success
{
mockRequest := &requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.Ping(&serializer.NodePingReq{})
a.NoError(err)
a.NotNil(res)
}
}
func TestSlaveNode_GetAria2Instance(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
a.NotNil(m.GetAria2Instance())
m.Model.Aria2Enabled = true
a.NotNil(m.GetAria2Instance())
a.NotNil(m.GetAria2Instance())
}
func TestSlaveNode_StartPingLoop(t *testing.T) {
callbackCount := 0
finishedChan := make(chan struct{})
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m := &SlaveNode{
Active: true,
Model: &model.Node{},
callback: func(b bool, u uint) {
callbackCount++
if callbackCount == 2 {
close(finishedChan)
}
if callbackCount == 1 {
mockRequest.AssertExpectations(t)
mockRequest = requestMock{}
mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
},
})
}
},
}
cache.Set("setting_slave_ping_interval", "0", 0)
cache.Set("setting_slave_recover_interval", "0", 0)
cache.Set("setting_slave_node_retry", "1", 0)
m.caller.Client = &mockRequest
go func() {
select {
case <-finishedChan:
m.Kill()
}
}()
m.StartPingLoop()
mockRequest.AssertExpectations(t)
}
func TestSlaveNode_AuthInstance(t *testing.T) {
a := assert.New(t)
m := &SlaveNode{
Model: &model.Node{},
}
a.NotNil(m.MasterAuthInstance())
a.NotNil(m.SlaveAuthInstance())
}
func TestSlaveNode_ChangeStatus(t *testing.T) {
a := assert.New(t)
isActive := false
m := &SlaveNode{
Model: &model.Node{},
callback: func(b bool, u uint) {
isActive = b
},
}
a.NotPanics(func() {
m.changeStatus(false)
})
m.changeStatus(true)
a.True(isActive)
}
func getTestRPCNodeSlave() *SlaveNode {
m := &SlaveNode{
Model: &model.Node{},
}
m.caller.parent = m
return m
}
func TestSlaveCaller_CreateTask(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
res, err := m.caller.CreateTask(&model.Download{}, nil)
a.Empty(res)
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.CreateTask(&model.Download{}, nil)
a.Empty(res)
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.CreateTask(&model.Download{}, nil)
a.Equal("res", res)
a.NoError(err)
}
}
func TestSlaveCaller_Status(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
res, err := m.caller.Status(&model.Download{})
a.Empty(res.Status)
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.Status(&model.Download{})
a.Empty(res.Status)
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"re456456s\"}")),
},
})
m.caller.Client = mockRequest
res, err := m.caller.Status(&model.Download{})
a.Empty(res.Status)
a.NoError(err)
}
}
func TestSlaveCaller_Cancel(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
err := m.caller.Cancel(&model.Download{})
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Cancel(&model.Download{})
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Cancel(&model.Download{})
a.NoError(err)
}
}
func TestSlaveCaller_Select(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
m.caller.Init()
m.caller.GetConfig()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
err := m.caller.Select(&model.Download{}, nil)
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Select(&model.Download{}, nil)
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
err := m.caller.Select(&model.Download{}, nil)
a.NoError(err)
}
}
func TestSlaveCaller_DeleteTempFile(t *testing.T) {
a := assert.New(t)
m := getTestRPCNodeSlave()
m.caller.Init()
m.caller.GetConfig()
// master return 404
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 404,
},
})
m.caller.Client = mockRequest
err := m.caller.DeleteTempFile(&model.Download{})
a.Error(err)
}
// master return error
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
},
})
m.caller.Client = mockRequest
err := m.caller.DeleteTempFile(&model.Download{})
a.Error(err)
}
// master return success
{
mockRequest := requestMock{}
mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
},
})
m.caller.Client = mockRequest
err := m.caller.DeleteTempFile(&model.Download{})
a.NoError(err)
}
}
func TestRemoteCallback(t *testing.T) {
asserts := assert.New(t)
// 回调成功
{
clientMock := requestmock.RequestMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 0})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.NoError(resp)
clientMock.AssertExpectations(t)
}
// 服务端返回业务错误
{
clientMock := requestmock.RequestMock{}
mockResp, _ := json.Marshal(serializer.Response{Code: 401})
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.EqualValues(401, resp.(serializer.AppError).Code)
clientMock.AssertExpectations(t)
}
// 无法解析回调响应
{
clientMock := requestmock.RequestMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// HTTP状态码非200
{
clientMock := requestmock.RequestMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 404,
Body: ioutil.NopCloser(strings.NewReader("mockResp")),
},
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
// 无法发起回调
{
clientMock := requestmock.RequestMock{}
clientMock.On(
"Request",
"POST",
"http://test/test/url",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: errors.New("error"),
})
request.GeneralClient = clientMock
resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
asserts.Error(resp)
clientMock.AssertExpectations(t)
}
}

View file

@ -53,7 +53,7 @@ type slave struct {
type redis struct {
Network string
Server string
User string
User string
Password string
DB string
}

View file

@ -1,100 +0,0 @@
package conf
import (
"io/ioutil"
"os"
"testing"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/stretchr/testify/assert"
)
// 测试Init日志路径错误
func TestInitPanic(t *testing.T) {
asserts := assert.New(t)
// 日志路径不存在时
asserts.NotPanics(func() {
Init("not/exist/path/conf.ini")
})
asserts.True(util.Exists("not/exist/path/conf.ini"))
}
// TestInitDelimiterNotFound 日志路径存在但 Key 格式错误时
func TestInitDelimiterNotFound(t *testing.T) {
asserts := assert.New(t)
testCase := `[Database]
Type = mysql
User = root
Password233root
Host = 127.0.0.1:3306
Name = v3
TablePrefix = v3_`
err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
defer func() { err = os.Remove("testConf.ini") }()
if err != nil {
panic(err)
}
asserts.Panics(func() {
Init("testConf.ini")
})
}
// TestInitNoPanic 日志路径存在且合法时
func TestInitNoPanic(t *testing.T) {
asserts := assert.New(t)
testCase := `
[System]
Listen = 3000
HashIDSalt = 1
[Database]
Type = mysql
User = root
Password = root
Host = 127.0.0.1:3306
Name = v3
TablePrefix = v3_
[OptionOverwrite]
key=value
`
err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
defer func() { err = os.Remove("testConf.ini") }()
if err != nil {
panic(err)
}
asserts.NotPanics(func() {
Init("testConf.ini")
})
asserts.Equal(OptionOverwrite["key"], "value")
}
func TestMapSection(t *testing.T) {
asserts := assert.New(t)
//正常情况
testCase := `
[System]
Listen = 3000
HashIDSalt = 1
[Database]
Type = mysql
User = root
Password:root
Host = 127.0.0.1:3306
Name = v3
TablePrefix = v3_`
err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
defer func() { err = os.Remove("testConf.ini") }()
if err != nil {
panic(err)
}
Init("testConf.ini")
err = mapSection("Database", DatabaseConfig)
asserts.NoError(err)
}

View file

@ -1,16 +1,22 @@
package conf
// plusVersion 增强版版本号
const plusVersion = "+1.1"
// BackendVersion 当前后端版本号
var BackendVersion = "3.8.3"
const BackendVersion = "3.8.3" + plusVersion
// KeyVersion 授权版本号
const KeyVersion = "3.3.1"
// RequiredDBVersion 与当前版本匹配的数据库版本
var RequiredDBVersion = "3.8.1"
const RequiredDBVersion = "3.8.1+1.0-plus"
// RequiredStaticVersion 与当前版本匹配的静态资源版本
var RequiredStaticVersion = "3.8.3"
const RequiredStaticVersion = "3.8.3" + plusVersion
// IsPro 是否为Pro版本
var IsPro = "false"
// IsPlus 是否为Plus版本
const IsPlus = "true"
// LastCommit 最后commit id
var LastCommit = "a11f819"
const LastCommit = "88409cc"

View file

@ -23,6 +23,8 @@ func Init() {
// 读取cron日程设置
options := model.GetSettingByNames(
"cron_garbage_collect",
"cron_notify_user",
"cron_ban_user",
"cron_recycle_upload_session",
)
Cron := cron.New()
@ -31,6 +33,10 @@ func Init() {
switch k {
case "cron_garbage_collect":
handler = garbageCollect
case "cron_notify_user":
handler = notifyExpiredVAS
case "cron_ban_user":
handler = banOverusedUser
case "cron_recycle_upload_session":
handler = uploadSessionCollect
default:

83
pkg/crontab/vas.go Executable file
View file

@ -0,0 +1,83 @@
package crontab
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
func notifyExpiredVAS() {
checkStoragePack()
checkUserGroup()
util.Log().Info("Crontab job \"cron_notify_user\" complete.")
}
// banOverusedUser 封禁超出宽容期的用户
func banOverusedUser() {
users := model.GetTolerantExpiredUser()
for _, user := range users {
// 清除最后通知日期标记
user.ClearNotified()
// 检查容量是否超额
if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
// 封禁用户
user.SetStatus(model.OveruseBaned)
}
}
}
// checkUserGroup 检查已过期用户组
func checkUserGroup() {
users := model.GetGroupExpiredUsers()
for _, user := range users {
// 将用户回退到初始用户组
user.GroupFallback()
// 重新加载用户
user, _ = model.GetUserByID(user.ID)
// 检查容量是否超额
if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
// 如果超额,则通知用户
sendNotification(&user, "用户组过期")
// 更新最后通知日期
user.Notified()
}
}
}
// checkStoragePack 检查已过期的容量包
func checkStoragePack() {
packs := model.GetExpiredStoragePack()
for _, pack := range packs {
// 删除过期的容量包
pack.Delete()
//找到所属用户
user, err := model.GetUserByID(pack.UserID)
if err != nil {
util.Log().Warning("Crontab job failed to get user info of [UID=%d]: %s", pack.UserID, err)
continue
}
// 检查容量是否超额
if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
// 如果超额,则通知用户
sendNotification(&user, "容量包过期")
// 更新最后通知日期
user.Notified()
}
}
}
func sendNotification(user *model.User, reason string) {
title, body := email.NewOveruseNotification(user.Nick, reason)
if err := email.Send(user.Email, title, body); err != nil {
util.Log().Warning("Failed to send notification email: %s", err)
}
}

View file

@ -5,6 +5,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/go-mail/mail"
"github.com/google/uuid"
)
// SMTP SMTP协议发送邮件
@ -50,6 +51,7 @@ func (client *SMTP) Send(to, title, body string) error {
m.SetAddressHeader("Reply-To", client.Config.ReplyTo, client.Config.Name)
m.SetHeader("To", to)
m.SetHeader("Subject", title)
m.SetHeader("Message-ID", util.StrConcat(`"<`, uuid.NewString(), `@`, `cloudreveplus`, `>"`))
m.SetBody("text/html", body)
client.ch <- m
return nil

View file

@ -7,6 +7,20 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// NewOveruseNotification 新建超额提醒邮件
func NewOveruseNotification(userName, reason string) (string, string) {
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "over_used_template")
replace := map[string]string{
"{siteTitle}": options["siteName"],
"{userName}": userName,
"{notifyReason}": reason,
"{siteUrl}": options["siteURL"],
"{siteSecTitle}": options["siteTitle"],
}
return fmt.Sprintf("【%s】空间容量超额提醒", options["siteName"]),
util.Replace(replace, options["over_used_template"])
}
// NewActivationEmail 新建激活邮件
func NewActivationEmail(userName, activateURL string) (string, string) {
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_activation_template")

View file

@ -239,13 +239,6 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string)
reader = zipFile
}
// 重设存储策略
fs.Policy = &fs.User.Policy
err = fs.DispatchHandler()
if err != nil {
return err
}
var wg sync.WaitGroup
parallel := model.GetIntSetting("max_parallel_transfer", 4)
worker := make(chan int, parallel)

View file

@ -1,256 +0,0 @@
package filesystem
import (
"bytes"
"context"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
testMock "github.com/stretchr/testify/mock"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestFileSystem_Compress(t *testing.T) {
asserts := assert.New(t)
ctx := context.Background()
fs := FileSystem{
User: &model.User{Model: gorm.Model{ID: 1}},
}
// 成功
{
// 查找压缩父目录
mock.ExpectQuery("SELECT(.+)folders(.+)").
WithArgs(1, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent"))
// 查找顶级待压缩文件
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(1, 1).
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "name", "source_name", "policy_id"}).
AddRow(1, "1.txt", "tests/file1.txt", 1),
)
asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
// 查找父目录子文件
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"}))
// 查找子目录
mock.ExpectQuery("SELECT(.+)folders(.+)").
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(2, "sub"))
// 查找子目录子文件
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(2).
WillReturnRows(
sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"}).
AddRow(2, "2.txt", "tests/file2.txt", 1),
)
// 查找上传策略
asserts.NoError(cache.Set("policy_1", model.Policy{Type: "local"}, -1))
w := &bytes.Buffer{}
err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true)
asserts.NoError(err)
asserts.NotEmpty(w.Len())
}
// 上下文取消
{
ctx, cancel := context.WithCancel(context.Background())
cancel()
// 查找压缩父目录
mock.ExpectQuery("SELECT(.+)folders(.+)").
WithArgs(1, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent"))
// 查找顶级待压缩文件
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(1, 1).
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "name", "source_name", "policy_id"}).
AddRow(1, "1.txt", "tests/file1.txt", 1),
)
asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
w := &bytes.Buffer{}
err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true)
asserts.Error(err)
asserts.NotEmpty(w.Len())
}
// 限制父目录
{
ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, &model.Folder{
Model: gorm.Model{ID: 3},
})
// 查找压缩父目录
mock.ExpectQuery("SELECT(.+)folders(.+)").
WithArgs(1, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(1, "parent", 3))
// 查找顶级待压缩文件
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(1, 1).
WillReturnRows(
sqlmock.NewRows(
[]string{"id", "name", "source_name", "policy_id"}).
AddRow(1, "1.txt", "tests/file1.txt", 1),
)
asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
w := &bytes.Buffer{}
err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true)
asserts.Error(err)
asserts.Equal(ErrObjectNotExist, err)
asserts.Empty(w.Len())
}
}
type MockNopRSC string
func (m MockNopRSC) Read(b []byte) (int, error) {
return 0, errors.New("read error")
}
func (m MockNopRSC) Seek(n int64, offset int) (int64, error) {
return 0, errors.New("read error")
}
func (m MockNopRSC) Close() error {
return errors.New("read error")
}
type MockRSC struct {
rs io.ReadSeeker
}
func (m MockRSC) Read(b []byte) (int, error) {
return m.rs.Read(b)
}
func (m MockRSC) Seek(n int64, offset int) (int64, error) {
return m.rs.Seek(n, offset)
}
func (m MockRSC) Close() error {
return nil
}
var basepath string
func init() {
_, currentFile, _, _ := runtime.Caller(0)
basepath = filepath.Dir(currentFile)
}
func Path(rel string) string {
return filepath.Join(basepath, rel)
}
func TestFileSystem_Decompress(t *testing.T) {
asserts := assert.New(t)
ctx := context.Background()
fs := FileSystem{
User: &model.User{Model: gorm.Model{ID: 1}},
}
os.RemoveAll(util.RelativePath("tests/decompress"))
// 压缩文件不存在
{
// 查找根目录
mock.ExpectQuery("SELECT(.+)folders(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "/"))
// 查找压缩文件,未找到
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
err := fs.Decompress(ctx, "/1.zip", "/", "")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
// 无法下载压缩文件
{
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
fs.FileTarget[0].Policy.ID = 1
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, errors.New("error"))
fs.Handler = testHandler
err := fs.Decompress(ctx, "/1.zip", "/", "")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.EqualError(err, "error")
}
// 无法创建临时压缩文件
{
cache.Set("setting_temp_path", "/tests:", 0)
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
fs.FileTarget[0].Policy.ID = 1
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, nil)
fs.Handler = testHandler
err := fs.Decompress(ctx, "/1.zip", "/", "")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
// 无法写入压缩文件
{
cache.Set("setting_temp_path", "tests", 0)
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
fs.FileTarget[0].Policy.ID = 1
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(MockNopRSC("1"), nil)
fs.Handler = testHandler
err := fs.Decompress(ctx, "/1.zip", "/", "")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Contains(err.Error(), "read error")
}
// 无法重设上传策略
{
cache.Set("setting_temp_path", "tests", 0)
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
fs.FileTarget[0].Policy.ID = 1
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{rs: strings.NewReader("read")}, nil)
fs.Handler = testHandler
err := fs.Decompress(ctx, "/1.zip", "/", "")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.True(util.IsEmpty(util.RelativePath("tests/decompress")))
}
// 无法上传,容量不足
{
cache.Set("setting_max_parallel_transfer", "1", 0)
zipFile, _ := os.Open(Path("tests/test.zip"))
fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
fs.FileTarget[0].Policy.ID = 1
fs.User.Policy.Type = "mock"
testHandler := new(FileHeaderMock)
testHandler.On("Get", testMock.Anything, "1.zip").Return(zipFile, nil)
fs.Handler = testHandler
fs.Decompress(ctx, "/1.zip", "/", "")
zipFile.Close()
asserts.NoError(mock.ExpectationsWereMet())
testHandler.AssertExpectations(t)
}
}

View file

@ -1,61 +0,0 @@
package backoff
import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
"time"
)
func TestConstantBackoff_Next(t *testing.T) {
a := assert.New(t)
// General error
{
err := errors.New("error")
b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
b.Reset()
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
}
// Retryable error
{
err := &RetryableError{RetryAfter: time.Duration(1)}
b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
b.Reset()
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
}
}
func TestNewRetryableErrorFromHeader(t *testing.T) {
a := assert.New(t)
// no retry-after header
{
err := NewRetryableErrorFromHeader(nil, http.Header{})
a.Empty(err.RetryAfter)
}
// with retry-after header
{
header := http.Header{}
header.Add("retry-after", "120")
err := NewRetryableErrorFromHeader(nil, header)
a.EqualValues(time.Duration(120)*time.Second, err.RetryAfter)
}
}

View file

@ -1,250 +0,0 @@
package chunk
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/stretchr/testify/assert"
"io"
"os"
"strings"
"testing"
)
func TestNewChunkGroup(t *testing.T) {
a := assert.New(t)
testCases := []struct {
fileSize uint64
chunkSize uint64
expectedInnerChunkSize uint64
expectedChunkNum uint64
expectedInfo [][2]int //Start, Index,Length
}{
{10, 0, 10, 1, [][2]int{{0, 10}}},
{0, 0, 0, 1, [][2]int{{0, 0}}},
{0, 10, 10, 1, [][2]int{{0, 0}}},
{50, 10, 10, 5, [][2]int{
{0, 10},
{10, 10},
{20, 10},
{30, 10},
{40, 10},
}},
{50, 50, 50, 1, [][2]int{
{0, 50},
}},
{50, 15, 15, 4, [][2]int{
{0, 15},
{15, 15},
{30, 15},
{45, 5},
}},
}
for index, testCase := range testCases {
file := &fsctx.FileStream{Size: testCase.fileSize}
chunkGroup := NewChunkGroup(file, testCase.chunkSize, &backoff.ConstantBackoff{}, true)
a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(),
"TestCase:%d,ChunkNum()", index)
a.EqualValues(testCase.expectedInnerChunkSize, chunkGroup.chunkSize,
"TestCase:%d,InnerChunkSize()", index)
a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(),
"TestCase:%d,len(Chunks)", index)
a.EqualValues(testCase.fileSize, chunkGroup.Total())
for cIndex, info := range testCase.expectedInfo {
a.True(chunkGroup.Next())
a.EqualValues(info[1], chunkGroup.Length(),
"TestCase:%d,Chunks[%d].Length()", index, cIndex)
a.EqualValues(info[0], chunkGroup.Start(),
"TestCase:%d,Chunks[%d].Start()", index, cIndex)
a.Equal(cIndex == len(testCase.expectedInfo)-1, chunkGroup.IsLast(),
"TestCase:%d,Chunks[%d].IsLast()", index, cIndex)
a.NotEmpty(chunkGroup.RangeHeader())
}
a.False(chunkGroup.Next())
}
}
func TestChunkGroup_TempAvailablet(t *testing.T) {
a := assert.New(t)
file := &fsctx.FileStream{Size: 1}
c := NewChunkGroup(file, 0, &backoff.ConstantBackoff{}, true)
a.False(c.TempAvailable())
f, err := os.CreateTemp("", "TestChunkGroup_TempAvailablet.*")
defer func() {
f.Close()
os.Remove(f.Name())
}()
a.NoError(err)
c.bufferTemp = f
a.False(c.TempAvailable())
f.Write([]byte("1"))
a.True(c.TempAvailable())
}
func TestChunkGroup_Process(t *testing.T) {
a := assert.New(t)
file := &fsctx.FileStream{Size: 10}
// success
{
file.File = io.NopCloser(strings.NewReader("1234567890"))
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{}, true)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("67890", string(res))
return nil
}))
a.False(c.Next())
a.Equal(2, count)
}
// retry, read from buffer file
{
file.File = io.NopCloser(strings.NewReader("1234567890"))
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, true)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("67890", string(res))
if count == 2 {
return errors.New("error")
}
return nil
}))
a.False(c.Next())
a.Equal(3, count)
}
// retry, read from seeker
{
f, _ := os.CreateTemp("", "TestChunkGroup_Process.*")
f.Write([]byte("1234567890"))
f.Seek(0, 0)
defer func() {
f.Close()
os.Remove(f.Name())
}()
file.File = f
file.Seeker = f
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("67890", string(res))
if count == 2 {
return errors.New("error")
}
return nil
}))
a.False(c.Next())
a.Equal(3, count)
}
// retry, seek error
{
f, _ := os.CreateTemp("", "TestChunkGroup_Process.*")
f.Write([]byte("1234567890"))
f.Seek(0, 0)
defer func() {
f.Close()
os.Remove(f.Name())
}()
file.File = f
file.Seeker = f
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
f.Close()
a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
if count == 2 {
return errors.New("error")
}
return nil
}))
a.False(c.Next())
a.Equal(2, count)
}
// retry, finally error
{
f, _ := os.CreateTemp("", "TestChunkGroup_Process.*")
f.Write([]byte("1234567890"))
f.Seek(0, 0)
defer func() {
f.Close()
os.Remove(f.Name())
}()
file.File = f
file.Seeker = f
c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false)
count := 0
a.True(c.Next())
a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
res, err := io.ReadAll(chunk)
a.NoError(err)
a.EqualValues("12345", string(res))
return nil
}))
a.True(c.Next())
a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
count++
return errors.New("error")
}))
a.False(c.Next())
a.Equal(4, count)
}
}

View file

@ -3,6 +3,7 @@ package driver
import (
"context"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"

View file

@ -1,338 +0,0 @@
package local
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"io"
"os"
"strings"
"testing"
)
func TestHandler_Put(t *testing.T) {
asserts := assert.New(t)
handler := Driver{}
defer func() {
os.Remove(util.RelativePath("TestHandler_Put.txt"))
os.Remove(util.RelativePath("inner/TestHandler_Put.txt"))
}()
testCases := []struct {
file fsctx.FileHeader
errContains string
}{
{&fsctx.FileStream{
SavePath: "TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("")),
}, ""},
{&fsctx.FileStream{
SavePath: "TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("")),
}, "file with the same name existed or unavailable"},
{&fsctx.FileStream{
SavePath: "inner/TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("")),
}, ""},
{&fsctx.FileStream{
Mode: fsctx.Append | fsctx.Overwrite,
SavePath: "inner/TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("123")),
}, ""},
{&fsctx.FileStream{
AppendStart: 10,
Mode: fsctx.Append | fsctx.Overwrite,
SavePath: "inner/TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("123")),
}, "size of unfinished uploaded chunks is not as expected"},
{&fsctx.FileStream{
Mode: fsctx.Append | fsctx.Overwrite,
SavePath: "inner/TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("123")),
}, ""},
}
for _, testCase := range testCases {
err := handler.Put(context.Background(), testCase.file)
if testCase.errContains != "" {
asserts.Error(err)
asserts.Contains(err.Error(), testCase.errContains)
} else {
asserts.NoError(err)
asserts.True(util.Exists(util.RelativePath(testCase.file.Info().SavePath)))
}
}
}
func TestDriver_TruncateFailed(t *testing.T) {
a := assert.New(t)
h := Driver{}
a.Error(h.Truncate(context.Background(), "TestDriver_TruncateFailed", 0))
}
func TestHandler_Delete(t *testing.T) {
asserts := assert.New(t)
handler := Driver{}
ctx := context.Background()
filePath := util.RelativePath("TestHandler_Delete.file")
file, err := os.Create(filePath)
asserts.NoError(err)
_ = file.Close()
list, err := handler.Delete(ctx, []string{"TestHandler_Delete.file"})
asserts.Equal([]string{}, list)
asserts.NoError(err)
file, err = os.Create(filePath)
_ = file.Close()
file, _ = os.OpenFile(filePath, os.O_RDWR, os.FileMode(0))
asserts.NoError(err)
list, err = handler.Delete(ctx, []string{"TestHandler_Delete.file", "test.notexist"})
file.Close()
asserts.Equal([]string{}, list)
asserts.NoError(err)
list, err = handler.Delete(ctx, []string{"test.notexist"})
asserts.Equal([]string{}, list)
asserts.NoError(err)
file, err = os.Create(filePath)
asserts.NoError(err)
list, err = handler.Delete(ctx, []string{"TestHandler_Delete.file"})
_ = file.Close()
asserts.Equal([]string{}, list)
asserts.NoError(err)
}
func TestHandler_Get(t *testing.T) {
asserts := assert.New(t)
handler := Driver{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 成功
file, err := os.Create(util.RelativePath("TestHandler_Get.txt"))
asserts.NoError(err)
_ = file.Close()
rs, err := handler.Get(ctx, "TestHandler_Get.txt")
asserts.NoError(err)
asserts.NotNil(rs)
// 文件不存在
rs, err = handler.Get(ctx, "TestHandler_Get_notExist.txt")
asserts.Error(err)
asserts.Nil(rs)
}
func TestHandler_Thumb(t *testing.T) {
asserts := assert.New(t)
handler := Driver{}
ctx := context.Background()
file, err := os.Create(util.RelativePath("TestHandler_Thumb._thumb"))
asserts.NoError(err)
file.Close()
f := &model.File{
SourceName: "TestHandler_Thumb",
MetadataSerialized: map[string]string{
model.ThumbStatusMetadataKey: model.ThumbStatusExist,
},
}
// 正常
{
thumb, err := handler.Thumb(ctx, f)
asserts.NoError(err)
asserts.NotNil(thumb.Content)
}
// file 不存在
{
f.SourceName = "not_exist"
_, err := handler.Thumb(ctx, f)
asserts.Error(err)
asserts.ErrorIs(err, driver.ErrorThumbNotExist)
}
// thumb not exist
{
f.MetadataSerialized[model.ThumbStatusMetadataKey] = model.ThumbStatusNotExist
_, err := handler.Thumb(ctx, f)
asserts.Error(err)
asserts.ErrorIs(err, driver.ErrorThumbNotExist)
}
}
func TestHandler_Source(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{},
}
ctx := context.Background()
auth.General = auth.HMACAuth{SecretKey: []byte("test")}
// 成功
{
file := model.File{
Model: gorm.Model{
ID: 1,
},
Name: "test.jpg",
}
ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
sourceURL, err := handler.Source(ctx, "", 0, false, 0)
asserts.NoError(err)
asserts.NotEmpty(sourceURL)
asserts.Contains(sourceURL, "sign=")
}
// 下载
{
file := model.File{
Model: gorm.Model{
ID: 1,
},
Name: "test.jpg",
}
ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
sourceURL, err := handler.Source(ctx, "", 0, true, 0)
asserts.NoError(err)
asserts.NotEmpty(sourceURL)
asserts.Contains(sourceURL, "sign=")
asserts.Contains(sourceURL, "download")
}
// 无法获取上下文
{
sourceURL, err := handler.Source(ctx, "", 0, false, 0)
asserts.Error(err)
asserts.Empty(sourceURL)
}
// 设定了CDN
{
handler.Policy.BaseURL = "https://cqu.edu.cn"
file := model.File{
Model: gorm.Model{
ID: 1,
},
Name: "test.jpg",
}
ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
sourceURL, err := handler.Source(ctx, "", 0, false, 0)
asserts.NoError(err)
asserts.NotEmpty(sourceURL)
asserts.Contains(sourceURL, "sign=")
asserts.Contains(sourceURL, "https://cqu.edu.cn")
}
// 设定了CDN解析失败
{
handler.Policy.BaseURL = string([]byte{0x7f})
file := model.File{
Model: gorm.Model{
ID: 1,
},
Name: "test.jpg",
}
ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
sourceURL, err := handler.Source(ctx, "", 0, false, 0)
asserts.Error(err)
asserts.Empty(sourceURL)
}
}
func TestHandler_GetDownloadURL(t *testing.T) {
asserts := assert.New(t)
handler := Driver{Policy: &model.Policy{}}
ctx := context.Background()
auth.General = auth.HMACAuth{SecretKey: []byte("test")}
// 成功
{
file := model.File{
Model: gorm.Model{
ID: 1,
},
Name: "test.jpg",
}
ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
downloadURL, err := handler.Source(ctx, "", 10, true, 0)
asserts.NoError(err)
asserts.Contains(downloadURL, "sign=")
}
// 无法获取上下文
{
downloadURL, err := handler.Source(ctx, "", 10, true, 0)
asserts.Error(err)
asserts.Empty(downloadURL)
}
}
func TestHandler_Token(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{},
}
ctx := context.Background()
upSession := &serializer.UploadSession{SavePath: "TestHandler_Token"}
_, err := handler.Token(ctx, 10, upSession, &fsctx.FileStream{})
asserts.NoError(err)
file, _ := os.Create("TestHandler_Token")
defer func() {
file.Close()
os.Remove("TestHandler_Token")
}()
_, err = handler.Token(ctx, 10, upSession, &fsctx.FileStream{})
asserts.Error(err)
asserts.Contains(err.Error(), "already exist")
}
func TestDriver_CancelToken(t *testing.T) {
a := assert.New(t)
handler := Driver{}
a.NoError(handler.CancelToken(context.Background(), &serializer.UploadSession{}))
}
func TestDriver_List(t *testing.T) {
asserts := assert.New(t)
handler := Driver{}
ctx := context.Background()
// 创建测试目录结构
for _, path := range []string{
"test/TestDriver_List/parent.txt",
"test/TestDriver_List/parent_folder2/sub2.txt",
"test/TestDriver_List/parent_folder1/sub_folder/sub1.txt",
"test/TestDriver_List/parent_folder1/sub_folder/sub2.txt",
} {
f, _ := util.CreatNestedFile(util.RelativePath(path))
f.Close()
}
// 非递归列出
{
res, err := handler.List(ctx, "test/TestDriver_List", false)
asserts.NoError(err)
asserts.Len(res, 3)
}
// 递归列出
{
res, err := handler.List(ctx, "test/TestDriver_List", true)
asserts.NoError(err)
asserts.Len(res, 7)
}
}

View file

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"io"
"net/http"
"net/url"
@ -15,6 +14,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"

File diff suppressed because it is too large Load diff

View file

@ -2,6 +2,7 @@ package onedrive
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
model "github.com/cloudreve/Cloudreve/v3/models"

View file

@ -1,32 +0,0 @@
package onedrive
import (
"testing"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/stretchr/testify/assert"
)
func TestNewClient(t *testing.T) {
asserts := assert.New(t)
// getOAuthEndpoint失败
{
policy := model.Policy{
BaseURL: string([]byte{0x7f}),
}
res, err := NewClient(&policy)
asserts.Error(err)
asserts.Nil(res)
}
// 成功
{
policy := model.Policy{}
res, err := NewClient(&policy)
asserts.NoError(err)
asserts.NotNil(res)
asserts.NotNil(res.Credential)
asserts.NotNil(res.Endpoints)
asserts.NotNil(res.Endpoints.OAuthEndpoints)
}
}

View file

@ -1,420 +0,0 @@
package onedrive
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/jinzhu/gorm"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
func TestDriver_Token(t *testing.T) {
asserts := assert.New(t)
h, _ := NewDriver(&model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
})
handler := h.(Driver)
// 分片上传 失败
{
cache.Set("setting_siteURL", "http://test.cloudreve.org", 0)
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 400,
Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
},
})
handler.Client.Request = clientMock
res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{})
asserts.Error(err)
asserts.Nil(res)
}
// 分片上传 成功
{
cache.Set("setting_siteURL", "http://test.cloudreve.org", 0)
cache.Set("setting_onedrive_monitor_timeout", "600", 0)
cache.Set("setting_onedrive_callback_check", "20", 0)
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1"
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
},
})
handler.Client.Request = clientMock
go func() {
time.Sleep(time.Duration(1) * time.Second)
mq.GlobalMQ.Publish("TestDriver_Token", mq.Message{})
}()
res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{Key: "TestDriver_Token"}, &fsctx.FileStream{})
asserts.NoError(err)
asserts.Equal("123321", res.UploadURLs[0])
}
}
func TestDriver_Source(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
cache.Set("setting_onedrive_source_timeout", "1800", 0)
// 失败
{
res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0)
asserts.Error(err)
asserts.Empty(res)
}
// 命中缓存 成功
{
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1"
cache.Set("onedrive_source_0_123.jpg", "res", 1)
res, err := handler.Source(context.Background(), "123.jpg", 0, true, 0)
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
asserts.NoError(err)
asserts.Equal("res", res)
}
// 命中缓存 上下文存在文件 成功
{
file := model.File{}
file.ID = 1
file.UpdatedAt = time.Now()
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1"
cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0)
res, err := handler.Source(ctx, "123.jpg", 1, true, 0)
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
asserts.NoError(err)
asserts.Equal("res", res)
}
// 成功
{
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)),
},
})
handler.Client.Request = clientMock
handler.Client.Credential.AccessToken = "1"
res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0)
asserts.NoError(err)
asserts.Equal("123321", res)
}
}
func TestDriver_List(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.AccessToken = "AccessToken"
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 非递归
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{}]}`)),
},
})
handler.Client.Request = clientMock
res, err := handler.List(context.Background(), "/", false)
asserts.NoError(err)
asserts.Len(res, 1)
}
// 递归一次
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
"me/drive/root/children?$top=999999999",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"1","folder":{}}]}`)),
},
})
clientMock.On(
"Request",
"GET",
"me/drive/root:/1:/children?$top=999999999",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"2"}]}`)),
},
})
handler.Client.Request = clientMock
res, err := handler.List(context.Background(), "/", true)
asserts.NoError(err)
asserts.Len(res, 2)
}
}
func TestDriver_Thumb(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
file := &model.File{PicInfo: "1,1", Model: gorm.Model{ID: 1}}
// 失败
{
ctx := context.WithValue(context.Background(), fsctx.ThumbSizeCtx, [2]uint{10, 20})
res, err := handler.Thumb(ctx, file)
asserts.Error(err)
asserts.Empty(res.URL)
}
// 上下文错误
{
_, err := handler.Thumb(context.Background(), file)
asserts.Error(err)
}
}
func TestDriver_Delete(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 失败
{
_, err := handler.Delete(context.Background(), []string{"1"})
asserts.Error(err)
}
}
func TestDriver_Put(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 失败
{
err := handler.Put(context.Background(), &fsctx.FileStream{})
asserts.Error(err)
}
}
func TestDriver_Get(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 无法获取source
{
res, err := handler.Get(context.Background(), "123.txt")
asserts.Error(err)
asserts.Nil(res)
}
// 成功
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)),
},
})
handler.Client.Request = clientMock
handler.Client.Credential.AccessToken = "1"
driverClientMock := ClientMock{}
driverClientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`123`)),
},
})
handler.HTTPClient = driverClientMock
res, err := handler.Get(context.Background(), "123.txt")
clientMock.AssertExpectations(t)
asserts.NoError(err)
_, err = res.Seek(0, io.SeekEnd)
asserts.NoError(err)
content, err := ioutil.ReadAll(res)
asserts.NoError(err)
asserts.Equal("123", string(content))
}
func TestDriver_replaceSourceHost(t *testing.T) {
tests := []struct {
name string
origin string
cdn string
want string
wantErr bool
}{
{"TestNoReplace", "http://1dr.ms/download.aspx?123456", "", "http://1dr.ms/download.aspx?123456", false},
{"TestReplaceCorrect", "http://1dr.ms/download.aspx?123456", "https://test.com:8080", "https://test.com:8080/download.aspx?123456", false},
{"TestCdnFormatError", "http://1dr.ms/download.aspx?123456", string([]byte{0x7f}), "", true},
{"TestSrcFormatError", string([]byte{0x7f}), "https://test.com:8080", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policy := &model.Policy{}
policy.OptionsSerialized.OdProxy = tt.cdn
handler := Driver{
Policy: policy,
}
got, err := handler.replaceSourceHost(tt.origin)
if (err != nil) != tt.wantErr {
t.Errorf("replaceSourceHost() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("replaceSourceHost() got = %v, want %v", got, tt.want)
}
})
}
}
func TestDriver_CancelToken(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 失败
{
err := handler.CancelToken(context.Background(), &serializer.UploadSession{})
asserts.Error(err)
}
}

View file

@ -0,0 +1,25 @@
package onedrive
import "sync"
// CredentialLock 针对存储策略凭证的锁
type CredentialLock interface {
Lock(uint)
Unlock(uint)
}
var GlobalMutex = mutexMap{}
type mutexMap struct {
locks sync.Map
}
func (m *mutexMap) Lock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Lock()
}
func (m *mutexMap) Unlock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Unlock()
}

View file

@ -1,386 +0,0 @@
package onedrive
import (
"context"
"database/sql"
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
var mock sqlmock.Sqlmock
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
defer db.Close()
m.Run()
}
func TestGetOAuthEndpoint(t *testing.T) {
asserts := assert.New(t)
// URL解析失败
{
client := Client{
Endpoints: &Endpoints{
OAuthURL: string([]byte{0x7f}),
},
}
res := client.getOAuthEndpoint()
asserts.Nil(res)
}
{
testCase := []struct {
OAuthURL string
token string
auth string
isChina bool
}{
{
OAuthURL: "http://login.live.com",
token: "https://login.live.com/oauth20_token.srf",
auth: "https://login.live.com/oauth20_authorize.srf",
isChina: false,
},
{
OAuthURL: "http://login.chinacloudapi.cn",
token: "https://login.chinacloudapi.cn/common/oauth2/v2.0/token",
auth: "https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize",
isChina: true,
},
{
OAuthURL: "other",
token: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
auth: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
isChina: false,
},
}
for i, testCase := range testCase {
client := Client{
Endpoints: &Endpoints{
OAuthURL: testCase.OAuthURL,
},
}
res := client.getOAuthEndpoint()
asserts.Equal(testCase.token, res.token.String(), "Test Case #%d", i)
asserts.Equal(testCase.auth, res.authorize.String(), "Test Case #%d", i)
asserts.Equal(testCase.isChina, client.Endpoints.isInChina, "Test Case #%d", i)
}
}
}
func TestClient_OAuthURL(t *testing.T) {
asserts := assert.New(t)
client := Client{
ClientID: "client_id",
Redirect: "http://cloudreve.org/callback",
Endpoints: &Endpoints{},
}
client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint()
res, err := url.Parse(client.OAuthURL(context.Background(), []string{"scope1", "scope2"}))
asserts.NoError(err)
query := res.Query()
asserts.Equal("client_id", query.Get("client_id"))
asserts.Equal("scope1 scope2", query.Get("scope"))
asserts.Equal(client.Redirect, query.Get("redirect_uri"))
}
type ClientMock struct {
testMock.Mock
}
func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
args := m.Called(method, target, body, opts)
return args.Get(0).(*request.Response)
}
type mockReader string
func (r mockReader) Read(b []byte) (int, error) {
return 0, errors.New("read error")
}
func TestClient_ObtainToken(t *testing.T) {
asserts := assert.New(t)
client := Client{
Endpoints: &Endpoints{},
ClientID: "ClientID",
ClientSecret: "ClientSecret",
Redirect: "Redirect",
}
client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint()
// 刷新Token 成功
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
client.Endpoints.OAuthEndpoints.token.String(),
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"access_token":"i am token"}`)),
},
})
client.Request = clientMock
res, err := client.ObtainToken(context.Background())
clientMock.AssertExpectations(t)
asserts.NoError(err)
asserts.NotNil(res)
asserts.Equal("i am token", res.AccessToken)
}
// 重新获取 无法发送请求
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
client.Endpoints.OAuthEndpoints.token.String(),
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: errors.New("error"),
})
client.Request = clientMock
res, err := client.ObtainToken(context.Background(), WithCode("code"))
clientMock.AssertExpectations(t)
asserts.Error(err)
asserts.Nil(res)
}
// 刷新Token 无法获取响应正文
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
client.Endpoints.OAuthEndpoints.token.String(),
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(mockReader("")),
},
})
client.Request = clientMock
res, err := client.ObtainToken(context.Background())
clientMock.AssertExpectations(t)
asserts.Error(err)
asserts.Nil(res)
asserts.Equal("read error", err.Error())
}
// 刷新Token OneDrive返回错误
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
client.Endpoints.OAuthEndpoints.token.String(),
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 400,
Body: ioutil.NopCloser(strings.NewReader(`{"error":"i am error"}`)),
},
})
client.Request = clientMock
res, err := client.ObtainToken(context.Background())
clientMock.AssertExpectations(t)
asserts.Error(err)
asserts.Nil(res)
asserts.Equal("", err.Error())
}
// 刷新Token OneDrive未知响应
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
client.Endpoints.OAuthEndpoints.token.String(),
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 400,
Body: ioutil.NopCloser(strings.NewReader(``)),
},
})
client.Request = clientMock
res, err := client.ObtainToken(context.Background())
clientMock.AssertExpectations(t)
asserts.Error(err)
asserts.Nil(res)
}
}
func TestClient_UpdateCredential(t *testing.T) {
asserts := assert.New(t)
client := Client{
Policy: &model.Policy{Model: gorm.Model{ID: 257}},
Endpoints: &Endpoints{},
ClientID: "TestClient_UpdateCredential",
ClientSecret: "ClientSecret",
Redirect: "Redirect",
Credential: &Credential{},
}
client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint()
// 无有效的RefreshToken
{
err := client.UpdateCredential(context.Background(), false)
asserts.Equal(ErrInvalidRefreshToken, err)
client.Credential = nil
err = client.UpdateCredential(context.Background(), false)
asserts.Equal(ErrInvalidRefreshToken, err)
}
// 成功
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
client.Endpoints.OAuthEndpoints.token.String(),
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"expires_in":3600,"refresh_token":"new_refresh_token","access_token":"i am token"}`)),
},
})
client.Request = clientMock
client.Credential = &Credential{
RefreshToken: "old_refresh_token",
}
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := client.UpdateCredential(context.Background(), false)
clientMock.AssertExpectations(t)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
cacheRes, ok := cache.Get("onedrive_TestClient_UpdateCredential")
asserts.True(ok)
cacheCredential := cacheRes.(Credential)
asserts.Equal("new_refresh_token", cacheCredential.RefreshToken)
asserts.Equal("i am token", cacheCredential.AccessToken)
}
// OneDrive返回错误
{
cache.Deletes([]string{"TestClient_UpdateCredential"}, "onedrive_")
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
client.Endpoints.OAuthEndpoints.token.String(),
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 400,
Body: ioutil.NopCloser(strings.NewReader(`{"error":"error"}`)),
},
})
client.Request = clientMock
client.Credential = &Credential{
RefreshToken: "old_refresh_token",
}
err := client.UpdateCredential(context.Background(), false)
clientMock.AssertExpectations(t)
asserts.Error(err)
}
// 从缓存中获取
{
cache.Set("onedrive_TestClient_UpdateCredential", Credential{
ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(),
AccessToken: "AccessToken",
RefreshToken: "RefreshToken",
}, 0)
client.Credential = &Credential{
RefreshToken: "old_refresh_token",
}
err := client.UpdateCredential(context.Background(), false)
asserts.NoError(err)
asserts.Equal("AccessToken", client.Credential.AccessToken)
asserts.Equal("RefreshToken", client.Credential.RefreshToken)
}
// 无需重新获取
{
client.Credential = &Credential{
RefreshToken: "old_refresh_token",
AccessToken: "AccessToken2",
ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(),
}
err := client.UpdateCredential(context.Background(), false)
asserts.NoError(err)
asserts.Equal("AccessToken2", client.Credential.AccessToken)
}
// slave failed
{
mockController := &controllermock.SlaveControllerMock{}
mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("", errors.New("error"))
client.ClusterController = mockController
err := client.UpdateCredential(context.Background(), true)
asserts.Error(err)
}
// slave success
{
mockController := &controllermock.SlaveControllerMock{}
mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("AccessToken3", nil)
client.ClusterController = mockController
err := client.UpdateCredential(context.Background(), true)
asserts.NoError(err)
asserts.Equal("AccessToken3", client.Credential.AccessToken)
}
}

View file

@ -172,14 +172,24 @@ func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object.Key),
Source: object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(object.Size),
IsDir: false,
LastModify: object.LastModified,
})
if strings.HasSuffix(object.Key, "/") {
res = append(res, response.Object{
Name: path.Base(object.Key),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
} else {
res = append(res, response.Object{
Name: path.Base(object.Key),
Source: object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(object.Size),
IsDir: false,
LastModify: object.LastModified,
})
}
}
return res, nil

View file

@ -1,262 +0,0 @@
package remote
import (
"context"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
"io/ioutil"
"net/http"
"strings"
"testing"
)
func TestNewClient(t *testing.T) {
a := assert.New(t)
policy := &model.Policy{}
// 无法解析服务端url
{
policy.Server = string([]byte{0x7f})
c, err := NewClient(policy)
a.Error(err)
a.Nil(c)
}
// 成功
{
policy.Server = ""
c, err := NewClient(policy)
a.NoError(err)
a.NotNil(c)
}
}
func TestRemoteClient_Upload(t *testing.T) {
a := assert.New(t)
c, _ := NewClient(&model.Policy{})
// 无法创建上传会话
{
clientMock := requestmock.RequestMock{}
c.(*remoteClient).httpClient = &clientMock
clientMock.On(
"Request",
"PUT",
"upload",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: errors.New("error"),
})
err := c.Upload(context.Background(), &fsctx.FileStream{})
a.Error(err)
a.Contains(err.Error(), "error")
clientMock.AssertExpectations(t)
}
// 分片上传失败,成功删除上传会话
{
cache.Set("setting_chunk_retries", "1", 0)
clientMock := requestmock.RequestMock{}
c.(*remoteClient).httpClient = &clientMock
clientMock.On(
"Request",
"PUT",
"upload",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
},
})
clientMock.On(
"Request",
"POST",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: errors.New("error"),
})
clientMock.On(
"Request",
"DELETE",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
},
})
err := c.Upload(context.Background(), &fsctx.FileStream{})
a.Error(err)
a.Contains(err.Error(), "error")
clientMock.AssertExpectations(t)
}
// 分片上传失败,无法删除上传会话
{
cache.Set("setting_chunk_retries", "1", 0)
clientMock := requestmock.RequestMock{}
c.(*remoteClient).httpClient = &clientMock
clientMock.On(
"Request",
"PUT",
"upload",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
},
})
clientMock.On(
"Request",
"POST",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: errors.New("error"),
})
clientMock.On(
"Request",
"DELETE",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: errors.New("error2"),
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
},
})
err := c.Upload(context.Background(), &fsctx.FileStream{})
a.Error(err)
a.Contains(err.Error(), "error")
clientMock.AssertExpectations(t)
}
// 成功
{
cache.Set("setting_chunk_retries", "1", 0)
clientMock := requestmock.RequestMock{}
c.(*remoteClient).httpClient = &clientMock
clientMock.On(
"Request",
"PUT",
"upload",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
},
})
clientMock.On(
"Request",
"POST",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
},
})
err := c.Upload(context.Background(), &fsctx.FileStream{})
a.NoError(err)
clientMock.AssertExpectations(t)
}
}
func TestRemoteClient_CreateUploadSessionFailed(t *testing.T) {
a := assert.New(t)
c, _ := NewClient(&model.Policy{})
clientMock := requestmock.RequestMock{}
c.(*remoteClient).httpClient = &clientMock
clientMock.On(
"Request",
"PUT",
"upload",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"code":500,"msg":"error"}`)),
},
})
err := c.Upload(context.Background(), &fsctx.FileStream{})
a.Error(err)
a.Contains(err.Error(), "error")
clientMock.AssertExpectations(t)
}
func TestRemoteClient_UploadChunkFailed(t *testing.T) {
a := assert.New(t)
c, _ := NewClient(&model.Policy{})
clientMock := requestmock.RequestMock{}
c.(*remoteClient).httpClient = &clientMock
clientMock.On(
"Request",
"POST",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"code":500,"msg":"error"}`)),
},
})
err := c.(*remoteClient).uploadChunk(context.Background(), "", 0, strings.NewReader(""), false, 0)
a.Error(err)
a.Contains(err.Error(), "error")
clientMock.AssertExpectations(t)
}
func TestRemoteClient_GetUploadURL(t *testing.T) {
a := assert.New(t)
c, _ := NewClient(&model.Policy{})
// url 解析失败
{
c.(*remoteClient).policy.Server = string([]byte{0x7f})
res, sign, err := c.GetUploadURL(0, "")
a.Error(err)
a.Empty(res)
a.Empty(sign)
}
// 成功
{
c.(*remoteClient).policy.Server = ""
res, sign, err := c.GetUploadURL(0, "")
a.NoError(err)
a.NotEmpty(res)
a.NotEmpty(sign)
}
}

Some files were not shown because too many files have changed in this diff Show more