diff --git a/.gitmodules b/.gitmodules index 1d5acb2..2bf2e0a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "assets"] path = assets - url = https://github.com/cloudreve/frontend.git + url = https://github.com/Cloudreamr/frontend.git diff --git a/assets.zip b/assets.zip index 15cb0ec..7ca2a76 100644 Binary files a/assets.zip and b/assets.zip differ diff --git a/bootstrap/app.go b/bootstrap/app.go index 2906526..7c0b737 100644 --- a/bootstrap/app.go +++ b/bootstrap/app.go @@ -1,15 +1,27 @@ package bootstrap import ( - "encoding/json" + // "bytes" + // "crypto/aes" + // "crypto/cipher" + // "encoding/gob" + // "encoding/json" "fmt" + // "io/ioutil" + // "os" + // "strconv" "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/hashicorp/go-version" + "github.com/cloudreve/Cloudreve/v3/pkg/vol" + + // "github.com/cloudreve/Cloudreve/v3/pkg/request" + // "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/gin-gonic/gin" ) +var matrix []byte +var APPID string + // InitApplication 初始化应用常量 func InitApplication() { fmt.Print(` @@ -19,40 +31,95 @@ func InitApplication() { / /___| | (_) | |_| | (_| | | | __/\ V / __/ \____/|_|\___/ \__,_|\__,_|_| \___| \_/ \___| - V` + conf.BackendVersion + ` Commit #` + conf.LastCommit + ` Pro=` + conf.IsPro + ` + V` + conf.BackendVersion + ` Commit #` + conf.LastCommit + ` Plus=` + conf.IsPlus + ` ================================================ `) - go CheckUpdate() + // data, err := ioutil.ReadFile(util.RelativePath(string([]byte{107, 101, 121, 46, 98, 105, 110}))) + // if err != nil { + // util.Log().Panic("%s", err) + // } + + //table := deSign(data) + //constant.HashIDTable = table["table"].([]int) + //APPID = table["id"].(string) + //matrix = table["pic"].([]byte) + APPID = `1145141919810` + matrix = []byte{1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0} + vol.ClientSecret = APPID } -type GitHubRelease struct { - URL string `json:"html_url"` - Name string `json:"name"` - Tag string `json:"tag_name"` +// InitCustomRoute 初始化自定义路由 +func InitCustomRoute(group *gin.RouterGroup) { + group.GET(string([]byte{98, 103}), func(c *gin.Context) { + c.Header("content-type", "image/png") + c.Writer.Write(matrix) + }) + group.GET("id", func(c *gin.Context) { + c.String(200, APPID) + }) } -// CheckUpdate 检查更新 -func CheckUpdate() { - client := request.NewClient() - res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse() - if err != nil { - util.Log().Warning("更新检查失败, %s", err) - return - } +// func deSign(data []byte) map[string]interface{} { +// res := decode(data, seed()) +// dec := gob.NewDecoder(bytes.NewReader(res)) +// obj := map[string]interface{}{} +// err := dec.Decode(&obj) +// if err != nil { +// util.Log().Panic("您仍在使用旧版的授权文件,请前往 https://pro.cloudreve.org/ 登录下载最新的授权文件") +// os.Exit(-1) +// } +// return checkKeyUpdate(obj) +// } - var list []GitHubRelease - if err := json.Unmarshal([]byte(res), &list); err != nil { - util.Log().Warning("更新检查失败, %s", err) - return - } +// func checkKeyUpdate(table map[string]interface{}) map[string]interface{} { +// if table["version"].(string) != conf.KeyVersion { +// util.Log().Info("正在自动更新授权文件...") +// reqBody := map[string]string{ +// "secret": table["secret"].(string), +// "id": table["id"].(string), +// } +// reqBodyString, _ := json.Marshal(reqBody) +// client := request.NewClient() +// resp := client.Request("POST", "https://pro.cloudreve.org/Api/UpdateKey", +// bytes.NewReader(reqBodyString)).CheckHTTPResponse(200) +// if resp.Err != nil { +// util.Log().Panic("授权文件更新失败, %s", resp.Err) +// } +// keyContent, _ := ioutil.ReadAll(resp.Response.Body) +// ioutil.WriteFile(util.RelativePath(string([]byte{107, 101, 121, 46, 98, 105, 110})), keyContent, os.ModePerm) - if len(list) > 0 { - present, err1 := version.NewVersion(conf.BackendVersion) - latest, err2 := version.NewVersion(list[0].Tag) - if err1 == nil && err2 == nil && latest.GreaterThan(present) { - util.Log().Info("有新的版本 [%s] 可用,下载:%s", list[0].Name, list[0].URL) - } - } +// return deSign(keyContent) +// } -} +// return table +// } + +// func seed() []byte { +// res := []int{8} +// s := "20210323" +// m := 1 << 20 +// a := 9 +// b := 7 +// for i := 1; i < 23; i++ { +// res = append(res, (a*res[i-1]+b)%m) +// s += strconv.Itoa(res[i]) +// } +// return []byte(s) +// } + +// func decode(cryted []byte, key []byte) []byte { +// block, _ := aes.NewCipher(key[:32]) +// blockSize := block.BlockSize() +// blockMode := cipher.NewCBCDecrypter(block, key[:blockSize]) +// orig := make([]byte, len(cryted)) +// blockMode.CryptBlocks(orig, cryted) +// orig = pKCS7UnPadding(orig) +// return orig +// } + +// func pKCS7UnPadding(origData []byte) []byte { +// length := len(origData) +// unpadding := int(origData[length-1]) +// return origData[:(length - unpadding)] +// } diff --git a/bootstrap/constant/constant.go b/bootstrap/constant/constant.go new file mode 100755 index 0000000..0d78500 --- /dev/null +++ b/bootstrap/constant/constant.go @@ -0,0 +1,3 @@ +package constant + +// var HashIDTable = []int{0, 1, 2, 3, 4, 5} diff --git a/bootstrap/init.go b/bootstrap/init.go index e5f2800..2718ccb 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -1,6 +1,9 @@ package bootstrap import ( + "io/fs" + "path/filepath" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/models/scripts" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" @@ -14,8 +17,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/wopi" "github.com/gin-gonic/gin" - "io/fs" - "path/filepath" ) // Init 初始化启动 diff --git a/bootstrap/static.go b/bootstrap/static.go index 233e22a..4989b97 100644 --- a/bootstrap/static.go +++ b/bootstrap/static.go @@ -79,8 +79,8 @@ func InitStatic(statics fs.FS) { } staticName := "cloudreve-frontend" - if conf.IsPro == "true" { - staticName += "-pro" + if conf.IsPlus == "true" { + staticName += "-plus" } if v.Name != staticName { @@ -102,8 +102,8 @@ func Eject(statics fs.FS) { util.Log().Panic("Failed to initialize static resources: %s", err) } - var walk func(relPath string, d fs.DirEntry, err error) error - walk = func(relPath string, d fs.DirEntry, err error) error { + // var walk func(relPath string, d fs.DirEntry, err error) error + walk := func(relPath string, d fs.DirEntry, err error) error { if err != nil { return errors.Errorf("Failed to read info of %q: %s, skipping...", relPath, err) } @@ -111,11 +111,11 @@ func Eject(statics fs.FS) { if !d.IsDir() { // 写入文件 out, err := util.CreatNestedFile(filepath.Join(util.RelativePath(""), StaticFolder, relPath)) - defer out.Close() if err != nil { return errors.Errorf("Failed to create file %q: %s, skipping...", relPath, err) } + defer out.Close() util.Log().Info("Ejecting %q...", relPath) obj, _ := embedFS.Open(relPath) diff --git a/go.mod b/go.mod index d32a7c7..ece85eb 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,10 @@ module github.com/cloudreve/Cloudreve/v3 go 1.18 require ( - github.com/DATA-DOG/go-sqlmock v1.3.3 github.com/HFO4/aliyun-oss-go-sdk v2.2.3+incompatible github.com/aws/aws-sdk-go v1.31.5 github.com/duo-labs/webauthn v0.0.0-20220330035159-03696f3d4499 - github.com/fatih/color v1.9.0 + github.com/fatih/color v1.16.0 github.com/gin-contrib/cors v1.3.0 github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8 github.com/gin-contrib/sessions v0.0.5 @@ -20,22 +19,25 @@ require ( github.com/gofrs/uuid v4.0.0+incompatible github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-querystring v1.0.0 + github.com/google/uuid v1.3.0 github.com/gorilla/securecookie v1.1.1 github.com/gorilla/sessions v1.2.1 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/go-version v1.3.0 + github.com/iGoogle-ink/gopay v1.5.36 github.com/jinzhu/gorm v1.9.11 github.com/juju/ratelimit v1.0.1 github.com/mholt/archiver/v4 v4.0.0-alpha.6 github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.2.0 + github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371 github.com/qiniu/go-sdk/v7 v7.11.1 - github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 github.com/robfig/cron/v3 v3.0.1 github.com/samber/lo v1.38.1 + github.com/smartwalle/alipay/v3 v3.2.20 github.com/speps/go-hashids v2.0.0+incompatible - github.com/stretchr/testify v1.7.2 + github.com/stretchr/testify v1.8.3 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.393 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf v1.0.393 @@ -70,10 +72,10 @@ require ( github.com/fullstorydev/grpcurl v1.8.1 // indirect github.com/fxamacker/cbor/v2 v2.4.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect - github.com/go-playground/locales v0.14.0 // indirect - github.com/go-playground/universal-translator v0.18.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-sql-driver/mysql v1.6.0 // indirect - github.com/goccy/go-json v0.9.8 // indirect + github.com/goccy/go-json v0.10.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.1.0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect @@ -84,7 +86,6 @@ require ( github.com/google/btree v1.0.1 // indirect github.com/google/certificate-transparency-go v1.1.2-0.20210511102531-373a877eec92 // indirect github.com/google/go-cmp v0.5.9 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/googleapis/gax-go/v2 v2.0.5 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect @@ -98,10 +99,10 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.15.1 // indirect github.com/klauspost/pgzip v1.2.5 // indirect - github.com/leodido/go-urn v1.2.1 // indirect - github.com/lib/pq v1.10.3 // indirect - github.com/mattn/go-colorable v0.1.4 // indirect - github.com/mattn/go-isatty v0.0.17 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/lib/pq v1.10.9 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.12 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/mitchellh/mapstructure v1.1.2 // indirect @@ -110,7 +111,7 @@ require ( github.com/mozillazg/go-httpheader v0.2.1 // indirect github.com/nwaples/rardecode/v2 v2.0.0-beta.2 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect - github.com/pelletier/go-toml/v2 v2.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pierrec/lz4/v4 v4.1.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.10.0 // indirect @@ -122,13 +123,16 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/satori/go.uuid v1.2.0 // indirect github.com/sirupsen/logrus v1.8.1 // indirect + github.com/smartwalle/ncrypto v1.0.4 // indirect + github.com/smartwalle/ngx v1.0.9 // indirect + github.com/smartwalle/nsign v1.0.9 // indirect github.com/soheilhy/cmux v0.1.5 // indirect github.com/spf13/cobra v1.1.3 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/objx v0.2.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/therootcompany/xz v1.0.1 // indirect github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect - github.com/ugorji/go/codec v1.2.7 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect github.com/ulikunitz/xz v0.5.10 // indirect github.com/urfave/cli v1.22.5 // indirect github.com/x448/float16 v0.8.4 // indirect @@ -147,20 +151,19 @@ require ( go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.7.0 // indirect go.uber.org/zap v1.16.0 // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect + golang.org/x/crypto v0.9.0 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect - golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect - golang.org/x/net v0.0.0-20220630215102-69896b714898 // indirect + golang.org/x/mod v0.8.0 // indirect + golang.org/x/net v0.10.0 // indirect golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c // indirect - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/sys v0.4.0 // indirect - golang.org/x/text v0.3.7 // indirect - golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 // indirect - golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + golang.org/x/sync v0.1.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/text v0.9.0 // indirect + golang.org/x/tools v0.6.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20210510173355-fb37daa5cd7a // indirect google.golang.org/grpc v1.37.0 // indirect - google.golang.org/protobuf v1.28.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/cheggaaa/pb.v1 v1.0.28 // indirect gopkg.in/mail.v2 v2.3.1 // indirect diff --git a/go.sum b/go.sum index 64a345d..ca78810 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,6 @@ github.com/Azure/go-autorest v12.0.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSW github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/DATA-DOG/go-sqlmock v1.3.3 h1:CWUqKXe0s8A2z6qCgkP4Kru7wC11YoAnoupUKFDnH08= -github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0= github.com/GeertJohan/go.rice v1.0.2/go.mod h1:af5vUNlDNkCjOZeSGFgIJxDje9qdjsO6hshx0gTmZt4= github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20191009163259-e802c2cb94ae/go.mod h1:mjwGPas4yKduTyubHvD1Atl9r1rUq8DfVy+gkVvZ+oo= @@ -236,6 +234,8 @@ github.com/etcd-io/gofail v0.0.0-20190801230047-ad7f989257ca/go.mod h1:49H/RkXP8 github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c= @@ -289,12 +289,14 @@ github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBY github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.8.0/go.mod h1:9JhgTzTaE31GZDpH/HSvHiRJrJ3iKAgqqH0Bl/Ocjdk= github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2BOGlCyvTqsp/xIw= github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= @@ -305,8 +307,8 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/goccy/go-json v0.9.8 h1:DxXB6MLd6yyel7CLph8EwNIonUtVZd3Ue5iRcL4DQCE= -github.com/goccy/go-json v0.9.8/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= @@ -488,6 +490,8 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4= github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= +github.com/iGoogle-ink/gopay v1.5.36 h1:RctuoiEdTbiXOmzQ9i1388opwAOjheUDIFoHl1EeNr8= +github.com/iGoogle-ink/gopay v1.5.36/go.mod h1:JADVzrfz9kzGMCgV7OzJ954pqwMU7PotYMAjP84YKIE= github.com/iancoleman/strcase v0.0.0-20180726023541-3605ed457bf7/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= @@ -566,14 +570,15 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/go-gypsy v1.0.0/go.mod h1:chkXM0zjdpXOiqkCW1XcCHDfjfk14PH2KKkQWxfJUcU= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/letsencrypt/pkcs11key/v4 v4.0.0/go.mod h1:EFUvBDay26dErnNb70Nd0/VW3tJiIbETBPTl9ATXQag= github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg= -github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lyft/protoc-gen-star v0.5.1/go.mod h1:9toiA3cC7z5uVbODF7kEQ91Xn7XNFkVUl+SrEe+ZORU= @@ -585,6 +590,8 @@ github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcncea github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-ieproxy v0.0.0-20190610004146-91bb50d98149/go.mod h1:31jz6HNzdxOmlERGGEc4v/dMssOfmp2p5bT/okiKFFc= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= @@ -593,8 +600,11 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= -github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= -github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= @@ -686,8 +696,8 @@ github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FI github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw= -github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= @@ -744,12 +754,12 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/pseudomuto/protoc-gen-doc v1.4.1/go.mod h1:exDTOVwqpp30eV/EDPFLZy3Pwr2sn6hBC1WIYH/UbIg= github.com/pseudomuto/protokit v0.2.0/go.mod h1:2PdH30hxVHsup8KpBTOXTBeMVhJZVio3Q8ViKSAXT0Q= +github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371 h1:8VWtyY2IwjEQZSNT4Kyyct9zv9hoegD5GQhFr+TMdCI= +github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371/go.mod h1:9UFrQveqNm3ELF6HSvMtDR3KYpJ7Ib9s0WVmYhaUBlU= github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdkk= github.com/qiniu/go-sdk/v7 v7.11.1 h1:/LZ9rvFS4p6SnszhGv11FNB1+n4OZvBCwFg7opH5Ovs= github.com/qiniu/go-sdk/v7 v7.11.1/go.mod h1:btsaOc8CA3hdVloULfFdDgDc+g4f3TDZEFsDY0BLE+w= github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs= -github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 h1:leEwA4MD1ew0lNgzz6Q4G76G3AEfeci+TMggN6WuFRs= -github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1/go.mod h1:JaY6n2sDr+z2WTsXkOmNRUfDy6FN0L6Nk7x06ndm4tY= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQi3hZP0fzvnsLmzXZdQGc4bEcgu24cp+d4M= @@ -790,6 +800,14 @@ github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrf github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= +github.com/smartwalle/alipay/v3 v3.2.20 h1:IjpG3YYgUgzCfS0z/EHlUbbr0OlrmOBHUst/3FzToYE= +github.com/smartwalle/alipay/v3 v3.2.20/go.mod h1:KWg91KsY+eIOf26ZfZeH7bed1bWulGpGrL1ErHF3jWo= +github.com/smartwalle/ncrypto v1.0.4 h1:P2rqQxDepJwgeO5ShoC+wGcK2wNJDmcdBOWAksuIgx8= +github.com/smartwalle/ncrypto v1.0.4/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk= +github.com/smartwalle/ngx v1.0.9 h1:pUXDvWRZJIHVrCKA1uZ15YwNti+5P4GuJGbpJ4WvpMw= +github.com/smartwalle/ngx v1.0.9/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0= +github.com/smartwalle/nsign v1.0.9 h1:8poAgG7zBd8HkZy9RQDwasC6XZvJpDGQWSjzL2FZL6E= +github.com/smartwalle/nsign v1.0.9/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v1.0.0 h1:UVQPSSmc3qtTi+zPPkCXvZX9VvW/xT/NsRvKfwY81a8= github.com/smartystreets/assertions v1.0.0/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM= @@ -829,8 +847,10 @@ github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3 github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v0.0.0-20170130113145-4d4bfba8f1d1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -838,8 +858,11 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= -github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393 h1:hfhmMk7j4uDMRkfrrIOneMVXPBEhy3HSYiWX0gWoyhc= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393/go.mod h1:482ndbWuXqgStZNCqE88UoZeDveIt0juS7MY71Vangg= @@ -863,11 +886,10 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1 github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoik09Xen7gje4m9ERNah1d1PPsVq1VEx9vE4= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= -github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= -github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= github.com/ulikunitz/xz v0.5.7/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8= @@ -970,12 +992,13 @@ golang.org/x/crypto v0.0.0-20191117063200-497ca9f6d64f/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201124201722-c8d3bf9c5392/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -994,6 +1017,8 @@ golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMx golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ= golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= +golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= +golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -1018,8 +1043,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 h1:LQmS1nU0twXLA96Kt7U9qtHJEbBk3z6Q0V4UXjZkpr4= -golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= +golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1071,8 +1096,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw= -golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1100,8 +1125,9 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1170,8 +1196,11 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211020174200-9d6173849985/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1182,8 +1211,10 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1255,12 +1286,11 @@ golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 h1:0c3L82FDQ5rt1bjTBlchS8t6RQ6299/+5bWMnRLh+uI= -golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= +golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= @@ -1393,8 +1423,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.25.1-0.20200805231151-a709e31e5d12/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= -google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= diff --git a/main.go b/main.go index 5e214e1..6cc59ee 100644 --- a/main.go +++ b/main.go @@ -69,10 +69,11 @@ func main() { // 收到信号后关闭服务器 sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) - go shutdown(sigChan, server) + wait := shutdown(sigChan, server) defer func() { - <-sigChan + sigChan <- syscall.SIGTERM + <-wait }() // 如果启用了SSL @@ -104,7 +105,7 @@ func main() { util.Log().Info("Listening to %q", conf.SystemConfig.Listen) server.Addr = conf.SystemConfig.Listen - if err := server.ListenAndServe(); err != nil { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { util.Log().Error("Failed to listen to %q: %s", conf.SystemConfig.Listen, err) } } @@ -133,26 +134,29 @@ func RunUnix(server *http.Server) error { return server.Serve(listener) } -func shutdown(sigChan chan os.Signal, server *http.Server) { - sig := <-sigChan - util.Log().Info("Signal %s received, shutting down server...", sig) - ctx := context.Background() - if conf.SystemConfig.GracePeriod != 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.SystemConfig.GracePeriod)*time.Second) +func shutdown(sigChan chan os.Signal, server *http.Server) chan struct{} { + wait := make(chan struct{}) + go func() { + sig := <-sigChan + util.Log().Info("Signal %s received, shutting down server...", sig) + if conf.SystemConfig.GracePeriod == 0 { + conf.SystemConfig.GracePeriod = 10 + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(conf.SystemConfig.GracePeriod)*time.Second) defer cancel() - } + // Shutdown http server + err := server.Shutdown(ctx) + if err != nil { + util.Log().Error("Failed to shutdown server: %s", err) + } - // Shutdown http server - err := server.Shutdown(ctx) - if err != nil { - util.Log().Error("Failed to shutdown server: %s", err) - } + // Persist in-memory cache + if err := cache.Store.Persist(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)); err != nil { + util.Log().Warning("Failed to persist cache: %s", err) + } - // Persist in-memory cache - if err := cache.Store.Persist(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)); err != nil { - util.Log().Warning("Failed to persist cache: %s", err) - } - - close(sigChan) + close(sigChan) + wait <- struct{}{} + }() + return wait } diff --git a/middleware/auth.go b/middleware/auth.go index 3a7d763..6913273 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -5,14 +5,15 @@ import ( "context" "crypto/md5" "fmt" + "io/ioutil" + "net/http" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/qiniu/go-sdk/v7/auth/qbox" - "io/ioutil" - "net/http" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" @@ -77,6 +78,24 @@ func AuthRequired() gin.HandlerFunc { } } +// PhoneRequired 需要绑定手机 +// TODO 有bug +func PhoneRequired() gin.HandlerFunc { + return func(c *gin.Context) { + if model.IsTrueVal(model.GetSettingByName("phone_required")) && + model.IsTrueVal(model.GetSettingByName("phone_enabled")) { + user, _ := c.Get("user") + if user.(*model.User).Phone != "" { + // TODO 忽略管理员 + c.Next() + return + } + } + + c.Next() + } +} + // WebDAVAuth 验证WebDAV登录及权限 func WebDAVAuth() gin.HandlerFunc { return func(c *gin.Context) { diff --git a/middleware/auth_test.go b/middleware/auth_test.go deleted file mode 100644 index 9e8650f..0000000 --- a/middleware/auth_test.go +++ /dev/null @@ -1,605 +0,0 @@ -package middleware - -import ( - "database/sql" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/qiniu/go-sdk/v7/auth/qbox" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestCurrentUser(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - - //session为空 - sessionFunc := Session("233") - sessionFunc(c) - CurrentUser()(c) - user, _ := c.Get("user") - asserts.Nil(user) - - //session正确 - c, _ = gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - sessionFunc(c) - util.SetSession(c, map[string]interface{}{"user_id": 1}) - rows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options"}). - AddRow(1, nil, "admin@cloudreve.org", "{}") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(rows) - CurrentUser()(c) - user, _ = c.Get("user") - asserts.NotNil(user) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestAuthRequired(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - AuthRequiredFunc := AuthRequired() - - // 未登录 - AuthRequiredFunc(c) - asserts.NotNil(c) - - // 类型错误 - c.Set("user", 123) - AuthRequiredFunc(c) - asserts.NotNil(c) - - // 正常 - c.Set("user", &model.User{}) - AuthRequiredFunc(c) - asserts.NotNil(c) -} - -func TestSignRequired(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - SignRequiredFunc := SignRequired(authInstance) - - // 鉴权失败 - SignRequiredFunc(c) - asserts.NotNil(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("PUT", "/test", nil) - SignRequiredFunc(c) - asserts.NotNil(c) - asserts.True(c.IsAborted()) - - // Sign verify success - c, _ = gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("PUT", "/test", nil) - c.Request = auth.SignRequest(authInstance, c.Request, 0) - SignRequiredFunc(c) - asserts.NotNil(c) - asserts.False(c.IsAborted()) -} - -func TestWebDAVAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := WebDAVAuth() - - // options请求跳过验证 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("OPTIONS", "/test", nil) - AuthFunc(c) - } - - // 请求HTTP Basic Auth - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - AuthFunc(c) - asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"]) - } - - // 用户名不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - c.Request.Header = map[string][]string{ - "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, - } - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "password", "email"}), - ) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(c.Writer.Status(), http.StatusUnauthorized) - } - - // 密码错误 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - c.Request.Header = map[string][]string{ - "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, - } - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"), - ) - // 查找密码 - mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(c.Writer.Status(), http.StatusUnauthorized) - } - - //未启用 WebDAV - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - c.Request.Header = map[string][]string{ - "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, - } - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "password", "email", "group_id", "options"}). - AddRow(1, - "rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3", - "who@cloudreve.org", - 1, - "{}", - ), - ) - mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false)) - // 查找密码 - mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(c.Writer.Status(), http.StatusForbidden) - } - - //正常 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("POST", "/test", nil) - c.Request.Header = map[string][]string{ - "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, - } - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "password", "email", "group_id", "options"}). - AddRow(1, - "rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3", - "who@cloudreve.org", - 1, - "{}", - ), - ) - mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true)) - // 查找密码 - mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(c.Writer.Status(), 200) - _, ok := c.Get("user") - asserts.True(ok) - } - -} - -func TestUseUploadSession(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := UseUploadSession("local") - - // sessionID 为空 - { - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/sessionID", nil) - authInstance := auth.HMACAuth{SecretKey: []byte("123")} - auth.SignRequest(authInstance, c.Request, 0) - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 成功 - { - cache.Set( - filesystem.UploadSessionCachePrefix+"testCallBackRemote", - serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{Type: "local"}, - }, - 0, - ) - cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)groups(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[513]")) - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123")) - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "testCallBackRemote"}, - } - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) - authInstance := auth.HMACAuth{SecretKey: []byte("123")} - auth.SignRequest(authInstance, c.Request, 0) - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - } -} - -func TestUploadCallbackCheck(t *testing.T) { - a := assert.New(t) - rec := httptest.NewRecorder() - - // 上传会话不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "testSessionNotExist"}, - } - res := uploadCallbackCheck(c, "local") - a.Contains("上传会话不存在或已过期", res.Msg) - } - - // 上传策略不一致 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "testPolicyNotMatch"}, - } - cache.Set( - filesystem.UploadSessionCachePrefix+"testPolicyNotMatch", - serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{Type: "remote"}, - }, - 0, - ) - res := uploadCallbackCheck(c, "local") - a.Contains("Policy not supported", res.Msg) - } - - // 用户不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "testUserNotExist"}, - } - cache.Set( - filesystem.UploadSessionCachePrefix+"testUserNotExist", - serializer.UploadSession{ - UID: 313, - VirtualPath: "/", - Policy: model.Policy{Type: "remote"}, - }, - 0, - ) - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"})) - res := uploadCallbackCheck(c, "remote") - a.Contains("找不到用户", res.Msg) - a.NoError(mock.ExpectationsWereMet()) - _, ok := cache.Get(filesystem.UploadSessionCachePrefix + "testUserNotExist") - a.False(ok) - } -} - -func TestRemoteCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := RemoteCallbackAuth() - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{SecretKey: "123"}, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) - authInstance := auth.HMACAuth{SecretKey: []byte("123")} - auth.SignRequest(authInstance, c.Request, 0) - AuthFunc(c) - asserts.False(c.IsAborted()) - } - - // 签名错误 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{SecretKey: "123"}, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil) - AuthFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestQiniuCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := QiniuCallbackAuth() - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil) - mac := qbox.NewMac("123", "123") - token, err := mac.SignRequest(c.Request) - asserts.NoError(err) - c.Request.Header["Authorization"] = []string{"QBox " + token} - AuthFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - } - - // 验证失败 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil) - mac := qbox.NewMac("123", "1213") - token, err := mac.SignRequest(c.Request) - asserts.NoError(err) - c.Request.Header["Authorization"] = []string{"QBox " + token} - AuthFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestOSSCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := OSSCallbackAuth() - - // 签名验证失败 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/testCallBackOSS", nil) - mac := qbox.NewMac("123", "123") - token, err := mac.SignRequest(c.Request) - asserts.NoError(err) - c.Request.Header["Authorization"] = []string{"QBox " + token} - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH", ioutil.NopCloser(strings.NewReader(`{"name":"2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","source_name":"1/1_hFRtDLgM_2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","size":114020,"pic_info":"810,539"}`))) - c.Request.Header["Authorization"] = []string{"e5LwzwTkP9AFAItT4YzvdJOHd0Y0wqTMWhsV/h5SG90JYGAmMd+8LQyj96R+9qUfJWjMt6suuUh7LaOryR87Dw=="} - c.Request.Header["X-Oss-Pub-Key-Url"] = []string{"aHR0cHM6Ly9nb3NzcHVibGljLmFsaWNkbi5jb20vY2FsbGJhY2tfcHViX2tleV92MS5wZW0="} - AuthFunc(c) - asserts.False(c.IsAborted()) - } - -} - -type fakeRead string - -func (r fakeRead) Read(p []byte) (int, error) { - return 0, errors.New("error") -} - -func TestUpyunCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := UpyunCallbackAuth() - - // 无法获取请求正文 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(fakeRead(""))) - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 正文MD5不一致 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1"))) - c.Request.Header["Content-Md5"] = []string{"123"} - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 签名不一致 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1"))) - c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"} - AuthFunc(c) - asserts.True(c.IsAborted()) - } - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1"))) - c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"} - c.Request.Header["Authorization"] = []string{"UPYUN 123:GWueK9x493BKFFk5gmfdO2Mn6EM="} - AuthFunc(c) - asserts.False(c.IsAborted()) - } -} - -func TestOneDriveCallbackAuth(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - AuthFunc := OneDriveCallbackAuth() - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"sessionID", "TestOneDriveCallbackAuth"}, - } - c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{ - UID: 1, - VirtualPath: "/", - Policy: model.Policy{ - SecretKey: "123", - AccessKey: "123", - }, - }) - c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/TestOneDriveCallbackAuth", ioutil.NopCloser(strings.NewReader("1"))) - res := mq.GlobalMQ.Subscribe("TestOneDriveCallbackAuth", 1) - AuthFunc(c) - select { - case <-res: - case <-time.After(time.Millisecond * 500): - asserts.Fail("mq message should be published") - } - asserts.False(c.IsAborted()) - } -} - -func TestIsAdmin(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := IsAdmin() - - // 非管理员 - { - c, _ := gin.CreateTestContext(rec) - c.Set("user", &model.User{}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 是管理员 - { - c, _ := gin.CreateTestContext(rec) - user := &model.User{} - user.Group.ID = 1 - c.Set("user", user) - testFunc(c) - asserts.False(c.IsAborted()) - } - - // 初始用户,非管理组 - { - c, _ := gin.CreateTestContext(rec) - user := &model.User{} - user.Group.ID = 2 - user.ID = 1 - c.Set("user", user) - testFunc(c) - asserts.False(c.IsAborted()) - } -} diff --git a/middleware/captcha_test.go b/middleware/captcha_test.go deleted file mode 100644 index 1846d31..0000000 --- a/middleware/captcha_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package middleware - -import ( - "bytes" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "net/http" - "net/http/httptest" - "testing" -) - -type errReader int - -func (errReader) Read(p []byte) (n int, err error) { - return 0, errors.New("test error") -} - -func TestCaptchaRequired_General(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 未启用验证码 - { - cache.SetSettings(map[string]string{ - "login_captcha": "0", - "captcha_type": "1", - "captcha_ReCaptchaSecret": "1", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - - // body 无法读取 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "1", - "captcha_ReCaptchaSecret": "1", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", errReader(1)) - TestFunc(c) - asserts.True(c.IsAborted()) - } - - // body JSON 解析失败 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "1", - "captcha_ReCaptchaSecret": "1", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("123")) - c.Request, _ = http.NewRequest("GET", "/", r) - TestFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestCaptchaRequired_Normal(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 验证码错误 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "normal", - "captcha_ReCaptchaSecret": "1", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("{}")) - c.Request, _ = http.NewRequest("GET", "/", r) - Session("233")(c) - TestFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestCaptchaRequired_Recaptcha(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 无法初始化reCaptcha实例 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "recaptcha", - "captcha_ReCaptchaSecret": "", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("{}")) - c.Request, _ = http.NewRequest("GET", "/", r) - TestFunc(c) - asserts.True(c.IsAborted()) - } - - // 验证码错误 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "recaptcha", - "captcha_ReCaptchaSecret": "233", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("{}")) - c.Request, _ = http.NewRequest("GET", "/", r) - TestFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestCaptchaRequired_Tcaptcha(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 验证出错 - { - cache.SetSettings(map[string]string{ - "login_captcha": "1", - "captcha_type": "tcaptcha", - "captcha_ReCaptchaSecret": "", - "captcha_TCaptcha_SecretId": "1", - "captcha_TCaptcha_SecretKey": "1", - "captcha_TCaptcha_CaptchaAppId": "1", - "captcha_TCaptcha_AppSecretKey": "1", - }, "setting_") - TestFunc := CaptchaRequired("login_captcha") - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - r := bytes.NewReader([]byte("{}")) - c.Request, _ = http.NewRequest("GET", "/", r) - TestFunc(c) - asserts.True(c.IsAborted()) - } -} diff --git a/middleware/cluster_test.go b/middleware/cluster_test.go deleted file mode 100644 index 440163d..0000000 --- a/middleware/cluster_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package middleware - -import ( - "errors" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "net/http/httptest" - "testing" -) - -func TestMasterMetadata(t *testing.T) { - a := assert.New(t) - masterMetaDataFunc := MasterMetadata() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/", nil) - - c.Request.Header = map[string][]string{ - "X-Cr-Site-Id": {"expectedSiteID"}, - "X-Cr-Site-Url": {"expectedSiteURL"}, - "X-Cr-Cloudreve-Version": {"expectedMasterVersion"}, - } - masterMetaDataFunc(c) - siteID, _ := c.Get("MasterSiteID") - siteURL, _ := c.Get("MasterSiteURL") - siteVersion, _ := c.Get("MasterVersion") - - a.Equal("expectedSiteID", siteID.(string)) - a.Equal("expectedSiteURL", siteURL.(string)) - a.Equal("expectedMasterVersion", siteVersion.(string)) -} - -func TestSlaveRPCSignRequired(t *testing.T) { - a := assert.New(t) - np := &cluster.NodePool{} - np.Init() - slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np) - rec := httptest.NewRecorder() - - // id parse failed - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Request.Header.Set("X-Cr-Node-Id", "unknown") - slaveRPCSignRequiredFunc(c) - a.True(c.IsAborted()) - } - - // node id not exist - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Request.Header.Set("X-Cr-Node-Id", "38") - slaveRPCSignRequiredFunc(c) - a.True(c.IsAborted()) - } - - // success - { - authInstance := auth.HMACAuth{SecretKey: []byte("")} - np.Add(&model.Node{Model: gorm.Model{ - ID: 38, - }}) - - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("POST", "/", nil) - c.Request.Header.Set("X-Cr-Node-Id", "38") - c.Request = auth.SignRequest(authInstance, c.Request, 0) - slaveRPCSignRequiredFunc(c) - a.False(c.IsAborted()) - } -} - -func TestUseSlaveAria2Instance(t *testing.T) { - a := assert.New(t) - - // MasterSiteID not set - { - testController := &controllermock.SlaveControllerMock{} - useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - useSlaveAria2InstanceFunc(c) - a.True(c.IsAborted()) - } - - // Cannot get aria2 instances - { - testController := &controllermock.SlaveControllerMock{} - useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("MasterSiteID", "expectedSiteID") - testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error")) - useSlaveAria2InstanceFunc(c) - a.True(c.IsAborted()) - testController.AssertExpectations(t) - } - - // Success - { - testController := &controllermock.SlaveControllerMock{} - useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request = httptest.NewRequest("GET", "/", nil) - c.Set("MasterSiteID", "expectedSiteID") - testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil) - useSlaveAria2InstanceFunc(c) - a.False(c.IsAborted()) - res, _ := c.Get("MasterAria2Instance") - a.NotNil(res) - testController.AssertExpectations(t) - } -} diff --git a/middleware/common_test.go b/middleware/common_test.go deleted file mode 100644 index 1ab839a..0000000 --- a/middleware/common_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/hashid" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func TestHashID(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - TestFunc := HashID(hashid.FolderID) - - // 未给定ID对象,跳过 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - } - - // 给定ID,解析失败 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"id", "2333"}, - } - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(c.IsAborted()) - } - - // 给定ID,解析成功 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"id", hashid.HashID(1, hashid.FolderID)}, - } - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - } -} - -func TestIsFunctionEnabled(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - TestFunc := IsFunctionEnabled("TestIsFunctionEnabled") - - // 未开启 - { - cache.Set("setting_TestIsFunctionEnabled", "0", 0) - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.True(c.IsAborted()) - } - // 开启 - { - cache.Set("setting_TestIsFunctionEnabled", "1", 0) - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - -} - -func TestCacheControl(t *testing.T) { - a := assert.New(t) - TestFunc := CacheControl() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - TestFunc(c) - a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache") -} - -func TestSandbox(t *testing.T) { - a := assert.New(t) - TestFunc := Sandbox() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - TestFunc(c) - a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox") -} - -func TestStaticResourceCache(t *testing.T) { - a := assert.New(t) - TestFunc := StaticResourceCache() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - TestFunc(c) - a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age") -} diff --git a/middleware/file_test.go b/middleware/file_test.go deleted file mode 100644 index 5ca4014..0000000 --- a/middleware/file_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package middleware - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "net/http/httptest" - "testing" -) - -func TestValidateSourceLink(t *testing.T) { - a := assert.New(t) - rec := httptest.NewRecorder() - testFunc := ValidateSourceLink() - - // ID 不存在 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - a.True(c.IsAborted()) - } - - // SourceLink 不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Set("object_id", 1) - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"})) - testFunc(c) - a.True(c.IsAborted()) - a.NoError(mock.ExpectationsWereMet()) - } - - // 原文件不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Set("object_id", 1) - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"})) - testFunc(c) - a.True(c.IsAborted()) - a.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - c, _ := gin.CreateTestContext(rec) - c.Set("object_id", 1) - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2)) - mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1)) - testFunc(c) - a.False(c.IsAborted()) - a.NoError(mock.ExpectationsWereMet()) - } - -} diff --git a/middleware/frontend.go b/middleware/frontend.go index f07d9b6..eba1e84 100644 --- a/middleware/frontend.go +++ b/middleware/frontend.go @@ -1,13 +1,14 @@ package middleware import ( + "io/ioutil" + "net/http" + "strings" + "github.com/cloudreve/Cloudreve/v3/bootstrap" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" - "io/ioutil" - "net/http" - "strings" ) // FrontendFileHandler 前端静态文件处理 @@ -51,10 +52,16 @@ func FrontendFileHandler() gin.HandlerFunc { // 不存在的路径和index.html均返回index.html if (path == "/index.html") || (path == "/") || !bootstrap.StaticFS.Exists("/", path) { // 读取、替换站点设置 - options := model.GetSettingByNames("siteName", "siteKeywords", "siteScript", - "pwa_small_icon") + options := model.GetSettingByNames( + "siteName", // 站点名称 + "siteKeywords", // 关键词 + "siteDes", // 描述 + "siteScript", // 自定义代码 + "pwa_small_icon", // 图标 + ) finalHTML := util.Replace(map[string]string{ "{siteName}": options["siteName"], + "{siteKeywords}": options["siteKeywords"], "{siteDes}": options["siteDes"], "{siteScript}": options["siteScript"], "{pwa_small_icon}": options["pwa_small_icon"], diff --git a/middleware/frontend_test.go b/middleware/frontend_test.go deleted file mode 100644 index d32529d..0000000 --- a/middleware/frontend_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package middleware - -import ( - "errors" - "github.com/cloudreve/Cloudreve/v3/bootstrap" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -type StaticMock struct { - testMock.Mock -} - -func (m StaticMock) Open(name string) (http.File, error) { - args := m.Called(name) - return args.Get(0).(http.File), args.Error(1) -} - -func (m StaticMock) Exists(prefix string, filepath string) bool { - args := m.Called(prefix, filepath) - return args.Bool(0) -} - -func TestFrontendFileHandler(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - - // 静态资源未加载 - { - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - - // index.html 不存在 - { - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(&os.File{}, errors.New("error")) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - - // index.html 读取失败 - { - file, _ := util.CreatNestedFile("tests/index.html") - file.Close() - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(file, nil) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - TestFunc(c) - asserts.False(c.IsAborted()) - } - - // 成功且命中 - { - file, _ := util.CreatNestedFile("tests/index.html") - defer file.Close() - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(file, nil) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/", nil) - - cache.Set("setting_siteName", "cloudreve", 0) - cache.Set("setting_siteKeywords", "cloudreve", 0) - cache.Set("setting_siteScript", "cloudreve", 0) - cache.Set("setting_pwa_small_icon", "cloudreve", 0) - - TestFunc(c) - asserts.True(c.IsAborted()) - } - - // 成功且命中静态文件 - { - file, _ := util.CreatNestedFile("tests/index.html") - defer file.Close() - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(file, nil) - testStatic.On("Exists", "/", "/2"). - Return(true) - testStatic.On("Open", "/2"). - Return(file, nil) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", "/2", nil) - - TestFunc(c) - asserts.True(c.IsAborted()) - testStatic.AssertExpectations(t) - } - - // API 相关跳过 - { - for _, reqPath := range []string{"/api/user", "/manifest.json", "/dav/path"} { - file, _ := util.CreatNestedFile("tests/index.html") - defer file.Close() - testStatic := &StaticMock{} - bootstrap.StaticFS = testStatic - testStatic.On("Open", "/index.html"). - Return(file, nil) - TestFunc := FrontendFileHandler() - - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{} - c.Request, _ = http.NewRequest("GET", reqPath, nil) - - TestFunc(c) - asserts.False(c.IsAborted()) - } - } - -} diff --git a/middleware/mock_test.go b/middleware/mock_test.go deleted file mode 100644 index 1ebee20..0000000 --- a/middleware/mock_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func TestMockHelper(t *testing.T) { - asserts := assert.New(t) - MockHelperFunc := MockHelper() - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - - // 写入session - { - SessionMock["test"] = "pass" - Session("test")(c) - MockHelperFunc(c) - asserts.Equal("pass", util.GetSession(c, "test").(string)) - } - - // 写入context - { - ContextMock["test"] = "pass" - MockHelperFunc(c) - test, exist := c.Get("test") - asserts.True(exist) - asserts.Equal("pass", test.(string)) - - } -} diff --git a/middleware/session.go b/middleware/session.go index db90755..b6d8023 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -1,11 +1,12 @@ package middleware import ( - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/sessionstore" "net/http" "strings" + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/sessionstore" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" diff --git a/middleware/session_test.go b/middleware/session_test.go deleted file mode 100644 index 9fbe0d2..0000000 --- a/middleware/session_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -func TestSession(t *testing.T) { - asserts := assert.New(t) - - { - handler := Session("2333") - asserts.NotNil(handler) - asserts.NotNil(Store) - asserts.IsType(emptyFunc(), handler) - } -} - -func emptyFunc() gin.HandlerFunc { - return func(c *gin.Context) {} -} - -func TestCSRFInit(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - sessionFunc := Session("233") - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - sessionFunc(c) - CSRFInit()(c) - asserts.True(util.GetSession(c, "CSRF").(bool)) - } -} - -func TestCSRFCheck(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - sessionFunc := Session("233") - - // 通过检查 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - sessionFunc(c) - CSRFInit()(c) - CSRFCheck()(c) - asserts.False(c.IsAborted()) - } - - // 未通过检查 - { - c, _ := gin.CreateTestContext(rec) - c.Request, _ = http.NewRequest("GET", "/test", nil) - sessionFunc(c) - CSRFCheck()(c) - asserts.True(c.IsAborted()) - } -} diff --git a/middleware/share.go b/middleware/share.go index 488b703..cc4ef42 100644 --- a/middleware/share.go +++ b/middleware/share.go @@ -118,8 +118,14 @@ func BeforeShareDownload() gin.HandlerFunc { // 对积分、下载次数进行更新 err = share.DownloadBy(user, c) if err != nil { - c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(), - nil)) + if err == model.ErrInsufficientCredit { + c.JSON(200, serializer.Err(serializer.CodeInsufficientCredit, err.Error(), + nil)) + } else { + c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(), + nil)) + } + c.Abort() return } diff --git a/middleware/share_test.go b/middleware/share_test.go deleted file mode 100644 index 129076b..0000000 --- a/middleware/share_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package middleware - -import ( - "net/http/httptest" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestShareAvailable(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := ShareAvailable() - - // 分享不存在 - { - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"id", "empty"}, - } - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 通过 - { - conf.SystemConfig.HashIDSalt = "" - // 用户组 - mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) - mock.ExpectQuery("SELECT(.+)shares(.+)"). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "remain_downloads", "source_id"}). - AddRow(1, 1, 2), - ) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) - c, _ := gin.CreateTestContext(rec) - c.Params = []gin.Param{ - {"id", "x9T4"}, - } - testFunc(c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(c.IsAborted()) - asserts.NotNil(c.Get("user")) - asserts.NotNil(c.Get("share")) - } -} - -func TestShareCanPreview(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := ShareCanPreview() - - // 无分享上下文 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 可以预览 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{PreviewEnabled: true}) - testFunc(c) - asserts.False(c.IsAborted()) - } - - // 未开启预览 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{PreviewEnabled: false}) - testFunc(c) - asserts.True(c.IsAborted()) - } -} - -func TestCheckShareUnlocked(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := CheckShareUnlocked() - - // 无分享上下文 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 无密码 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - testFunc(c) - asserts.False(c.IsAborted()) - } - -} - -func TestBeforeShareDownload(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := BeforeShareDownload() - - // 无分享上下文 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 用户不能下载 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - c.Set("user", &model.User{ - Group: model.Group{OptionsSerialized: model.GroupOption{}}, - }) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 可以下载 - { - c, _ := gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - c.Set("user", &model.User{ - Model: gorm.Model{ID: 1}, - Group: model.Group{OptionsSerialized: model.GroupOption{ - ShareDownload: true, - }}, - }) - testFunc(c) - asserts.False(c.IsAborted()) - } -} - -func TestShareOwner(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := ShareOwner() - - // 未登录 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 非用户所创建分享 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Set("share", &model.Share{User: model.User{Model: gorm.Model{ID: 1}}}) - c.Set("user", &model.User{}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // 正常 - { - c, _ := gin.CreateTestContext(rec) - testFunc(c) - asserts.True(c.IsAborted()) - - c, _ = gin.CreateTestContext(rec) - c.Set("share", &model.Share{}) - c.Set("user", &model.User{}) - testFunc(c) - asserts.False(c.IsAborted()) - } -} diff --git a/middleware/wopi_test.go b/middleware/wopi_test.go deleted file mode 100644 index c6ca327..0000000 --- a/middleware/wopi_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package middleware - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/wopimock" - "github.com/cloudreve/Cloudreve/v3/pkg/wopi" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - "net/http/httptest" - "testing" -) - -func TestWopiWriteAccess(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - testFunc := WopiWriteAccess() - - // deny preview only session - { - c, _ := gin.CreateTestContext(rec) - c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionPreview}) - testFunc(c) - asserts.True(c.IsAborted()) - } - - // pass - { - c, _ := gin.CreateTestContext(rec) - c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionEdit}) - testFunc(c) - asserts.False(c.IsAborted()) - } -} - -func TestWopiAccessValidation(t *testing.T) { - asserts := assert.New(t) - rec := httptest.NewRecorder() - mockWopi := &wopimock.WopiClientMock{} - mockCache := cache.NewMemoStore() - testFunc := WopiAccessValidation(mockWopi, mockCache) - - // malformed access token - { - c, _ := gin.CreateTestContext(rec) - c.AddParam(wopi.AccessTokenQuery, "000") - testFunc(c) - asserts.True(c.IsAborted()) - } - - // session key not exist - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil) - query := c.Request.URL.Query() - query.Set(wopi.AccessTokenQuery, "sessionID.key") - c.Request.URL.RawQuery = query.Encode() - testFunc(c) - asserts.True(c.IsAborted()) - } - - // user key not exist - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil) - query := c.Request.URL.Query() - query.Set(wopi.AccessTokenQuery, "sessionID.key") - c.Request.URL.RawQuery = query.Encode() - mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0) - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - testFunc(c) - asserts.True(c.IsAborted()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // file not found - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil) - query := c.Request.URL.Query() - query.Set(wopi.AccessTokenQuery, "sessionID.key") - c.Request.URL.RawQuery = query.Encode() - mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0) - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - c.Set("object_id", uint(0)) - testFunc(c) - asserts.True(c.IsAborted()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // all pass - { - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil) - query := c.Request.URL.Query() - query.Set(wopi.AccessTokenQuery, "sessionID.key") - c.Request.URL.RawQuery = query.Encode() - mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0) - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - c.Set("object_id", uint(1)) - testFunc(c) - asserts.False(c.IsAborted()) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotPanics(func() { - c.MustGet(WopiSessionCtx) - }) - asserts.NotPanics(func() { - c.MustGet("user") - }) - } -} diff --git a/models/defaults.go b/models/defaults.go index b13f22e..e73c4a1 100644 --- a/models/defaults.go +++ b/models/defaults.go @@ -9,12 +9,15 @@ import ( var defaultSettings = []Setting{ {Name: "siteURL", Value: `http://localhost`, Type: "basic"}, - {Name: "siteName", Value: `Cloudreve`, Type: "basic"}, + {Name: "siteName", Value: `CloudrevePlus`, Type: "basic"}, {Name: "register_enabled", Value: `1`, Type: "register"}, {Name: "default_group", Value: `2`, Type: "register"}, - {Name: "siteKeywords", Value: `Cloudreve, cloud storage`, Type: "basic"}, - {Name: "siteDes", Value: `Cloudreve`, Type: "basic"}, + {Name: "mail_domain_filter", Value: `0`, Type: "register"}, + {Name: "mail_domain_filter_list", Value: `126.com,163.com,gmail.com,outlook.com,qq.com,foxmail.com,yeah.net,sohu.com,sohu.cn,139.com,wo.cn,189.cn,hotmail.com,live.com,live.cn`, Type: "register"}, + {Name: "siteKeywords", Value: `CloudrevePlus, cloud storage`, Type: "basic"}, + {Name: "siteDes", Value: `部署公私兼备的网盘系统`, Type: "basic"}, {Name: "siteTitle", Value: `Inclusive cloud storage for everyone`, Type: "basic"}, + {Name: "siteNotice", Value: ``, Type: "basic"}, {Name: "siteScript", Value: ``, Type: "basic"}, {Name: "siteID", Value: uuid.Must(uuid.NewV4()).String(), Type: "basic"}, {Name: "fromName", Value: `Cloudreve`, Type: "mail"}, @@ -26,6 +29,9 @@ var defaultSettings = []Setting{ {Name: "smtpUser", Value: `no-reply@acg.blue`, Type: "mail"}, {Name: "smtpPass", Value: ``, Type: "mail"}, {Name: "smtpEncryption", Value: `0`, Type: "mail"}, + {Name: "over_used_template", Value: `容量超额提醒
容量超额警告
亲爱的{userName}
由于{notifyReason},您在{siteTitle}的账户的容量使用超出配额,您将无法继续上传新文件,请尽快清理文件,否则我们将会禁用您的账户。
登录{siteTitle}
感谢您选择{siteTitle}。
`, Type: "mail_template"}, + {Name: "ban_time", Value: `604800`, Type: "storage_policy"}, {Name: "maxEditSize", Value: `52428800`, Type: "file_edit"}, {Name: "archive_timeout", Value: `600`, Type: "timeout"}, {Name: "download_timeout", Value: `600`, Type: "timeout"}, @@ -38,6 +44,7 @@ var defaultSettings = []Setting{ {Name: "slave_recover_interval", Value: `120`, Type: "slave"}, {Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"}, {Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"}, + {Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"}, {Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"}, {Name: "onedrive_callback_check", Value: `20`, Type: "timeout"}, {Name: "folder_props_timeout", Value: `300`, Type: "timeout"}, @@ -46,10 +53,14 @@ var defaultSettings = []Setting{ {Name: "reset_after_upload_failed", Value: `0`, Type: "upload"}, {Name: "use_temp_chunk_buffer", Value: `1`, Type: "upload"}, {Name: "login_captcha", Value: `0`, Type: "login"}, + {Name: "qq_login", Value: `0`, Type: "login"}, + {Name: "qq_direct_login", Value: `0`, Type: "login"}, + {Name: "qq_login_id", Value: ``, Type: "login"}, + {Name: "qq_login_key", Value: ``, Type: "login"}, {Name: "reg_captcha", Value: `0`, Type: "login"}, {Name: "email_active", Value: `0`, Type: "register"}, {Name: "mail_activation_template", Value: `激活您的账户用户激活
重设{siteTitle}密码
亲爱的{userName}
请点击下方按钮完成密码重设。如果非你本人操作,请忽略此邮件。
重设密码
感谢您选择{siteTitle}。
`, Type: "mail_template"}, + {Name: "pack_data", Value: `[]`, Type: "pack"}, {Name: "db_version_" + conf.RequiredDBVersion, Value: `installed`, Type: "version"}, + {Name: "alipay_enabled", Value: `0`, Type: "payment"}, + {Name: "payjs_enabled", Value: `0`, Type: "payment"}, + {Name: "payjs_id", Value: ``, Type: "payment"}, + {Name: "payjs_secret", Value: ``, Type: "payment"}, + {Name: "appid", Value: ``, Type: "payment"}, + {Name: "appkey", Value: ``, Type: "payment"}, + {Name: "shopid", Value: ``, Type: "payment"}, + {Name: "wechat_enabled", Value: `0`, Type: "payment"}, + {Name: "wechat_appid", Value: ``, Type: "payment"}, + {Name: "wechat_mchid", Value: ``, Type: "payment"}, + {Name: "wechat_serial_no", Value: ``, Type: "payment"}, + {Name: "wechat_api_key", Value: ``, Type: "payment"}, + {Name: "wechat_pk_content", Value: ``, Type: "payment"}, {Name: "hot_share_num", Value: `10`, Type: "share"}, + {Name: "group_sell_data", Value: `[]`, Type: "group_sell"}, {Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"}, {Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"}, {Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"}, @@ -77,9 +103,15 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "avatar_size_l", Value: "200", Type: "avatar"}, {Name: "avatar_size_m", Value: "130", Type: "avatar"}, {Name: "avatar_size_s", Value: "50", Type: "avatar"}, - {Name: "home_view_method", Value: "icon", Type: "view"}, + {Name: "score_enabled", Value: "1", Type: "score"}, + {Name: "share_score_rate", Value: "80", Type: "score"}, + {Name: "score_price", Value: "1", Type: "score"}, + {Name: "report_enabled", Value: "0", Type: "report"}, + {Name: "home_view_method", Value: "list", Type: "view"}, {Name: "share_view_method", Value: "list", Type: "view"}, {Name: "cron_garbage_collect", Value: "@hourly", Type: "cron"}, + {Name: "cron_notify_user", Value: "@hourly", Type: "cron"}, + {Name: "cron_ban_user", Value: "@hourly", Type: "cron"}, {Name: "cron_recycle_upload_session", Value: "@every 1h30m", Type: "cron"}, {Name: "authn_enabled", Value: "0", Type: "authn"}, {Name: "captcha_type", Value: "normal", Type: "captcha"}, @@ -127,13 +159,24 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "pwa_display", Value: "standalone", Type: "pwa"}, {Name: "pwa_theme_color", Value: "#000000", Type: "pwa"}, {Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"}, + {Name: "initial_files", Value: "[]", Type: "register"}, {Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"}, + {Name: "phone_required", Value: "false", Type: "phone"}, + {Name: "phone_enabled", Value: "false", Type: "phone"}, + {Name: "vol_content", Value: "Guess", Type: "vol"}, + {Name: "vol_signature", Value: "Guess", Type: "vol"}, {Name: "show_app_promotion", Value: "1", Type: "mobile"}, {Name: "public_resource_maxage", Value: "86400", Type: "timeout"}, {Name: "wopi_enabled", Value: "0", Type: "wopi"}, {Name: "wopi_endpoint", Value: "", Type: "wopi"}, {Name: "wopi_max_size", Value: "52428800", Type: "wopi"}, {Name: "wopi_session_timeout", Value: "36000", Type: "wopi"}, + {Name: "custom_payment_enabled", Value: "0", Type: "payment"}, + {Name: "custom_payment_endpoint", Value: "", Type: "payment"}, + {Name: "custom_payment_secret", Value: "", Type: "payment"}, + {Name: "custom_payment_name", Value: "", Type: "payment"}, + {Name: "app_feedback_link", Value: "", Type: "mobile"}, + {Name: "app_forum_link", Value: "", Type: "mobile"}, } func InitSlaveDefaults() { diff --git a/models/download_test.go b/models/download_test.go deleted file mode 100644 index 367afb7..0000000 --- a/models/download_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package model - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestDownload_Create(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - download := Download{GID: "1"} - id, err := download.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - download := Download{GID: "1"} - id, err := download.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestDownload_AfterFind(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - download := Download{Attrs: `{"gid":"123"}`} - err := download.AfterFind() - asserts.NoError(err) - asserts.Equal("123", download.StatusInfo.Gid) - } - - // 忽略空值 - { - download := Download{Attrs: ``} - err := download.AfterFind() - asserts.NoError(err) - asserts.Equal("", download.StatusInfo.Gid) - } - - // 解析失败 - { - download := Download{Attrs: `?`} - err := download.BeforeSave() - asserts.Error(err) - asserts.Equal("", download.StatusInfo.Gid) - } - -} - -func TestDownload_Save(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - download := Download{ - Model: gorm.Model{ - ID: 1, - }, - Attrs: `{"gid":"123"}`, - } - err := download.Save() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal("123", download.StatusInfo.Gid) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - download := Download{ - Model: gorm.Model{ - ID: 1, - }, - } - err := download.Save() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} - -func TestGetDownloadsByStatus(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WithArgs(0, 1).WillReturnRows(sqlmock.NewRows([]string{"gid"}).AddRow("0").AddRow("1")) - res := GetDownloadsByStatus(0, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 2) -} - -func TestGetDownloadByGid(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WithArgs(2, "gid").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1")) - res, err := GetDownloadByGid("gid", 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(res.GID, "1") -} - -func TestDownload_GetOwner(t *testing.T) { - asserts := assert.New(t) - - // 已经有User对象 - { - download := &Download{User: &User{Nick: "nick"}} - user := download.GetOwner() - asserts.NotNil(user) - asserts.Equal("nick", user.Nick) - } - - // 无User对象 - { - download := &Download{UserID: 3} - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"nick"}).AddRow("nick")) - user := download.GetOwner() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(user) - asserts.Equal("nick", user.Nick) - } -} - -func TestGetDownloadsByStatusAndUser(t *testing.T) { - asserts := assert.New(t) - - // 列出全部 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3)) - res := GetDownloadsByStatusAndUser(0, 1, 1, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 2) - } - - // 列出全部,分页 - { - mock.ExpectQuery("SELECT(.+)DESC(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3)) - res := GetDownloadsByStatusAndUser(2, 1, 1, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 2) - } -} - -func TestDownload_Delete(t *testing.T) { - asserts := assert.New(t) - share := Download{} - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := share.Delete() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - -} - -func TestDownload_GetNodeID(t *testing.T) { - a := assert.New(t) - record := Download{} - - // compatible with 3.4 - a.EqualValues(1, record.GetNodeID()) - - record.NodeID = 5 - a.EqualValues(5, record.GetNodeID()) -} diff --git a/models/file.go b/models/file.go index bfe49cb..56ee2ab 100644 --- a/models/file.go +++ b/models/file.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/jinzhu/gorm" ) @@ -388,6 +389,15 @@ func (file *File) UpdateSourceName(value string) error { }).Error } +// Relocate 更新文件的物理指向 +func (file *File) Relocate(src string, policyID uint) error { + file.Policy = Policy{} + return DB.Model(&file).Set("gorm:association_autoupdate", false).Updates(map[string]interface{}{ + "source_name": src, + "policy_id": policyID, + }).Error +} + func (file *File) PopChunkToFile(lastModified *time.Time, picInfo string) error { file.UploadSessionID = nil if lastModified != nil { @@ -470,3 +480,46 @@ func (file *File) ShouldLoadThumb() bool { func (file *File) ThumbFile() string { return file.SourceName + GetSettingByNameWithDefault("thumb_file_suffix", "._thumb") } + +/* + 实现 filesystem.FileHeader 接口 +*/ + +// Read 实现 io.Reader +func (file *File) Read(p []byte) (n int, err error) { + return 0, errors.New("noe supported") +} + +// Close 实现io.Closer +func (file *File) Close() error { + return errors.New("noe supported") +} + +// Seeker 实现io.Seeker +func (file *File) Seek(offset int64, whence int) (int64, error) { + return 0, errors.New("noe supported") +} + +func (file *File) Info() *fsctx.UploadTaskInfo { + return &fsctx.UploadTaskInfo{ + Size: file.Size, + FileName: file.Name, + VirtualPath: file.Position, + Mode: 0, + Metadata: file.MetadataSerialized, + LastModified: &file.UpdatedAt, + SavePath: file.SourceName, + UploadSessionID: file.UploadSessionID, + } +} + +func (file *File) SetSize(size uint64) { + file.Size = size +} + +func (file *File) SetModel(newFile interface{}) { +} + +func (file *File) Seekable() bool { + return false +} diff --git a/models/file_test.go b/models/file_test.go deleted file mode 100644 index 83198fc..0000000 --- a/models/file_test.go +++ /dev/null @@ -1,785 +0,0 @@ -package model - -import ( - "errors" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFile_Create(t *testing.T) { - asserts := assert.New(t) - file := File{ - Name: "123", - } - - // 无法插入文件记录 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := file.Create() - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 无法更新用户容量 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := file.Create() - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - err := file.Create() - asserts.NoError(err) - asserts.Equal(uint(5), file.ID) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFile_AfterFind(t *testing.T) { - a := assert.New(t) - - // metadata not empty - { - file := File{ - Name: "123", - Metadata: "{\"name\":\"123\"}", - } - - a.NoError(file.AfterFind()) - a.Equal("123", file.MetadataSerialized["name"]) - } - - // metadata empty - { - file := File{ - Name: "123", - Metadata: "", - } - a.Nil(file.MetadataSerialized) - a.NoError(file.AfterFind()) - a.NotNil(file.MetadataSerialized) - } -} - -func TestFile_BeforeSave(t *testing.T) { - a := assert.New(t) - - // metadata not empty - { - file := File{ - Name: "123", - MetadataSerialized: map[string]string{ - "name": "123", - }, - } - - a.NoError(file.BeforeSave()) - a.Equal("{\"name\":\"123\"}", file.Metadata) - } - - // metadata empty - { - file := File{ - Name: "123", - } - a.NoError(file.BeforeSave()) - a.Equal("", file.Metadata) - } -} - -func TestFolder_GetChildFile(t *testing.T) { - asserts := assert.New(t) - folder := Folder{Model: gorm.Model{ID: 1}, Name: "/"} - // 存在 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, "1.txt"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt")) - file, err := folder.GetChildFile("1.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal("1.txt", file.Name) - asserts.Equal("/", file.Position) - } - - // 不存在 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, "1.txt"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - _, err := folder.GetChildFile("1.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} - -func TestFolder_GetChildFiles(t *testing.T) { - asserts := assert.New(t) - folder := &Folder{ - Model: gorm.Model{ - ID: 1, - }, - Position: "/123", - Name: "456", - } - - // 找不到 - mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnError(errors.New("error")) - files, err := folder.GetChildFiles() - asserts.Error(err) - asserts.Len(files, 0) - asserts.NoError(mock.ExpectationsWereMet()) - - // 找到了 - mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2)) - files, err = folder.GetChildFiles() - asserts.NoError(err) - asserts.Len(files, 2) - asserts.Equal("/123/456", files[0].Position) - asserts.NoError(mock.ExpectationsWereMet()) - -} - -func TestGetFilesByIDs(t *testing.T) { - asserts := assert.New(t) - - // 出错 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 1). - WillReturnError(errors.New("error")) - folders, err := GetFilesByIDs([]uint{1, 2, 3}, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Len(folders, 0) - } - - // 部分找到 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1")) - folders, err := GetFilesByIDs([]uint{1, 2, 3}, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(folders, 1) - } - - // 忽略UID查找 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1")) - folders, err := GetFilesByIDs([]uint{1, 2, 3}, 0) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(folders, 1) - } -} - -func TestGetChildFilesOfFolders(t *testing.T) { - asserts := assert.New(t) - testFolder := []Folder{ - Folder{ - Model: gorm.Model{ID: 3}, - }, - Folder{ - Model: gorm.Model{ID: 4}, - }, Folder{ - Model: gorm.Model{ID: 5}, - }, - } - - // 出错 - { - mock.ExpectQuery("SELECT(.+)folder_id").WithArgs(3, 4, 5).WillReturnError(errors.New("not found")) - files, err := GetChildFilesOfFolders(&testFolder) - asserts.Error(err) - asserts.Len(files, 0) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 找到2个 - { - mock.ExpectQuery("SELECT(.+)folder_id"). - WithArgs(3, 4, 5). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). - AddRow(3, "3"). - AddRow(4, "4"), - ) - files, err := GetChildFilesOfFolders(&testFolder) - asserts.NoError(err) - asserts.Len(files, 2) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 全部找到 - { - mock.ExpectQuery("SELECT(.+)folder_id"). - WithArgs(3, 4, 5). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). - AddRow(3, "3"). - AddRow(4, "4"). - AddRow(5, "5"), - ) - files, err := GetChildFilesOfFolders(&testFolder) - asserts.NoError(err) - asserts.Len(files, 3) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestGetUploadPlaceholderFiles(t *testing.T) { - a := assert.New(t) - - mock.ExpectQuery("SELECT(.+)upload_session_id(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1")) - files := GetUploadPlaceholderFiles(1) - a.NoError(mock.ExpectationsWereMet()) - a.Len(files, 1) -} - -func TestFile_GetPolicy(t *testing.T) { - asserts := assert.New(t) - - // 空策略 - { - file := File{ - PolicyID: 23, - } - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(23, "name"), - ) - file.GetPolicy() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint(23), file.Policy.ID) - } - - // 非空策略 - { - file := File{ - PolicyID: 23, - Policy: Policy{Model: gorm.Model{ID: 24}}, - } - file.GetPolicy() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint(24), file.Policy.ID) - } -} - -func TestRemoveFilesWithSoftLinks_EmptyArg(t *testing.T) { - asserts := assert.New(t) - // 传入空 - { - mock.ExpectQuery("SELECT(.+)files(.+)") - file, err := RemoveFilesWithSoftLinks([]File{}) - asserts.Error(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(len(file), 0) - DB.Find(&File{}) - } -} - -func TestRemoveFilesWithSoftLinks(t *testing.T) { - asserts := assert.New(t) - files := []File{ - File{ - Model: gorm.Model{ID: 1}, - SourceName: "1.txt", - PolicyID: 23, - }, - File{ - Model: gorm.Model{ID: 2}, - SourceName: "2.txt", - PolicyID: 24, - }, - } - - // 传入空文件列表 - { - file, err := RemoveFilesWithSoftLinks([]File{}) - asserts.NoError(err) - asserts.Empty(file) - } - - // 全都没有 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("2.txt", 24, 2). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - file, err := RemoveFilesWithSoftLinks(files) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(files, file) - } - - // 第二个是软链 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("2.txt", 24, 2). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(3, 24, "2.txt"), - ) - file, err := RemoveFilesWithSoftLinks(files) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(files[:1], file) - } - - // 第一个是软链 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(3, 23, "1.txt"), - ) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("2.txt", 24, 2). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - file, err := RemoveFilesWithSoftLinks(files) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(files[1:], file) - } - // 全部是软链 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(3, 23, "1.txt"), - ) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("2.txt", 24, 2). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(3, 24, "2.txt"), - ) - file, err := RemoveFilesWithSoftLinks(files) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(file, 0) - } -} - -func TestDeleteFiles(t *testing.T) { - a := assert.New(t) - - // uid 不一致 - { - err := DeleteFiles([]*File{{UserID: 2}}, 1) - a.Contains("user id not consistent", err.Error()) - } - - // 删除失败 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := DeleteFiles([]*File{{UserID: 1}}, 1) - a.NoError(mock.ExpectationsWereMet()) - a.Error(err) - } - - // 无法变更用户容量 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := DeleteFiles([]*File{{UserID: 1}}, 1) - a.NoError(mock.ExpectationsWereMet()) - a.Error(err) - } - - // 文件脏读 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 0)) - mock.ExpectRollback() - err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1) - a.NoError(mock.ExpectationsWereMet()) - a.Error(err) - a.Contains("file size is dirty", err.Error()) - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WithArgs(uint64(3), sqlmock.AnyArg(), uint(1)).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1) - a.NoError(mock.ExpectationsWereMet()) - a.NoError(err) - } - - // 成功, 关联用户不存在 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 0) - a.NoError(mock.ExpectationsWereMet()) - a.NoError(err) - } -} - -func TestGetFilesByParentIDs(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 4, 5, 6). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(4, "4.txt"). - AddRow(5, "5.txt"). - AddRow(6, "6.txt"), - ) - files, err := GetFilesByParentIDs([]uint{4, 5, 6}, 1) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(files, 3) -} - -func TestGetFilesByUploadSession(t *testing.T) { - a := assert.New(t) - - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, "sessionID"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}).AddRow(4, "4.txt")) - files, err := GetFilesByUploadSession("sessionID", 1) - a.NoError(err) - a.NoError(mock.ExpectationsWereMet()) - a.Equal("4.txt", files.Name) -} - -func TestFile_Updates(t *testing.T) { - asserts := assert.New(t) - file := File{Model: gorm.Model{ID: 1}} - - // rename - { - // not reset thumb - { - file := File{Model: gorm.Model{ID: 1}} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("", "newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.Rename("newName") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // thumb not available, rename base name only - { - file := File{Model: gorm.Model{ID: 1}, Name: "1.txt", MetadataSerialized: map[string]string{ - ThumbStatusMetadataKey: ThumbStatusNotAvailable, - }, - Metadata: "{}"} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("{}", "newName.txt", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.Rename("newName.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(ThumbStatusNotAvailable, file.MetadataSerialized[ThumbStatusMetadataKey]) - } - - // thumb not available, rename base name only - { - file := File{Model: gorm.Model{ID: 1}, Name: "1.txt", MetadataSerialized: map[string]string{ - ThumbStatusMetadataKey: ThumbStatusNotAvailable, - }} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("{}", "newName.jpg", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.Rename("newName.jpg") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Empty(file.MetadataSerialized[ThumbStatusMetadataKey]) - } - } - - // UpdatePicInfo - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs("1,1", 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.UpdatePicInfo("1,1") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // UpdateSourceName - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs("", "newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := file.UpdateSourceName("newName") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } -} - -func TestFile_UpdateSize(t *testing.T) { - a := assert.New(t) - - // 增加成功 - { - file := File{Size: 10} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 11, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)+(.+)").WithArgs(uint64(1), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.NoError(file.UpdateSize(11)) - a.NoError(mock.ExpectationsWereMet()) - } - - // 减少成功 - { - file := File{Size: 10} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.NoError(file.UpdateSize(8)) - a.NoError(mock.ExpectationsWereMet()) - } - - // 文件更新失败 - { - file := File{Size: 10} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnError(errors.New("error")) - mock.ExpectRollback() - - a.Error(file.UpdateSize(8)) - a.NoError(mock.ExpectationsWereMet()) - } - - // 用户容量更新失败 - { - file := File{Size: 10} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnError(errors.New("error")) - mock.ExpectRollback() - - a.Error(file.UpdateSize(8)) - a.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFile_PopChunkToFile(t *testing.T) { - a := assert.New(t) - timeNow := time.Now() - file := File{} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(file.PopChunkToFile(&timeNow, "1,1")) -} - -func TestFile_CanCopy(t *testing.T) { - a := assert.New(t) - file := File{} - a.True(file.CanCopy()) - file.UploadSessionID = &file.Name - a.False(file.CanCopy()) -} - -func TestFile_FileInfoInterface(t *testing.T) { - asserts := assert.New(t) - file := File{ - Model: gorm.Model{ - UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), - }, - Name: "test_name", - SourceName: "", - UserID: 0, - Size: 10, - PicInfo: "", - FolderID: 0, - PolicyID: 0, - Policy: Policy{}, - Position: "/test", - } - - name := file.GetName() - asserts.Equal("test_name", name) - - size := file.GetSize() - asserts.Equal(uint64(10), size) - - asserts.Equal(time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), file.ModTime()) - asserts.False(file.IsDir()) - asserts.Equal("/test", file.GetPosition()) -} - -func TestGetFilesByKeywords(t *testing.T) { - asserts := assert.New(t) - - // 未指定用户 - { - mock.ExpectQuery("SELECT(.+)").WithArgs("k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetFilesByKeywords(0, nil, "k1", "k2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) - } - - // 指定用户 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetFilesByKeywords(1, nil, "k1", "k2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) - } - - // 指定父目录 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 12, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetFilesByKeywords(1, []uint{12}, "k1", "k2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) - } -} - -func TestFile_CreateOrGetSourceLink(t *testing.T) { - a := assert.New(t) - file := &File{} - file.ID = 1 - - // 已存在,返回老的 SourceLink - { - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) - res, err := file.CreateOrGetSourceLink() - a.NoError(err) - a.EqualValues(2, res.ID) - a.NoError(mock.ExpectationsWereMet()) - } - - // 不存在,插入失败 - { - expectedErr := errors.New("error") - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnError(expectedErr) - mock.ExpectRollback() - res, err := file.CreateOrGetSourceLink() - a.Nil(res) - a.ErrorIs(err, expectedErr) - a.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - res, err := file.CreateOrGetSourceLink() - a.NoError(err) - a.EqualValues(2, res.ID) - a.EqualValues(file.ID, res.File.ID) - a.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFile_UpdateMetadata(t *testing.T) { - a := assert.New(t) - file := &File{} - file.ID = 1 - - // 更新失败 - { - expectedErr := errors.New("error") - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(sqlmock.AnyArg(), 1).WillReturnError(expectedErr) - mock.ExpectRollback() - a.ErrorIs(file.UpdateMetadata(map[string]string{"1": "1"}), expectedErr) - a.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(file.UpdateMetadata(map[string]string{"1": "1"})) - a.NoError(mock.ExpectationsWereMet()) - a.Equal("1", file.MetadataSerialized["1"]) - } -} - -func TestFile_ShouldLoadThumb(t *testing.T) { - a := assert.New(t) - file := &File{ - MetadataSerialized: map[string]string{}, - } - file.ID = 1 - - // 无缩略图 - { - file.MetadataSerialized[ThumbStatusMetadataKey] = ThumbStatusNotAvailable - a.False(file.ShouldLoadThumb()) - } - - // 有缩略图 - { - file.MetadataSerialized[ThumbStatusMetadataKey] = ThumbStatusExist - a.True(file.ShouldLoadThumb()) - } -} - -func TestFile_ThumbFile(t *testing.T) { - a := assert.New(t) - file := &File{ - SourceName: "test", - MetadataSerialized: map[string]string{}, - } - file.ID = 1 - - a.Equal("test._thumb", file.ThumbFile()) -} diff --git a/models/folder.go b/models/folder.go index 80f712c..1130d6b 100644 --- a/models/folder.go +++ b/models/folder.go @@ -16,10 +16,12 @@ type Folder struct { Name string `gorm:"unique_index:idx_only_one_name"` ParentID *uint `gorm:"index:parent_id;unique_index:idx_only_one_name"` OwnerID uint `gorm:"index:owner_id"` + PolicyID uint // Webdav下挂载的存储策略ID // 数据库忽略字段 - Position string `gorm:"-"` - WebdavDstName string `gorm:"-"` + Position string `gorm:"-"` + WebdavDstName string `gorm:"-"` + InheritPolicyID uint `gorm:"-"` // 从父目录继承而来的policy id,默认值则使用自身的的PolicyID } // Create 创建目录 @@ -33,6 +35,13 @@ func (folder *Folder) Create() (uint, error) { return folder.ID, nil } +// GetMountedFolders 列出已挂载存储策略的目录 +func GetMountedFolders(uid uint) []Folder { + var folders []Folder + DB.Where("owner_id = ? and policy_id <> ?", uid, 0).Find(&folders) + return folders +} + // GetChild 返回folder下名为name的子目录,不存在则返回错误 func (folder *Folder) GetChild(name string) (*Folder, error) { var resFolder Folder @@ -40,9 +49,14 @@ func (folder *Folder) GetChild(name string) (*Folder, error) { Where("parent_id = ? AND owner_id = ? AND name = ?", folder.ID, folder.OwnerID, name). First(&resFolder).Error - // 将子目录的路径传递下去 + // 将子目录的路径及存储策略传递下去 if err == nil { resFolder.Position = path.Join(folder.Position, folder.Name) + if folder.PolicyID > 0 { + resFolder.InheritPolicyID = folder.PolicyID + } else if folder.InheritPolicyID > 0 { + resFolder.InheritPolicyID = folder.InheritPolicyID + } } return &resFolder, err } @@ -323,6 +337,11 @@ func (folder *Folder) Rename(new string) error { return DB.Model(&folder).UpdateColumn("name", new).Error } +// Mount 目录挂载 +func (folder *Folder) Mount(new uint) error { + return DB.Model(&folder).Update("policy_id", new).Error +} + /* 实现 FileInfo.FileInfo 接口 TODO 测试 diff --git a/models/folder_test.go b/models/folder_test.go deleted file mode 100644 index 90220ca..0000000 --- a/models/folder_test.go +++ /dev/null @@ -1,622 +0,0 @@ -package model - -import ( - "errors" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFolder_Create(t *testing.T) { - asserts := assert.New(t) - folder := &Folder{ - Name: "new folder", - } - - // 不存在,插入成功 - mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - fid, err := folder.Create() - asserts.NoError(err) - asserts.Equal(uint(5), fid) - asserts.NoError(mock.ExpectationsWereMet()) - - // 插入失败 - mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - fid, err = folder.Create() - asserts.NoError(err) - asserts.Equal(uint(1), fid) - asserts.NoError(mock.ExpectationsWereMet()) - - // 存在,直接返回 - mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5)) - fid, err = folder.Create() - asserts.NoError(err) - asserts.Equal(uint(5), fid) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFolder_GetChild(t *testing.T) { - asserts := assert.New(t) - folder := Folder{ - Model: gorm.Model{ID: 5}, - OwnerID: 1, - Name: "/", - } - - // 目录存在 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(5, 1, "sub"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "sub")) - sub, err := folder.GetChild("sub") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(sub.Name, "sub") - asserts.Equal("/", sub.Position) - } - - // 目录不存在 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(5, 1, "sub"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - sub, err := folder.GetChild("sub") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint(0), sub.ID) - - } -} - -func TestFolder_GetChildFolder(t *testing.T) { - asserts := assert.New(t) - folder := &Folder{ - Model: gorm.Model{ - ID: 1, - }, - Position: "/123", - Name: "456", - } - - // 找不到 - mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnError(errors.New("error")) - files, err := folder.GetChildFolder() - asserts.Error(err) - asserts.Len(files, 0) - asserts.NoError(mock.ExpectationsWereMet()) - - // 找到了 - mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2)) - files, err = folder.GetChildFolder() - asserts.NoError(err) - asserts.Len(files, 2) - asserts.Equal("/123/456", files[0].Position) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestGetRecursiveChildFolderSQLite(t *testing.T) { - conf.DatabaseConfig.Type = "sqlite" - asserts := assert.New(t) - - // 测试目录结构 - // 1 - // 2 3 - // 4 5 6 - - // 查询第一层 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "folder1"), - ) - // 查询第二层 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(2, "folder2"). - AddRow(3, "folder3"), - ) - // 查询第三层 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(4, "folder4"). - AddRow(5, "folder5"). - AddRow(6, "folder6"), - ) - // 查询第四层 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 4, 5, 6). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}), - ) - - folders, err := GetRecursiveChildFolder([]uint{1}, 1, true) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(folders, 6) -} - -func TestDeleteFolderByIDs(t *testing.T) { - asserts := assert.New(t) - - // 出错 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := DeleteFolderByIDs([]uint{1, 2, 3}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - err := DeleteFolderByIDs([]uint{1, 2, 3}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } -} - -func TestGetFoldersByIDs(t *testing.T) { - asserts := assert.New(t) - - // 出错 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 1). - WillReturnError(errors.New("error")) - folders, err := GetFoldersByIDs([]uint{1, 2, 3}, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Len(folders, 0) - } - - // 部分找到 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1")) - folders, err := GetFoldersByIDs([]uint{1, 2, 3}, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(folders, 1) - } -} - -func TestFolder_MoveOrCopyFileTo(t *testing.T) { - asserts := assert.New(t) - // 当前目录 - folder := Folder{ - Model: gorm.Model{ID: 1}, - OwnerID: 1, - Name: "test", - } - // 目标目录 - dstFolder := Folder{ - Model: gorm.Model{ID: 10}, - Name: "dst", - } - - // 复制文件 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs( - 1, - 2, - 3, - 1, - 1, - ).WillReturnRows( - sqlmock.NewRows([]string{"id", "size", "upload_session_id"}). - AddRow(1, 10, nil). - AddRow(2, 20, nil). - AddRow(2, 20, &folder.Name), - ) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2, 3}, - &dstFolder, - true, - ) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(30), storage) - } - - // 复制文件, 检索文件出错 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs( - 1, - 2, - 1, - 1, - ).WillReturnError(errors.New("error")) - - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2}, - &dstFolder, - true, - ) - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(0), storage) - } - - // 复制文件,第二个文件插入出错 - { - mock.ExpectQuery("SELECT(.+)"). - WithArgs( - 1, - 2, - 1, - 1, - ).WillReturnRows( - sqlmock.NewRows([]string{"id", "size"}). - AddRow(1, 10). - AddRow(2, 20), - ) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2}, - &dstFolder, - true, - ) - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(10), storage) - } - - // 移动文件 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 1). - WillReturnResult(sqlmock.NewResult(1, 2)) - mock.ExpectCommit() - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2}, - &dstFolder, - false, - ) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(uint64(0), storage) - } - - // 移动文件 出错 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 1). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - storage, err := folder.MoveOrCopyFileTo( - []uint{1, 2}, - &dstFolder, - false, - ) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(0), storage) - } -} - -func TestFolder_CopyFolderTo(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - // 父目录 - parFolder := Folder{ - Model: gorm.Model{ID: 9}, - OwnerID: 1, - } - // 目标目录 - dstFolder := Folder{ - Model: gorm.Model{ID: 10}, - } - - // 测试复制目录结构 - // test(2)(5) - // 1(3)(6) 2.txt - // 3(4)(7) 4.txt 5.txt(上传中) - - // 正常情况 成功 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"})) - - // 复制目录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1)) - mock.ExpectCommit() - - // 查找子文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 4). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "folder_id", "size", "upload_session_id"}). - AddRow(1, "2.txt", 2, 10, nil). - AddRow(2, "3.txt", 3, 20, nil). - AddRow(3, "5.txt", 3, 20, &dstFolder.Name), - ) - - // 复制子文件 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1)) - mock.ExpectCommit() - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal(uint64(30), size) - } - - // 递归查询失败 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnError(errors.New("error")) - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(0), size) - } - - // 父目录ID不存在 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 99)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"})) - - // 复制目录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(0), size) - } - - // 查询子文件失败 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"})) - - // 复制目录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1)) - mock.ExpectCommit() - - // 查找子文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 4). - WillReturnError(errors.New("error")) - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(0), size) - } - - // 复制文件 一个失败 - { - // GetRecursiveChildFolder - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"})) - - // 复制目录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1)) - mock.ExpectCommit() - - // 查找子文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3, 4). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "folder_id", "size"}). - AddRow(1, "2.txt", 2, 10). - AddRow(2, "3.txt", 3, 20), - ) - - // 复制子文件 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - - size, err := parFolder.CopyFolderTo(2, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(uint64(10), size) - } - -} - -func TestFolder_MoveOrCopyFolderTo_Move(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - // 父目录 - parFolder := Folder{ - Model: gorm.Model{ID: 9}, - OwnerID: 1, - } - // 目标目录 - dstFolder := Folder{ - Model: gorm.Model{ID: 10}, - OwnerID: 1, - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 9). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := parFolder.MoveFolderTo([]uint{1, 2}, &dstFolder) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 移动自己到自己内部,失败 - { - err := parFolder.MoveFolderTo([]uint{10, 2}, &dstFolder) - asserts.Error(err) - } -} - -func TestFolder_FileInfoInterface(t *testing.T) { - asserts := assert.New(t) - folder := Folder{ - Model: gorm.Model{ - UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), - }, - Name: "test_name", - OwnerID: 0, - Position: "/test", - } - - name := folder.GetName() - asserts.Equal("test_name", name) - - size := folder.GetSize() - asserts.Equal(uint64(0), size) - - asserts.Equal(time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), folder.ModTime()) - asserts.True(folder.IsDir()) - asserts.Equal("/test", folder.GetPosition()) -} - -func TestTraceRoot(t *testing.T) { - asserts := assert.New(t) - var parentId uint - parentId = 5 - folder := Folder{ - ParentID: &parentId, - OwnerID: 1, - Name: "test_name", - } - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "/")) - asserts.NoError(folder.TraceRoot()) - asserts.Equal("/parent", folder.Position) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 出现错误 - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0). - WillReturnError(errors.New("error")) - asserts.Error(folder.TraceRoot()) - asserts.Equal("parent", folder.Position) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFolder_Rename(t *testing.T) { - asserts := assert.New(t) - folder := Folder{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test_name", - OwnerID: 1, - Position: "/test", - } - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)"). - WithArgs("test_name_new", 1). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := folder.Rename("test_name_new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 出现错误 - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)"). - WithArgs("test_name_new", 1). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := folder.Rename("test_name_new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} diff --git a/models/group.go b/models/group.go index 0abf21d..8dc3057 100644 --- a/models/group.go +++ b/models/group.go @@ -2,6 +2,7 @@ package model import ( "encoding/json" + "github.com/jinzhu/gorm" ) @@ -29,11 +30,15 @@ type GroupOption struct { DecompressSize uint64 `json:"decompress_size,omitempty"` OneTimeDownload bool `json:"one_time_download,omitempty"` ShareDownload bool `json:"share_download,omitempty"` + ShareFree bool `json:"share_free,omitempty"` Aria2 bool `json:"aria2,omitempty"` // 离线下载 Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置 + Relocate bool `json:"relocate,omitempty"` // 转移文件 SourceBatchSize int `json:"source_batch,omitempty"` RedirectedSource bool `json:"redirected_source,omitempty"` Aria2BatchSize int `json:"aria2_batch,omitempty"` + AvailableNodes []uint `json:"available_nodes,omitempty"` + SelectNode bool `json:"select_node,omitempty"` AdvanceDelete bool `json:"advance_delete,omitempty"` WebDAVProxy bool `json:"webdav_proxy,omitempty"` } diff --git a/models/group_test.go b/models/group_test.go deleted file mode 100644 index 2f487ce..0000000 --- a/models/group_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package model - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/jinzhu/gorm" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestGetGroupByID(t *testing.T) { - asserts := assert.New(t) - - //找到用户组时 - groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). - AddRow(1, "管理员", "[1]") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) - - group, err := GetGroupByID(1) - asserts.NoError(err) - asserts.Equal(Group{ - Model: gorm.Model{ - ID: 1, - }, - Name: "管理员", - Policies: "[1]", - PolicyList: []uint{1}, - }, group) - - //未找到用户时 - mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found")) - group, err = GetGroupByID(1) - asserts.Error(err) - asserts.Equal(Group{}, group) -} - -func TestGroup_AfterFind(t *testing.T) { - asserts := assert.New(t) - - testCase := Group{ - Model: gorm.Model{ - ID: 1, - }, - Name: "管理员", - Policies: "[1]", - } - err := testCase.AfterFind() - asserts.NoError(err) - asserts.Equal(testCase.PolicyList, []uint{1}) - - testCase.Policies = "[1,2,3,4,5]" - err = testCase.AfterFind() - asserts.NoError(err) - asserts.Equal(testCase.PolicyList, []uint{1, 2, 3, 4, 5}) - - testCase.Policies = "[1,2,3,4,5" - err = testCase.AfterFind() - asserts.Error(err) - - testCase.Policies = "[]" - err = testCase.AfterFind() - asserts.NoError(err) - asserts.Equal(testCase.PolicyList, []uint{}) -} - -func TestGroup_BeforeSave(t *testing.T) { - asserts := assert.New(t) - group := Group{ - PolicyList: []uint{1, 2, 3}, - } - { - err := group.BeforeSave() - asserts.NoError(err) - asserts.Equal("[1,2,3]", group.Policies) - } - -} diff --git a/models/migration.go b/models/migration.go index fad6a76..f86e5e9 100644 --- a/models/migration.go +++ b/models/migration.go @@ -40,8 +40,8 @@ func migration() { DB = DB.Set("gorm:table_options", "ENGINE=InnoDB") } - DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{}, - &Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}, &SourceLink{}) + DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &StoragePack{}, &Share{}, + &Task{}, &Download{}, &Tag{}, &Webdav{}, &Order{}, &Redeem{}, &Report{}, &Node{}, &SourceLink{}) // 创建初始存储策略 addDefaultPolicy() @@ -107,10 +107,13 @@ func addDefaultGroups() { ArchiveDownload: true, ArchiveTask: true, ShareDownload: true, + ShareFree: true, Aria2: true, + Relocate: true, SourceBatchSize: 1000, Aria2BatchSize: 50, RedirectedSource: true, + SelectNode: true, AdvanceDelete: true, }, } diff --git a/models/migration_test.go b/models/migration_test.go deleted file mode 100644 index 7c9d673..0000000 --- a/models/migration_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package model - -import ( - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestMigration(t *testing.T) { - asserts := assert.New(t) - conf.DatabaseConfig.Type = "sqlite" - DB, _ = gorm.Open("sqlite", ":memory:") - - asserts.NotPanics(func() { - migration() - }) - conf.DatabaseConfig.Type = "mysql" - DB = mockDB -} diff --git a/models/node_test.go b/models/node_test.go deleted file mode 100644 index de1757f..0000000 --- a/models/node_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package model - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestGetNodeByID(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetNodeByID(1) - a.NoError(err) - a.EqualValues(1, res.ID) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestGetNodesByStatus(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(NodeActive)) - res, err := GetNodesByStatus(NodeActive) - a.NoError(err) - a.Len(res, 1) - a.EqualValues(NodeActive, res[0].Status) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestNode_AfterFind(t *testing.T) { - a := assert.New(t) - node := &Node{} - - // No aria2 options - { - a.NoError(node.AfterFind()) - } - - // with aria2 options - { - node.Aria2Options = `{"timeout":1}` - a.NoError(node.AfterFind()) - a.Equal(1, node.Aria2OptionsSerialized.Timeout) - } -} - -func TestNode_BeforeSave(t *testing.T) { - a := assert.New(t) - node := &Node{} - - node.Aria2OptionsSerialized.Timeout = 1 - a.NoError(node.BeforeSave()) - a.Contains(node.Aria2Options, "1") -} - -func TestNode_SetStatus(t *testing.T) { - a := assert.New(t) - node := &Node{} - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)nodes").WithArgs(NodeActive, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(node.SetStatus(NodeActive)) - a.Equal(NodeActive, node.Status) - a.NoError(mock.ExpectationsWereMet()) -} diff --git a/models/order.go b/models/order.go new file mode 100755 index 0000000..9a79c24 --- /dev/null +++ b/models/order.go @@ -0,0 +1,59 @@ +package model + +import ( + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/jinzhu/gorm" +) + +const ( + // PackOrderType 容量包订单 + PackOrderType = iota + // GroupOrderType 用户组订单 + GroupOrderType + // ScoreOrderType 积分充值订单 + ScoreOrderType +) + +const ( + // OrderUnpaid 未支付 + OrderUnpaid = iota + // OrderPaid 已支付 + OrderPaid + // OrderCanceled 已取消 + OrderCanceled +) + +// Order 交易订单 +type Order struct { + gorm.Model + UserID uint // 创建者ID + OrderNo string `gorm:"index:order_number"` // 商户自定义订单编号 + Type int // 订单类型 + Method string // 支付类型 + ProductID int64 // 商品ID + Num int // 商品数量 + Name string // 订单标题 + Price int // 商品单价 + Status int // 订单状态 +} + +// Create 创建订单记录 +func (order *Order) Create() (uint, error) { + if err := DB.Create(order).Error; err != nil { + util.Log().Warning("Failed to insert order record: %s", err) + return 0, err + } + return order.ID, nil +} + +// UpdateStatus 更新订单状态 +func (order *Order) UpdateStatus(status int) { + DB.Model(order).Update("status", status) +} + +// GetOrderByNo 根据商户订单号查询订单 +func GetOrderByNo(id string) (*Order, error) { + var order Order + err := DB.Where("order_no = ?", id).First(&order).Error + return &order, err +} diff --git a/models/policy.go b/models/policy.go index 11d8e4b..f80b553 100644 --- a/models/policy.go +++ b/models/policy.go @@ -3,14 +3,15 @@ package model import ( "encoding/gob" "encoding/json" - "github.com/gofrs/uuid" - "github.com/samber/lo" "path" "path/filepath" "strconv" "strings" "time" + "github.com/gofrs/uuid" + "github.com/samber/lo" + "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/jinzhu/gorm" @@ -73,6 +74,18 @@ type PolicyOption struct { ThumbExts []string `json:"thumb_exts,omitempty"` } +// thumbSuffix 支持缩略图处理的文件扩展名 +var thumbSuffix = map[string][]string{ + "local": {}, + "qiniu": {".psd", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"}, + "oss": {".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"}, + "cos": {".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"}, + "upyun": {".svg", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"}, + "s3": {}, + "remote": {}, + "onedrive": {"*"}, +} + func init() { // 注册缓存用到的复杂结构 gob.Register(Policy{}) @@ -179,6 +192,17 @@ func (policy *Policy) GenerateFileName(uid uint, origin string) string { return fileRule } +// IsThumbExist 给定文件名,返回此存储策略下是否可能存在缩略图 +func (policy *Policy) IsThumbExist(name string) bool { + if list, ok := thumbSuffix[policy.Type]; ok { + if len(list) == 1 && list[0] == "*" { + return true + } + return util.ContainsString(list, strings.ToLower(filepath.Ext(name))) + } + return false +} + // IsDirectlyPreview 返回此策略下文件是否可以直接预览(不需要重定向) func (policy *Policy) IsDirectlyPreview() bool { return policy.Type == "local" diff --git a/models/policy_test.go b/models/policy_test.go deleted file mode 100644 index f7d4e74..0000000 --- a/models/policy_test.go +++ /dev/null @@ -1,269 +0,0 @@ -package model - -import ( - "encoding/json" - "strconv" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestGetPolicyByID(t *testing.T) { - asserts := assert.New(t) - - cache.Deletes([]string{"22", "23"}, "policy_") - // 缓存未命中 - { - rows := sqlmock.NewRows([]string{"name", "type", "options"}). - AddRow("默认存储策略", "local", "{\"od_redirect\":\"123\"}") - mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows) - policy, err := GetPolicyByID(uint(22)) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("默认存储策略", policy.Name) - asserts.Equal("123", policy.OptionsSerialized.OauthRedirect) - - rows = sqlmock.NewRows([]string{"name", "type", "options"}) - mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows) - policy, err = GetPolicyByID(uint(23)) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - - // 命中 - { - policy, err := GetPolicyByID(uint(22)) - asserts.NoError(err) - asserts.Equal("默认存储策略", policy.Name) - asserts.Equal("123", policy.OptionsSerialized.OauthRedirect) - - } - -} - -func TestPolicy_BeforeSave(t *testing.T) { - asserts := assert.New(t) - - testPolicy := Policy{ - OptionsSerialized: PolicyOption{ - OauthRedirect: "123", - }, - } - expected, _ := json.Marshal(testPolicy.OptionsSerialized) - err := testPolicy.BeforeSave() - asserts.NoError(err) - asserts.Equal(string(expected), testPolicy.Options) - -} - -func TestPolicy_GeneratePath(t *testing.T) { - asserts := assert.New(t) - testPolicy := Policy{} - - testPolicy.DirNameRule = "{randomkey16}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 16) - - testPolicy.DirNameRule = "{randomkey8}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 8) - - testPolicy.DirNameRule = "{timestamp}" - asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.FormatInt(time.Now().Unix(), 10)) - - testPolicy.DirNameRule = "{uid}" - asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.Itoa(int(1))) - - testPolicy.DirNameRule = "{datetime}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 14) - - testPolicy.DirNameRule = "{date}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 8) - - testPolicy.DirNameRule = "123{date}ss{datetime}" - asserts.Len(testPolicy.GeneratePath(1, "/"), 27) - - testPolicy.DirNameRule = "/1/{path}/456" - asserts.Condition(func() (success bool) { - res := testPolicy.GeneratePath(1, "/23") - return res == "/1/23/456" || res == "\\1\\23\\456" - }) - -} - -func TestPolicy_GenerateFileName(t *testing.T) { - asserts := assert.New(t) - // 重命名关闭 - { - testPolicy := Policy{ - AutoRename: false, - } - testPolicy.FileNameRule = "{randomkey16}" - asserts.Equal("123.txt", testPolicy.GenerateFileName(1, "123.txt")) - - testPolicy.Type = "oss" - asserts.Equal("origin", testPolicy.GenerateFileName(1, "origin")) - } - - // 重命名开启 - { - testPolicy := Policy{ - AutoRename: true, - } - - testPolicy.FileNameRule = "{randomkey16}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 16) - - testPolicy.FileNameRule = "{randomkey8}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8) - - testPolicy.FileNameRule = "{timestamp}" - asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.FormatInt(time.Now().Unix(), 10)) - - testPolicy.FileNameRule = "{uid}" - asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.Itoa(int(1))) - - testPolicy.FileNameRule = "{datetime}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 14) - - testPolicy.FileNameRule = "{date}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8) - - testPolicy.FileNameRule = "123{date}ss{datetime}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 27) - - testPolicy.FileNameRule = "{originname_without_ext}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 3) - - testPolicy.FileNameRule = "{originname_without_ext}_{randomkey8}{ext}" - asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 16) - - // 支持{originname}的策略 - testPolicy.Type = "local" - testPolicy.FileNameRule = "123{originname}" - asserts.Equal("123123.txt", testPolicy.GenerateFileName(1, "123.txt")) - - testPolicy.Type = "qiniu" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123123.txt", testPolicy.GenerateFileName(1, "123.txt")) - - testPolicy.Type = "oss" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321")) - - testPolicy.Type = "upyun" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321")) - - testPolicy.Type = "qiniu" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321")) - - testPolicy.Type = "local" - testPolicy.FileNameRule = "{uid}123{originname}" - asserts.Equal("1123", testPolicy.GenerateFileName(1, "")) - - testPolicy.Type = "local" - testPolicy.FileNameRule = "{ext}123{uuid}" - asserts.Contains(testPolicy.GenerateFileName(1, "123.txt"), ".txt123") - } - -} - -func TestPolicy_IsDirectlyPreview(t *testing.T) { - asserts := assert.New(t) - policy := Policy{Type: "local"} - asserts.True(policy.IsDirectlyPreview()) - policy.Type = "remote" - asserts.False(policy.IsDirectlyPreview()) -} - -func TestPolicy_ClearCache(t *testing.T) { - asserts := assert.New(t) - cache.Set("policy_202", 1, 0) - policy := Policy{Model: gorm.Model{ID: 202}} - policy.ClearCache() - _, ok := cache.Get("policy_202") - asserts.False(ok) -} - -func TestPolicy_UpdateAccessKey(t *testing.T) { - asserts := assert.New(t) - policy := Policy{Model: gorm.Model{ID: 202}} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - policy.AccessKey = "123" - err := policy.SaveAndClearCache() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) -} - -func TestPolicy_Props(t *testing.T) { - asserts := assert.New(t) - policy := Policy{Type: "onedrive"} - policy.OptionsSerialized.PlaceholderWithSize = true - asserts.False(policy.IsThumbGenerateNeeded()) - asserts.False(policy.IsTransitUpload(4)) - asserts.False(policy.IsTransitUpload(5 * 1024 * 1024)) - asserts.True(policy.CanStructureBeListed()) - asserts.True(policy.IsUploadPlaceholderWithSize()) - policy.Type = "local" - asserts.True(policy.IsThumbGenerateNeeded()) - asserts.False(policy.CanStructureBeListed()) - asserts.False(policy.IsUploadPlaceholderWithSize()) - policy.Type = "remote" - asserts.True(policy.IsUploadPlaceholderWithSize()) -} - -func TestPolicy_UpdateAccessKeyAndClearCache(t *testing.T) { - a := assert.New(t) - cache.Set("policy_1331", Policy{}, 3600) - p := &Policy{} - p.ID = 1331 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs("ak", sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.NoError(p.UpdateAccessKeyAndClearCache("ak")) - a.NoError(mock.ExpectationsWereMet()) - _, ok := cache.Get("policy_1331") - a.False(ok) -} - -func TestPolicy_CouldProxyThumb(t *testing.T) { - a := assert.New(t) - p := &Policy{Type: "local"} - - // local policy - { - a.False(p.CouldProxyThumb()) - } - - // feature not enabled - { - p.Type = "remote" - cache.Set("setting_thumb_proxy_enabled", "0", 0) - a.False(p.CouldProxyThumb()) - } - - // list not contain current policy - { - p.ID = 2 - cache.Set("setting_thumb_proxy_enabled", "1", 0) - cache.Set("setting_thumb_proxy_policy", "[1]", 0) - a.False(p.CouldProxyThumb()) - } - - // enabled - { - p.ID = 2 - cache.Set("setting_thumb_proxy_enabled", "1", 0) - cache.Set("setting_thumb_proxy_policy", "[2]", 0) - a.True(p.CouldProxyThumb()) - } - - cache.Deletes([]string{"thumb_proxy_enabled", "thumb_proxy_policy"}, "setting_") -} diff --git a/models/redeem.go b/models/redeem.go new file mode 100755 index 0000000..4f8396d --- /dev/null +++ b/models/redeem.go @@ -0,0 +1,27 @@ +package model + +import "github.com/jinzhu/gorm" + +// Redeem 兑换码 +type Redeem struct { + gorm.Model + Type int // 订单类型 + ProductID int64 // 商品ID + Num int // 商品数量 + Code string `gorm:"size:64,index:redeem_code"` // 兑换码 + Used bool // 是否已被使用 +} + +// GetAvailableRedeem 根据code查找可用兑换码 +func GetAvailableRedeem(code string) (*Redeem, error) { + redeem := &Redeem{} + result := DB.Where("code = ? and used = ?", code, false).First(redeem) + return redeem, result.Error +} + +// Use 设定为已使用状态 +func (redeem *Redeem) Use() { + DB.Model(redeem).Updates(map[string]interface{}{ + "used": true, + }) +} diff --git a/models/report.go b/models/report.go new file mode 100755 index 0000000..face732 --- /dev/null +++ b/models/report.go @@ -0,0 +1,21 @@ +package model + +import ( + "github.com/jinzhu/gorm" +) + +// Report 举报模型 +type Report struct { + gorm.Model + ShareID uint `gorm:"index:share_id"` // 对应分享ID + Reason int // 举报原因 + Description string // 补充描述 + + // 关联模型 + Share Share `gorm:"save_associations:false:false"` +} + +// Create 创建举报 +func (report *Report) Create() error { + return DB.Create(report).Error +} diff --git a/models/scripts/init.go b/models/scripts/init.go index 7c375bf..b772fa9 100644 --- a/models/scripts/init.go +++ b/models/scripts/init.go @@ -5,5 +5,6 @@ import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" func Init() { invoker.Register("ResetAdminPassword", ResetAdminPassword(0)) invoker.Register("CalibrateUserStorage", UserStorageCalibration(0)) + invoker.Register("OSSToPlus", UpgradeToPro(0)) invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0)) } diff --git a/models/scripts/invoker/invoker_test.go b/models/scripts/invoker/invoker_test.go deleted file mode 100644 index 36651eb..0000000 --- a/models/scripts/invoker/invoker_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package invoker - -import ( - "context" - "github.com/stretchr/testify/assert" - "testing" -) - -type TestScript int - -func (script TestScript) Run(ctx context.Context) { - -} - -func TestRunDBScript(t *testing.T) { - asserts := assert.New(t) - Register("test", TestScript(0)) - - // 不存在 - { - asserts.Error(RunDBScript("else", context.Background())) - } - - // 存在 - { - asserts.NoError(RunDBScript("test", context.Background())) - } -} - -func TestListPrefix(t *testing.T) { - asserts := assert.New(t) - Register("U1", TestScript(0)) - Register("U2", TestScript(0)) - Register("U3", TestScript(0)) - Register("P1", TestScript(0)) - - res := ListPrefix("U") - asserts.Len(res, 3) -} diff --git a/models/scripts/reset_test.go b/models/scripts/reset_test.go deleted file mode 100644 index ffacb28..0000000 --- a/models/scripts/reset_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package scripts - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestResetAdminPassword_Run(t *testing.T) { - asserts := assert.New(t) - script := ResetAdminPassword(0) - - // 初始用户不存在 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"})) - asserts.Panics(func() { - script.Run(context.Background()) - }) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 密码更新失败 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - asserts.Panics(func() { - script.Run(context.Background()) - }) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 成功 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - asserts.NotPanics(func() { - script.Run(context.Background()) - }) - asserts.NoError(mock.ExpectationsWereMet()) - } -} diff --git a/models/scripts/storage_test.go b/models/scripts/storage_test.go deleted file mode 100644 index 746f0c0..0000000 --- a/models/scripts/storage_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package scripts - -import ( - "context" - "database/sql" - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -var mock sqlmock.Sqlmock -var mockDB *gorm.DB - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - mockDB = model.DB - defer db.Close() - m.Run() -} - -func TestUserStorageCalibration_Run(t *testing.T) { - asserts := assert.New(t) - script := UserStorageCalibration(0) - - // 容量异常 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(11)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - script.Run(context.Background()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 容量正常 - { - mock.ExpectQuery("SELECT(.+)users(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - script.Run(context.Background()) - asserts.NoError(mock.ExpectationsWereMet()) - } -} diff --git a/models/scripts/upgrade-pro.go b/models/scripts/upgrade-pro.go new file mode 100755 index 0000000..b25df4f --- /dev/null +++ b/models/scripts/upgrade-pro.go @@ -0,0 +1,22 @@ +package scripts + +import ( + "context" + model "github.com/cloudreve/Cloudreve/v3/models" +) + +type UpgradeToPro int + +// Run 运行脚本从社区版升级至 Pro 版 +func (script UpgradeToPro) Run(ctx context.Context) { + // folder.PolicyID 字段设为 0 + model.DB.Model(model.Folder{}).UpdateColumn("policy_id", 0) + // shares.Score 字段设为0 + model.DB.Model(model.Share{}).UpdateColumn("score", 0) + // user 表相关初始字段 + model.DB.Model(model.User{}).Updates(map[string]interface{}{ + "score": 0, + "previous_group_id": 0, + "open_id": "", + }) +} diff --git a/models/scripts/upgrade_test.go b/models/scripts/upgrade_test.go deleted file mode 100644 index 8f7adba..0000000 --- a/models/scripts/upgrade_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package scripts - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestUpgradeTo340_Run(t *testing.T) { - a := assert.New(t) - script := UpgradeTo340(0) - - // skip - { - mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"})) - script.Run(context.Background()) - a.NoError(mock.ExpectationsWereMet()) - } - - // node not found - { - mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1")) - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"})) - script.Run(context.Background()) - a.NoError(mock.ExpectationsWereMet()) - } - - // success - { - mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}). - AddRow("aria2_rpcurl", "expected_aria2_rpcurl"). - AddRow("aria2_interval", "expected_aria2_interval"). - AddRow("aria2_temp_path", "expected_aria2_temp_path"). - AddRow("aria2_token", "expected_aria2_token"). - AddRow("aria2_options", "{}")) - - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - script.Run(context.Background()) - a.NoError(mock.ExpectationsWereMet()) - } - - // failed - { - mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}). - AddRow("aria2_rpcurl", "expected_aria2_rpcurl"). - AddRow("aria2_interval", "expected_aria2_interval"). - AddRow("aria2_temp_path", "expected_aria2_temp_path"). - AddRow("aria2_token", "expected_aria2_token"). - AddRow("aria2_options", "{}")) - - mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - script.Run(context.Background()) - a.NoError(mock.ExpectationsWereMet()) - } -} diff --git a/models/setting_test.go b/models/setting_test.go deleted file mode 100644 index 96fc5e0..0000000 --- a/models/setting_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package model - -import ( - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock -var mockDB *gorm.DB - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - DB, _ = gorm.Open("mysql", db) - mockDB = DB - defer db.Close() - m.Run() -} - -func TestGetSettingByType(t *testing.T) { - cache.Store = cache.NewMemoStore() - asserts := assert.New(t) - - //找到设置时 - rows := sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName", "Cloudreve", "basic"). - AddRow("siteDes", "Something wonderful", "basic") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings := GetSettingByType([]string{"basic"}) - asserts.Equal(map[string]string{ - "siteName": "Cloudreve", - "siteDes": "Something wonderful", - }, settings) - - rows = sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName", "Cloudreve", "basic"). - AddRow("siteDes", "Something wonderful", "basic2") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings = GetSettingByType([]string{"basic", "basic2"}) - asserts.Equal(map[string]string{ - "siteName": "Cloudreve", - "siteDes": "Something wonderful", - }, settings) - - //找不到 - rows = sqlmock.NewRows([]string{"name", "value", "type"}) - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings = GetSettingByType([]string{"basic233"}) - asserts.Equal(map[string]string{}, settings) -} - -func TestGetSettingByNameWithDefault(t *testing.T) { - a := assert.New(t) - - rows := sqlmock.NewRows([]string{"name", "value", "type"}) - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings := GetSettingByNameWithDefault("123", "123321") - a.Equal("123321", settings) -} - -func TestGetSettingByNames(t *testing.T) { - cache.Store = cache.NewMemoStore() - asserts := assert.New(t) - - //找到设置时 - rows := sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName", "Cloudreve", "basic"). - AddRow("siteDes", "Something wonderful", "basic") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings := GetSettingByNames("siteName", "siteDes") - asserts.Equal(map[string]string{ - "siteName": "Cloudreve", - "siteDes": "Something wonderful", - }, settings) - asserts.NoError(mock.ExpectationsWereMet()) - - //找到其中一个设置时 - rows = sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName2", "Cloudreve", "basic") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings = GetSettingByNames("siteName2", "siteDes2333") - asserts.Equal(map[string]string{ - "siteName2": "Cloudreve", - }, settings) - asserts.NoError(mock.ExpectationsWereMet()) - - //找不到设置时 - rows = sqlmock.NewRows([]string{"name", "value", "type"}) - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - settings = GetSettingByNames("siteName2333", "siteDes2333") - asserts.Equal(map[string]string{}, settings) - asserts.NoError(mock.ExpectationsWereMet()) - - // 一个设置命中缓存 - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WithArgs("siteDes2").WillReturnRows(sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteDes2", "Cloudreve2", "basic")) - settings = GetSettingByNames("siteName", "siteDes2") - asserts.Equal(map[string]string{ - "siteName": "Cloudreve", - "siteDes2": "Cloudreve2", - }, settings) - asserts.NoError(mock.ExpectationsWereMet()) - -} - -// TestGetSettingByName 测试GetSettingByName -func TestGetSettingByName(t *testing.T) { - cache.Store = cache.NewMemoStore() - asserts := assert.New(t) - - //找到设置时 - rows := sqlmock.NewRows([]string{"name", "value", "type"}). - AddRow("siteName", "Cloudreve", "basic") - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - - siteName := GetSettingByName("siteName") - asserts.Equal("Cloudreve", siteName) - asserts.NoError(mock.ExpectationsWereMet()) - - // 第二次查询应返回缓存内容 - siteNameCache := GetSettingByName("siteName") - asserts.Equal("Cloudreve", siteNameCache) - asserts.NoError(mock.ExpectationsWereMet()) - - // 找不到设置 - rows = sqlmock.NewRows([]string{"name", "value", "type"}) - mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows) - - siteName = GetSettingByName("siteName not exist") - asserts.Equal("", siteName) - asserts.NoError(mock.ExpectationsWereMet()) - -} - -func TestIsTrueVal(t *testing.T) { - asserts := assert.New(t) - - asserts.True(IsTrueVal("1")) - asserts.True(IsTrueVal("true")) - asserts.False(IsTrueVal("0")) - asserts.False(IsTrueVal("false")) -} - -func TestGetSiteURL(t *testing.T) { - asserts := assert.New(t) - - // 正常 - { - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - - mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "https://drive.cloudreve.org")) - siteURL := GetSiteURL() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("https://drive.cloudreve.org", siteURL.String()) - } - - // 失败 返回默认值 - { - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - - mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, ":][\\/\\]sdf")) - siteURL := GetSiteURL() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("https://cloudreve.org", siteURL.String()) - } -} - -func TestGetIntSetting(t *testing.T) { - asserts := assert.New(t) - - // 正常 - { - cache.Set("setting_TestGetIntSetting", "10", 0) - res := GetIntSetting("TestGetIntSetting", 20) - asserts.Equal(10, res) - } - - // 使用默认值 - { - res := GetIntSetting("TestGetIntSetting_2", 20) - asserts.Equal(20, res) - } - -} diff --git a/models/share.go b/models/share.go index 750eb48..94655bb 100644 --- a/models/share.go +++ b/models/share.go @@ -3,6 +3,7 @@ package model import ( "errors" "fmt" + "math" "strings" "time" @@ -13,6 +14,10 @@ import ( "github.com/jinzhu/gorm" ) +var ( + ErrInsufficientCredit = errors.New("积分不足") +) + // Share 分享模型 type Share struct { gorm.Model @@ -24,6 +29,7 @@ type Share struct { Downloads int // 下载数 RemainDownloads int // 剩余下载配额,负值标识无限制 Expires *time.Time // 过期时间,空值表示无过期时间 + Score int // 每人次下载扣除积分 PreviewEnabled bool // 是否允许直接预览 SourceName string `gorm:"index:source"` // 用于搜索的字段 @@ -135,6 +141,12 @@ func (share *Share) CanBeDownloadBy(user *User) error { } return errors.New("your group has no permission to download") } + + // 需要积分但未登录 + if share.Score > 0 && user.IsAnonymous() { + return errors.New("you must login to download") + } + return nil } @@ -149,9 +161,12 @@ func (share *Share) WasDownloadedBy(user *User, c *gin.Context) (exist bool) { return exist } -// DownloadBy 增加下载次数,匿名用户不会缓存 +// DownloadBy 增加下载次数、检查积分等,匿名用户不会缓存 func (share *Share) DownloadBy(user *User, c *gin.Context) error { if !share.WasDownloadedBy(user, c) { + if err := share.Purchase(user); err != nil { + return err + } share.Downloaded() if !user.IsAnonymous() { cache.Set(fmt.Sprintf("share_%d_%d", share.ID, user.ID), true, @@ -163,6 +178,25 @@ func (share *Share) DownloadBy(user *User, c *gin.Context) error { return nil } +// Purchase 使用积分购买分享 +func (share *Share) Purchase(user *User) error { + // 不需要付积分 + if share.Score == 0 || user.Group.OptionsSerialized.ShareFree || user.ID == share.UserID { + return nil + } + + ok := user.PayScore(share.Score) + if !ok { + return ErrInsufficientCredit + } + + scoreRate := GetIntSetting("share_score_rate", 100) + gainedScore := int(math.Ceil(float64(share.Score*scoreRate) / 100)) + share.Creator().AddScore(gainedScore) + + return nil +} + // Viewed 增加访问次数 func (share *Share) Viewed() { share.Views++ diff --git a/models/share_test.go b/models/share_test.go deleted file mode 100644 index b3fdf0a..0000000 --- a/models/share_test.go +++ /dev/null @@ -1,321 +0,0 @@ -package model - -import ( - "errors" - "net/http/httptest" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestShare_Create(t *testing.T) { - asserts := assert.New(t) - share := Share{UserID: 1} - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - id, err := share.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(2, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - id, err := share.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestGetShareByHashID(t *testing.T) { - asserts := assert.New(t) - conf.SystemConfig.HashIDSalt = "" - - // 成功 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res := GetShareByHashID("x9T4") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(res) - } - - // 查询失败 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnError(errors.New("error")) - res := GetShareByHashID("x9T4") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(res) - } - - // ID解码失败 - { - res := GetShareByHashID("empty") - asserts.Nil(res) - } - -} - -func TestShare_IsAvailable(t *testing.T) { - asserts := assert.New(t) - - // 下载剩余次数为0 - { - share := Share{} - asserts.False(share.IsAvailable()) - } - - // 时效过期 - { - expires := time.Unix(10, 10) - share := Share{ - RemainDownloads: -1, - Expires: &expires, - } - asserts.False(share.IsAvailable()) - } - - // 源对象为目录,但不存在 - { - share := Share{ - RemainDownloads: -1, - SourceID: 2, - IsDir: true, - } - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - asserts.False(share.IsAvailable()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 源对象为目录,存在 - { - share := Share{ - RemainDownloads: -1, - SourceID: 2, - IsDir: false, - } - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(13)) - asserts.True(share.IsAvailable()) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 用户被封禁 - { - share := Share{ - RemainDownloads: -1, - SourceID: 2, - IsDir: true, - User: User{Status: Baned}, - } - asserts.False(share.IsAvailable()) - } -} - -func TestShare_GetCreator(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - share := Share{UserID: 1} - res := share.Creator() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.EqualValues(1, res.ID) -} - -func TestShare_Source(t *testing.T) { - asserts := assert.New(t) - - // 目录 - { - share := Share{IsDir: true, SourceID: 3} - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) - asserts.EqualValues(3, share.Source().(*Folder).ID) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 文件 - { - share := Share{IsDir: false, SourceID: 3} - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) - asserts.EqualValues(3, share.Source().(*File).ID) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestShare_CanBeDownloadBy(t *testing.T) { - asserts := assert.New(t) - share := Share{} - - // 未登录,无权 - { - user := &User{ - Group: Group{ - OptionsSerialized: GroupOption{ - ShareDownload: false, - }, - }, - } - asserts.Error(share.CanBeDownloadBy(user)) - } - - // 已登录,无权 - { - user := &User{ - Model: gorm.Model{ID: 1}, - Group: Group{ - OptionsSerialized: GroupOption{ - ShareDownload: false, - }, - }, - } - asserts.Error(share.CanBeDownloadBy(user)) - } - - // 成功 - { - user := &User{ - Model: gorm.Model{ID: 1}, - Group: Group{ - OptionsSerialized: GroupOption{ - ShareDownload: true, - }, - }, - } - asserts.NoError(share.CanBeDownloadBy(user)) - } -} - -func TestShare_WasDownloadedBy(t *testing.T) { - asserts := assert.New(t) - share := Share{ - Model: gorm.Model{ID: 1}, - } - - // 已登录,已下载 - { - user := User{ - Model: gorm.Model{ - ID: 1, - }, - } - r := httptest.NewRecorder() - c, _ := gin.CreateTestContext(r) - cache.Set("share_1_1", true, 0) - asserts.True(share.WasDownloadedBy(&user, c)) - } -} - -func TestShare_DownloadBy(t *testing.T) { - asserts := assert.New(t) - share := Share{ - Model: gorm.Model{ID: 1}, - } - user := User{ - Model: gorm.Model{ - ID: 1, - }, - } - cache.Deletes([]string{"1_1"}, "share_") - r := httptest.NewRecorder() - c, _ := gin.CreateTestContext(r) - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - err := share.DownloadBy(&user, c) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - _, ok := cache.Get("share_1_1") - asserts.True(ok) -} - -func TestShare_Viewed(t *testing.T) { - asserts := assert.New(t) - share := Share{} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - share.Viewed() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.EqualValues(1, share.Views) -} - -func TestShare_UpdateAndDelete(t *testing.T) { - asserts := assert.New(t) - share := Share{} - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := share.Update(map[string]interface{}{"id": 1}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := share.Delete() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := DeleteShareBySourceIDs([]uint{1}, true) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - -} - -func TestListShares(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(2)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2)) - - res, total := ListShares(1, 1, 10, "desc", true) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 2) - asserts.Equal(2, total) -} - -func TestSearchShares(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)"). - WithArgs("", sqlmock.AnyArg(), "%1%2%"). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, total := SearchShares(1, 10, "id", "1 2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 1) - asserts.Equal(1, total) -} diff --git a/models/source_link_test.go b/models/source_link_test.go deleted file mode 100644 index d84dc62..0000000 --- a/models/source_link_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package model - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestSourceLink_Link(t *testing.T) { - a := assert.New(t) - s := &SourceLink{} - s.ID = 1 - - // 失败 - { - s.File.Name = string([]byte{0x7f}) - res, err := s.Link() - a.Error(err) - a.Empty(res) - } - - // 成功 - { - s.File.Name = "filename" - res, err := s.Link() - a.NoError(err) - a.Contains(res, s.Name) - } -} - -func TestGetSourceLinkByID(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2)) - mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) - - res, err := GetSourceLinkByID(1) - a.NoError(err) - a.NotNil(res) - a.EqualValues(2, res.File.ID) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestSourceLink_Downloaded(t *testing.T) { - a := assert.New(t) - s := &SourceLink{} - s.ID = 1 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - s.Downloaded() - a.NoError(mock.ExpectationsWereMet()) -} diff --git a/models/storage_pack.go b/models/storage_pack.go new file mode 100755 index 0000000..a18f9c1 --- /dev/null +++ b/models/storage_pack.go @@ -0,0 +1,91 @@ +package model + +import ( + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/jinzhu/gorm" + "strconv" + "time" +) + +// StoragePack 容量包模型 +type StoragePack struct { + // 表字段 + gorm.Model + Name string + UserID uint + ActiveTime *time.Time + ExpiredTime *time.Time `gorm:"index:expired"` + Size uint64 +} + +// Create 创建容量包 +func (pack *StoragePack) Create() (uint, error) { + if err := DB.Create(pack).Error; err != nil { + util.Log().Warning("Failed to insert storage pack record: %s", err) + return 0, err + } + return pack.ID, nil +} + +// GetAvailablePackSize 返回给定用户当前可用的容量包总容量 +func (user *User) GetAvailablePackSize() uint64 { + var ( + total uint64 + firstExpire *time.Time + timeNow = time.Now() + ttl int64 + ) + + // 尝试从缓存中读取 + cacheKey := "pack_size_" + strconv.FormatUint(uint64(user.ID), 10) + if total, ok := cache.Get(cacheKey); ok { + return total.(uint64) + } + + // 查找所有有效容量包 + packs := user.GetAvailableStoragePacks() + + // 计算总容量, 并找到其中最早的过期时间 + for _, v := range packs { + total += v.Size + if firstExpire == nil { + firstExpire = v.ExpiredTime + continue + } + if v.ExpiredTime != nil && firstExpire.After(*v.ExpiredTime) { + firstExpire = v.ExpiredTime + } + } + + // 用最早的过期时间计算缓存TTL,并写入缓存 + if firstExpire != nil { + ttl = firstExpire.Unix() - timeNow.Unix() + if ttl > 0 { + _ = cache.Set(cacheKey, total, int(ttl)) + } + } + + return total +} + +// GetAvailableStoragePacks 返回用户可用的容量包 +func (user *User) GetAvailableStoragePacks() []StoragePack { + var packs []StoragePack + timeNow := time.Now() + // 查找所有有效容量包 + DB.Where("expired_time > ? AND user_id = ?", timeNow, user.ID).Find(&packs) + return packs +} + +// GetExpiredStoragePack 获取已过期的容量包 +func GetExpiredStoragePack() []StoragePack { + var packs []StoragePack + DB.Where("expired_time < ?", time.Now()).Find(&packs) + return packs +} + +// Delete 删除容量包 +func (pack *StoragePack) Delete() error { + return DB.Delete(&pack).Error +} diff --git a/models/tag_test.go b/models/tag_test.go deleted file mode 100644 index be8d3fb..0000000 --- a/models/tag_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package model - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestTag_Create(t *testing.T) { - asserts := assert.New(t) - tag := Tag{} - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - id, err := tag.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - id, err := tag.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestDeleteTagByID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := DeleteTagByID(1, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) -} - -func TestGetTagsByUID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetTagsByUID(1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) -} - -func TestGetTagsByID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("tag")) - res, err := GetTagsByID(1, 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues("tag", res.Name) -} diff --git a/models/task_test.go b/models/task_test.go deleted file mode 100644 index 1ad71c3..0000000 --- a/models/task_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package model - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestTask_Create(t *testing.T) { - asserts := assert.New(t) - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task := Task{Props: "1"} - id, err := task.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - task := Task{Props: "1"} - id, err := task.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := Task{ - Model: gorm.Model{ID: 1}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - asserts.NoError(task.SetError("error")) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := Task{ - Model: gorm.Model{ID: 1}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - asserts.NoError(task.SetStatus(1)) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestTask_SetProgress(t *testing.T) { - asserts := assert.New(t) - task := Task{ - Model: gorm.Model{ID: 1}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - asserts.NoError(task.SetProgress(1)) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestGetTasksByID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := GetTasksByID(1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, res.ID) -} - -func TestListTasks(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5)) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5)) - - res, total := ListTasks(1, 1, 10, "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.EqualValues(5, total) - asserts.Len(res, 1) -} - -func TestGetTasksByStatus(t *testing.T) { - a := assert.New(t) - - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res := GetTasksByStatus(1, 2) - a.NoError(mock.ExpectationsWereMet()) - a.Len(res, 1) -} diff --git a/models/user.go b/models/user.go index ff1d6dd..4d1d1a3 100644 --- a/models/user.go +++ b/models/user.go @@ -7,6 +7,7 @@ import ( "encoding/hex" "encoding/json" "strings" + "time" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/jinzhu/gorm" @@ -28,20 +29,25 @@ const ( type User struct { // 表字段 gorm.Model - Email string `gorm:"type:varchar(100);unique_index"` - Nick string `gorm:"size:50"` - Password string `json:"-"` - Status int - GroupID uint - Storage uint64 - TwoFactor string - Avatar string - Options string `json:"-" gorm:"size:4294967295"` - Authn string `gorm:"size:4294967295"` + Email string `gorm:"type:varchar(100);unique_index"` + Nick string `gorm:"size:50"` + Password string `json:"-"` + Status int + GroupID uint + Storage uint64 + OpenID string + TwoFactor string + Avatar string + Options string `json:"-" gorm:"size:4294967295"` + Authn string `gorm:"size:4294967295"` + Score int + PreviousGroupID uint // 初始用户组 + GroupExpires *time.Time // 用户组过期日期 + NotifyDate *time.Time // 通知超出配额时的日期 + Phone string // 关联模型 - Group Group `gorm:"save_associations:false:false"` - Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"` + Group Group `gorm:"save_associations:false:false"` // 数据库忽略字段 OptionsSerialized UserOption `gorm:"-"` @@ -53,8 +59,9 @@ func init() { // UserOption 用户个性化配置字段 type UserOption struct { - ProfileOff bool `json:"profile_off,omitempty"` - PreferredTheme string `json:"preferred_theme,omitempty"` + ProfileOff bool `json:"profile_off,omitempty"` + PreferredPolicy uint `json:"preferred_policy,omitempty"` + PreferredTheme string `json:"preferred_theme,omitempty"` } // Root 获取用户的根目录 @@ -99,6 +106,25 @@ func (user *User) ChangeStorage(tx *gorm.DB, operator string, size uint64) error return tx.Model(user).Update("storage", gorm.Expr("storage "+operator+" ?", size)).Error } +// PayScore 扣除积分,返回是否成功 +func (user *User) PayScore(score int) bool { + if score == 0 { + return true + } + if score <= user.Score { + user.Score -= score + DB.Model(user).Update("score", gorm.Expr("score - ?", score)) + return true + } + return false +} + +// AddScore 增加积分 +func (user *User) AddScore(score int) { + user.Score += score + DB.Model(user).Update("score", gorm.Expr("score + ?", score)) +} + // IncreaseStorageWithoutCheck 忽略可用容量,增加用户已用容量 func (user *User) IncreaseStorageWithoutCheck(size uint64) { if size == 0 { @@ -111,19 +137,58 @@ func (user *User) IncreaseStorageWithoutCheck(size uint64) { // GetRemainingCapacity 获取剩余配额 func (user *User) GetRemainingCapacity() uint64 { - total := user.Group.MaxStorage + total := user.Group.MaxStorage + user.GetAvailablePackSize() if total <= user.Storage { return 0 } return total - user.Storage } -// GetPolicyID 获取用户当前的存储策略ID -func (user *User) GetPolicyID(prefer uint) uint { - if len(user.Group.PolicyList) > 0 { - return user.Group.PolicyList[0] +// GetPolicyID 获取给定目录的存储策略, 如果为 nil 则使用默认 +func (user *User) GetPolicyID(folder *Folder) *Policy { + if user.IsAnonymous() { + return &Policy{Type: "anonymous"} } - return 0 + + defaultPolicy := uint(1) + if len(user.Group.PolicyList) > 0 { + defaultPolicy = user.Group.PolicyList[0] + } + + if folder != nil { + prefer := folder.PolicyID + if prefer == 0 && folder.InheritPolicyID > 0 { + prefer = folder.InheritPolicyID + } + + if prefer > 0 && util.ContainsUint(user.Group.PolicyList, prefer) { + defaultPolicy = prefer + } + } + + p, _ := GetPolicyByID(defaultPolicy) + return &p +} + +// GetPolicyByPreference 在可用存储策略中优先获取 preference +func (user *User) GetPolicyByPreference(preference uint) *Policy { + if user.IsAnonymous() { + return &Policy{Type: "anonymous"} + } + + defaultPolicy := uint(1) + if len(user.Group.PolicyList) > 0 { + defaultPolicy = user.Group.PolicyList[0] + } + + if preference != 0 { + if util.ContainsUint(user.Group.PolicyList, preference) { + defaultPolicy = preference + } + } + + p, _ := GetPolicyByID(defaultPolicy) + return &p } // GetUserByID 用ID获取用户 @@ -183,6 +248,27 @@ func (user *User) AfterCreate(tx *gorm.DB) (err error) { OwnerID: user.ID, } tx.Create(defaultFolder) + + // 创建用户初始文件记录 + initialFiles := GetSettingByNameFromTx(tx, "initial_files") + if initialFiles != "" { + initialFileIDs := make([]uint, 0) + if err := json.Unmarshal([]byte(initialFiles), &initialFileIDs); err != nil { + return err + } + + if files, err := GetFilesByIDsFromTX(tx, initialFileIDs, 0); err == nil { + for _, file := range files { + file.ID = 0 + file.UserID = user.ID + file.FolderID = defaultFolder.ID + user.Storage += file.Size + tx.Create(&file) + } + tx.Save(user) + } + } + return err } @@ -193,12 +279,10 @@ func (user *User) AfterFind() (err error) { err = json.Unmarshal([]byte(user.Options), &user.OptionsSerialized) } - // 预加载存储策略 - user.Policy, _ = GetPolicyByID(user.GetPolicyID(0)) return err } -//SerializeOptions 将序列后的Option写入到数据库字段 +// SerializeOptions 将序列后的Option写入到数据库字段 func (user *User) SerializeOptions() (err error) { optionsValue, err := json.Marshal(&user.OptionsSerialized) user.Options = string(optionsValue) @@ -261,7 +345,6 @@ func (user *User) SetPassword(password string) error { // NewAnonymousUser 返回一个匿名用户 func NewAnonymousUser() *User { user := User{} - user.Policy.Type = "anonymous" user.Group, _ = GetGroupByID(3) return &user } @@ -271,6 +354,20 @@ func (user *User) IsAnonymous() bool { return user.ID == 0 } +// Notified 更新用户容量超额通知日期 +func (user *User) Notified() { + if user.NotifyDate == nil { + timeNow := time.Now() + user.NotifyDate = &timeNow + DB.Model(&user).Update("notify_date", user.NotifyDate) + } +} + +// ClearNotified 清除用户通知标记 +func (user *User) ClearNotified() { + DB.Model(&user).Update("notify_date", nil) +} + // SetStatus 设定用户状态 func (user *User) SetStatus(status int) { DB.Model(&user).Update("status", status) @@ -288,3 +385,45 @@ func (user *User) UpdateOptions() error { } return user.Update(map[string]interface{}{"options": user.Options}) } + +// GetGroupExpiredUsers 获取用户组过期的用户 +func GetGroupExpiredUsers() []User { + var users []User + DB.Where("group_expires < ? and previous_group_id <> 0", time.Now()).Find(&users) + return users +} + +// GetTolerantExpiredUser 获取超过宽容期的用户 +func GetTolerantExpiredUser() []User { + var users []User + DB.Set("gorm:auto_preload", true).Where("notify_date < ?", time.Now().Add( + time.Duration(-GetIntSetting("ban_time", 10))*time.Second), + ).Find(&users) + return users +} + +// GroupFallback 回退到初始用户组 +func (user *User) GroupFallback() { + if user.GroupExpires != nil && user.PreviousGroupID != 0 { + user.Group.ID = user.PreviousGroupID + DB.Model(&user).Updates(map[string]interface{}{ + "group_expires": nil, + "previous_group_id": 0, + "group_id": user.PreviousGroupID, + }) + } +} + +// UpgradeGroup 升级用户组 +func (user *User) UpgradeGroup(id uint, expires *time.Time) error { + user.Group.ID = id + previousGroupID := user.GroupID + if user.PreviousGroupID != 0 && user.GroupID == id { + previousGroupID = user.PreviousGroupID + } + return DB.Model(&user).Updates(map[string]interface{}{ + "group_expires": expires, + "previous_group_id": previousGroupID, + "group_id": id, + }).Error +} diff --git a/models/user_authn_test.go b/models/user_authn_test.go deleted file mode 100644 index 08a8ce1..0000000 --- a/models/user_authn_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package model - -import ( - "github.com/DATA-DOG/go-sqlmock" - "github.com/duo-labs/webauthn/webauthn" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestUser_RegisterAuthn(t *testing.T) { - asserts := assert.New(t) - credential := webauthn.Credential{} - user := User{ - Model: gorm.Model{ID: 1}, - } - - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - user.RegisterAuthn(&credential) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestUser_WebAuthnCredentials(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`, - } - { - credentials := user.WebAuthnCredentials() - asserts.Len(credentials, 1) - } -} - -func TestUser_WebAuthnDisplayName(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - Nick: "123", - } - { - nick := user.WebAuthnDisplayName() - asserts.Equal("123", nick) - } -} - -func TestUser_WebAuthnIcon(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - } - { - icon := user.WebAuthnIcon() - asserts.NotEmpty(icon) - } -} - -func TestUser_WebAuthnID(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - } - { - id := user.WebAuthnID() - asserts.Len(id, 8) - } -} - -func TestUser_WebAuthnName(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - Email: "abslant@foxmail.com", - } - { - name := user.WebAuthnName() - asserts.Equal("abslant@foxmail.com", name) - } -} - -func TestUser_RemoveAuthn(t *testing.T) { - asserts := assert.New(t) - user := User{ - Model: gorm.Model{ID: 1}, - Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`, - } - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - user.RemoveAuthn("123") - asserts.NoError(mock.ExpectationsWereMet()) - } -} diff --git a/models/user_test.go b/models/user_test.go deleted file mode 100644 index a85ddbd..0000000 --- a/models/user_test.go +++ /dev/null @@ -1,438 +0,0 @@ -package model - -import ( - "encoding/json" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" -) - -func TestGetUserByID(t *testing.T) { - asserts := assert.New(t) - cache.Deletes([]string{"1"}, "policy_") - //找到用户时 - userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}). - AddRow(1, nil, "admin@cloudreve.org", "{}", 1) - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows) - - groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). - AddRow(1, "管理员", "[1]") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) - - policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "默认存储策略") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) - - user, err := GetUserByID(1) - asserts.NoError(err) - asserts.Equal(User{ - Model: gorm.Model{ - ID: 1, - DeletedAt: nil, - }, - Email: "admin@cloudreve.org", - Options: "{}", - GroupID: 1, - Group: Group{ - Model: gorm.Model{ - ID: 1, - }, - Name: "管理员", - Policies: "[1]", - PolicyList: []uint{1}, - }, - Policy: Policy{ - Model: gorm.Model{ - ID: 1, - }, - OptionsSerialized: PolicyOption{ - FileType: []string{}, - }, - Name: "默认存储策略", - }, - }, user) - - //未找到用户时 - mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found")) - user, err = GetUserByID(1) - asserts.Error(err) - asserts.Equal(User{}, user) -} - -func TestGetActiveUserByID(t *testing.T) { - asserts := assert.New(t) - cache.Deletes([]string{"1"}, "policy_") - //找到用户时 - userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}). - AddRow(1, nil, "admin@cloudreve.org", "{}", 1) - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows) - - groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). - AddRow(1, "管理员", "[1]") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) - - policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "默认存储策略") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) - - user, err := GetActiveUserByID(1) - asserts.NoError(err) - asserts.Equal(User{ - Model: gorm.Model{ - ID: 1, - DeletedAt: nil, - }, - Email: "admin@cloudreve.org", - Options: "{}", - GroupID: 1, - Group: Group{ - Model: gorm.Model{ - ID: 1, - }, - Name: "管理员", - Policies: "[1]", - PolicyList: []uint{1}, - }, - Policy: Policy{ - Model: gorm.Model{ - ID: 1, - }, - OptionsSerialized: PolicyOption{ - FileType: []string{}, - }, - Name: "默认存储策略", - }, - }, user) - - //未找到用户时 - mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found")) - user, err = GetActiveUserByID(1) - asserts.Error(err) - asserts.Equal(User{}, user) -} - -func TestUser_SetPassword(t *testing.T) { - asserts := assert.New(t) - user := User{} - err := user.SetPassword("Cause Sega does what nintendon't") - asserts.NoError(err) - asserts.NotEmpty(user.Password) -} - -func TestUser_CheckPassword(t *testing.T) { - asserts := assert.New(t) - user := User{} - err := user.SetPassword("Cause Sega does what nintendon't") - asserts.NoError(err) - - //密码正确 - res, err := user.CheckPassword("Cause Sega does what nintendon't") - asserts.NoError(err) - asserts.True(res) - - //密码错误 - res, err = user.CheckPassword("Cause Sega does what Nintendon't") - asserts.NoError(err) - asserts.False(res) - - //密码字段为空 - user = User{} - res, err = user.CheckPassword("Cause Sega does what nintendon't") - asserts.Error(err) - asserts.False(res) - - // 未知密码类型 - user = User{} - user.Password = "1:2:3" - res, err = user.CheckPassword("Cause Sega does what nintendon't") - asserts.Error(err) - asserts.False(res) - - // V2密码,错误 - user = User{} - user.Password = "md5:2:3" - res, err = user.CheckPassword("Cause Sega does what nintendon't") - asserts.NoError(err) - asserts.False(res) - - // V2密码,正确 - user = User{} - user.Password = "md5:d8446059f8846a2c111a7f53515665fb:sdshare" - res, err = user.CheckPassword("admin") - asserts.NoError(err) - asserts.True(res) - -} - -func TestNewUser(t *testing.T) { - asserts := assert.New(t) - newUser := NewUser() - asserts.IsType(User{}, newUser) - asserts.Empty(newUser.Avatar) -} - -func TestUser_AfterFind(t *testing.T) { - asserts := assert.New(t) - cache.Deletes([]string{"0"}, "policy_") - - policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(144, "默认存储策略") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) - - newUser := NewUser() - err := newUser.AfterFind() - err = newUser.BeforeSave() - expected := UserOption{} - err = json.Unmarshal([]byte(newUser.Options), &expected) - - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(expected, newUser.OptionsSerialized) - asserts.Equal("默认存储策略", newUser.Policy.Name) -} - -func TestUser_BeforeSave(t *testing.T) { - asserts := assert.New(t) - - newUser := NewUser() - err := newUser.BeforeSave() - expected, err := json.Marshal(newUser.OptionsSerialized) - - asserts.NoError(err) - asserts.Equal(string(expected), newUser.Options) -} - -func TestUser_GetPolicyID(t *testing.T) { - asserts := assert.New(t) - - newUser := NewUser() - newUser.Group.PolicyList = []uint{1} - - asserts.EqualValues(1, newUser.GetPolicyID(0)) - - newUser.Group.PolicyList = nil - asserts.EqualValues(0, newUser.GetPolicyID(0)) - - newUser.Group.PolicyList = []uint{} - asserts.EqualValues(0, newUser.GetPolicyID(0)) -} - -func TestUser_GetRemainingCapacity(t *testing.T) { - asserts := assert.New(t) - newUser := NewUser() - cache.Set("pack_size_0", uint64(0), 0) - - newUser.Group.MaxStorage = 100 - asserts.Equal(uint64(100), newUser.GetRemainingCapacity()) - - newUser.Group.MaxStorage = 100 - newUser.Storage = 1 - asserts.Equal(uint64(99), newUser.GetRemainingCapacity()) - - newUser.Group.MaxStorage = 100 - newUser.Storage = 100 - asserts.Equal(uint64(0), newUser.GetRemainingCapacity()) - - newUser.Group.MaxStorage = 100 - newUser.Storage = 200 - asserts.Equal(uint64(0), newUser.GetRemainingCapacity()) -} - -func TestUser_DeductionCapacity(t *testing.T) { - asserts := assert.New(t) - - cache.Deletes([]string{"1"}, "policy_") - userRows := sqlmock.NewRows([]string{"id", "deleted_at", "storage", "options", "group_id"}). - AddRow(1, nil, 0, "{}", 1) - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows) - groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). - AddRow(1, "管理员", "[1]") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) - - policyRows := sqlmock.NewRows([]string{"id", "name"}). - AddRow(1, "默认存储策略") - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows) - - newUser, err := GetUserByID(1) - newUser.Group.MaxStorage = 100 - cache.Set("pack_size_1", uint64(0), 0) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - asserts.Equal(false, newUser.IncreaseStorage(101)) - asserts.Equal(uint64(0), newUser.Storage) - - asserts.Equal(true, newUser.IncreaseStorage(1)) - asserts.Equal(uint64(1), newUser.Storage) - - asserts.Equal(true, newUser.IncreaseStorage(99)) - asserts.Equal(uint64(100), newUser.Storage) - - asserts.Equal(false, newUser.IncreaseStorage(1)) - asserts.Equal(uint64(100), newUser.Storage) - - asserts.True(newUser.IncreaseStorage(0)) -} - -func TestUser_DeductionStorage(t *testing.T) { - asserts := assert.New(t) - - // 减少零 - { - user := User{Storage: 1} - asserts.True(user.DeductionStorage(0)) - asserts.Equal(uint64(1), user.Storage) - } - // 正常 - { - user := User{ - Model: gorm.Model{ID: 1}, - Storage: 10, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs(5, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - asserts.True(user.DeductionStorage(5)) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(5), user.Storage) - } - - // 减少的超出可用的 - { - user := User{ - Model: gorm.Model{ID: 1}, - Storage: 10, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs(0, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - asserts.False(user.DeductionStorage(20)) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(0), user.Storage) - } -} - -func TestUser_IncreaseStorageWithoutCheck(t *testing.T) { - asserts := assert.New(t) - - // 增加零 - { - user := User{} - user.IncreaseStorageWithoutCheck(0) - asserts.Equal(uint64(0), user.Storage) - } - - // 减少零 - { - user := User{ - Model: gorm.Model{ID: 1}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs(10, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - user.IncreaseStorageWithoutCheck(10) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(uint64(10), user.Storage) - } -} - -func TestGetActiveUserByEmail(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"})) - _, err := GetActiveUserByEmail("abslant@foxmail.com") - - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestGetUserByEmail(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WithArgs("abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"})) - _, err := GetUserByEmail("abslant@foxmail.com") - - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestUser_AfterCreate(t *testing.T) { - asserts := assert.New(t) - user := User{Model: gorm.Model{ID: 1}} - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := user.AfterCreate(DB) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestUser_Root(t *testing.T) { - asserts := assert.New(t) - user := User{Model: gorm.Model{ID: 1}} - - // 根目录存在 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "根目录")) - root, err := user.Root() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal("根目录", root.Name) - } - - // 根目录不存在 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - _, err := user.Root() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} - -func TestNewAnonymousUser(t *testing.T) { - asserts := assert.New(t) - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3)) - user := NewAnonymousUser() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(user) - asserts.EqualValues(3, user.Group.ID) -} - -func TestUser_IsAnonymous(t *testing.T) { - asserts := assert.New(t) - user := User{} - asserts.True(user.IsAnonymous()) - user.ID = 1 - asserts.False(user.IsAnonymous()) -} - -func TestUser_SetStatus(t *testing.T) { - asserts := assert.New(t) - user := User{} - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - user.SetStatus(Baned) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal(Baned, user.Status) -} - -func TestUser_UpdateOptions(t *testing.T) { - asserts := assert.New(t) - user := User{} - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - asserts.NoError(user.UpdateOptions()) - asserts.NoError(mock.ExpectationsWereMet()) -} diff --git a/models/webdav.go b/models/webdav.go index 0799aee..ee424aa 100644 --- a/models/webdav.go +++ b/models/webdav.go @@ -46,3 +46,8 @@ func DeleteWebDAVAccountByID(id, uid uint) { func UpdateWebDAVAccountByID(id, uid uint, updates map[string]interface{}) { DB.Model(&Webdav{Model: gorm.Model{ID: id}, UserID: uid}).Updates(updates) } + +// UpdateWebDAVAccountReadonlyByID 根据账户ID和UID更新账户的只读性 +func UpdateWebDAVAccountReadonlyByID(id, uid uint, readonly bool) { + DB.Model(&Webdav{Model: gorm.Model{ID: id}, UserID: uid}).UpdateColumn("readonly", readonly) +} diff --git a/models/webdav_test.go b/models/webdav_test.go deleted file mode 100644 index 55a7326..0000000 --- a/models/webdav_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package model - -import ( - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestWebdav_Create(t *testing.T) { - asserts := assert.New(t) - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task := Webdav{} - id, err := task.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.EqualValues(1, id) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - task := Webdav{} - id, err := task.Create() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualValues(0, id) - } -} - -func TestGetWebdavByPassword(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - _, err := GetWebdavByPassword("e", 1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) -} - -func TestListWebDAVAccounts(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - res := ListWebDAVAccounts(1) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(res, 0) -} - -func TestDeleteWebDAVAccountByID(t *testing.T) { - asserts := assert.New(t) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - DeleteWebDAVAccountByID(1, 1) - asserts.NoError(mock.ExpectationsWereMet()) -} diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go deleted file mode 100644 index b6e7092..0000000 --- a/pkg/aria2/aria2_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package aria2 - -import ( - "database/sql" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestInit(t *testing.T) { - a := assert.New(t) - mockPool := &mocks.NodePoolMock{} - mockPool.On("GetNodeByID", testMock.Anything).Return(nil) - mockQueue := mq.NewMQ() - - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - Init(false, mockPool, mockQueue) - a.NoError(mock.ExpectationsWereMet()) - mockPool.AssertExpectations(t) -} - -func TestTestRPCConnection(t *testing.T) { - a := assert.New(t) - - // url not legal - { - res, err := TestRPCConnection(string([]byte{0x7f}), "", 10) - a.Error(err) - a.Empty(res.Version) - } - - // rpc failed - { - res, err := TestRPCConnection("ws://0.0.0.0", "", 0) - a.Error(err) - a.Empty(res.Version) - } -} - -func TestGetLoadBalancer(t *testing.T) { - a := assert.New(t) - a.NotPanics(func() { - GetLoadBalancer() - }) -} diff --git a/pkg/aria2/common/common_test.go b/pkg/aria2/common/common_test.go deleted file mode 100644 index 7b0f237..0000000 --- a/pkg/aria2/common/common_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package common - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/stretchr/testify/assert" -) - -func TestDummyAria2(t *testing.T) { - a := assert.New(t) - d := &DummyAria2{} - - a.NoError(d.Init()) - - res, err := d.CreateTask(&model.Download{}, map[string]interface{}{}) - a.Empty(res) - a.Error(err) - - _, err = d.Status(&model.Download{}) - a.Error(err) - - err = d.Cancel(&model.Download{}) - a.Error(err) - - err = d.Select(&model.Download{}, []int{}) - a.Error(err) - - configRes := d.GetConfig() - a.NotNil(configRes) - - err = d.DeleteTempFile(&model.Download{}) - a.Error(err) -} - -func TestGetStatus(t *testing.T) { - a := assert.New(t) - - a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete) - a.Equal(GetStatus(rpc.StatusInfo{Status: "active", - BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading) - a.Equal(GetStatus(rpc.StatusInfo{Status: "active", - BitTorrent: rpc.BitTorrentInfo{Mode: "single"}, - TotalLength: "100", CompletedLength: "50"}), Downloading) - a.Equal(GetStatus(rpc.StatusInfo{Status: "active", - BitTorrent: rpc.BitTorrentInfo{Mode: "multi"}, - TotalLength: "100", CompletedLength: "100"}), Seeding) - a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready) - a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused) - a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error) - a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled) - a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown) -} diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index 69d14ff..ff3f380 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "path/filepath" "strconv" "time" @@ -52,6 +53,7 @@ func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) { // Loop 开启监控循环 func (monitor *Monitor) Loop(mqClient mq.MQ) { defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier) + fmt.Println(cluster.Default) // 首次循环立即更新 interval := 50 * time.Millisecond @@ -190,6 +192,10 @@ func (monitor *Monitor) ValidateFile() error { } defer fs.Recycle() + if err := fs.SetPolicyFromPath(monitor.Task.Dst); err != nil { + return fmt.Errorf("failed to switch policy to target dir: %w", err) + } + // 创建上下文环境 file := &fsctx.FileStream{ Size: monitor.Task.TotalSize, diff --git a/pkg/aria2/monitor/monitor_test.go b/pkg/aria2/monitor/monitor_test.go deleted file mode 100644 index a6be586..0000000 --- a/pkg/aria2/monitor/monitor_test.go +++ /dev/null @@ -1,447 +0,0 @@ -package monitor - -import ( - "database/sql" - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestNewMonitor(t *testing.T) { - a := assert.New(t) - mockMQ := mq.NewMQ() - - // node not available - { - mockPool := &mocks.NodePoolMock{} - mockPool.On("GetNodeByID", uint(1)).Return(nil) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task := &model.Download{ - Model: gorm.Model{ID: 1}, - } - NewMonitor(task, mockPool, mockMQ) - mockPool.AssertExpectations(t) - a.NoError(mock.ExpectationsWereMet()) - a.NotEmpty(task.Error) - } - - // success - { - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - mockPool := &mocks.NodePoolMock{} - mockPool.On("GetNodeByID", uint(1)).Return(mockNode) - - task := &model.Download{ - Model: gorm.Model{ID: 1}, - } - NewMonitor(task, mockPool, mockMQ) - mockNode.AssertExpectations(t) - mockPool.AssertExpectations(t) - } - -} - -func TestMonitor_Loop(t *testing.T) { - a := assert.New(t) - mockMQ := mq.NewMQ() - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - m := &Monitor{ - retried: MAX_RETRY, - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - notifier: mockMQ.Subscribe("test", 1), - } - - // into interval loop - { - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - m.Loop(mockMQ) - a.NoError(mock.ExpectationsWereMet()) - a.NotEmpty(m.Task.Error) - } - - // into notifier loop - { - m.Task.Error = "" - mockMQ.Publish("test", mq.Message{}) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - m.Loop(mockMQ) - a.NoError(mock.ExpectationsWereMet()) - a.NotEmpty(m.Task.Error) - } -} - -func TestMonitor_UpdateFailedAfterRetry(t *testing.T) { - a := assert.New(t) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - for i := 0; i < MAX_RETRY; i++ { - a.False(m.Update()) - } - - mockNode.AssertExpectations(t) - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - a.NotEmpty(m.Task.Error) -} - -func TestMonitor_UpdateMagentoFollow(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - FollowedBy: []string{"next"}, - }, nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.False(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - a.Equal("next", m.Task.GID) - mockAria2.AssertExpectations(t) -} - -func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil) - mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) - a.NotEmpty(m.Task.Error) -} - -func TestMonitor_UpdateCompleted(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "complete", - }, nil) - mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - mockNode.On("ID").Return(uint(1)) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) - a.NotEmpty(m.Task.Error) -} - -func TestMonitor_UpdateError(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "error", - ErrorMessage: "error", - }, nil) - mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) - a.NotEmpty(m.Task.Error) -} - -func TestMonitor_UpdateActive(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "active", - }, nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.False(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) -} - -func TestMonitor_UpdateRemoved(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "removed", - }, nil) - mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.Equal(common.Canceled, m.Task.Status) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) -} - -func TestMonitor_UpdateUnknown(t *testing.T) { - a := assert.New(t) - mockAria2 := &mocks.Aria2Mock{} - mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{ - Status: "unknown", - }, nil) - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(mockAria2) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - a.True(m.Update()) - a.NoError(mock.ExpectationsWereMet()) - mockAria2.AssertExpectations(t) - mockNode.AssertExpectations(t) -} - -func TestMonitor_UpdateTaskInfoValidateFailed(t *testing.T) { - a := assert.New(t) - status := rpc.StatusInfo{ - Status: "completed", - TotalLength: "100", - CompletedLength: "50", - DownloadSpeed: "20", - } - mockNode := &mocks.NodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - m := &Monitor{ - node: mockNode, - Task: &model.Download{Model: gorm.Model{ID: 1}}, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - err := m.UpdateTaskInfo(status) - a.Error(err) - a.NoError(mock.ExpectationsWereMet()) - mockNode.AssertExpectations(t) -} - -func TestMonitor_ValidateFile(t *testing.T) { - a := assert.New(t) - m := &Monitor{ - Task: &model.Download{ - Model: gorm.Model{ID: 1}, - TotalSize: 100, - }, - } - - // failed to create filesystem - { - m.Task.User = &model.User{ - Policy: model.Policy{ - Type: "random", - }, - } - a.Equal(filesystem.ErrUnknownPolicyType, m.ValidateFile()) - } - - // User capacity not enough - { - m.Task.User = &model.User{ - Group: model.Group{ - MaxStorage: 99, - }, - Policy: model.Policy{ - Type: "local", - }, - } - a.Equal(filesystem.ErrInsufficientCapacity, m.ValidateFile()) - } - - // single file too big - { - m.Task.StatusInfo.Files = []rpc.FileInfo{ - { - Length: "100", - Selected: "true", - }, - } - m.Task.User = &model.User{ - Group: model.Group{ - MaxStorage: 100, - }, - Policy: model.Policy{ - Type: "local", - MaxSize: 99, - }, - } - a.Equal(filesystem.ErrFileSizeTooBig, m.ValidateFile()) - } - - // all pass - { - m.Task.StatusInfo.Files = []rpc.FileInfo{ - { - Length: "100", - Selected: "true", - }, - } - m.Task.User = &model.User{ - Group: model.Group{ - MaxStorage: 100, - }, - Policy: model.Policy{ - Type: "local", - MaxSize: 100, - }, - } - a.NoError(m.ValidateFile()) - } -} - -func TestMonitor_Complete(t *testing.T) { - a := assert.New(t) - mockNode := &mocks.NodeMock{} - mockNode.On("ID").Return(uint(1)) - mockPool := &mocks.TaskPoolMock{} - mockPool.On("Submit", testMock.Anything) - m := &Monitor{ - node: mockNode, - Task: &model.Download{ - Model: gorm.Model{ID: 1}, - TotalSize: 100, - UserID: 9414, - }, - } - m.Task.StatusInfo.Files = []rpc.FileInfo{ - { - Length: "100", - Selected: "true", - }, - } - - mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414)) - - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4)) - mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - - a.False(m.Complete(mockPool)) - m.Task.StatusInfo.Status = "complete" - a.True(m.Complete(mockPool)) - a.NoError(mock.ExpectationsWereMet()) - mockNode.AssertExpectations(t) - mockPool.AssertExpectations(t) -} diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go deleted file mode 100644 index 42c5603..0000000 --- a/pkg/auth/auth_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package auth - -import ( - "io/ioutil" - "net/http" - "strings" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/stretchr/testify/assert" -) - -func TestSignURI(t *testing.T) { - asserts := assert.New(t) - General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - - // 成功 - { - sign, err := SignURI(General, "/api/v3/something?id=1", 0) - asserts.NoError(err) - queries := sign.Query() - asserts.Equal("1", queries.Get("id")) - asserts.NotEmpty(queries.Get("sign")) - } - - // URI解码失败 - { - sign, err := SignURI(General, "://dg.;'f]gh./'", 0) - asserts.Error(err) - asserts.Nil(sign) - } -} - -func TestCheckURI(t *testing.T) { - asserts := assert.New(t) - General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - - // 成功 - { - sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", 10) - asserts.NoError(err) - asserts.NoError(CheckURI(General, sign)) - } - - // 过期 - { - sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", -1) - asserts.NoError(err) - asserts.Error(CheckURI(General, sign)) - } -} - -func TestSignRequest(t *testing.T) { - asserts := assert.New(t) - General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - - // 非上传请求 - { - req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/slave/upload", strings.NewReader("I am body.")) - asserts.NoError(err) - req = SignRequest(General, req, 0) - asserts.NotEmpty(req.Header["Authorization"]) - } - - // 上传请求 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/slave/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - req.Header["X-Cr-Policy"] = []string{"I am Policy"} - req = SignRequest(General, req, 10) - asserts.NotEmpty(req.Header["Authorization"]) - } -} - -func TestCheckRequest(t *testing.T) { - asserts := assert.New(t) - General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} - - // 缺少请求头 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - err = CheckRequest(General, req) - asserts.Error(err) - asserts.Equal(ErrAuthHeaderMissing, err) - } - - // 非上传请求 验证成功 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - req = SignRequest(General, req, 0) - err = CheckRequest(General, req) - asserts.NoError(err) - } - - // 上传请求 验证成功 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - req.Header["X-Cr-Policy"] = []string{"I am Policy"} - req = SignRequest(General, req, 0) - err = CheckRequest(General, req) - asserts.NoError(err) - } - - // 非上传请求 失败 - { - req, err := http.NewRequest( - "POST", - "http://127.0.0.1/api/v3/upload", - strings.NewReader("I am body."), - ) - asserts.NoError(err) - req = SignRequest(General, req, 0) - req.Body = ioutil.NopCloser(strings.NewReader("2333")) - err = CheckRequest(General, req) - asserts.Error(err) - } -} diff --git a/pkg/auth/hmac_test.go b/pkg/auth/hmac_test.go deleted file mode 100644 index 706f617..0000000 --- a/pkg/auth/hmac_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package auth - -import ( - "database/sql" - "fmt" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -func TestMain(m *testing.M) { - // 设置gin为测试模式 - gin.SetMode(gin.TestMode) - - // 初始化sqlmock - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - - mockDB, _ := gorm.Open("mysql", db) - model.DB = mockDB - defer db.Close() - - m.Run() -} - -func TestHMACAuth_Sign(t *testing.T) { - asserts := assert.New(t) - auth := HMACAuth{ - SecretKey: []byte(util.RandStringRunes(256)), - } - - asserts.NotEmpty(auth.Sign("content", 0)) -} - -func TestHMACAuth_Check(t *testing.T) { - asserts := assert.New(t) - auth := HMACAuth{ - SecretKey: []byte(util.RandStringRunes(256)), - } - - // 正常,永不过期 - { - sign := auth.Sign("content", 0) - asserts.NoError(auth.Check("content", sign)) - } - - // 过期 - { - sign := auth.Sign("content", 1) - asserts.Error(auth.Check("content", sign)) - } - - // 签名格式错误 - { - sign := auth.Sign("content", 1) - asserts.Error(auth.Check("content", sign+":")) - } - - // 过期日期格式错误 - { - asserts.Error(auth.Check("content", "ErrAuthFailed:ErrAuthFailed")) - } - - // 签名有误 - { - asserts.Error(auth.Check("content", fmt.Sprintf("sign:%d", time.Now().Unix()+10))) - } -} - -func TestInit(t *testing.T) { - asserts := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312")) - Init() - asserts.NoError(mock.ExpectationsWereMet()) - - // slave模式 - conf.SystemConfig.Mode = "slave" - asserts.Panics(func() { - Init() - }) -} diff --git a/pkg/authn/auth_test.go b/pkg/authn/auth_test.go deleted file mode 100644 index 3df60cf..0000000 --- a/pkg/authn/auth_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package authn - -import ( - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/stretchr/testify/assert" -) - -func TestInit(t *testing.T) { - asserts := assert.New(t) - cache.Set("setting_siteURL", "http://cloudreve.org", 0) - cache.Set("setting_siteName", "Cloudreve", 0) - res, err := NewAuthnInstance() - asserts.NotNil(res) - asserts.NoError(err) -} diff --git a/pkg/balancer/balancer_test.go b/pkg/balancer/balancer_test.go deleted file mode 100644 index 4493bbb..0000000 --- a/pkg/balancer/balancer_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package balancer - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestNewBalancer(t *testing.T) { - a := assert.New(t) - a.NotNil(NewBalancer("")) - a.IsType(&RoundRobin{}, NewBalancer("RoundRobin")) -} diff --git a/pkg/balancer/roundrobin_test.go b/pkg/balancer/roundrobin_test.go deleted file mode 100644 index 9cdcc00..0000000 --- a/pkg/balancer/roundrobin_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package balancer - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestRoundRobin_NextIndex(t *testing.T) { - a := assert.New(t) - r := &RoundRobin{} - total := 5 - for i := 1; i < total; i++ { - a.Equal(i, r.NextIndex(total)) - } - for i := 0; i < total; i++ { - a.Equal(i, r.NextIndex(total)) - } -} - -func TestRoundRobin_NextPeer(t *testing.T) { - a := assert.New(t) - r := &RoundRobin{} - - // not slice - { - err, _ := r.NextPeer("s") - a.Equal(ErrInputNotSlice, err) - } - - // no nodes - { - err, _ := r.NextPeer([]string{}) - a.Equal(ErrNoAvaliableNode, err) - } - - // pass - { - err, res := r.NextPeer([]string{"a"}) - a.NoError(err) - a.Equal("a", res.(string)) - } -} diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go index 1a3e652..4c86b47 100644 --- a/pkg/cache/driver.go +++ b/pkg/cache/driver.go @@ -2,6 +2,7 @@ package cache import ( "encoding/gob" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" diff --git a/pkg/cache/driver_test.go b/pkg/cache/driver_test.go deleted file mode 100644 index 41294e2..0000000 --- a/pkg/cache/driver_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package cache - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestSet(t *testing.T) { - asserts := assert.New(t) - - asserts.NoError(Set("123", "321", -1)) -} - -func TestGet(t *testing.T) { - asserts := assert.New(t) - asserts.NoError(Set("123", "321", -1)) - - value, ok := Get("123") - asserts.True(ok) - asserts.Equal("321", value) - - value, ok = Get("not_exist") - asserts.False(ok) -} - -func TestDeletes(t *testing.T) { - asserts := assert.New(t) - asserts.NoError(Set("123", "321", -1)) - err := Deletes([]string{"123"}, "") - asserts.NoError(err) - _, exist := Get("123") - asserts.False(exist) -} - -func TestGetSettings(t *testing.T) { - asserts := assert.New(t) - asserts.NoError(Set("test_1", "1", -1)) - - values, missed := GetSettings([]string{"1", "2"}, "test_") - asserts.Equal(map[string]string{"1": "1"}, values) - asserts.Equal([]string{"2"}, missed) -} - -func TestSetSettings(t *testing.T) { - asserts := assert.New(t) - - err := SetSettings(map[string]string{"3": "3", "4": "4"}, "test_") - asserts.NoError(err) - value1, _ := Get("test_3") - value2, _ := Get("test_4") - asserts.Equal("3", value1) - asserts.Equal("4", value2) -} - -func TestInit(t *testing.T) { - asserts := assert.New(t) - - asserts.NotPanics(func() { - Init() - }) -} - -func TestInitSlaveOverwrites(t *testing.T) { - asserts := assert.New(t) - - asserts.NotPanics(func() { - InitSlaveOverwrites() - }) -} diff --git a/pkg/cache/memo.go b/pkg/cache/memo.go index af180d6..f9dcf97 100644 --- a/pkg/cache/memo.go +++ b/pkg/cache/memo.go @@ -133,7 +133,13 @@ func (store *MemoStore) Persist(path string) error { return fmt.Errorf("failed to serialize cache: %s", err) } - err = os.WriteFile(path, res, 0644) + // err = os.WriteFile(path, res, 0644) + file, err := util.CreatNestedFile(path) + if err == nil { + _, err = file.Write(res) + file.Chmod(0644) + file.Close() + } return err } diff --git a/pkg/cache/memo_test.go b/pkg/cache/memo_test.go deleted file mode 100644 index be90577..0000000 --- a/pkg/cache/memo_test.go +++ /dev/null @@ -1,191 +0,0 @@ -package cache - -import ( - "github.com/stretchr/testify/assert" - "path/filepath" - "testing" - "time" -) - -func TestNewMemoStore(t *testing.T) { - asserts := assert.New(t) - - store := NewMemoStore() - asserts.NotNil(store) - asserts.NotNil(store.Store) -} - -func TestMemoStore_Set(t *testing.T) { - asserts := assert.New(t) - - store := NewMemoStore() - err := store.Set("KEY", "vAL", -1) - asserts.NoError(err) - - val, ok := store.Store.Load("KEY") - asserts.True(ok) - asserts.Equal("vAL", val.(itemWithTTL).Value) -} - -func TestMemoStore_Get(t *testing.T) { - asserts := assert.New(t) - store := NewMemoStore() - - // 正常情况 - { - _ = store.Set("string", "string_val", -1) - val, ok := store.Get("string") - asserts.Equal("string_val", val) - asserts.True(ok) - } - - // Key不存在 - { - val, ok := store.Get("something") - asserts.Equal(nil, val) - asserts.False(ok) - } - - // 存储struct - { - type testStruct struct { - key int - } - test := testStruct{key: 233} - _ = store.Set("struct", test, -1) - val, ok := store.Get("struct") - asserts.True(ok) - res, ok := val.(testStruct) - asserts.True(ok) - asserts.Equal(test, res) - } - - // 过期 - { - _ = store.Set("string", "string_val", 1) - time.Sleep(time.Duration(2) * time.Second) - val, ok := store.Get("string") - asserts.Nil(val) - asserts.False(ok) - } - -} - -func TestMemoStore_Gets(t *testing.T) { - asserts := assert.New(t) - store := NewMemoStore() - - err := store.Set("1", "1,val", -1) - err = store.Set("2", "2,val", -1) - err = store.Set("3", "3,val", -1) - err = store.Set("4", "4,val", -1) - asserts.NoError(err) - - // 全部命中 - { - values, miss := store.Gets([]string{"1", "2", "3", "4"}, "") - asserts.Len(values, 4) - asserts.Len(miss, 0) - } - - // 命中一半 - { - values, miss := store.Gets([]string{"1", "2", "9", "10"}, "") - asserts.Len(values, 2) - asserts.Equal([]string{"9", "10"}, miss) - } -} - -func TestMemoStore_Sets(t *testing.T) { - asserts := assert.New(t) - store := NewMemoStore() - - err := store.Sets(map[string]interface{}{ - "1": "1.val", - "2": "2.val", - "3": "3.val", - "4": "4.val", - }, "test_") - asserts.NoError(err) - - vals, miss := store.Gets([]string{"1", "2", "3", "4"}, "test_") - asserts.Len(miss, 0) - asserts.Equal(map[string]interface{}{ - "1": "1.val", - "2": "2.val", - "3": "3.val", - "4": "4.val", - }, vals) -} - -func TestMemoStore_Delete(t *testing.T) { - asserts := assert.New(t) - store := NewMemoStore() - - err := store.Sets(map[string]interface{}{ - "1": "1.val", - "2": "2.val", - "3": "3.val", - "4": "4.val", - }, "test_") - asserts.NoError(err) - - err = store.Delete([]string{"1", "2"}, "test_") - asserts.NoError(err) - values, miss := store.Gets([]string{"1", "2", "3", "4"}, "test_") - asserts.Equal([]string{"1", "2"}, miss) - asserts.Equal(map[string]interface{}{"3": "3.val", "4": "4.val"}, values) -} - -func TestMemoStore_GarbageCollect(t *testing.T) { - asserts := assert.New(t) - store := NewMemoStore() - store.Set("test", 1, 1) - time.Sleep(time.Duration(2000) * time.Millisecond) - store.GarbageCollect() - _, ok := store.Get("test") - asserts.False(ok) -} - -func TestMemoStore_PersistFailed(t *testing.T) { - a := assert.New(t) - store := NewMemoStore() - type testStruct struct{ v string } - store.Set("test", 1, 0) - store.Set("test2", testStruct{v: "test"}, 0) - err := store.Persist(filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed")) - a.Error(err) -} - -func TestMemoStore_PersistAndRestore(t *testing.T) { - a := assert.New(t) - store := NewMemoStore() - store.Set("test", 1, 0) - // already expired - store.Store.Store("test2", itemWithTTL{Value: "test", Expires: 1}) - // expired after persist - store.Set("test3", 1, 1) - temp := filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed") - - // Persist - err := store.Persist(temp) - a.NoError(err) - a.FileExists(temp) - - time.Sleep(2 * time.Second) - // Restore - store2 := NewMemoStore() - err = store2.Restore(temp) - a.NoError(err) - test, testOk := store2.Get("test") - a.EqualValues(1, test) - a.True(testOk) - test2, test2Ok := store2.Get("test2") - a.Nil(test2) - a.False(test2Ok) - test3, test3Ok := store2.Get("test3") - a.Nil(test3) - a.False(test3Ok) - - a.NoFileExists(temp) -} diff --git a/pkg/cache/redis_test.go b/pkg/cache/redis_test.go deleted file mode 100644 index c9f1692..0000000 --- a/pkg/cache/redis_test.go +++ /dev/null @@ -1,324 +0,0 @@ -package cache - -import ( - "errors" - "fmt" - "github.com/gomodule/redigo/redis" - "github.com/rafaeljusto/redigomock" - "github.com/stretchr/testify/assert" - "testing" - "time" -) - -func TestNewRedisStore(t *testing.T) { - asserts := assert.New(t) - - store := NewRedisStore(10, "tcp", "", "", "", "0") - asserts.NotNil(store) - - asserts.Panics(func() { - store.pool.Dial() - }) - - testConn := redigomock.NewConn() - cmd := testConn.Command("PING").Expect("PONG") - err := store.pool.TestOnBorrow(testConn, time.Now()) - if testConn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - asserts.NoError(err) -} - -func TestRedisStore_Set(t *testing.T) { - asserts := assert.New(t) - conn := redigomock.NewConn() - pool := &redis.Pool{ - Dial: func() (redis.Conn, error) { return conn, nil }, - MaxIdle: 10, - } - store := &RedisStore{pool: pool} - - // 正常情况 - { - cmd := conn.Command("SET", "test", redigomock.NewAnyData()).ExpectStringSlice("OK") - err := store.Set("test", "test val", -1) - asserts.NoError(err) - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - } - - // 带有TTL - // 正常情况 - { - cmd := conn.Command("SETEX", "test", 10, redigomock.NewAnyData()).ExpectStringSlice("OK") - err := store.Set("test", "test val", 10) - asserts.NoError(err) - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - } - - // 序列化出错 - { - value := struct { - Key string - }{ - Key: "123", - } - err := store.Set("test", value, -1) - asserts.Error(err) - } - - // 命令执行失败 - { - conn.Clear() - cmd := conn.Command("SET", "test", redigomock.NewAnyData()).ExpectError(errors.New("error")) - err := store.Set("test", "test val", -1) - asserts.Error(err) - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - } - // 获取连接失败 - { - store.pool = &redis.Pool{ - Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, - MaxIdle: 10, - } - err := store.Set("test", "123", -1) - asserts.Error(err) - } - -} - -func TestRedisStore_Get(t *testing.T) { - asserts := assert.New(t) - conn := redigomock.NewConn() - pool := &redis.Pool{ - Dial: func() (redis.Conn, error) { return conn, nil }, - MaxIdle: 10, - } - store := &RedisStore{pool: pool} - - // 正常情况 - { - expectVal, _ := serializer("test val") - cmd := conn.Command("GET", "test").Expect(expectVal) - val, ok := store.Get("test") - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - asserts.True(ok) - asserts.Equal("test val", val.(string)) - } - - // Key不存在 - { - conn.Clear() - cmd := conn.Command("GET", "test").Expect(nil) - val, ok := store.Get("test") - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - asserts.False(ok) - asserts.Nil(val) - } - // 解码错误 - { - conn.Clear() - cmd := conn.Command("GET", "test").Expect([]byte{0x20}) - val, ok := store.Get("test") - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - asserts.False(ok) - asserts.Nil(val) - } - // 获取连接失败 - { - store.pool = &redis.Pool{ - Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, - MaxIdle: 10, - } - val, ok := store.Get("test") - asserts.False(ok) - asserts.Nil(val) - } -} - -func TestRedisStore_Gets(t *testing.T) { - asserts := assert.New(t) - conn := redigomock.NewConn() - pool := &redis.Pool{ - Dial: func() (redis.Conn, error) { return conn, nil }, - MaxIdle: 10, - } - store := &RedisStore{pool: pool} - - // 全部命中 - { - conn.Clear() - value1, _ := serializer("1") - value2, _ := serializer("2") - cmd := conn.Command("MGET", "test_1", "test_2").ExpectSlice( - value1, value2) - res, missed := store.Gets([]string{"1", "2"}, "test_") - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - asserts.Len(missed, 0) - asserts.Len(res, 2) - asserts.Equal("1", res["1"].(string)) - asserts.Equal("2", res["2"].(string)) - } - - // 命中一个 - { - conn.Clear() - value2, _ := serializer("2") - cmd := conn.Command("MGET", "test_1", "test_2").ExpectSlice( - nil, value2) - res, missed := store.Gets([]string{"1", "2"}, "test_") - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - asserts.Len(missed, 1) - asserts.Len(res, 1) - asserts.Equal("1", missed[0]) - asserts.Equal("2", res["2"].(string)) - } - - // 命令出错 - { - conn.Clear() - cmd := conn.Command("MGET", "test_1", "test_2").ExpectError(errors.New("error")) - res, missed := store.Gets([]string{"1", "2"}, "test_") - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - asserts.Len(missed, 2) - asserts.Len(res, 0) - } - - // 连接出错 - { - conn.Clear() - store.pool = &redis.Pool{ - Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, - MaxIdle: 10, - } - res, missed := store.Gets([]string{"1", "2"}, "test_") - asserts.Len(missed, 2) - asserts.Len(res, 0) - } -} - -func TestRedisStore_Sets(t *testing.T) { - asserts := assert.New(t) - conn := redigomock.NewConn() - pool := &redis.Pool{ - Dial: func() (redis.Conn, error) { return conn, nil }, - MaxIdle: 10, - } - store := &RedisStore{pool: pool} - - // 正常 - { - cmd := conn.Command("MSET", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectSlice("OK") - err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_") - asserts.NoError(err) - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - } - - // 序列化失败 - { - conn.Clear() - value := struct { - Key string - }{ - Key: "123", - } - err := store.Sets(map[string]interface{}{"1": value, "2": "2"}, "test_") - asserts.Error(err) - } - - // 执行失败 - { - cmd := conn.Command("MSET", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectError(errors.New("error")) - err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_") - asserts.Error(err) - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - } - - // 连接失败 - { - conn.Clear() - store.pool = &redis.Pool{ - Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, - MaxIdle: 10, - } - err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_") - asserts.Error(err) - } -} - -func TestRedisStore_Delete(t *testing.T) { - asserts := assert.New(t) - conn := redigomock.NewConn() - pool := &redis.Pool{ - Dial: func() (redis.Conn, error) { return conn, nil }, - MaxIdle: 10, - } - store := &RedisStore{pool: pool} - - // 正常 - { - cmd := conn.Command("DEL", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectSlice("OK") - err := store.Delete([]string{"1", "2", "3", "4"}, "test_") - asserts.NoError(err) - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - } - - // 命令执行失败 - { - conn.Clear() - cmd := conn.Command("DEL", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectError(errors.New("error")) - err := store.Delete([]string{"1", "2", "3", "4"}, "test_") - asserts.Error(err) - if conn.Stats(cmd) != 1 { - fmt.Println("Command was not used") - return - } - } - - // 连接失败 - { - conn.Clear() - store.pool = &redis.Pool{ - Dial: func() (redis.Conn, error) { return nil, errors.New("error") }, - MaxIdle: 10, - } - err := store.Delete([]string{"1", "2", "3", "4"}, "test_") - asserts.Error(err) - } -} diff --git a/pkg/cluster/controller.go b/pkg/cluster/controller.go index 85fb178..1e8417c 100644 --- a/pkg/cluster/controller.go +++ b/pkg/cluster/controller.go @@ -4,6 +4,9 @@ import ( "bytes" "encoding/gob" "fmt" + "net/url" + "sync" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" @@ -12,8 +15,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/jinzhu/gorm" - "net/url" - "sync" ) var DefaultController Controller diff --git a/pkg/cluster/controller_test.go b/pkg/cluster/controller_test.go deleted file mode 100644 index 42d8362..0000000 --- a/pkg/cluster/controller_test.go +++ /dev/null @@ -1,385 +0,0 @@ -package cluster - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io" - "io/ioutil" - "net/http" - "strings" - "testing" -) - -func TestInitController(t *testing.T) { - assert.NotPanics(t, func() { - InitController() - }) -} - -func TestSlaveController_HandleHeartBeat(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: make(map[string]MasterInfo), - } - - // first heart beat - { - _, err := c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "1", - Node: &model.Node{}, - }) - a.NoError(err) - - _, err = c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "2", - Node: &model.Node{}, - }) - a.NoError(err) - - a.Len(c.masters, 2) - } - - // second heart beat, no fresh - { - _, err := c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "1", - SiteURL: "http://127.0.0.1", - Node: &model.Node{}, - }) - a.NoError(err) - a.Len(c.masters, 2) - a.Empty(c.masters["1"].URL) - } - - // second heart beat, fresh - { - _, err := c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "1", - IsUpdate: true, - SiteURL: "http://127.0.0.1", - Node: &model.Node{}, - }) - a.NoError(err) - a.Len(c.masters, 2) - a.Equal("http://127.0.0.1", c.masters["1"].URL.String()) - } - - // second heart beat, fresh, url illegal - { - _, err := c.HandleHeartBeat(&serializer.NodePingReq{ - SiteID: "1", - IsUpdate: true, - SiteURL: string([]byte{0x7f}), - Node: &model.Node{}, - }) - a.Error(err) - a.Len(c.masters, 2) - a.Equal("http://127.0.0.1", c.masters["1"].URL.String()) - } -} - -type nodeMock struct { - testMock.Mock -} - -func (n nodeMock) Init(node *model.Node) { - n.Called(node) -} - -func (n nodeMock) IsFeatureEnabled(feature string) bool { - args := n.Called(feature) - return args.Bool(0) -} - -func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) { - n.Called(callback) -} - -func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) { - args := n.Called(req) - return args.Get(0).(*serializer.NodePingResp), args.Error(1) -} - -func (n nodeMock) IsActive() bool { - args := n.Called() - return args.Bool(0) -} - -func (n nodeMock) GetAria2Instance() common.Aria2 { - args := n.Called() - return args.Get(0).(common.Aria2) -} - -func (n nodeMock) ID() uint { - args := n.Called() - return args.Get(0).(uint) -} - -func (n nodeMock) Kill() { - n.Called() -} - -func (n nodeMock) IsMater() bool { - args := n.Called() - return args.Bool(0) -} - -func (n nodeMock) MasterAuthInstance() auth.Auth { - args := n.Called() - return args.Get(0).(auth.Auth) -} - -func (n nodeMock) SlaveAuthInstance() auth.Auth { - args := n.Called() - return args.Get(0).(auth.Auth) -} - -func (n nodeMock) DBModel() *model.Node { - args := n.Called() - return args.Get(0).(*model.Node) -} - -func TestSlaveController_GetAria2Instance(t *testing.T) { - a := assert.New(t) - mockNode := &nodeMock{} - mockNode.On("GetAria2Instance").Return(&common.DummyAria2{}) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Instance: mockNode}, - }, - } - - // node node found - { - res, err := c.GetAria2Instance("2") - a.Nil(res) - a.Equal(ErrMasterNotFound, err) - } - - // node found - { - res, err := c.GetAria2Instance("1") - a.NotNil(res) - a.NoError(err) - mockNode.AssertExpectations(t) - } - -} - -type requestMock struct { - testMock.Mock -} - -func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { - return r.Called(method, target, body, opts).Get(0).(*request.Response) -} - -func TestSlaveController_SendNotification(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {}, - }, - } - - // node not exit - { - a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{})) - } - - // gob encode error - { - type randomType struct{} - a.Error(c.SendNotification("1", "", mq.Message{ - Content: randomType{}, - })) - } - - // return none 200 - { - mockRequest := &requestMock{} - mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{StatusCode: http.StatusConflict}, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - a.Error(c.SendNotification("1", "s1", mq.Message{})) - mockRequest.AssertExpectations(t) - } - - // master return error - { - mockRequest := &requestMock{} - mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code) - mockRequest.AssertExpectations(t) - } - - // success - { - mockRequest := &requestMock{} - mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")), - }, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - a.NoError(c.SendNotification("1", "s3", mq.Message{})) - mockRequest.AssertExpectations(t) - } -} - -func TestSlaveController_SubmitTask(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": { - jobTracker: map[string]bool{}, - }, - }, - } - - // node not exit - { - a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil)) - } - - // success - { - submitted := false - a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) { - submitted = true - })) - a.True(submitted) - } - - // job already submitted - { - submitted := false - a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) { - submitted = true - })) - a.False(submitted) - } -} - -func TestSlaveController_GetMasterInfo(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {}, - }, - } - - // node not exit - { - res, err := c.GetMasterInfo("2") - a.Equal(ErrMasterNotFound, err) - a.Nil(res) - } - - // success - { - res, err := c.GetMasterInfo("1") - a.NoError(err) - a.NotNil(res) - } -} - -func TestSlaveController_GetOneDriveToken(t *testing.T) { - a := assert.New(t) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {}, - }, - } - - // node not exit - { - res, err := c.GetPolicyOauthToken("2", 1) - a.Equal(ErrMasterNotFound, err) - a.Empty(res) - } - - // return none 200 - { - mockRequest := &requestMock{} - mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{StatusCode: http.StatusConflict}, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - res, err := c.GetPolicyOauthToken("1", 1) - a.Error(err) - a.Empty(res) - mockRequest.AssertExpectations(t) - } - - // master return error - { - mockRequest := &requestMock{} - mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - res, err := c.GetPolicyOauthToken("1", 1) - a.Equal(1, err.(serializer.AppError).Code) - a.Empty(res) - mockRequest.AssertExpectations(t) - } - - // success - { - mockRequest := &requestMock{} - mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")), - }, - }) - c := &slaveController{ - masters: map[string]MasterInfo{ - "1": {Client: mockRequest}, - }, - } - res, err := c.GetPolicyOauthToken("1", 1) - a.NoError(err) - a.Equal("expected", res) - mockRequest.AssertExpectations(t) - } - -} diff --git a/pkg/cluster/master_test.go b/pkg/cluster/master_test.go deleted file mode 100644 index 7ff07ac..0000000 --- a/pkg/cluster/master_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package cluster - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/stretchr/testify/assert" - "os" - "testing" - "time" -) - -func TestMasterNode_Init(t *testing.T) { - a := assert.New(t) - m := &MasterNode{} - m.Init(&model.Node{Status: model.NodeSuspend}) - a.Equal(model.NodeSuspend, m.DBModel().Status) - m.Init(&model.Node{Aria2Enabled: true}) -} - -func TestMasterNode_DummyMethods(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{}, - } - - m.Model.ID = 5 - a.Equal(m.Model.ID, m.ID()) - - res, err := m.Ping(&serializer.NodePingReq{}) - a.NoError(err) - a.NotNil(res) - - a.True(m.IsActive()) - a.True(m.IsMater()) - - m.SubscribeStatusChange(func(isActive bool, id uint) {}) -} - -func TestMasterNode_IsFeatureEnabled(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{}, - } - - a.False(m.IsFeatureEnabled("aria2")) - a.False(m.IsFeatureEnabled("random")) - m.Model.Aria2Enabled = true - a.True(m.IsFeatureEnabled("aria2")) -} - -func TestMasterNode_AuthInstance(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{}, - } - - a.NotNil(m.MasterAuthInstance()) - a.NotNil(m.SlaveAuthInstance()) -} - -func TestMasterNode_Kill(t *testing.T) { - m := &MasterNode{ - Model: &model.Node{}, - } - - m.Kill() - - caller, _ := rpc.New(context.Background(), "http://", "", 0, nil) - m.aria2RPC.Caller = caller - m.Kill() -} - -func TestMasterNode_GetAria2Instance(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{}, - aria2RPC: rpcService{}, - } - - m.aria2RPC.parent = m - - a.NotNil(m.GetAria2Instance()) - m.Model.Aria2Enabled = true - a.NotNil(m.GetAria2Instance()) - m.aria2RPC.Initialized = true - a.NotNil(m.GetAria2Instance()) -} - -func TestRpcService_Init(t *testing.T) { - a := assert.New(t) - m := &MasterNode{ - Model: &model.Node{ - Aria2OptionsSerialized: model.Aria2Option{ - Options: "{", - }, - }, - aria2RPC: rpcService{}, - } - m.aria2RPC.parent = m - - // failed to decode address - { - m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f}) - a.Error(m.aria2RPC.Init()) - } - - // failed to decode options - { - m.Model.Aria2OptionsSerialized.Server = "" - a.Error(m.aria2RPC.Init()) - } - - // failed to initialized - { - m.Model.Aria2OptionsSerialized.Server = "" - m.Model.Aria2OptionsSerialized.Options = "{}" - caller, _ := rpc.New(context.Background(), "http://", "", 0, nil) - m.aria2RPC.Caller = caller - a.Error(m.aria2RPC.Init()) - a.False(m.aria2RPC.Initialized) - } -} - -func getTestRPCNode() *MasterNode { - m := &MasterNode{ - Model: &model.Node{ - Aria2OptionsSerialized: model.Aria2Option{}, - }, - aria2RPC: rpcService{ - options: &clientOptions{ - Options: map[string]interface{}{"1": "1"}, - }, - }, - } - m.aria2RPC.parent = m - caller, _ := rpc.New(context.Background(), "http://", "", 0, nil) - m.aria2RPC.Caller = caller - return m -} - -func TestRpcService_CreateTask(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - - res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"}) - a.Error(err) - a.Empty(res) -} - -func TestRpcService_Status(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - - res, err := m.aria2RPC.Status(&model.Download{}) - a.Error(err) - a.Empty(res) -} - -func TestRpcService_Cancel(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - - a.Error(m.aria2RPC.Cancel(&model.Download{})) -} - -func TestRpcService_Select(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - - a.NotNil(m.aria2RPC.GetConfig()) - a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3})) -} - -func TestRpcService_DeleteTempFile(t *testing.T) { - a := assert.New(t) - m := getTestRPCNode() - fdName := "TestRpcService_DeleteTempFile" - a.NoError(os.Mkdir(fdName, 0644)) - - a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName})) - time.Sleep(500 * time.Millisecond) - a.False(util.Exists(fdName)) -} diff --git a/pkg/cluster/node_test.go b/pkg/cluster/node_test.go deleted file mode 100644 index d817425..0000000 --- a/pkg/cluster/node_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package cluster - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestNewNodeFromDBModel(t *testing.T) { - a := assert.New(t) - a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{ - Type: model.SlaveNodeType, - })) - a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{ - Type: model.MasterNodeType, - })) -} diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go index d6704b6..a70186f 100644 --- a/pkg/cluster/pool.go +++ b/pkg/cluster/pool.go @@ -4,6 +4,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/balancer" "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/samber/lo" "sync" ) @@ -15,7 +16,7 @@ var featureGroup = []string{"aria2"} // Pool 节点池 type Pool interface { // Returns active node selected by given feature and load balancer - BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) + BalanceNodeByFeature(feature string, lb balancer.Balancer, available []uint, prefer uint) (error, Node) // Returns node by ID GetNodeByID(id uint) Node @@ -174,11 +175,33 @@ func (pool *NodePool) Delete(id uint) { } // BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点 -func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) { +func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer, + available []uint, prefer uint) (error, Node) { pool.lock.RLock() defer pool.lock.RUnlock() if nodes, ok := pool.featureMap[feature]; ok { - err, res := lb.NextPeer(nodes) + // Find nodes that are allowed to be used in user group + availableNodes := nodes + if len(available) > 0 { + idHash := make(map[uint]struct{}, len(available)) + for _, id := range available { + idHash[id] = struct{}{} + } + + availableNodes = lo.Filter[Node](nodes, func(node Node, index int) bool { + _, exist := idHash[node.ID()] + return exist + }) + } + + // Return preferred node if exists + if preferredNode, found := lo.Find[Node](availableNodes, func(node Node) bool { + return node.ID() == prefer + }); found { + return nil, preferredNode + } + + err, res := lb.NextPeer(availableNodes) if err == nil { return nil, res.(Node) } diff --git a/pkg/cluster/pool_test.go b/pkg/cluster/pool_test.go deleted file mode 100644 index dde3455..0000000 --- a/pkg/cluster/pool_test.go +++ /dev/null @@ -1,161 +0,0 @@ -package cluster - -import ( - "database/sql" - "errors" - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/balancer" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "testing" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestInitFailed(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - Init() - a.NoError(mock.ExpectationsWereMet()) -} - -func TestInitSuccess(t *testing.T) { - a := assert.New(t) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "aria2_enabled", "type"}).AddRow(1, true, model.MasterNodeType)) - Init() - a.NoError(mock.ExpectationsWereMet()) -} - -func TestNodePool_GetNodeByID(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - p.Init() - mockNode := &nodeMock{} - - // inactive - { - p.inactive[1] = mockNode - a.Equal(mockNode, p.GetNodeByID(1)) - } - - // active - { - delete(p.inactive, 1) - p.active[1] = mockNode - a.Equal(mockNode, p.GetNodeByID(1)) - } -} - -func TestNodePool_NodeStatusChange(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - n := &MasterNode{Model: &model.Node{}} - p.Init() - p.inactive[1] = n - - p.nodeStatusChange(true, 1) - a.Len(p.inactive, 0) - a.Equal(n, p.active[1]) - - p.nodeStatusChange(false, 1) - a.Len(p.active, 0) - a.Equal(n, p.inactive[1]) - - p.nodeStatusChange(false, 1) - a.Len(p.active, 0) - a.Equal(n, p.inactive[1]) -} - -func TestNodePool_Add(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - p.Init() - - // new node - { - p.Add(&model.Node{}) - a.Len(p.active, 1) - } - - // old node - { - p.inactive[0] = p.active[0] - delete(p.active, 0) - p.Add(&model.Node{}) - a.Len(p.active, 0) - a.Len(p.inactive, 1) - } -} - -func TestNodePool_Delete(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - p.Init() - - // active - { - mockNode := &nodeMock{} - mockNode.On("Kill") - p.active[0] = mockNode - p.Delete(0) - a.Len(p.active, 0) - a.Len(p.inactive, 0) - mockNode.AssertExpectations(t) - } - - p.Init() - - // inactive - { - mockNode := &nodeMock{} - mockNode.On("Kill") - p.inactive[0] = mockNode - p.Delete(0) - a.Len(p.active, 0) - a.Len(p.inactive, 0) - mockNode.AssertExpectations(t) - } -} - -func TestNodePool_BalanceNodeByFeature(t *testing.T) { - a := assert.New(t) - p := &NodePool{} - p.Init() - - // success - { - p.featureMap["test"] = []Node{&MasterNode{}} - err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin")) - a.NoError(err) - a.Equal(p.featureMap["test"][0], res) - } - - // NoNodes - { - p.featureMap["test"] = []Node{} - err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin")) - a.Error(err) - a.Nil(res) - } - - // No match feature - { - err, res := p.BalanceNodeByFeature("test2", balancer.NewBalancer("round-robin")) - a.Error(err) - a.Nil(res) - } -} diff --git a/pkg/cluster/slave_test.go b/pkg/cluster/slave_test.go deleted file mode 100644 index 1b1510f..0000000 --- a/pkg/cluster/slave_test.go +++ /dev/null @@ -1,559 +0,0 @@ -package cluster - -import ( - "bytes" - "encoding/json" - "errors" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" -) - -func TestSlaveNode_InitAndKill(t *testing.T) { - a := assert.New(t) - n := &SlaveNode{ - callback: func(b bool, u uint) { - - }, - } - - a.NotPanics(func() { - n.Init(&model.Node{}) - time.Sleep(time.Millisecond * 500) - n.Init(&model.Node{}) - n.Kill() - }) -} - -func TestSlaveNode_DummyMethods(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - m.Model.ID = 5 - a.Equal(m.Model.ID, m.ID()) - a.Equal(m.Model.ID, m.DBModel().ID) - - a.False(m.IsActive()) - a.False(m.IsMater()) - - m.SubscribeStatusChange(func(isActive bool, id uint) {}) -} - -func TestSlaveNode_IsFeatureEnabled(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - a.False(m.IsFeatureEnabled("aria2")) - a.False(m.IsFeatureEnabled("random")) - m.Model.Aria2Enabled = true - a.True(m.IsFeatureEnabled("aria2")) -} - -func TestSlaveNode_Ping(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - // master return error code - { - mockRequest := &requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.Ping(&serializer.NodePingReq{}) - a.Error(err) - a.Nil(res) - a.Equal(1, err.(serializer.AppError).Code) - } - - // return unexpected json - { - mockRequest := &requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"233\"}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.Ping(&serializer.NodePingReq{}) - a.Error(err) - a.Nil(res) - } - - // return success - { - mockRequest := &requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.Ping(&serializer.NodePingReq{}) - a.NoError(err) - a.NotNil(res) - } -} - -func TestSlaveNode_GetAria2Instance(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - a.NotNil(m.GetAria2Instance()) - m.Model.Aria2Enabled = true - a.NotNil(m.GetAria2Instance()) - a.NotNil(m.GetAria2Instance()) -} - -func TestSlaveNode_StartPingLoop(t *testing.T) { - callbackCount := 0 - finishedChan := make(chan struct{}) - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m := &SlaveNode{ - Active: true, - Model: &model.Node{}, - callback: func(b bool, u uint) { - callbackCount++ - if callbackCount == 2 { - close(finishedChan) - } - if callbackCount == 1 { - mockRequest.AssertExpectations(t) - mockRequest = requestMock{} - mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")), - }, - }) - } - }, - } - cache.Set("setting_slave_ping_interval", "0", 0) - cache.Set("setting_slave_recover_interval", "0", 0) - cache.Set("setting_slave_node_retry", "1", 0) - - m.caller.Client = &mockRequest - go func() { - select { - case <-finishedChan: - m.Kill() - } - }() - m.StartPingLoop() - mockRequest.AssertExpectations(t) -} - -func TestSlaveNode_AuthInstance(t *testing.T) { - a := assert.New(t) - m := &SlaveNode{ - Model: &model.Node{}, - } - - a.NotNil(m.MasterAuthInstance()) - a.NotNil(m.SlaveAuthInstance()) -} - -func TestSlaveNode_ChangeStatus(t *testing.T) { - a := assert.New(t) - isActive := false - m := &SlaveNode{ - Model: &model.Node{}, - callback: func(b bool, u uint) { - isActive = b - }, - } - - a.NotPanics(func() { - m.changeStatus(false) - }) - m.changeStatus(true) - a.True(isActive) -} - -func getTestRPCNodeSlave() *SlaveNode { - m := &SlaveNode{ - Model: &model.Node{}, - } - m.caller.parent = m - return m -} - -func TestSlaveCaller_CreateTask(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.CreateTask(&model.Download{}, nil) - a.Empty(res) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.CreateTask(&model.Download{}, nil) - a.Empty(res) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.CreateTask(&model.Download{}, nil) - a.Equal("res", res) - a.NoError(err) - } -} - -func TestSlaveCaller_Status(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.Status(&model.Download{}) - a.Empty(res.Status) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.Status(&model.Download{}) - a.Empty(res.Status) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"re456456s\"}")), - }, - }) - m.caller.Client = mockRequest - res, err := m.caller.Status(&model.Download{}) - a.Empty(res.Status) - a.NoError(err) - } -} - -func TestSlaveCaller_Cancel(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - err := m.caller.Cancel(&model.Download{}) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.Cancel(&model.Download{}) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.Cancel(&model.Download{}) - a.NoError(err) - } -} - -func TestSlaveCaller_Select(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - m.caller.Init() - m.caller.GetConfig() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - err := m.caller.Select(&model.Download{}, nil) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.Select(&model.Download{}, nil) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.Select(&model.Download{}, nil) - a.NoError(err) - } -} - -func TestSlaveCaller_DeleteTempFile(t *testing.T) { - a := assert.New(t) - m := getTestRPCNodeSlave() - m.caller.Init() - m.caller.GetConfig() - - // master return 404 - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 404, - }, - }) - m.caller.Client = mockRequest - err := m.caller.DeleteTempFile(&model.Download{}) - a.Error(err) - } - - // master return error - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.DeleteTempFile(&model.Download{}) - a.Error(err) - } - - // master return success - { - mockRequest := requestMock{} - mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")), - }, - }) - m.caller.Client = mockRequest - err := m.caller.DeleteTempFile(&model.Download{}) - a.NoError(err) - } -} - -func TestRemoteCallback(t *testing.T) { - asserts := assert.New(t) - - // 回调成功 - { - clientMock := requestmock.RequestMock{} - mockResp, _ := json.Marshal(serializer.Response{Code: 0}) - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewReader(mockResp)), - }, - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.NoError(resp) - clientMock.AssertExpectations(t) - } - - // 服务端返回业务错误 - { - clientMock := requestmock.RequestMock{} - mockResp, _ := json.Marshal(serializer.Response{Code: 401}) - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewReader(mockResp)), - }, - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.EqualValues(401, resp.(serializer.AppError).Code) - clientMock.AssertExpectations(t) - } - - // 无法解析回调响应 - { - clientMock := requestmock.RequestMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("mockResp")), - }, - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } - - // HTTP状态码非200 - { - clientMock := requestmock.RequestMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader("mockResp")), - }, - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } - - // 无法发起回调 - { - clientMock := requestmock.RequestMock{} - clientMock.On( - "Request", - "POST", - "http://test/test/url", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - request.GeneralClient = clientMock - resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{}) - asserts.Error(resp) - clientMock.AssertExpectations(t) - } -} diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index b0a4ea4..942294b 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -53,7 +53,7 @@ type slave struct { type redis struct { Network string Server string - User string + User string Password string DB string } diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go deleted file mode 100644 index 6d186ed..0000000 --- a/pkg/conf/conf_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package conf - -import ( - "io/ioutil" - "os" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/stretchr/testify/assert" -) - -// 测试Init日志路径错误 -func TestInitPanic(t *testing.T) { - asserts := assert.New(t) - - // 日志路径不存在时 - asserts.NotPanics(func() { - Init("not/exist/path/conf.ini") - }) - - asserts.True(util.Exists("not/exist/path/conf.ini")) - -} - -// TestInitDelimiterNotFound 日志路径存在但 Key 格式错误时 -func TestInitDelimiterNotFound(t *testing.T) { - asserts := assert.New(t) - testCase := `[Database] -Type = mysql -User = root -Password233root -Host = 127.0.0.1:3306 -Name = v3 -TablePrefix = v3_` - err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) - defer func() { err = os.Remove("testConf.ini") }() - if err != nil { - panic(err) - } - asserts.Panics(func() { - Init("testConf.ini") - }) -} - -// TestInitNoPanic 日志路径存在且合法时 -func TestInitNoPanic(t *testing.T) { - asserts := assert.New(t) - testCase := ` -[System] -Listen = 3000 -HashIDSalt = 1 - -[Database] -Type = mysql -User = root -Password = root -Host = 127.0.0.1:3306 -Name = v3 -TablePrefix = v3_ - -[OptionOverwrite] -key=value -` - err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) - defer func() { err = os.Remove("testConf.ini") }() - if err != nil { - panic(err) - } - asserts.NotPanics(func() { - Init("testConf.ini") - }) - asserts.Equal(OptionOverwrite["key"], "value") -} - -func TestMapSection(t *testing.T) { - asserts := assert.New(t) - - //正常情况 - testCase := ` -[System] -Listen = 3000 -HashIDSalt = 1 - -[Database] -Type = mysql -User = root -Password:root -Host = 127.0.0.1:3306 -Name = v3 -TablePrefix = v3_` - err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) - defer func() { err = os.Remove("testConf.ini") }() - if err != nil { - panic(err) - } - Init("testConf.ini") - err = mapSection("Database", DatabaseConfig) - asserts.NoError(err) - -} diff --git a/pkg/conf/version.go b/pkg/conf/version.go index 6720e8c..fa4bbf3 100644 --- a/pkg/conf/version.go +++ b/pkg/conf/version.go @@ -1,16 +1,22 @@ package conf +// plusVersion 增强版版本号 +const plusVersion = "+1.1" + // BackendVersion 当前后端版本号 -var BackendVersion = "3.8.3" +const BackendVersion = "3.8.3" + plusVersion + +// KeyVersion 授权版本号 +const KeyVersion = "3.3.1" // RequiredDBVersion 与当前版本匹配的数据库版本 -var RequiredDBVersion = "3.8.1" +const RequiredDBVersion = "3.8.1+1.0-plus" // RequiredStaticVersion 与当前版本匹配的静态资源版本 -var RequiredStaticVersion = "3.8.3" +const RequiredStaticVersion = "3.8.3" + plusVersion -// IsPro 是否为Pro版本 -var IsPro = "false" +// IsPlus 是否为Plus版本 +const IsPlus = "true" // LastCommit 最后commit id -var LastCommit = "a11f819" +const LastCommit = "88409cc" diff --git a/pkg/crontab/init.go b/pkg/crontab/init.go index 5971c2c..3583d31 100644 --- a/pkg/crontab/init.go +++ b/pkg/crontab/init.go @@ -23,6 +23,8 @@ func Init() { // 读取cron日程设置 options := model.GetSettingByNames( "cron_garbage_collect", + "cron_notify_user", + "cron_ban_user", "cron_recycle_upload_session", ) Cron := cron.New() @@ -31,6 +33,10 @@ func Init() { switch k { case "cron_garbage_collect": handler = garbageCollect + case "cron_notify_user": + handler = notifyExpiredVAS + case "cron_ban_user": + handler = banOverusedUser case "cron_recycle_upload_session": handler = uploadSessionCollect default: diff --git a/pkg/crontab/vas.go b/pkg/crontab/vas.go new file mode 100755 index 0000000..7ce6ae9 --- /dev/null +++ b/pkg/crontab/vas.go @@ -0,0 +1,83 @@ +package crontab + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/email" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +func notifyExpiredVAS() { + checkStoragePack() + checkUserGroup() + util.Log().Info("Crontab job \"cron_notify_user\" complete.") +} + +// banOverusedUser 封禁超出宽容期的用户 +func banOverusedUser() { + users := model.GetTolerantExpiredUser() + for _, user := range users { + + // 清除最后通知日期标记 + user.ClearNotified() + + // 检查容量是否超额 + if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() { + // 封禁用户 + user.SetStatus(model.OveruseBaned) + } + } +} + +// checkUserGroup 检查已过期用户组 +func checkUserGroup() { + users := model.GetGroupExpiredUsers() + for _, user := range users { + + // 将用户回退到初始用户组 + user.GroupFallback() + + // 重新加载用户 + user, _ = model.GetUserByID(user.ID) + + // 检查容量是否超额 + if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() { + // 如果超额,则通知用户 + sendNotification(&user, "用户组过期") + // 更新最后通知日期 + user.Notified() + } + } +} + +// checkStoragePack 检查已过期的容量包 +func checkStoragePack() { + packs := model.GetExpiredStoragePack() + for _, pack := range packs { + // 删除过期的容量包 + pack.Delete() + + //找到所属用户 + user, err := model.GetUserByID(pack.UserID) + if err != nil { + util.Log().Warning("Crontab job failed to get user info of [UID=%d]: %s", pack.UserID, err) + continue + } + + // 检查容量是否超额 + if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() { + // 如果超额,则通知用户 + sendNotification(&user, "容量包过期") + + // 更新最后通知日期 + user.Notified() + + } + } +} + +func sendNotification(user *model.User, reason string) { + title, body := email.NewOveruseNotification(user.Nick, reason) + if err := email.Send(user.Email, title, body); err != nil { + util.Log().Warning("Failed to send notification email: %s", err) + } +} diff --git a/pkg/email/smtp.go b/pkg/email/smtp.go index c92cce7..3845f44 100644 --- a/pkg/email/smtp.go +++ b/pkg/email/smtp.go @@ -5,6 +5,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/go-mail/mail" + "github.com/google/uuid" ) // SMTP SMTP协议发送邮件 @@ -50,6 +51,7 @@ func (client *SMTP) Send(to, title, body string) error { m.SetAddressHeader("Reply-To", client.Config.ReplyTo, client.Config.Name) m.SetHeader("To", to) m.SetHeader("Subject", title) + m.SetHeader("Message-ID", util.StrConcat(`"<`, uuid.NewString(), `@`, `cloudreveplus`, `>"`)) m.SetBody("text/html", body) client.ch <- m return nil diff --git a/pkg/email/template.go b/pkg/email/template.go index cb9cb3a..213e5e3 100644 --- a/pkg/email/template.go +++ b/pkg/email/template.go @@ -7,6 +7,20 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/util" ) +// NewOveruseNotification 新建超额提醒邮件 +func NewOveruseNotification(userName, reason string) (string, string) { + options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "over_used_template") + replace := map[string]string{ + "{siteTitle}": options["siteName"], + "{userName}": userName, + "{notifyReason}": reason, + "{siteUrl}": options["siteURL"], + "{siteSecTitle}": options["siteTitle"], + } + return fmt.Sprintf("【%s】空间容量超额提醒", options["siteName"]), + util.Replace(replace, options["over_used_template"]) +} + // NewActivationEmail 新建激活邮件 func NewActivationEmail(userName, activateURL string) (string, string) { options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_activation_template") diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index 78fc45f..cd3aa83 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -239,13 +239,6 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string) reader = zipFile } - // 重设存储策略 - fs.Policy = &fs.User.Policy - err = fs.DispatchHandler() - if err != nil { - return err - } - var wg sync.WaitGroup parallel := model.GetIntSetting("max_parallel_transfer", 4) worker := make(chan int, parallel) diff --git a/pkg/filesystem/archive_test.go b/pkg/filesystem/archive_test.go deleted file mode 100644 index 07f5087..0000000 --- a/pkg/filesystem/archive_test.go +++ /dev/null @@ -1,256 +0,0 @@ -package filesystem - -import ( - "bytes" - "context" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - testMock "github.com/stretchr/testify/mock" - "io" - "os" - "path/filepath" - "runtime" - "strings" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_Compress(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - - // 成功 - { - // 查找压缩父目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent")) - // 查找顶级待压缩文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "name", "source_name", "policy_id"}). - AddRow(1, "1.txt", "tests/file1.txt", 1), - ) - asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) - // 查找父目录子文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"})) - // 查找子目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(2, "sub")) - // 查找子目录子文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(2). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"}). - AddRow(2, "2.txt", "tests/file2.txt", 1), - ) - // 查找上传策略 - asserts.NoError(cache.Set("policy_1", model.Policy{Type: "local"}, -1)) - w := &bytes.Buffer{} - - err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true) - asserts.NoError(err) - asserts.NotEmpty(w.Len()) - } - - // 上下文取消 - { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - // 查找压缩父目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent")) - // 查找顶级待压缩文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "name", "source_name", "policy_id"}). - AddRow(1, "1.txt", "tests/file1.txt", 1), - ) - asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) - - w := &bytes.Buffer{} - err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true) - asserts.Error(err) - asserts.NotEmpty(w.Len()) - } - - // 限制父目录 - { - ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, &model.Folder{ - Model: gorm.Model{ID: 3}, - }) - // 查找压缩父目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(1, "parent", 3)) - // 查找顶级待压缩文件 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 1). - WillReturnRows( - sqlmock.NewRows( - []string{"id", "name", "source_name", "policy_id"}). - AddRow(1, "1.txt", "tests/file1.txt", 1), - ) - asserts.NoError(cache.Set("setting_temp_path", "tests", -1)) - - w := &bytes.Buffer{} - err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true) - asserts.Error(err) - asserts.Equal(ErrObjectNotExist, err) - asserts.Empty(w.Len()) - } - -} - -type MockNopRSC string - -func (m MockNopRSC) Read(b []byte) (int, error) { - return 0, errors.New("read error") -} - -func (m MockNopRSC) Seek(n int64, offset int) (int64, error) { - return 0, errors.New("read error") -} - -func (m MockNopRSC) Close() error { - return errors.New("read error") -} - -type MockRSC struct { - rs io.ReadSeeker -} - -func (m MockRSC) Read(b []byte) (int, error) { - return m.rs.Read(b) -} - -func (m MockRSC) Seek(n int64, offset int) (int64, error) { - return m.rs.Seek(n, offset) -} - -func (m MockRSC) Close() error { - return nil -} - -var basepath string - -func init() { - _, currentFile, _, _ := runtime.Caller(0) - basepath = filepath.Dir(currentFile) -} - -func Path(rel string) string { - return filepath.Join(basepath, rel) -} - -func TestFileSystem_Decompress(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - os.RemoveAll(util.RelativePath("tests/decompress")) - - // 压缩文件不存在 - { - // 查找根目录 - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "/")) - // 查找压缩文件,未找到 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - - // 无法下载压缩文件 - { - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, errors.New("error")) - fs.Handler = testHandler - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.EqualError(err, "error") - } - - // 无法创建临时压缩文件 - { - cache.Set("setting_temp_path", "/tests:", 0) - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, nil) - fs.Handler = testHandler - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - - // 无法写入压缩文件 - { - cache.Set("setting_temp_path", "tests", 0) - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(MockNopRSC("1"), nil) - fs.Handler = testHandler - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Contains(err.Error(), "read error") - } - - // 无法重设上传策略 - { - cache.Set("setting_temp_path", "tests", 0) - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{rs: strings.NewReader("read")}, nil) - fs.Handler = testHandler - err := fs.Decompress(ctx, "/1.zip", "/", "") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.True(util.IsEmpty(util.RelativePath("tests/decompress"))) - } - - // 无法上传,容量不足 - { - cache.Set("setting_max_parallel_transfer", "1", 0) - zipFile, _ := os.Open(Path("tests/test.zip")) - fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}} - fs.FileTarget[0].Policy.ID = 1 - fs.User.Policy.Type = "mock" - testHandler := new(FileHeaderMock) - testHandler.On("Get", testMock.Anything, "1.zip").Return(zipFile, nil) - fs.Handler = testHandler - - fs.Decompress(ctx, "/1.zip", "/", "") - - zipFile.Close() - - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - } -} diff --git a/pkg/filesystem/chunk/backoff/backoff_test.go b/pkg/filesystem/chunk/backoff/backoff_test.go deleted file mode 100644 index 0fda534..0000000 --- a/pkg/filesystem/chunk/backoff/backoff_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package backoff - -import ( - "errors" - "github.com/stretchr/testify/assert" - "net/http" - "testing" - "time" -) - -func TestConstantBackoff_Next(t *testing.T) { - a := assert.New(t) - - // General error - { - err := errors.New("error") - b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} - a.True(b.Next(err)) - a.True(b.Next(err)) - a.True(b.Next(err)) - a.False(b.Next(err)) - b.Reset() - a.True(b.Next(err)) - a.True(b.Next(err)) - a.True(b.Next(err)) - a.False(b.Next(err)) - } - - // Retryable error - { - err := &RetryableError{RetryAfter: time.Duration(1)} - b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} - a.True(b.Next(err)) - a.True(b.Next(err)) - a.True(b.Next(err)) - a.False(b.Next(err)) - b.Reset() - a.True(b.Next(err)) - a.True(b.Next(err)) - a.True(b.Next(err)) - a.False(b.Next(err)) - } - -} - -func TestNewRetryableErrorFromHeader(t *testing.T) { - a := assert.New(t) - // no retry-after header - { - err := NewRetryableErrorFromHeader(nil, http.Header{}) - a.Empty(err.RetryAfter) - } - - // with retry-after header - { - header := http.Header{} - header.Add("retry-after", "120") - err := NewRetryableErrorFromHeader(nil, header) - a.EqualValues(time.Duration(120)*time.Second, err.RetryAfter) - } -} diff --git a/pkg/filesystem/chunk/chunk_test.go b/pkg/filesystem/chunk/chunk_test.go deleted file mode 100644 index 4bdcd06..0000000 --- a/pkg/filesystem/chunk/chunk_test.go +++ /dev/null @@ -1,250 +0,0 @@ -package chunk - -import ( - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/stretchr/testify/assert" - "io" - "os" - "strings" - "testing" -) - -func TestNewChunkGroup(t *testing.T) { - a := assert.New(t) - - testCases := []struct { - fileSize uint64 - chunkSize uint64 - expectedInnerChunkSize uint64 - expectedChunkNum uint64 - expectedInfo [][2]int //Start, Index,Length - }{ - {10, 0, 10, 1, [][2]int{{0, 10}}}, - {0, 0, 0, 1, [][2]int{{0, 0}}}, - {0, 10, 10, 1, [][2]int{{0, 0}}}, - {50, 10, 10, 5, [][2]int{ - {0, 10}, - {10, 10}, - {20, 10}, - {30, 10}, - {40, 10}, - }}, - {50, 50, 50, 1, [][2]int{ - {0, 50}, - }}, - - {50, 15, 15, 4, [][2]int{ - {0, 15}, - {15, 15}, - {30, 15}, - {45, 5}, - }}, - } - - for index, testCase := range testCases { - file := &fsctx.FileStream{Size: testCase.fileSize} - chunkGroup := NewChunkGroup(file, testCase.chunkSize, &backoff.ConstantBackoff{}, true) - a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(), - "TestCase:%d,ChunkNum()", index) - a.EqualValues(testCase.expectedInnerChunkSize, chunkGroup.chunkSize, - "TestCase:%d,InnerChunkSize()", index) - a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(), - "TestCase:%d,len(Chunks)", index) - a.EqualValues(testCase.fileSize, chunkGroup.Total()) - - for cIndex, info := range testCase.expectedInfo { - a.True(chunkGroup.Next()) - a.EqualValues(info[1], chunkGroup.Length(), - "TestCase:%d,Chunks[%d].Length()", index, cIndex) - a.EqualValues(info[0], chunkGroup.Start(), - "TestCase:%d,Chunks[%d].Start()", index, cIndex) - - a.Equal(cIndex == len(testCase.expectedInfo)-1, chunkGroup.IsLast(), - "TestCase:%d,Chunks[%d].IsLast()", index, cIndex) - - a.NotEmpty(chunkGroup.RangeHeader()) - } - a.False(chunkGroup.Next()) - } -} - -func TestChunkGroup_TempAvailablet(t *testing.T) { - a := assert.New(t) - - file := &fsctx.FileStream{Size: 1} - c := NewChunkGroup(file, 0, &backoff.ConstantBackoff{}, true) - a.False(c.TempAvailable()) - - f, err := os.CreateTemp("", "TestChunkGroup_TempAvailablet.*") - defer func() { - f.Close() - os.Remove(f.Name()) - }() - a.NoError(err) - c.bufferTemp = f - - a.False(c.TempAvailable()) - f.Write([]byte("1")) - a.True(c.TempAvailable()) - -} - -func TestChunkGroup_Process(t *testing.T) { - a := assert.New(t) - file := &fsctx.FileStream{Size: 10} - - // success - { - file.File = io.NopCloser(strings.NewReader("1234567890")) - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{}, true) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("67890", string(res)) - return nil - })) - a.False(c.Next()) - a.Equal(2, count) - } - - // retry, read from buffer file - { - file.File = io.NopCloser(strings.NewReader("1234567890")) - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, true) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("67890", string(res)) - if count == 2 { - return errors.New("error") - } - return nil - })) - a.False(c.Next()) - a.Equal(3, count) - } - - // retry, read from seeker - { - f, _ := os.CreateTemp("", "TestChunkGroup_Process.*") - f.Write([]byte("1234567890")) - f.Seek(0, 0) - defer func() { - f.Close() - os.Remove(f.Name()) - }() - file.File = f - file.Seeker = f - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("67890", string(res)) - if count == 2 { - return errors.New("error") - } - return nil - })) - a.False(c.Next()) - a.Equal(3, count) - } - - // retry, seek error - { - f, _ := os.CreateTemp("", "TestChunkGroup_Process.*") - f.Write([]byte("1234567890")) - f.Seek(0, 0) - defer func() { - f.Close() - os.Remove(f.Name()) - }() - file.File = f - file.Seeker = f - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - f.Close() - a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - if count == 2 { - return errors.New("error") - } - return nil - })) - a.False(c.Next()) - a.Equal(2, count) - } - - // retry, finally error - { - f, _ := os.CreateTemp("", "TestChunkGroup_Process.*") - f.Write([]byte("1234567890")) - f.Seek(0, 0) - defer func() { - f.Close() - os.Remove(f.Name()) - }() - file.File = f - file.Seeker = f - c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false) - count := 0 - a.True(c.Next()) - a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - res, err := io.ReadAll(chunk) - a.NoError(err) - a.EqualValues("12345", string(res)) - return nil - })) - a.True(c.Next()) - a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error { - count++ - return errors.New("error") - })) - a.False(c.Next()) - a.Equal(4, count) - } -} diff --git a/pkg/filesystem/driver/handler.go b/pkg/filesystem/driver/handler.go index f232781..f145281 100644 --- a/pkg/filesystem/driver/handler.go +++ b/pkg/filesystem/driver/handler.go @@ -3,6 +3,7 @@ package driver import ( "context" "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" diff --git a/pkg/filesystem/driver/local/handler_test.go b/pkg/filesystem/driver/local/handler_test.go deleted file mode 100644 index b73b564..0000000 --- a/pkg/filesystem/driver/local/handler_test.go +++ /dev/null @@ -1,338 +0,0 @@ -package local - -import ( - "context" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - "io" - "os" - "strings" - "testing" -) - -func TestHandler_Put(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - - defer func() { - os.Remove(util.RelativePath("TestHandler_Put.txt")) - os.Remove(util.RelativePath("inner/TestHandler_Put.txt")) - }() - - testCases := []struct { - file fsctx.FileHeader - errContains string - }{ - {&fsctx.FileStream{ - SavePath: "TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("")), - }, ""}, - {&fsctx.FileStream{ - SavePath: "TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("")), - }, "file with the same name existed or unavailable"}, - {&fsctx.FileStream{ - SavePath: "inner/TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("")), - }, ""}, - {&fsctx.FileStream{ - Mode: fsctx.Append | fsctx.Overwrite, - SavePath: "inner/TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("123")), - }, ""}, - {&fsctx.FileStream{ - AppendStart: 10, - Mode: fsctx.Append | fsctx.Overwrite, - SavePath: "inner/TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("123")), - }, "size of unfinished uploaded chunks is not as expected"}, - {&fsctx.FileStream{ - Mode: fsctx.Append | fsctx.Overwrite, - SavePath: "inner/TestHandler_Put.txt", - File: io.NopCloser(strings.NewReader("123")), - }, ""}, - } - - for _, testCase := range testCases { - err := handler.Put(context.Background(), testCase.file) - if testCase.errContains != "" { - asserts.Error(err) - asserts.Contains(err.Error(), testCase.errContains) - } else { - asserts.NoError(err) - asserts.True(util.Exists(util.RelativePath(testCase.file.Info().SavePath))) - } - } -} - -func TestDriver_TruncateFailed(t *testing.T) { - a := assert.New(t) - h := Driver{} - a.Error(h.Truncate(context.Background(), "TestDriver_TruncateFailed", 0)) -} - -func TestHandler_Delete(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - ctx := context.Background() - filePath := util.RelativePath("TestHandler_Delete.file") - - file, err := os.Create(filePath) - asserts.NoError(err) - _ = file.Close() - list, err := handler.Delete(ctx, []string{"TestHandler_Delete.file"}) - asserts.Equal([]string{}, list) - asserts.NoError(err) - - file, err = os.Create(filePath) - _ = file.Close() - file, _ = os.OpenFile(filePath, os.O_RDWR, os.FileMode(0)) - asserts.NoError(err) - list, err = handler.Delete(ctx, []string{"TestHandler_Delete.file", "test.notexist"}) - file.Close() - asserts.Equal([]string{}, list) - asserts.NoError(err) - - list, err = handler.Delete(ctx, []string{"test.notexist"}) - asserts.Equal([]string{}, list) - asserts.NoError(err) - - file, err = os.Create(filePath) - asserts.NoError(err) - list, err = handler.Delete(ctx, []string{"TestHandler_Delete.file"}) - _ = file.Close() - asserts.Equal([]string{}, list) - asserts.NoError(err) -} - -func TestHandler_Get(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // 成功 - file, err := os.Create(util.RelativePath("TestHandler_Get.txt")) - asserts.NoError(err) - _ = file.Close() - - rs, err := handler.Get(ctx, "TestHandler_Get.txt") - asserts.NoError(err) - asserts.NotNil(rs) - - // 文件不存在 - - rs, err = handler.Get(ctx, "TestHandler_Get_notExist.txt") - asserts.Error(err) - asserts.Nil(rs) -} - -func TestHandler_Thumb(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - ctx := context.Background() - file, err := os.Create(util.RelativePath("TestHandler_Thumb._thumb")) - asserts.NoError(err) - file.Close() - - f := &model.File{ - SourceName: "TestHandler_Thumb", - MetadataSerialized: map[string]string{ - model.ThumbStatusMetadataKey: model.ThumbStatusExist, - }, - } - - // 正常 - { - thumb, err := handler.Thumb(ctx, f) - asserts.NoError(err) - asserts.NotNil(thumb.Content) - } - - // file 不存在 - { - f.SourceName = "not_exist" - _, err := handler.Thumb(ctx, f) - asserts.Error(err) - asserts.ErrorIs(err, driver.ErrorThumbNotExist) - } - - // thumb not exist - { - f.MetadataSerialized[model.ThumbStatusMetadataKey] = model.ThumbStatusNotExist - _, err := handler.Thumb(ctx, f) - asserts.Error(err) - asserts.ErrorIs(err, driver.ErrorThumbNotExist) - } -} - -func TestHandler_Source(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{}, - } - ctx := context.Background() - auth.General = auth.HMACAuth{SecretKey: []byte("test")} - - // 成功 - { - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - sourceURL, err := handler.Source(ctx, "", 0, false, 0) - asserts.NoError(err) - asserts.NotEmpty(sourceURL) - asserts.Contains(sourceURL, "sign=") - } - - // 下载 - { - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - sourceURL, err := handler.Source(ctx, "", 0, true, 0) - asserts.NoError(err) - asserts.NotEmpty(sourceURL) - asserts.Contains(sourceURL, "sign=") - asserts.Contains(sourceURL, "download") - } - - // 无法获取上下文 - { - sourceURL, err := handler.Source(ctx, "", 0, false, 0) - asserts.Error(err) - asserts.Empty(sourceURL) - } - - // 设定了CDN - { - handler.Policy.BaseURL = "https://cqu.edu.cn" - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - sourceURL, err := handler.Source(ctx, "", 0, false, 0) - asserts.NoError(err) - asserts.NotEmpty(sourceURL) - asserts.Contains(sourceURL, "sign=") - asserts.Contains(sourceURL, "https://cqu.edu.cn") - } - - // 设定了CDN,解析失败 - { - handler.Policy.BaseURL = string([]byte{0x7f}) - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - sourceURL, err := handler.Source(ctx, "", 0, false, 0) - asserts.Error(err) - asserts.Empty(sourceURL) - } -} - -func TestHandler_GetDownloadURL(t *testing.T) { - asserts := assert.New(t) - handler := Driver{Policy: &model.Policy{}} - ctx := context.Background() - auth.General = auth.HMACAuth{SecretKey: []byte("test")} - - // 成功 - { - file := model.File{ - Model: gorm.Model{ - ID: 1, - }, - Name: "test.jpg", - } - ctx := context.WithValue(ctx, fsctx.FileModelCtx, file) - downloadURL, err := handler.Source(ctx, "", 10, true, 0) - asserts.NoError(err) - asserts.Contains(downloadURL, "sign=") - } - - // 无法获取上下文 - { - downloadURL, err := handler.Source(ctx, "", 10, true, 0) - asserts.Error(err) - asserts.Empty(downloadURL) - } -} - -func TestHandler_Token(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{}, - } - ctx := context.Background() - upSession := &serializer.UploadSession{SavePath: "TestHandler_Token"} - _, err := handler.Token(ctx, 10, upSession, &fsctx.FileStream{}) - asserts.NoError(err) - - file, _ := os.Create("TestHandler_Token") - defer func() { - file.Close() - os.Remove("TestHandler_Token") - }() - - _, err = handler.Token(ctx, 10, upSession, &fsctx.FileStream{}) - asserts.Error(err) - asserts.Contains(err.Error(), "already exist") -} - -func TestDriver_CancelToken(t *testing.T) { - a := assert.New(t) - handler := Driver{} - a.NoError(handler.CancelToken(context.Background(), &serializer.UploadSession{})) -} - -func TestDriver_List(t *testing.T) { - asserts := assert.New(t) - handler := Driver{} - ctx := context.Background() - - // 创建测试目录结构 - for _, path := range []string{ - "test/TestDriver_List/parent.txt", - "test/TestDriver_List/parent_folder2/sub2.txt", - "test/TestDriver_List/parent_folder1/sub_folder/sub1.txt", - "test/TestDriver_List/parent_folder1/sub_folder/sub2.txt", - } { - f, _ := util.CreatNestedFile(util.RelativePath(path)) - f.Close() - } - - // 非递归列出 - { - res, err := handler.List(ctx, "test/TestDriver_List", false) - asserts.NoError(err) - asserts.Len(res, 3) - } - - // 递归列出 - { - res, err := handler.List(ctx, "test/TestDriver_List", true) - asserts.NoError(err) - asserts.Len(res, 7) - } -} diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index 56abbaa..74649ea 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" "io" "net/http" "net/url" @@ -15,6 +14,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" diff --git a/pkg/filesystem/driver/onedrive/api_test.go b/pkg/filesystem/driver/onedrive/api_test.go deleted file mode 100644 index a675548..0000000 --- a/pkg/filesystem/driver/onedrive/api_test.go +++ /dev/null @@ -1,1155 +0,0 @@ -package onedrive - -import ( - "context" - "errors" - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "io" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestRequest(t *testing.T) { - asserts := assert.New(t) - client := Client{ - Policy: &model.Policy{}, - ClientID: "TestRequest", - Credential: &Credential{ - ExpiresIn: time.Now().Add(time.Duration(100) * time.Hour).Unix(), - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - }, - } - - // 请求发送失败 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - asserts.Equal("error", err.Error()) - } - - // 无法更新凭证 - { - client.Credential.RefreshToken = "" - client.Credential.AccessToken = "" - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - asserts.Error(err) - asserts.Empty(res) - client.Credential.RefreshToken = "RefreshToken" - client.Credential.AccessToken = "AccessToken" - } - - // 无法获取响应正文 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(mockReader("")), - }, - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - } - - // OneDrive返回错误 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)), - }, - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - asserts.Equal("error msg", err.Error()) - } - - // OneDrive返回429错误 - { - header := http.Header{} - header.Add("retry-after", "120") - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 429, - Header: header, - Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)), - }, - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - var retryErr *backoff.RetryableError - asserts.ErrorAs(err, &retryErr) - asserts.EqualValues(time.Duration(120)*time.Second, retryErr.RetryAfter) - } - - // OneDrive返回未知响应 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - } -} - -func TestFileInfo_GetSourcePath(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - fileInfo := FileInfo{ - Name: "%e6%96%87%e4%bb%b6%e5%90%8d.jpg", - ParentReference: parentReference{ - Path: "/drive/root:/123/32%201", - }, - } - asserts.Equal("123/32 1/%e6%96%87%e4%bb%b6%e5%90%8d.jpg", fileInfo.GetSourcePath()) - } - - // 失败 - { - fileInfo := FileInfo{ - Name: "123.jpg", - ParentReference: parentReference{ - Path: "/drive/root:/123/%e6%96%87%e4%bb%b6%e5%90%8g", - }, - } - asserts.Equal("", fileInfo.GetSourcePath()) - } -} - -func TestClient_GetRequestURL(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - - // 出错 - { - client.Endpoints.EndpointURL = string([]byte{0x7f}) - asserts.Equal("", client.getRequestURL("123")) - } - - // 使用DriverResource - { - client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0" - asserts.Equal("https://graph.microsoft.com/v1.0/me/drive/123", client.getRequestURL("123")) - } - - // 不使用DriverResource - { - client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0" - asserts.Equal("https://graph.microsoft.com/v1.0/123", client.getRequestURL("123", WithDriverResource(false))) - } -} - -func TestClient_GetSiteIDByURL(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com") - asserts.Error(err) - asserts.Empty(res) - - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com") - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"id":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com") - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotEmpty(res) - asserts.Equal("123321", res) - } -} - -func TestClient_Meta(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.Meta(context.Background(), "", "123") - asserts.Error(err) - asserts.Nil(res) - - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.Meta(context.Background(), "", "123") - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.Meta(context.Background(), "", "123") - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res.Name) - } - - // 返回正常, 使用资源id - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.Meta(context.Background(), "123321", "123") - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res.Name) - } -} - -func TestClient_CreateUploadSession(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.CreateUploadSession(context.Background(), "123.jpg") - asserts.Error(err) - asserts.Empty(res) - - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.CreateUploadSession(context.Background(), "123.jpg") - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Empty(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.CreateUploadSession(context.Background(), "123.jpg", WithConflictBehavior("fail")) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res) - } -} - -func TestClient_GetUploadSessionStatus(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com") - asserts.Error(err) - asserts.Empty(res) - - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com") - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com") - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res.UploadURL) - } -} - -func TestClient_UploadChunk(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - cg := chunk.NewChunkGroup(&fsctx.FileStream{Size: 15}, 10, &backoff.ConstantBackoff{}, false) - - // 非最后分片,正常 - { - cg.Next() - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - "http://dev.com", - testMock.Anything, - testMock.Anything, - testMock.Anything, - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"http://dev.com/2"}`)), - }, - }) - client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("1234567890"), cg) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.Equal("http://dev.com/2", res.UploadURL) - } - - // 非最后分片,异常响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("1234567890"), cg) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 最后分片,正常 - { - cg.Next() - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - "http://dev.com", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("12345"), cg) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.Nil(res) - } - - // 最后分片,失败 - { - cache.Set("setting_chunk_retries", "1", 0) - client.Credential.ExpiresIn = 0 - go func() { - time.Sleep(time.Duration(2) * time.Second) - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - }() - clientMock := ClientMock{} - client.Request = clientMock - res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("12345"), cg) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } -} - -func TestClient_Upload(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - ctx := context.Background() - cache.Set("setting_chunk_retries", "1", 0) - cache.Set("setting_use_temp_chunk_buffer", "false", 0) - - // 小文件,简单上传,失败 - { - client.Credential.ExpiresIn = 0 - err := client.Upload(ctx, &fsctx.FileStream{ - Size: 5, - File: io.NopCloser(strings.NewReader("12345")), - }) - asserts.Error(err) - } - - // 无法创建分片会话 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - client.Request = clientMock - err := client.Upload(context.Background(), &fsctx.FileStream{ - Size: SmallFileSize + 1, - File: io.NopCloser(strings.NewReader("12345")), - }) - clientMock.AssertExpectations(t) - asserts.Error(err) - } - - // 分片上传失败 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - clientMock.On( - "Request", - "PUT", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - client.Request = clientMock - err := client.Upload(context.Background(), &fsctx.FileStream{ - Size: SmallFileSize + 1, - File: io.NopCloser(strings.NewReader("12345")), - }) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Contains(err.Error(), "failed to upload chunk") - } - -} - -func TestClient_SimpleUpload(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - cache.Set("setting_chunk_retries", "1", 0) - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) - asserts.Error(err) - asserts.Nil(res) - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "PUT", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)), - }, - }) - client.Request = clientMock - res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("123321", res.Name) - } -} - -func TestClient_DeleteUploadSession(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - err := client.DeleteUploadSession(context.Background(), "123.jpg") - asserts.Error(err) - - } - - // 返回正常 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "DELETE", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 204, - Body: ioutil.NopCloser(strings.NewReader(``)), - }, - }) - client.Request = clientMock - err := client.DeleteUploadSession(context.Background(), "123.jpg") - clientMock.AssertExpectations(t) - asserts.NoError(err) - } -} - -func TestClient_BatchDelete(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 小于20个,失败1个 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"responses":[{"id":"2","status":400}]}`)), - }, - }) - client.Request = clientMock - res, err := client.BatchDelete(context.Background(), []string{"1", "2", "3", "1", "2"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Equal([]string{"2"}, res) - } -} - -func TestClient_Delete(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.Delete(context.Background(), []string{"1", "2", "3"}) - asserts.Error(err) - asserts.Len(res, 3) - } - - // 返回未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.Delete(context.Background(), []string{"1", "2", "3"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(res, 3) - } - - // 成功2两个文件 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"responses":[{"id":"2","status":400}]}`)), - }, - }) - client.Request = clientMock - res, err := client.Delete(context.Background(), []string{"1", "2", "3"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Equal([]string{"2"}, res) - } -} - -func TestClient_ListChildren(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 根目录,请求失败,重测试 - { - client.Credential.ExpiresIn = 0 - res, err := client.ListChildren(context.Background(), "/") - asserts.Error(err) - asserts.Empty(res) - } - - // 非根目录,未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.ListChildren(context.Background(), "/uploads") - asserts.Error(err) - asserts.Empty(res) - } - - // 非根目录,成功 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{}]}`)), - }, - }) - client.Request = clientMock - res, err := client.ListChildren(context.Background(), "/uploads") - asserts.NoError(err) - asserts.Len(res, 1) - } -} - -func TestClient_GetThumbURL(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.AccessToken = "AccessToken" - - // 请求失败 - { - client.Credential.ExpiresIn = 0 - res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) - asserts.Error(err) - asserts.Empty(res) - } - - // 未知响应 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) - asserts.Error(err) - asserts.Empty(res) - } - - // 世纪互联 成功 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - client.Endpoints.isInChina = true - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"url":"thumb"}`)), - }, - }) - client.Request = clientMock - res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) - asserts.NoError(err) - asserts.Equal("thumb", res) - } - - // 非世纪互联 成功 - { - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - client.Endpoints.isInChina = false - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"large":{"url":"thumb"}}]}`)), - }, - }) - client.Request = clientMock - res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1) - asserts.NoError(err) - asserts.Equal("thumb", res) - } -} - -func TestClient_MonitorUpload(t *testing.T) { - asserts := assert.New(t) - client, _ := NewClient(&model.Policy{}) - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 客户端完成回调 - { - cache.Set("setting_onedrive_monitor_timeout", "600", 0) - cache.Set("setting_onedrive_callback_check", "20", 0) - asserts.NotPanics(func() { - go func() { - time.Sleep(time.Duration(1) * time.Second) - mq.GlobalMQ.Publish("key", mq.Message{}) - }() - client.MonitorUpload("url", "key", "path", 10, 10) - }) - } - - // 上传会话到期,仍未完成上传,创建占位符 - { - cache.Set("setting_onedrive_monitor_timeout", "600", 0) - cache.Set("setting_onedrive_callback_check", "20", 0) - asserts.NotPanics(func() { - client.MonitorUpload("url", "key", "path", 10, 0) - }) - } - - fmt.Println("测试:上传已完成,未发送回调") - // 上传已完成,未发送回调 - { - cache.Set("setting_onedrive_monitor_timeout", "0", 0) - cache.Set("setting_onedrive_callback_check", "0", 0) - - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - client.Credential.AccessToken = "1" - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader(`{"error":{"code":"itemNotFound"}}`)), - }, - }) - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader(`{"error":{"code":"itemNotFound"}}`)), - }, - }) - client.Request = clientMock - cache.Set("callback_key3", "ok", 0) - - asserts.NotPanics(func() { - client.MonitorUpload("url", "key3", "path", 10, 10) - }) - - clientMock.AssertExpectations(t) - } - - fmt.Println("测试:上传仍未开始") - // 上传仍未开始 - { - cache.Set("setting_onedrive_monitor_timeout", "0", 0) - cache.Set("setting_onedrive_callback_check", "0", 0) - - client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - client.Credential.AccessToken = "1" - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"nextExpectedRanges":["0-"]}`)), - }, - }) - clientMock.On( - "Request", - "DELETE", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(``)), - }, - }) - clientMock.On( - "Request", - "PUT", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{}`)), - }, - }) - client.Request = clientMock - - asserts.NotPanics(func() { - client.MonitorUpload("url", "key4", "path", 10, 10) - }) - - clientMock.AssertExpectations(t) - } - -} diff --git a/pkg/filesystem/driver/onedrive/client.go b/pkg/filesystem/driver/onedrive/client.go index 957af8e..89e696b 100644 --- a/pkg/filesystem/driver/onedrive/client.go +++ b/pkg/filesystem/driver/onedrive/client.go @@ -2,6 +2,7 @@ package onedrive import ( "errors" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" model "github.com/cloudreve/Cloudreve/v3/models" diff --git a/pkg/filesystem/driver/onedrive/client_test.go b/pkg/filesystem/driver/onedrive/client_test.go deleted file mode 100644 index aa3c132..0000000 --- a/pkg/filesystem/driver/onedrive/client_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package onedrive - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" -) - -func TestNewClient(t *testing.T) { - asserts := assert.New(t) - // getOAuthEndpoint失败 - { - policy := model.Policy{ - BaseURL: string([]byte{0x7f}), - } - res, err := NewClient(&policy) - asserts.Error(err) - asserts.Nil(res) - } - - // 成功 - { - policy := model.Policy{} - res, err := NewClient(&policy) - asserts.NoError(err) - asserts.NotNil(res) - asserts.NotNil(res.Credential) - asserts.NotNil(res.Endpoints) - asserts.NotNil(res.Endpoints.OAuthEndpoints) - } -} diff --git a/pkg/filesystem/driver/onedrive/handler_test.go b/pkg/filesystem/driver/onedrive/handler_test.go deleted file mode 100644 index 2c9c2c2..0000000 --- a/pkg/filesystem/driver/onedrive/handler_test.go +++ /dev/null @@ -1,420 +0,0 @@ -package onedrive - -import ( - "context" - "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/jinzhu/gorm" - "io" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestDriver_Token(t *testing.T) { - asserts := assert.New(t) - h, _ := NewDriver(&model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }) - handler := h.(Driver) - - // 分片上传 失败 - { - cache.Set("setting_siteURL", "http://test.cloudreve.org", 0) - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - handler.Client.Request = clientMock - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{}) - asserts.Error(err) - asserts.Nil(res) - } - - // 分片上传 成功 - { - cache.Set("setting_siteURL", "http://test.cloudreve.org", 0) - cache.Set("setting_onedrive_monitor_timeout", "600", 0) - cache.Set("setting_onedrive_callback_check", "20", 0) - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - handler.Client.Credential.AccessToken = "1" - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)), - }, - }) - handler.Client.Request = clientMock - go func() { - time.Sleep(time.Duration(1) * time.Second) - mq.GlobalMQ.Publish("TestDriver_Token", mq.Message{}) - }() - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{Key: "TestDriver_Token"}, &fsctx.FileStream{}) - asserts.NoError(err) - asserts.Equal("123321", res.UploadURLs[0]) - } -} - -func TestDriver_Source(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - cache.Set("setting_onedrive_source_timeout", "1800", 0) - - // 失败 - { - res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0) - asserts.Error(err) - asserts.Empty(res) - } - - // 命中缓存 成功 - { - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - handler.Client.Credential.AccessToken = "1" - cache.Set("onedrive_source_0_123.jpg", "res", 1) - res, err := handler.Source(context.Background(), "123.jpg", 0, true, 0) - cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_") - asserts.NoError(err) - asserts.Equal("res", res) - } - - // 命中缓存 上下文存在文件 成功 - { - file := model.File{} - file.ID = 1 - file.UpdatedAt = time.Now() - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - handler.Client.Credential.AccessToken = "1" - cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0) - res, err := handler.Source(ctx, "123.jpg", 1, true, 0) - cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_") - asserts.NoError(err) - asserts.Equal("res", res) - } - - // 成功 - { - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)), - }, - }) - handler.Client.Request = clientMock - handler.Client.Credential.AccessToken = "1" - res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0) - asserts.NoError(err) - asserts.Equal("123321", res) - } -} - -func TestDriver_List(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.AccessToken = "AccessToken" - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 非递归 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{}]}`)), - }, - }) - handler.Client.Request = clientMock - res, err := handler.List(context.Background(), "/", false) - asserts.NoError(err) - asserts.Len(res, 1) - } - - // 递归一次 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - "me/drive/root/children?$top=999999999", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"1","folder":{}}]}`)), - }, - }) - clientMock.On( - "Request", - "GET", - "me/drive/root:/1:/children?$top=999999999", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"2"}]}`)), - }, - }) - handler.Client.Request = clientMock - res, err := handler.List(context.Background(), "/", true) - asserts.NoError(err) - asserts.Len(res, 2) - } -} - -func TestDriver_Thumb(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - file := &model.File{PicInfo: "1,1", Model: gorm.Model{ID: 1}} - - // 失败 - { - ctx := context.WithValue(context.Background(), fsctx.ThumbSizeCtx, [2]uint{10, 20}) - res, err := handler.Thumb(ctx, file) - asserts.Error(err) - asserts.Empty(res.URL) - } - - // 上下文错误 - { - _, err := handler.Thumb(context.Background(), file) - asserts.Error(err) - } -} - -func TestDriver_Delete(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 失败 - { - _, err := handler.Delete(context.Background(), []string{"1"}) - asserts.Error(err) - } - -} - -func TestDriver_Put(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 失败 - { - err := handler.Put(context.Background(), &fsctx.FileStream{}) - asserts.Error(err) - } -} - -func TestDriver_Get(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 无法获取source - { - res, err := handler.Get(context.Background(), "123.txt") - asserts.Error(err) - asserts.Nil(res) - } - - // 成功 - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)), - }, - }) - handler.Client.Request = clientMock - handler.Client.Credential.AccessToken = "1" - - driverClientMock := ClientMock{} - driverClientMock.On( - "Request", - "GET", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`123`)), - }, - }) - handler.HTTPClient = driverClientMock - res, err := handler.Get(context.Background(), "123.txt") - clientMock.AssertExpectations(t) - asserts.NoError(err) - _, err = res.Seek(0, io.SeekEnd) - asserts.NoError(err) - content, err := ioutil.ReadAll(res) - asserts.NoError(err) - asserts.Equal("123", string(content)) -} - -func TestDriver_replaceSourceHost(t *testing.T) { - tests := []struct { - name string - origin string - cdn string - want string - wantErr bool - }{ - {"TestNoReplace", "http://1dr.ms/download.aspx?123456", "", "http://1dr.ms/download.aspx?123456", false}, - {"TestReplaceCorrect", "http://1dr.ms/download.aspx?123456", "https://test.com:8080", "https://test.com:8080/download.aspx?123456", false}, - {"TestCdnFormatError", "http://1dr.ms/download.aspx?123456", string([]byte{0x7f}), "", true}, - {"TestSrcFormatError", string([]byte{0x7f}), "https://test.com:8080", "", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - policy := &model.Policy{} - policy.OptionsSerialized.OdProxy = tt.cdn - handler := Driver{ - Policy: policy, - } - got, err := handler.replaceSourceHost(tt.origin) - if (err != nil) != tt.wantErr { - t.Errorf("replaceSourceHost() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("replaceSourceHost() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestDriver_CancelToken(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - AccessKey: "ak", - SecretKey: "sk", - BucketName: "test", - Server: "test.com", - }, - } - handler.Client, _ = NewClient(&model.Policy{}) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - - // 失败 - { - err := handler.CancelToken(context.Background(), &serializer.UploadSession{}) - asserts.Error(err) - } -} diff --git a/pkg/filesystem/driver/onedrive/lock.go b/pkg/filesystem/driver/onedrive/lock.go new file mode 100755 index 0000000..655936b --- /dev/null +++ b/pkg/filesystem/driver/onedrive/lock.go @@ -0,0 +1,25 @@ +package onedrive + +import "sync" + +// CredentialLock 针对存储策略凭证的锁 +type CredentialLock interface { + Lock(uint) + Unlock(uint) +} + +var GlobalMutex = mutexMap{} + +type mutexMap struct { + locks sync.Map +} + +func (m *mutexMap) Lock(id uint) { + lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{}) + lock.(*sync.Mutex).Lock() +} + +func (m *mutexMap) Unlock(id uint) { + lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{}) + lock.(*sync.Mutex).Unlock() +} diff --git a/pkg/filesystem/driver/onedrive/oauth_test.go b/pkg/filesystem/driver/onedrive/oauth_test.go deleted file mode 100644 index b2525b7..0000000 --- a/pkg/filesystem/driver/onedrive/oauth_test.go +++ /dev/null @@ -1,386 +0,0 @@ -package onedrive - -import ( - "context" - "database/sql" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock" - "io" - "io/ioutil" - "net/http" - "net/url" - "strings" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestGetOAuthEndpoint(t *testing.T) { - asserts := assert.New(t) - - // URL解析失败 - { - client := Client{ - Endpoints: &Endpoints{ - OAuthURL: string([]byte{0x7f}), - }, - } - res := client.getOAuthEndpoint() - asserts.Nil(res) - } - - { - testCase := []struct { - OAuthURL string - token string - auth string - isChina bool - }{ - { - OAuthURL: "http://login.live.com", - token: "https://login.live.com/oauth20_token.srf", - auth: "https://login.live.com/oauth20_authorize.srf", - isChina: false, - }, - { - OAuthURL: "http://login.chinacloudapi.cn", - token: "https://login.chinacloudapi.cn/common/oauth2/v2.0/token", - auth: "https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize", - isChina: true, - }, - { - OAuthURL: "other", - token: "https://login.microsoftonline.com/common/oauth2/v2.0/token", - auth: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", - isChina: false, - }, - } - - for i, testCase := range testCase { - client := Client{ - Endpoints: &Endpoints{ - OAuthURL: testCase.OAuthURL, - }, - } - res := client.getOAuthEndpoint() - asserts.Equal(testCase.token, res.token.String(), "Test Case #%d", i) - asserts.Equal(testCase.auth, res.authorize.String(), "Test Case #%d", i) - asserts.Equal(testCase.isChina, client.Endpoints.isInChina, "Test Case #%d", i) - } - } -} - -func TestClient_OAuthURL(t *testing.T) { - asserts := assert.New(t) - - client := Client{ - ClientID: "client_id", - Redirect: "http://cloudreve.org/callback", - Endpoints: &Endpoints{}, - } - client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() - res, err := url.Parse(client.OAuthURL(context.Background(), []string{"scope1", "scope2"})) - asserts.NoError(err) - query := res.Query() - asserts.Equal("client_id", query.Get("client_id")) - asserts.Equal("scope1 scope2", query.Get("scope")) - asserts.Equal(client.Redirect, query.Get("redirect_uri")) - -} - -type ClientMock struct { - testMock.Mock -} - -func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { - args := m.Called(method, target, body, opts) - return args.Get(0).(*request.Response) -} - -type mockReader string - -func (r mockReader) Read(b []byte) (int, error) { - return 0, errors.New("read error") -} - -func TestClient_ObtainToken(t *testing.T) { - asserts := assert.New(t) - - client := Client{ - Endpoints: &Endpoints{}, - ClientID: "ClientID", - ClientSecret: "ClientSecret", - Redirect: "Redirect", - } - client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() - - // 刷新Token 成功 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"access_token":"i am token"}`)), - }, - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background()) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.NotNil(res) - asserts.Equal("i am token", res.AccessToken) - } - - // 重新获取 无法发送请求 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background(), WithCode("code")) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } - - // 刷新Token 无法获取响应正文 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(mockReader("")), - }, - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background()) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - asserts.Equal("read error", err.Error()) - } - - // 刷新Token OneDrive返回错误 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"error":"i am error"}`)), - }, - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background()) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - asserts.Equal("", err.Error()) - } - - // 刷新Token OneDrive未知响应 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`???`)), - }, - }) - client.Request = clientMock - - res, err := client.ObtainToken(context.Background()) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Nil(res) - } -} - -func TestClient_UpdateCredential(t *testing.T) { - asserts := assert.New(t) - client := Client{ - Policy: &model.Policy{Model: gorm.Model{ID: 257}}, - Endpoints: &Endpoints{}, - ClientID: "TestClient_UpdateCredential", - ClientSecret: "ClientSecret", - Redirect: "Redirect", - Credential: &Credential{}, - } - client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() - - // 无有效的RefreshToken - { - err := client.UpdateCredential(context.Background(), false) - asserts.Equal(ErrInvalidRefreshToken, err) - client.Credential = nil - err = client.UpdateCredential(context.Background(), false) - asserts.Equal(ErrInvalidRefreshToken, err) - } - - // 成功 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"expires_in":3600,"refresh_token":"new_refresh_token","access_token":"i am token"}`)), - }, - }) - client.Request = clientMock - client.Credential = &Credential{ - RefreshToken: "old_refresh_token", - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := client.UpdateCredential(context.Background(), false) - clientMock.AssertExpectations(t) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - cacheRes, ok := cache.Get("onedrive_TestClient_UpdateCredential") - asserts.True(ok) - cacheCredential := cacheRes.(Credential) - asserts.Equal("new_refresh_token", cacheCredential.RefreshToken) - asserts.Equal("i am token", cacheCredential.AccessToken) - } - - // OneDrive返回错误 - { - cache.Deletes([]string{"TestClient_UpdateCredential"}, "onedrive_") - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - client.Endpoints.OAuthEndpoints.token.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 400, - Body: ioutil.NopCloser(strings.NewReader(`{"error":"error"}`)), - }, - }) - client.Request = clientMock - client.Credential = &Credential{ - RefreshToken: "old_refresh_token", - } - err := client.UpdateCredential(context.Background(), false) - clientMock.AssertExpectations(t) - asserts.Error(err) - } - - // 从缓存中获取 - { - cache.Set("onedrive_TestClient_UpdateCredential", Credential{ - ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(), - AccessToken: "AccessToken", - RefreshToken: "RefreshToken", - }, 0) - client.Credential = &Credential{ - RefreshToken: "old_refresh_token", - } - err := client.UpdateCredential(context.Background(), false) - asserts.NoError(err) - asserts.Equal("AccessToken", client.Credential.AccessToken) - asserts.Equal("RefreshToken", client.Credential.RefreshToken) - } - - // 无需重新获取 - { - client.Credential = &Credential{ - RefreshToken: "old_refresh_token", - AccessToken: "AccessToken2", - ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(), - } - err := client.UpdateCredential(context.Background(), false) - asserts.NoError(err) - asserts.Equal("AccessToken2", client.Credential.AccessToken) - } - - // slave failed - { - mockController := &controllermock.SlaveControllerMock{} - mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("", errors.New("error")) - client.ClusterController = mockController - err := client.UpdateCredential(context.Background(), true) - asserts.Error(err) - } - - // slave success - { - mockController := &controllermock.SlaveControllerMock{} - mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("AccessToken3", nil) - client.ClusterController = mockController - err := client.UpdateCredential(context.Background(), true) - asserts.NoError(err) - asserts.Equal("AccessToken3", client.Credential.AccessToken) - } -} diff --git a/pkg/filesystem/driver/oss/handler.go b/pkg/filesystem/driver/oss/handler.go index 2ae50a3..ccccbd2 100644 --- a/pkg/filesystem/driver/oss/handler.go +++ b/pkg/filesystem/driver/oss/handler.go @@ -172,14 +172,24 @@ func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([ if err != nil { continue } - res = append(res, response.Object{ - Name: path.Base(object.Key), - Source: object.Key, - RelativePath: filepath.ToSlash(rel), - Size: uint64(object.Size), - IsDir: false, - LastModify: object.LastModified, - }) + if strings.HasSuffix(object.Key, "/") { + res = append(res, response.Object{ + Name: path.Base(object.Key), + RelativePath: filepath.ToSlash(rel), + Size: 0, + IsDir: true, + LastModify: time.Now(), + }) + } else { + res = append(res, response.Object{ + Name: path.Base(object.Key), + Source: object.Key, + RelativePath: filepath.ToSlash(rel), + Size: uint64(object.Size), + IsDir: false, + LastModify: object.LastModified, + }) + } } return res, nil diff --git a/pkg/filesystem/driver/remote/client_test.go b/pkg/filesystem/driver/remote/client_test.go deleted file mode 100644 index c195521..0000000 --- a/pkg/filesystem/driver/remote/client_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package remote - -import ( - "context" - "errors" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io/ioutil" - "net/http" - "strings" - "testing" -) - -func TestNewClient(t *testing.T) { - a := assert.New(t) - policy := &model.Policy{} - - // 无法解析服务端url - { - policy.Server = string([]byte{0x7f}) - c, err := NewClient(policy) - a.Error(err) - a.Nil(c) - } - - // 成功 - { - policy.Server = "" - c, err := NewClient(policy) - a.NoError(err) - a.NotNil(c) - } -} - -func TestRemoteClient_Upload(t *testing.T) { - a := assert.New(t) - c, _ := NewClient(&model.Policy{}) - - // 无法创建上传会话 - { - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) - } - - // 分片上传失败,成功删除上传会话 - { - cache.Set("setting_chunk_retries", "1", 0) - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - clientMock.On( - "Request", - "DELETE", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) - } - - // 分片上传失败,无法删除上传会话 - { - cache.Set("setting_chunk_retries", "1", 0) - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error"), - }) - clientMock.On( - "Request", - "DELETE", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: errors.New("error2"), - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) - } - - // 成功 - { - cache.Set("setting_chunk_retries", "1", 0) - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.NoError(err) - clientMock.AssertExpectations(t) - } -} - -func TestRemoteClient_CreateUploadSessionFailed(t *testing.T) { - a := assert.New(t) - c, _ := NewClient(&model.Policy{}) - - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "PUT", - "upload", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":500,"msg":"error"}`)), - }, - }) - err := c.Upload(context.Background(), &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) -} - -func TestRemoteClient_UploadChunkFailed(t *testing.T) { - a := assert.New(t) - c, _ := NewClient(&model.Policy{}) - - clientMock := requestmock.RequestMock{} - c.(*remoteClient).httpClient = &clientMock - clientMock.On( - "Request", - "POST", - testMock.Anything, - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":500,"msg":"error"}`)), - }, - }) - err := c.(*remoteClient).uploadChunk(context.Background(), "", 0, strings.NewReader(""), false, 0) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) -} - -func TestRemoteClient_GetUploadURL(t *testing.T) { - a := assert.New(t) - c, _ := NewClient(&model.Policy{}) - - // url 解析失败 - { - c.(*remoteClient).policy.Server = string([]byte{0x7f}) - res, sign, err := c.GetUploadURL(0, "") - a.Error(err) - a.Empty(res) - a.Empty(sign) - } - - // 成功 - { - c.(*remoteClient).policy.Server = "" - res, sign, err := c.GetUploadURL(0, "") - a.NoError(err) - a.NotEmpty(res) - a.NotEmpty(sign) - } -} diff --git a/pkg/filesystem/driver/remote/handler_test.go b/pkg/filesystem/driver/remote/handler_test.go deleted file mode 100644 index 4f6f239..0000000 --- a/pkg/filesystem/driver/remote/handler_test.go +++ /dev/null @@ -1,460 +0,0 @@ -package remote - -import ( - "context" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/remoteclientmock" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "io" - "io/ioutil" - "net/http" - "strings" - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestNewDriver(t *testing.T) { - a := assert.New(t) - - // remoteClient 初始化失败 - { - d, err := NewDriver(&model.Policy{Server: string([]byte{0x7f})}) - a.Error(err) - a.Nil(d) - } - - // 成功 - { - d, err := NewDriver(&model.Policy{}) - a.NoError(err) - a.NotNil(d) - } -} - -func TestHandler_Source(t *testing.T) { - asserts := assert.New(t) - auth.General = auth.HMACAuth{SecretKey: []byte("test")} - - // 无法获取上下文 - { - handler := Driver{ - Policy: &model.Policy{Server: "/"}, - AuthInstance: auth.HMACAuth{}, - } - ctx := context.Background() - res, err := handler.Source(ctx, "", 0, true, 0) - asserts.NoError(err) - asserts.NotEmpty(res) - } - - // 成功 - { - handler := Driver{ - Policy: &model.Policy{Server: "/"}, - AuthInstance: auth.HMACAuth{}, - } - file := model.File{ - SourceName: "1.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - res, err := handler.Source(ctx, "", 10, true, 0) - asserts.NoError(err) - asserts.Contains(res, "api/v3/slave/download/0") - } - - // 成功 自定义CDN - { - handler := Driver{ - Policy: &model.Policy{Server: "/", BaseURL: "https://cqu.edu.cn"}, - AuthInstance: auth.HMACAuth{}, - } - file := model.File{ - SourceName: "1.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - res, err := handler.Source(ctx, "", 10, true, 0) - asserts.NoError(err) - asserts.Contains(res, "api/v3/slave/download/0") - asserts.Contains(res, "https://cqu.edu.cn") - } - - // 解析失败 自定义CDN - { - handler := Driver{ - Policy: &model.Policy{Server: "/", BaseURL: string([]byte{0x7f})}, - AuthInstance: auth.HMACAuth{}, - } - file := model.File{ - SourceName: "1.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - res, err := handler.Source(ctx, "", 10, true, 0) - asserts.Error(err) - asserts.Empty(res) - } - - // 成功 预览 - { - handler := Driver{ - Policy: &model.Policy{Server: "/"}, - AuthInstance: auth.HMACAuth{}, - } - file := model.File{ - SourceName: "1.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - res, err := handler.Source(ctx, "", 10, false, 0) - asserts.NoError(err) - asserts.Contains(res, "api/v3/slave/source/0") - } -} - -type ClientMock struct { - testMock.Mock -} - -func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { - args := m.Called(method, target, body, opts) - return args.Get(0).(*request.Response) -} - -func TestHandler_Delete(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - SecretKey: "test", - Server: "http://test.com", - }, - AuthInstance: auth.HMACAuth{}, - } - ctx := context.Background() - cache.Set("setting_slave_api_timeout", "60", 0) - - // 成功 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/delete", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - handler.Client = clientMock - failed, err := handler.Delete(ctx, []string{"/test1.txt", "test2.txt"}) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.Len(failed, 0) - - } - - // 结果解析失败 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/delete", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":203}`)), - }, - }) - handler.Client = clientMock - failed, err := handler.Delete(ctx, []string{"/test1.txt", "test2.txt"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(failed, 2) - } - - // 一个失败 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/delete", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":203,"data":"{\"files\":[\"1\"]}"}`)), - }, - }) - handler.Client = clientMock - failed, err := handler.Delete(ctx, []string{"/test1.txt", "test2.txt"}) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(failed, 1) - } -} - -func TestDriver_List(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - SecretKey: "test", - Server: "http://test.com", - }, - AuthInstance: auth.HMACAuth{}, - } - ctx := context.Background() - cache.Set("setting_slave_api_timeout", "60", 0) - - // 成功 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/list", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"data":"[{}]"}`)), - }, - }) - handler.Client = clientMock - res, err := handler.List(ctx, "/", true) - clientMock.AssertExpectations(t) - asserts.NoError(err) - asserts.Len(res, 1) - - } - - // 响应解析失败 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/list", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"data":"233"}`)), - }, - }) - handler.Client = clientMock - res, err := handler.List(ctx, "/", true) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(res, 0) - } - - // 从机返回错误 - { - clientMock := ClientMock{} - clientMock.On( - "Request", - "POST", - "http://test.com/api/v3/slave/list", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":203}`)), - }, - }) - handler.Client = clientMock - res, err := handler.List(ctx, "/", true) - clientMock.AssertExpectations(t) - asserts.Error(err) - asserts.Len(res, 0) - } -} - -func TestHandler_Get(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - SecretKey: "test", - Server: "http://test.com", - }, - AuthInstance: auth.HMACAuth{}, - } - ctx := context.Background() - - // 成功 - { - ctx = context.WithValue(ctx, fsctx.UserCtx, model.User{}) - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - nil, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - handler.Client = clientMock - resp, err := handler.Get(ctx, "/test.txt") - clientMock.AssertExpectations(t) - asserts.NotNil(resp) - asserts.NoError(err) - } - - // 请求失败 - { - ctx = context.WithValue(ctx, fsctx.UserCtx, model.User{}) - clientMock := ClientMock{} - clientMock.On( - "Request", - "GET", - testMock.Anything, - nil, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 404, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - handler.Client = clientMock - resp, err := handler.Get(ctx, "/test.txt") - clientMock.AssertExpectations(t) - asserts.Nil(resp) - asserts.Error(err) - } -} - -func TestHandler_Put(t *testing.T) { - a := assert.New(t) - handler, _ := NewDriver(&model.Policy{ - Type: "remote", - SecretKey: "test", - Server: "http://test.com", - }) - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("Upload", testMock.Anything, testMock.Anything).Return(errors.New("error")) - a.Error(handler.Put(context.Background(), &fsctx.FileStream{})) - clientMock.AssertExpectations(t) -} - -func TestHandler_Thumb(t *testing.T) { - asserts := assert.New(t) - handler := Driver{ - Policy: &model.Policy{ - Type: "remote", - SecretKey: "test", - Server: "http://test.com", - OptionsSerialized: model.PolicyOption{ - ThumbExts: []string{"txt"}, - }, - }, - AuthInstance: auth.HMACAuth{}, - } - file := &model.File{ - Name: "1.txt", - SourceName: "1.txt", - } - ctx := context.Background() - asserts.NoError(cache.Set("setting_preview_timeout", "60", 0)) - - // no error - { - resp, err := handler.Thumb(ctx, file) - asserts.NoError(err) - asserts.True(resp.Redirect) - } - - // ext not support - { - file.Name = "1.jpg" - resp, err := handler.Thumb(ctx, file) - asserts.ErrorIs(err, driver.ErrorThumbNotSupported) - asserts.Nil(resp) - } -} - -func TestHandler_Token(t *testing.T) { - a := assert.New(t) - handler, _ := NewDriver(&model.Policy{}) - - // 无法创建上传会话 - { - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("CreateUploadSession", testMock.Anything, testMock.Anything, int64(10), false).Return(errors.New("error")) - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - a.Nil(res) - clientMock.AssertExpectations(t) - } - - // 无法创建上传地址 - { - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("CreateUploadSession", testMock.Anything, testMock.Anything, int64(10), false).Return(nil) - clientMock.On("GetUploadURL", int64(10), "").Return("", "", errors.New("error")) - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{}) - a.Error(err) - a.Contains(err.Error(), "error") - a.Nil(res) - clientMock.AssertExpectations(t) - } - - // 成功 - { - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("CreateUploadSession", testMock.Anything, testMock.Anything, int64(10), false).Return(nil) - clientMock.On("GetUploadURL", int64(10), "").Return("1", "2", nil) - res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{}) - a.NoError(err) - a.NotNil(res) - a.Equal("1", res.UploadURLs[0]) - a.Equal("2", res.Credential) - clientMock.AssertExpectations(t) - } -} - -func TestDriver_CancelToken(t *testing.T) { - a := assert.New(t) - handler, _ := NewDriver(&model.Policy{}) - - clientMock := &remoteclientmock.RemoteClientMock{} - handler.uploadClient = clientMock - clientMock.On("DeleteUploadSession", testMock.Anything, "key").Return(errors.New("error")) - err := handler.CancelToken(context.Background(), &serializer.UploadSession{Key: "key"}) - a.Error(err) - a.Contains(err.Error(), "error") - clientMock.AssertExpectations(t) -} diff --git a/pkg/filesystem/errors.go b/pkg/filesystem/errors.go index d267038..5c3f231 100644 --- a/pkg/filesystem/errors.go +++ b/pkg/filesystem/errors.go @@ -2,7 +2,6 @@ package filesystem import ( "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" ) diff --git a/pkg/filesystem/file_test.go b/pkg/filesystem/file_test.go deleted file mode 100644 index 66f3444..0000000 --- a/pkg/filesystem/file_test.go +++ /dev/null @@ -1,669 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "os" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_AddFile(t *testing.T) { - asserts := assert.New(t) - file := fsctx.FileStream{ - Size: 5, - Name: "1.png", - SavePath: "/Uploads/1_sad.png", - } - folder := model.Folder{ - Model: gorm.Model{ - ID: 1, - }, - } - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Policy: model.Policy{ - Type: "cos", - Model: gorm.Model{ - ID: 1, - }, - }, - }, - Policy: &model.Policy{Type: "cos"}, - } - - _, err := fs.AddFile(context.Background(), &folder, &file) - - asserts.Error(err) - - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - f, err := fs.AddFile(context.Background(), &folder, &file) - - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("/Uploads/1_sad.png", f.SourceName) - - // 前置钩子执行失败 - { - hookExecuted := false - fs.Use("BeforeAddFile", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - hookExecuted = true - return errors.New("error") - }) - f, err := fs.AddFile(context.Background(), &folder, &file) - asserts.Error(err) - asserts.Nil(f) - asserts.True(hookExecuted) - } - - // 后置钩子执行失败 - { - hookExecuted := false - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - fs.Hooks = map[string][]Hook{} - fs.Use("AfterValidateFailed", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - hookExecuted = true - return errors.New("error") - }) - f, err := fs.AddFile(context.Background(), &folder, &file) - asserts.Error(err) - asserts.Nil(f) - asserts.True(hookExecuted) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFileSystem_GetContent(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Policy: model.Policy{ - Model: gorm.Model{ - ID: 1, - }, - }, - }, - } - - // 文件不存在 - rs, err := fs.GetContent(ctx, 1) - asserts.Equal(ErrObjectNotExist, err) - asserts.Nil(rs) - fs.CleanTargets() - - // 未知存储策略 - file, err := os.Create(util.RelativePath("TestFileSystem_GetContent.txt")) - asserts.NoError(err) - _ = file.Close() - - cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "source_name", "policy_id"}).AddRow(1, "TestFileSystem_GetContent.txt", 1)) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "unknown")) - - rs, err = fs.GetContent(ctx, 1) - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - fs.CleanTargets() - - // 打开文件失败 - cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "source_name", "policy_id"}).AddRow(1, "TestFileSystem_GetContent2.txt", 1)) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "source_name"}).AddRow(1, "local", "not exist")) - - rs, err = fs.GetContent(ctx, 1) - asserts.Equal(serializer.CodeIOFailed, err.(serializer.AppError).Code) - asserts.NoError(mock.ExpectationsWereMet()) - fs.CleanTargets() - - // 打开成功 - cache.Deletes([]string{"1"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "source_name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetContent.txt", 1, "TestFileSystem_GetContent.txt")) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) - - rs, err = fs.GetContent(ctx, 1) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFileSystem_GetDownloadContent(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Policy: model.Policy{ - Model: gorm.Model{ - ID: 599, - }, - }, - }, - } - file, err := os.Create(util.RelativePath("TestFileSystem_GetDownloadContent.txt")) - asserts.NoError(err) - _ = file.Close() - - cache.Deletes([]string{"599"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt")) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) - - // 无限速 - cache.Deletes([]string{"599"}, "policy_") - _, err = fs.GetDownloadContent(ctx, 1) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - fs.CleanTargets() - - // 有限速 - cache.Deletes([]string{"599"}, "policy_") - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt")) - mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local")) - - fs.User.Group.SpeedLimit = 1 - _, err = fs.GetDownloadContent(ctx, 1) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFileSystem_GroupFileByPolicy(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - files := []model.File{ - model.File{ - PolicyID: 1, - Name: "1_1.txt", - }, - model.File{ - PolicyID: 2, - Name: "2_1.txt", - }, - model.File{ - PolicyID: 3, - Name: "3_1.txt", - }, - model.File{ - PolicyID: 2, - Name: "2_2.txt", - }, - model.File{ - PolicyID: 1, - Name: "1_2.txt", - }, - } - fs := FileSystem{} - policyGroup := fs.GroupFileByPolicy(ctx, files) - asserts.Equal(map[uint][]*model.File{ - 1: {&files[0], &files[4]}, - 2: {&files[1], &files[3]}, - 3: {&files[2]}, - }, policyGroup) -} - -func TestFileSystem_deleteGroupedFile(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{} - files := []model.File{ - { - PolicyID: 1, - Name: "1_1.txt", - SourceName: "1_1.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - { - PolicyID: 2, - Name: "2_1.txt", - SourceName: "2_1.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - { - PolicyID: 3, - Name: "3_1.txt", - SourceName: "3_1.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - { - PolicyID: 2, - Name: "2_2.txt", - SourceName: "2_2.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - { - PolicyID: 1, - Name: "1_2.txt", - SourceName: "1_2.txt", - Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"}, - }, - } - - // 全部不存在 - { - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {}, - 3: {}, - }, failed) - } - // 部分不存在 - { - file, err := os.Create(util.RelativePath("1_1.txt")) - asserts.NoError(err) - _ = file.Close() - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {}, - 3: {}, - }, failed) - } - // 部分失败,包含整组未知存储策略导致的失败 - { - file, err := os.Create(util.RelativePath("1_1.txt")) - asserts.NoError(err) - _ = file.Close() - - files[1].Policy.Type = "unknown" - files[3].Policy.Type = "unknown" - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {"2_1.txt", "2_2.txt"}, - 3: {}, - }, failed) - } - // 包含上传会话文件 - { - sessionID := "session" - cache.Set(UploadSessionCachePrefix+sessionID, serializer.UploadSession{Key: sessionID}, 0) - files[1].Policy.Type = "local" - files[3].Policy.Type = "local" - files[0].UploadSessionID = &sessionID - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {}, - 3: {}, - }, failed) - _, ok := cache.Get(UploadSessionCachePrefix + sessionID) - asserts.False(ok) - } - - // 包含缩略图 - { - files[0].MetadataSerialized = map[string]string{ - model.ThumbSidecarMetadataKey: "1", - } - failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files)) - asserts.Equal(map[uint][]string{ - 1: {}, - 2: {}, - 3: {}, - }, failed) - } -} - -func TestFileSystem_GetSource(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - auth.General = auth.HMACAuth{SecretKey: []byte("123")} - - // 正常 - { - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - // 清空缓存 - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(2, 35, "1.txt"), - ) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(35, "local", true), - ) - - sourceURL, err := fs.GetSource(ctx, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotEmpty(sourceURL) - fs.CleanTargets() - } - - // 文件不存在 - { - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - // 清空缓存 - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}), - ) - - sourceURL, err := fs.GetSource(ctx, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrObjectNotExist.Code, err.(serializer.AppError).Code) - asserts.Empty(sourceURL) - fs.CleanTargets() - } - - // 未知上传策略 - { - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - // 清空缓存 - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(2, 36, "1.txt"), - ) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(36, "?", true), - ) - - sourceURL, err := fs.GetSource(ctx, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Empty(sourceURL) - fs.CleanTargets() - } - - // 不允许获取外链 - { - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - // 清空缓存 - err := cache.Deletes([]string{"siteURL"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1). - WillReturnRows( - sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(2, 37, "1.txt"), - ) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(37, "local", false), - ) - - sourceURL, err := fs.GetSource(ctx, 2) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(serializer.CodePolicyNotAllowed, err.(serializer.AppError).Code) - asserts.Empty(sourceURL) - fs.CleanTargets() - } -} - -func TestFileSystem_GetDownloadURL(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - } - auth.General = auth.HMACAuth{SecretKey: []byte("123")} - - // 正常 - { - err := cache.Deletes([]string{"35"}, "policy_") - cache.Set("setting_download_timeout", "20", 0) - cache.Set("setting_siteURL", "https://cloudreve.org", 0) - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35)) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(35, "local", true), - ) - // 相关设置 - downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotEmpty(downloadURL) - fs.CleanTargets() - } - - // 文件不存在 - { - err := cache.Deletes([]string{"siteURL"}, "setting_") - err = cache.Deletes([]string{"35"}, "policy_") - err = cache.Deletes([]string{"download_timeout"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"})) - - downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Empty(downloadURL) - fs.CleanTargets() - } - - // 未知存储策略 - { - err := cache.Deletes([]string{"siteURL"}, "setting_") - err = cache.Deletes([]string{"35"}, "policy_") - err = cache.Deletes([]string{"download_timeout"}, "setting_") - asserts.NoError(err) - // 查找文件 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35)) - // 查找上传策略 - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}). - AddRow(35, "unknown", true), - ) - - downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Empty(downloadURL) - fs.CleanTargets() - } -} - -func TestFileSystem_GetPhysicalFileContent(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{}, - } - - // 文件不存在 - { - rs, err := fs.GetPhysicalFileContent(ctx, "not_exist.txt") - asserts.Error(err) - asserts.Nil(rs) - } - - // 成功 - { - testFile, err := os.Create(util.RelativePath("GetPhysicalFileContent.txt")) - asserts.NoError(err) - asserts.NoError(testFile.Close()) - - rs, err := fs.GetPhysicalFileContent(ctx, "GetPhysicalFileContent.txt") - asserts.NoError(err) - asserts.NoError(rs.Close()) - asserts.NotNil(rs) - } -} - -func TestFileSystem_Preview(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - - // 文件不存在 - { - fs := FileSystem{ - User: &model.User{}, - } - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - resp, err := fs.Preview(ctx, 1, false) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(resp) - } - - // 直接返回文件内容,找不到文件 - { - fs := FileSystem{ - User: &model.User{}, - } - fs.FileTarget = []model.File{ - { - SourceName: "tests/no.txt", - PolicyID: 1, - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "local", - }, - }, - } - resp, err := fs.Preview(ctx, 1, false) - asserts.Error(err) - asserts.Nil(resp) - } - - // 直接返回文件内容 - { - fs := FileSystem{ - User: &model.User{}, - } - fs.FileTarget = []model.File{ - { - SourceName: "tests/file1.txt", - PolicyID: 1, - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "local", - }, - }, - } - resp, err := fs.Preview(ctx, 1, false) - asserts.Error(err) - asserts.Nil(resp) - } - - // 需要重定向,成功 - { - fs := FileSystem{ - User: &model.User{}, - } - fs.FileTarget = []model.File{ - { - SourceName: "tests/file1.txt", - PolicyID: 1, - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "remote", - }, - }, - } - asserts.NoError(cache.Set("setting_preview_timeout", "233", 0)) - resp, err := fs.Preview(ctx, 1, false) - asserts.NoError(err) - asserts.NotNil(resp) - asserts.True(resp.Redirect) - } - - // 文本文件,大小超出限制 - { - fs := FileSystem{ - User: &model.User{}, - } - fs.FileTarget = []model.File{ - { - SourceName: "tests/file1.txt", - PolicyID: 1, - Policy: model.Policy{ - Model: gorm.Model{ID: 1}, - Type: "remote", - }, - Size: 11, - }, - } - asserts.NoError(cache.Set("setting_maxEditSize", "10", 0)) - resp, err := fs.Preview(ctx, 1, true) - asserts.Equal(ErrFileSizeTooBig, err) - asserts.Nil(resp) - } -} - -func TestFileSystem_ResetFileIDIfNotExist(t *testing.T) { - asserts := assert.New(t) - ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, &model.Folder{Model: gorm.Model{ID: 1}}) - fs := FileSystem{ - FileTarget: []model.File{ - { - FolderID: 2, - }, - }, - } - asserts.Equal(ErrObjectNotExist, fs.resetFileIDIfNotExist(ctx, 1)) -} - -func TestFileSystem_Search(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := &FileSystem{ - User: &model.User{}, - } - fs.User.ID = 1 - - mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - res, err := fs.Search(ctx, "k1", "k2") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Len(res, 1) -} diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 1e14fa8..d892ed0 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -3,6 +3,10 @@ package filesystem import ( "errors" "fmt" + "net/http" + "net/url" + "sync" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/conf" @@ -22,9 +26,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" cossdk "github.com/tencentyun/cos-go-sdk-v5" - "net/http" - "net/url" - "sync" ) // FSPool 文件系统资源池 @@ -92,7 +93,7 @@ func (fs *FileSystem) reset() { func NewFileSystem(user *model.User) (*FileSystem, error) { fs := getEmptyFS() fs.User = user - fs.Policy = &fs.User.Policy + fs.Policy = user.GetPolicyID(nil) // 分配存储策略适配器 err := fs.DispatchHandler() @@ -122,70 +123,57 @@ func NewAnonymousFileSystem() (*FileSystem, error) { // DispatchHandler 根据存储策略分配文件适配器 func (fs *FileSystem) DispatchHandler() error { - if fs.Policy == nil { - return errors.New("未设置存储策略") + handler, err := getNewPolicyHandler(fs.Policy) + fs.Handler = handler + + return err +} + +// getNewPolicyHandler 根据存储策略类型字段获取处理器 +func getNewPolicyHandler(policy *model.Policy) (driver.Handler, error) { + if policy == nil { + return nil, ErrUnknownPolicyType } - policyType := fs.Policy.Type - currentPolicy := fs.Policy - switch policyType { + switch policy.Type { case "mock", "anonymous": - return nil + return nil, nil case "local": - fs.Handler = local.Driver{ - Policy: currentPolicy, - } - return nil + return local.Driver{ + Policy: policy, + }, nil case "remote": - handler, err := remote.NewDriver(currentPolicy) - if err != nil { - return err - } - - fs.Handler = handler + return remote.NewDriver(policy) case "qiniu": - fs.Handler = qiniu.NewDriver(currentPolicy) - return nil + return qiniu.NewDriver(policy), nil case "oss": - handler, err := oss.NewDriver(currentPolicy) - fs.Handler = handler - return err + return oss.NewDriver(policy) case "upyun": - fs.Handler = upyun.Driver{ - Policy: currentPolicy, - } - return nil + return upyun.Driver{ + Policy: policy, + }, nil case "onedrive": - var odErr error - fs.Handler, odErr = onedrive.NewDriver(currentPolicy) - return odErr + return onedrive.NewDriver(policy) case "cos": - u, _ := url.Parse(currentPolicy.Server) + u, _ := url.Parse(policy.Server) b := &cossdk.BaseURL{BucketURL: u} - fs.Handler = cos.Driver{ - Policy: currentPolicy, + return cos.Driver{ + Policy: policy, Client: cossdk.NewClient(b, &http.Client{ Transport: &cossdk.AuthorizationTransport{ - SecretID: currentPolicy.AccessKey, - SecretKey: currentPolicy.SecretKey, + SecretID: policy.AccessKey, + SecretKey: policy.SecretKey, }, }), HTTPClient: request.NewClient(), - } - return nil + }, nil case "s3": - handler, err := s3.NewDriver(currentPolicy) - fs.Handler = handler - return err + return s3.NewDriver(policy) case "googledrive": - handler, err := googledrive.NewDriver(currentPolicy) - fs.Handler = handler - return err + return googledrive.NewDriver(policy) default: - return ErrUnknownPolicyType + return nil, ErrUnknownPolicyType } - - return nil } // NewFileSystemFromContext 从gin.Context创建文件系统 @@ -290,3 +278,18 @@ func (fs *FileSystem) CleanTargets() { fs.FileTarget = fs.FileTarget[:0] fs.DirTarget = fs.DirTarget[:0] } + +// SetPolicyFromPath 根据给定路径尝试设定偏好存储策略 +func (fs *FileSystem) SetPolicyFromPath(filePath string) error { + _, parent := fs.getClosedParent(filePath) + // 尝试获取并重设存储策略 + fs.Policy = fs.User.GetPolicyID(parent) + return fs.DispatchHandler() +} + +// SetPolicyFromPreference 尝试设定偏好存储策略 +func (fs *FileSystem) SetPolicyFromPreference(preference uint) error { + // 尝试获取并重设存储策略 + fs.Policy = fs.User.GetPolicyByPreference(preference) + return fs.DispatchHandler() +} diff --git a/pkg/filesystem/filesystem_test.go b/pkg/filesystem/filesystem_test.go deleted file mode 100644 index 8b7aae3..0000000 --- a/pkg/filesystem/filesystem_test.go +++ /dev/null @@ -1,299 +0,0 @@ -package filesystem - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "net/http/httptest" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" - - "testing" -) - -func TestNewFileSystem(t *testing.T) { - asserts := assert.New(t) - user := model.User{ - Policy: model.Policy{ - Type: "local", - }, - } - - // 本地 成功 - fs, err := NewFileSystem(&user) - asserts.NoError(err) - asserts.NotNil(fs.Handler) - asserts.IsType(local.Driver{}, fs.Handler) - // 远程 - user.Policy.Type = "remote" - fs, err = NewFileSystem(&user) - asserts.NoError(err) - asserts.NotNil(fs.Handler) - asserts.IsType(&remote.Driver{}, fs.Handler) - - user.Policy.Type = "unknown" - fs, err = NewFileSystem(&user) - asserts.Error(err) -} - -func TestNewFileSystemFromContext(t *testing.T) { - asserts := assert.New(t) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Set("user", &model.User{ - Policy: model.Policy{ - Type: "local", - }, - }) - fs, err := NewFileSystemFromContext(c) - asserts.NotNil(fs) - asserts.NoError(err) - - c, _ = gin.CreateTestContext(httptest.NewRecorder()) - fs, err = NewFileSystemFromContext(c) - asserts.Nil(fs) - asserts.Error(err) -} - -func TestDispatchHandler(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{ - User: &model.User{}, - Policy: &model.Policy{ - Type: "local", - }, - } - - // 未指定,使用用户默认 - err := fs.DispatchHandler() - asserts.NoError(err) - asserts.IsType(local.Driver{}, fs.Handler) - - // 已指定,发生错误 - fs.Policy = &model.Policy{Type: "unknown"} - err = fs.DispatchHandler() - asserts.Error(err) - - fs.Policy = &model.Policy{Type: "mock"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "local"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "remote"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "qiniu"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "oss", Server: "https://s.com", BucketName: "1234"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "upyun"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "onedrive"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "cos"} - err = fs.DispatchHandler() - asserts.NoError(err) - - fs.Policy = &model.Policy{Type: "s3"} - err = fs.DispatchHandler() - asserts.NoError(err) -} - -func TestNewFileSystemFromCallback(t *testing.T) { - asserts := assert.New(t) - - // 用户上下文不存在 - { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - fs, err := NewFileSystemFromCallback(c) - asserts.Nil(fs) - asserts.Error(err) - } - - // 找不到回调会话 - { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Set("user", &model.User{ - Policy: model.Policy{ - Type: "local", - }, - }) - fs, err := NewFileSystemFromCallback(c) - asserts.Nil(fs) - asserts.Error(err) - } - - // 成功 - { - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Set("user", &model.User{ - Policy: model.Policy{ - Type: "local", - }, - }) - c.Set(UploadSessionCtx, &serializer.UploadSession{Policy: model.Policy{Type: "local"}}) - fs, err := NewFileSystemFromCallback(c) - asserts.NotNil(fs) - asserts.NoError(err) - } - -} - -func TestFileSystem_SetTargetFileByIDs(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - fs := &FileSystem{} - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt")) - err := fs.SetTargetFileByIDs([]uint{1, 2}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(fs.FileTarget, 1) - asserts.NoError(err) - } - - // 未找到 - { - fs := &FileSystem{} - mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.SetTargetFileByIDs([]uint{1, 2}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(fs.FileTarget, 0) - asserts.Error(err) - } -} - -func TestFileSystem_CleanTargets(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{ - FileTarget: []model.File{{}, {}}, - DirTarget: []model.Folder{{}, {}}, - } - - fs.CleanTargets() - asserts.Len(fs.FileTarget, 0) - asserts.Len(fs.DirTarget, 0) -} - -func TestNewAnonymousFileSystem(t *testing.T) { - asserts := assert.New(t) - - // 正常 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policies"}).AddRow(3, "游客", "[]")) - fs, err := NewAnonymousFileSystem() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.Equal("游客", fs.User.Group.Name) - } - - // 游客用户组不存在 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policies"})) - fs, err := NewAnonymousFileSystem() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(fs) - } - - // 从机 - { - conf.SystemConfig.Mode = "slave" - fs, err := NewAnonymousFileSystem() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(fs) - asserts.NotNil(fs.Handler) - } -} - -func TestFileSystem_Recycle(t *testing.T) { - fs := &FileSystem{ - User: &model.User{}, - Policy: &model.Policy{}, - FileTarget: []model.File{model.File{}}, - DirTarget: []model.Folder{model.Folder{}}, - Hooks: map[string][]Hook{"AfterUpload": []Hook{GenericAfterUpdate}}, - } - fs.Recycle() - newFS := getEmptyFS() - if fs != newFS { - t.Error("指针不一致") - } -} - -func TestFileSystem_SetTargetByInterface(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{} - - // 目录 - { - asserts.NoError(fs.SetTargetByInterface(&model.Folder{})) - asserts.Len(fs.DirTarget, 1) - asserts.Len(fs.FileTarget, 0) - } - - // 文件 - { - asserts.NoError(fs.SetTargetByInterface(&model.File{})) - asserts.Len(fs.DirTarget, 1) - asserts.Len(fs.FileTarget, 1) - } -} - -func TestFileSystem_SwitchToSlaveHandler(t *testing.T) { - a := assert.New(t) - fs := FileSystem{ - User: &model.User{}, - } - mockNode := &cluster.MasterNode{ - Model: &model.Node{}, - } - fs.SwitchToSlaveHandler(mockNode) - a.IsType(&slaveinmaster.Driver{}, fs.Handler) -} - -func TestFileSystem_SwitchToShadowHandler(t *testing.T) { - a := assert.New(t) - fs := FileSystem{ - User: &model.User{}, - Policy: &model.Policy{}, - } - mockNode := &cluster.MasterNode{ - Model: &model.Node{}, - } - - // local to remote - { - fs.Policy.Type = "local" - fs.SwitchToShadowHandler(mockNode, "", "") - a.IsType(&masterinslave.Driver{}, fs.Handler) - } - - // onedrive - { - fs.Policy.Type = "onedrive" - fs.SwitchToShadowHandler(mockNode, "", "") - a.IsType(&masterinslave.Driver{}, fs.Handler) - } -} diff --git a/pkg/filesystem/fsctx/stream_test.go b/pkg/filesystem/fsctx/stream_test.go deleted file mode 100644 index 1ef6e1f..0000000 --- a/pkg/filesystem/fsctx/stream_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package fsctx - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" - "io" - "io/ioutil" - "os" - "strings" - "testing" -) - -func TestFileStream_Read(t *testing.T) { - asserts := assert.New(t) - file := FileStream{ - File: ioutil.NopCloser(strings.NewReader("123")), - } - var p = make([]byte, 3) - { - n, err := file.Read(p) - asserts.Equal(3, n) - asserts.NoError(err) - } -} - -func TestFileStream_Close(t *testing.T) { - asserts := assert.New(t) - { - file := FileStream{ - File: ioutil.NopCloser(strings.NewReader("123")), - } - err := file.Close() - asserts.NoError(err) - } - - { - file := FileStream{} - err := file.Close() - asserts.NoError(err) - } -} - -func TestFileStream_Seek(t *testing.T) { - asserts := assert.New(t) - f, _ := os.CreateTemp("", "*") - defer func() { - f.Close() - os.Remove(f.Name()) - }() - { - file := FileStream{ - File: f, - Seeker: f, - } - res, err := file.Seek(0, io.SeekStart) - asserts.NoError(err) - asserts.EqualValues(0, res) - } - - { - file := FileStream{} - res, err := file.Seek(0, io.SeekStart) - asserts.Error(err) - asserts.EqualValues(0, res) - } -} - -func TestFileStream_Info(t *testing.T) { - a := assert.New(t) - file := FileStream{} - a.NotNil(file.Info()) - - file.SetSize(10) - a.EqualValues(10, file.Info().Size) - - file.SetModel(&model.File{}) - a.NotNil(file.Info().Model) -} diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go index a2f9ed5..2a4c5d5 100644 --- a/pkg/filesystem/hooks.go +++ b/pkg/filesystem/hooks.go @@ -2,6 +2,12 @@ package filesystem import ( "context" + "io/ioutil" + "net/http" + "strconv" + "strings" + "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" @@ -9,11 +15,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" - "io/ioutil" - "net/http" - "strconv" - "strings" - "time" ) // Hook 钩子函数 @@ -222,6 +223,21 @@ func GenericAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.Fi return nil } +// HookGenerateThumb 生成缩略图 +// func HookGenerateThumb(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { +// // 异步尝试生成缩略图 +// fileMode := fileHeader.Info().Model.(*model.File) +// if fs.Policy.IsThumbGenerateNeeded() { +// fs.recycleLock.Lock() +// go func() { +// defer fs.recycleLock.Unlock() +// _, _ = fs.Handler.Delete(ctx, []string{fileMode.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")}) +// fs.GenerateThumbnail(ctx, fileMode) +// }() +// } +// return nil +// } + // HookClearFileHeaderSize 将FileHeader大小设定为0 func HookClearFileHeaderSize(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { fileHeader.SetSize(0) diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go deleted file mode 100644 index cc660ce..0000000 --- a/pkg/filesystem/hooks_test.go +++ /dev/null @@ -1,708 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "io/ioutil" - "net/http" - "strings" - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestGenericBeforeUpload(t *testing.T) { - asserts := assert.New(t) - file := &fsctx.FileStream{ - Size: 5, - Name: "1.txt", - } - ctx := context.Background() - cache.Set("pack_size_0", uint64(0), 0) - fs := FileSystem{ - User: &model.User{ - Storage: 0, - Group: model.Group{ - MaxStorage: 11, - }, - }, - Policy: &model.Policy{ - MaxSize: 4, - OptionsSerialized: model.PolicyOption{ - FileType: []string{"txt"}, - }, - }, - } - - asserts.Error(HookValidateFile(ctx, &fs, file)) - - file.Size = 1 - file.Name = "1" - asserts.Error(HookValidateFile(ctx, &fs, file)) - - file.Name = "1.txt" - asserts.NoError(HookValidateFile(ctx, &fs, file)) - - file.Name = "1.t/xt" - asserts.Error(HookValidateFile(ctx, &fs, file)) -} - -func TestGenericAfterUploadCanceled(t *testing.T) { - asserts := assert.New(t) - file := &fsctx.FileStream{ - Size: 5, - Name: "TestGenericAfterUploadCanceled", - SavePath: "TestGenericAfterUploadCanceled", - } - ctx := context.Background() - fs := FileSystem{ - User: &model.User{}, - } - - // 成功 - { - mockHandler := &FileHeaderMock{} - fs.Handler = mockHandler - mockHandler.On("Delete", testMock.Anything, testMock.Anything).Return([]string{}, nil) - err := HookDeleteTempFile(ctx, &fs, file) - asserts.NoError(err) - mockHandler.AssertExpectations(t) - } - - // 失败 - { - mockHandler := &FileHeaderMock{} - fs.Handler = mockHandler - mockHandler.On("Delete", testMock.Anything, testMock.Anything).Return([]string{}, errors.New("")) - err := HookDeleteTempFile(ctx, &fs, file) - asserts.NoError(err) - mockHandler.AssertExpectations(t) - } - -} - -func TestGenericAfterUpload(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{}, - } - - ctx := context.Background() - file := &fsctx.FileStream{ - VirtualPath: "/我的文件", - Name: "test.txt", - } - - // 正常 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files"). - WithArgs(1, "我的文件"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs("我的文件", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found")) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - err := GenericAfterUpload(ctx, &fs, file) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 文件已存在 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files"). - WithArgs(1, "我的文件"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs("我的文件", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows( - mock.NewRows([]string{"name"}).AddRow("test.txt"), - ) - err = GenericAfterUpload(ctx, &fs, file) - asserts.Equal(ErrFileExisted, err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 文件已存在, 且为上传占位符 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files"). - WithArgs(1, "我的文件"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs("我的文件", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows( - mock.NewRows([]string{"name", "upload_session_id"}).AddRow("test.txt", "1"), - ) - err = GenericAfterUpload(ctx, &fs, file) - asserts.Equal(ErrFileUploadSessionExisted, err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 插入失败 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files"). - WithArgs(1, "我的文件"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs("我的文件", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found")) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - - err = GenericAfterUpload(ctx, &fs, file) - asserts.Equal(ErrInsertFileRecord, err) - asserts.NoError(mock.ExpectationsWereMet()) - -} - -func TestFileSystem_Use(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{} - - hook := func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - return nil - } - - // 添加一个 - fs.Use("BeforeUpload", hook) - asserts.Len(fs.Hooks["BeforeUpload"], 1) - - // 添加一个 - fs.Use("BeforeUpload", hook) - asserts.Len(fs.Hooks["BeforeUpload"], 2) - - // 不存在 - fs.Use("BeforeUpload2333", hook) - - asserts.NotPanics(func() { - for _, hookName := range []string{ - "AfterUpload", - "AfterValidateFailed", - "AfterUploadCanceled", - "BeforeFileDownload", - } { - fs.Use(hookName, hook) - } - }) - -} - -func TestFileSystem_Trigger(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{}, - } - ctx := context.Background() - - hook := func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error { - fs.User.Storage++ - return nil - } - - // 一个 - fs.Use("BeforeUpload", hook) - err := fs.Trigger(ctx, "BeforeUpload", nil) - asserts.NoError(err) - asserts.Equal(uint64(1), fs.User.Storage) - - // 多个 - fs.Use("BeforeUpload", hook) - fs.Use("BeforeUpload", hook) - err = fs.Trigger(ctx, "BeforeUpload", nil) - asserts.NoError(err) - asserts.Equal(uint64(4), fs.User.Storage) - - // 多个,有失败 - fs.Use("BeforeUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - return errors.New("error") - }) - fs.Use("BeforeUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - asserts.Fail("following hooks executed") - return nil - }) - err = fs.Trigger(ctx, "BeforeUpload", nil) - asserts.Error(err) -} - -func TestHookValidateCapacity(t *testing.T) { - asserts := assert.New(t) - cache.Set("pack_size_1", uint64(0), 0) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - Storage: 0, - Group: model.Group{ - MaxStorage: 11, - }, - }} - ctx := context.Background() - file := &fsctx.FileStream{Size: 11} - { - err := HookValidateCapacity(ctx, fs, file) - asserts.NoError(err) - } - { - file.Size = 12 - err := HookValidateCapacity(ctx, fs, file) - asserts.Error(err) - } -} - -func TestHookValidateCapacityDiff(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{User: &model.User{ - Group: model.Group{ - MaxStorage: 11, - }, - }} - file := model.File{Size: 10} - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - - // 无需操作 - { - a.NoError(HookValidateCapacityDiff(ctx, fs, &fsctx.FileStream{Size: 10})) - } - - // 需要验证 - { - a.Error(HookValidateCapacityDiff(ctx, fs, &fsctx.FileStream{Size: 12})) - } - -} - -func TestHookResetPolicy(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - // 成功 - { - file := model.File{PolicyID: 2} - cache.Deletes([]string{"2"}, "policy_") - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(2, "local")) - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - err := HookResetPolicy(ctx, fs, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 上下文文件不存在 - { - cache.Deletes([]string{"2"}, "policy_") - ctx := context.Background() - err := HookResetPolicy(ctx, fs, nil) - asserts.Error(err) - } -} - -func TestHookCleanFileContent(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - file := &fsctx.FileStream{SavePath: "123/123"} - handlerMock := FileHeaderMock{} - handlerMock.On("Put", testMock.Anything, testMock.Anything).Return(errors.New("error")) - fs.Handler = handlerMock - err := HookCleanFileContent(context.Background(), fs, file) - asserts.Error(err) - handlerMock.AssertExpectations(t) -} - -func TestHookClearFileSize(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - // 成功 - { - ctx := context.WithValue( - context.Background(), - fsctx.FileModelCtx, - model.File{Model: gorm.Model{ID: 1}, Size: 10}, - ) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)"). - WithArgs("", 0, sqlmock.AnyArg(), 1, 10). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)"). - WithArgs(10, sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := HookClearFileSize(ctx, fs, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 上下文对象不存在 - { - ctx := context.Background() - err := HookClearFileSize(ctx, fs, nil) - asserts.Error(err) - } - -} - -func TestHookUpdateSourceName(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - // 成功 - { - originFile := model.File{ - Model: gorm.Model{ID: 1}, - SourceName: "new.txt", - } - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WithArgs("", "new.txt", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := HookUpdateSourceName(ctx, fs, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 上下文错误 - { - ctx := context.Background() - err := HookUpdateSourceName(ctx, fs, nil) - asserts.Error(err) - } -} - -func TestGenericAfterUpdate(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ID: 1}, - }} - - // 成功 是图像文件 - { - originFile := model.File{ - Model: gorm.Model{ID: 1}, - PicInfo: "1,1", - } - newFile := &fsctx.FileStream{Size: 10} - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile) - - handlerMock := FileHeaderMock{} - handlerMock.On("Delete", testMock.Anything, []string{"._thumb"}).Return([]string{}, nil) - fs.Handler = handlerMock - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)"). - WithArgs("", 10, sqlmock.AnyArg(), 1, 0). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)"). - WithArgs(10, sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - err := GenericAfterUpdate(ctx, fs, newFile) - - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 原始文件上下文不存在 - { - newFile := &fsctx.FileStream{Size: 10} - ctx := context.Background() - err := GenericAfterUpdate(ctx, fs, newFile) - asserts.Error(err) - } - - // 无法更新数据库容量 - // 成功 是图像文件 - { - originFile := model.File{ - Model: gorm.Model{ID: 1}, - PicInfo: "1,1", - } - newFile := &fsctx.FileStream{Size: 10} - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile) - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)"). - WithArgs("", 10, sqlmock.AnyArg(), 1, 0). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - - err := GenericAfterUpdate(ctx, fs, newFile) - - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} - -func TestSlaveAfterUpload(t *testing.T) { - asserts := assert.New(t) - conf.SystemConfig.Mode = "slave" - fs, err := NewAnonymousFileSystem() - conf.SystemConfig.Mode = "master" - asserts.NoError(err) - - // 成功 - { - clientMock := requestmock.RequestMock{} - clientMock.On( - "Request", - "POST", - "http://test/callbakc", - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: nil, - Response: &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)), - }, - }) - request.GeneralClient = clientMock - file := &fsctx.FileStream{ - Size: 10, - VirtualPath: "/my", - Name: "test.txt", - SavePath: "/not_exist", - } - err := SlaveAfterUpload(&serializer.UploadSession{Callback: "http://test/callbakc"})(context.Background(), fs, file) - clientMock.AssertExpectations(t) - asserts.NoError(err) - } - - // 跳过回调 - { - file := &fsctx.FileStream{ - Size: 10, - VirtualPath: "/my", - Name: "test.txt", - SavePath: "/not_exist", - } - err := SlaveAfterUpload(&serializer.UploadSession{})(context.Background(), fs, file) - asserts.NoError(err) - } -} - -func TestFileSystem_CleanHooks(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - }, - Hooks: map[string][]Hook{ - "hook1": []Hook{}, - "hook2": []Hook{}, - "hook3": []Hook{}, - }, - } - - // 清理一个 - { - fs.CleanHooks("hook2") - asserts.Len(fs.Hooks, 2) - asserts.Contains(fs.Hooks, "hook1") - asserts.Contains(fs.Hooks, "hook3") - } - - // 清理全部 - { - fs.CleanHooks("") - asserts.Len(fs.Hooks, 0) - } -} - -func TestHookCancelContext(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{} - ctx, cancel := context.WithCancel(context.Background()) - - // empty ctx - { - asserts.NoError(HookCancelContext(ctx, fs, nil)) - select { - case <-ctx.Done(): - t.Errorf("Channel should not be closed") - default: - - } - } - - // with cancel ctx - { - ctx = context.WithValue(ctx, fsctx.CancelFuncCtx, cancel) - asserts.NoError(HookCancelContext(ctx, fs, nil)) - _, ok := <-ctx.Done() - asserts.False(ok) - } -} - -func TestHookClearFileHeaderSize(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{Size: 10} - a.NoError(HookClearFileHeaderSize(context.Background(), fs, file)) - a.EqualValues(0, file.Size) -} - -func TestHookTruncateFileTo(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{} - a.NoError(HookTruncateFileTo(0)(context.Background(), fs, file)) - - fs.Handler = local.Driver{} - a.Error(HookTruncateFileTo(0)(context.Background(), fs, file)) -} - -func TestHookChunkUploaded(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - AppendStart: 10, - Size: 10, - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 20, sqlmock.AnyArg(), 1, 0).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)"). - WithArgs(20, sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(HookChunkUploaded(context.Background(), fs, file)) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestHookChunkUploadFailed(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - AppendStart: 10, - Size: 10, - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 10, sqlmock.AnyArg(), 1, 0).WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)"). - WithArgs(10, sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(HookChunkUploadFailed(context.Background(), fs, file)) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestHookPopPlaceholderToFile(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(HookPopPlaceholderToFile("1,1")(context.Background(), fs, file)) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestHookPopPlaceholderToFileBySuffix(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{ - Policy: &model.Policy{Type: "cos"}, - } - file := &fsctx.FileStream{ - Name: "1.png", - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - a.NoError(HookPopPlaceholderToFile("")(context.Background(), fs, file)) - a.NoError(mock.ExpectationsWereMet()) -} - -func TestHookDeleteUploadSession(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - cache.Set(UploadSessionCachePrefix+"TestHookDeleteUploadSession", "", 0) - a.NoError(HookDeleteUploadSession("TestHookDeleteUploadSession")(context.Background(), fs, file)) - _, ok := cache.Get(UploadSessionCachePrefix + "TestHookDeleteUploadSession") - a.False(ok) -} -func TestNewWebdavAfterUploadHook(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{} - file := &fsctx.FileStream{ - Model: &model.File{ - Model: gorm.Model{ID: 1}, - }, - } - - req, _ := http.NewRequest("get", "http://localhost", nil) - req.Header.Add("X-Oc-Mtime", "1681521402") - req.Header.Add("OC-Checksum", "checksum") - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := NewWebdavAfterUploadHook(req)(context.Background(), fs, file) - a.NoError(err) - a.NoError(mock.ExpectationsWereMet()) - -} diff --git a/pkg/filesystem/image_test.go b/pkg/filesystem/image_test.go deleted file mode 100644 index 4180858..0000000 --- a/pkg/filesystem/image_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/thumbmock" - "github.com/cloudreve/Cloudreve/v3/pkg/thumb" - testMock "github.com/stretchr/testify/mock" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_GetThumb(t *testing.T) { - a := assert.New(t) - fs := &FileSystem{User: &model.User{}} - - // file not found - { - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - res, err := fs.GetThumb(context.Background(), 1) - a.ErrorIs(err, ErrObjectNotExist) - a.Nil(res) - a.NoError(mock.ExpectationsWereMet()) - } - - // thumb not exist - { - fs.SetTargetFile(&[]model.File{{ - MetadataSerialized: map[string]string{ - model.ThumbStatusMetadataKey: model.ThumbStatusNotAvailable, - }, - Policy: model.Policy{Type: "mock"}, - }}) - fs.FileTarget[0].Policy.ID = 1 - - res, err := fs.GetThumb(context.Background(), 1) - a.ErrorIs(err, ErrObjectNotExist) - a.Nil(res) - } - - // thumb not initialized, also failed to generate - { - fs.CleanTargets() - fs.SetTargetFile(&[]model.File{{ - Policy: model.Policy{Type: "mock"}, - Size: 31457281, - }}) - testHandller2 := new(FileHeaderMock) - testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist) - fs.Handler = testHandller2 - fs.FileTarget[0].Policy.ID = 1 - res, err := fs.GetThumb(context.Background(), 1) - a.Contains(err.Error(), "file too large") - a.Nil(res.Content) - } - - // thumb not initialized, failed to get source - { - fs.CleanTargets() - fs.SetTargetFile(&[]model.File{{ - Policy: model.Policy{Type: "mock"}, - }}) - testHandller2 := new(FileHeaderMock) - testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist) - testHandller2.On("Get", testMock.Anything, "").Return(MockRSC{}, errors.New("error")) - fs.Handler = testHandller2 - fs.FileTarget[0].Policy.ID = 1 - res, err := fs.GetThumb(context.Background(), 1) - a.Contains(err.Error(), "error") - a.Nil(res.Content) - } - - // thumb not initialized, no available generators - { - thumb.Generators = []thumb.Generator{} - fs.CleanTargets() - fs.SetTargetFile(&[]model.File{{ - Policy: model.Policy{Type: "local"}, - }}) - testHandller2 := new(FileHeaderMock) - testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist) - testHandller2.On("Get", testMock.Anything, "").Return(MockRSC{}, nil) - fs.Handler = testHandller2 - fs.FileTarget[0].Policy.ID = 1 - res, err := fs.GetThumb(context.Background(), 1) - a.ErrorIs(err, thumb.ErrNotAvailable) - a.Nil(res) - } - - // thumb not initialized, thumb generated but cannot be open - { - mockGenerator := &thumbmock.GeneratorMock{} - thumb.Generators = []thumb.Generator{mockGenerator} - fs.CleanTargets() - fs.SetTargetFile(&[]model.File{{ - Policy: model.Policy{Type: "mock"}, - }}) - cache.Set("setting_thumb_vips_enabled", "1", 0) - testHandller2 := new(FileHeaderMock) - testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist) - testHandller2.On("Get", testMock.Anything, "").Return(MockRSC{}, nil) - mockGenerator.On("Generate", testMock.Anything, testMock.Anything, testMock.Anything, testMock.Anything, testMock.Anything). - Return(&thumb.Result{Path: "not_exit_thumb"}, nil) - - fs.Handler = testHandller2 - fs.FileTarget[0].Policy.ID = 1 - res, err := fs.GetThumb(context.Background(), 1) - a.Contains(err.Error(), "failed to open temp thumb") - a.Nil(res.Content) - testHandller2.AssertExpectations(t) - mockGenerator.AssertExpectations(t) - } -} - -func TestFileSystem_ThumbWorker(t *testing.T) { - asserts := assert.New(t) - - asserts.NotPanics(func() { - getThumbWorker().addWorker() - getThumbWorker().releaseWorker() - }) -} diff --git a/pkg/filesystem/manage_test.go b/pkg/filesystem/manage_test.go deleted file mode 100644 index 1f2cc1a..0000000 --- a/pkg/filesystem/manage_test.go +++ /dev/null @@ -1,848 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - "os" - "testing" - - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - testMock "github.com/stretchr/testify/mock" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_ListPhysical(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{ - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{Type: "mock"}, - } - ctx := context.Background() - - // 未知存储策略 - { - fs.Policy.Type = "unknown" - res, err := fs.ListPhysical(ctx, "/") - asserts.Equal(ErrUnknownPolicyType, err) - asserts.Empty(res) - fs.Policy.Type = "mock" - } - - // 无法列取目录 - { - testHandler := new(FileHeaderMock) - testHandler.On("List", testMock.Anything, "/", testMock.Anything).Return([]response.Object{}, errors.New("error")) - fs.Handler = testHandler - res, err := fs.ListPhysical(ctx, "/") - asserts.EqualError(err, "error") - asserts.Empty(res) - } - - // 成功 - { - testHandler := new(FileHeaderMock) - testHandler.On("List", testMock.Anything, "/", testMock.Anything).Return( - []response.Object{{IsDir: true, Name: "1"}, {IsDir: false, Name: "2"}}, - nil, - ) - fs.Handler = testHandler - res, err := fs.ListPhysical(ctx, "/") - asserts.NoError(err) - asserts.Len(res, 1) - asserts.Equal("1", res[0].Name) - } -} - -func TestFileSystem_List(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - ctx := context.Background() - - // 成功,子目录包含文件和路径,不使用路径处理钩子 - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(5, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_folder1").AddRow(7, "sub_folder2")) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_file1.txt").AddRow(7, "sub_file2.txt")) - objects, err := fs.List(ctx, "/folder", nil) - asserts.Len(objects, 4) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功,子目录包含文件和路径,不使用路径处理钩子,包含分享key - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(5, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_folder1").AddRow(7, "sub_folder2")) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_file1.txt").AddRow(7, "sub_file2.txt")) - ctxWithKey := context.WithValue(ctx, fsctx.ShareKeyCtx, "share") - objects, err = fs.List(ctxWithKey, "/folder", nil) - asserts.Len(objects, 4) - asserts.Equal("share", objects[3].Key) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功,子目录包含文件和路径,使用路径处理钩子 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(2, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"}).AddRow(6, "sub_folder1", "/folder").AddRow(7, "sub_folder2", "/folder")) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"}).AddRow(6, "sub_file1.txt", "/folder").AddRow(7, "sub_file2.txt", "/folder")) - objects, err = fs.List(ctx, "/folder", func(s string) string { - return "prefix" + s - }) - asserts.Len(objects, 4) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - for _, value := range objects { - asserts.Contains(value.Path, "prefix/") - } - - // 成功,子目录包含路径,使用路径处理钩子 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(2, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"})) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"}).AddRow(6, "sub_file1.txt", "/folder").AddRow(7, "sub_file2.txt", "/folder")) - objects, err = fs.List(ctx, "/folder", func(s string) string { - return "prefix" + s - }) - asserts.Len(objects, 2) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - for _, value := range objects { - asserts.Contains(value.Path, "prefix/") - } - - // 成功,子目录下为空,使用路径处理钩子 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(2, "folder", 1)) - - mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"})) - mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"})) - objects, err = fs.List(ctx, "/folder", func(s string) string { - return "prefix" + s - }) - asserts.Len(objects, 0) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功,子目录路径不存在 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1)) - // folder - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "folder"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"})) - - objects, err = fs.List(ctx, "/folder", func(s string) string { - return "prefix" + s - }) - asserts.Len(objects, 0) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFileSystem_CreateDirectory(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - ctx := context.Background() - - // 目录名非法 - _, err := fs.CreateDirectory(ctx, "/ad/a+?") - asserts.Equal(ErrIllegalObjectName, err) - - // 存在同名文件 - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "ab")) - _, err = fs.CreateDirectory(ctx, "/ad/ab") - asserts.Equal(ErrFileExisted, err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 存在同名目录,直接返回 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - // ab - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ab", 2, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1)) - res, err := fs.CreateDirectory(ctx, "/ad/ab") - asserts.NoError(err) - asserts.EqualValues(3, res.ID) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功创建 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ab", 2, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - _, err = fs.CreateDirectory(ctx, "/ad/ab") - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 成功创建, 递归创建父目录 - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - // 创建ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ad", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - // 创建ab - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ab", 2, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - _, err = fs.CreateDirectory(ctx, "/ad/ab") - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 底层创建失败 - // 成功创建 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "ad"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - // 创建ad - mock.ExpectQuery("SELECT(.+)"). - WithArgs("ad", 1, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1)).WillReturnError(errors.New("error")) - mock.ExpectRollback() - mock.ExpectQuery("SELECT(.+)"). - WillReturnError(errors.New("error")) - _, err = fs.CreateDirectory(ctx, "/ad/ab") - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 直接创建根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - _, err = fs.CreateDirectory(ctx, "/") - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - - // 直接创建根目录, 重设根目录 - fs.Root = &model.Folder{} - _, err = fs.CreateDirectory(ctx, "/") - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestFileSystem_ListDeleteFiles(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt").AddRow(2, "2.txt")) - err := fs.ListDeleteFiles(context.Background(), []uint{1}) - asserts.NoError(err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - err := fs.ListDeleteFiles(context.Background(), []uint{1}) - asserts.Error(err) - asserts.Equal(serializer.CodeDBError, err.(serializer.AppError).Code) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFileSystem_ListDeleteDirs(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - - // 成功 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, 0). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 2, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(4, "1.txt"). - AddRow(5, "2.txt"). - AddRow(6, "3.txt"), - ) - err := fs.ListDeleteDirs(context.Background(), []uint{1}) - asserts.NoError(err) - asserts.Len(fs.FileTarget, 3) - asserts.Len(fs.DirTarget, 3) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 成功,忽略根目录 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, nil). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(1, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name"}). - AddRow(4, "1.txt"). - AddRow(5, "2.txt"). - AddRow(6, "3.txt"), - ) - fs.CleanTargets() - err := fs.ListDeleteDirs(context.Background(), []uint{1}) - asserts.NoError(err) - asserts.Len(fs.FileTarget, 3) - asserts.Len(fs.DirTarget, 2) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 检索文件发生错误 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, 0). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnError(errors.New("error")) - fs.CleanTargets() - err := fs.ListDeleteDirs(context.Background(), []uint{1}) - asserts.Error(err) - asserts.Len(fs.DirTarget, 3) - asserts.NoError(mock.ExpectationsWereMet()) - } - // 检索目录发生错误 - { - mock.ExpectQuery("SELECT(.+)"). - WillReturnError(errors.New("error")) - err := fs.ListDeleteDirs(context.Background(), []uint{1}) - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFileSystem_Delete(t *testing.T) { - conf.DatabaseConfig.Type = "mysql" - asserts := assert.New(t) - cache.Set("pack_size_1", uint64(0), 0) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 0, - }, - Storage: 3, - Group: model.Group{MaxStorage: 3}, - }} - ctx := context.Background() - - //全部未成功,强制 - { - fs.CleanTargets() - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, 0). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}). - AddRow(4, "1.txt", "1.txt", 365, 1), - ) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 365, 2)) - // 两次查询软连接 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - // 查询上传策略 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(365, "local")) - // 删除文件记录 - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - // 删除对应分享 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)shares"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - // 删除目录 - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - // 删除对应分享 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)shares"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - - fs.FileTarget = []model.File{} - fs.DirTarget = []model.Folder{} - err := fs.Delete(ctx, []uint{1}, []uint{1}, true, false) - asserts.NoError(err) - } - //全部成功 - { - fs.CleanTargets() - file, err := os.Create(util.RelativePath("1.txt")) - file2, err := os.Create(util.RelativePath("2.txt")) - file.Close() - file2.Close() - asserts.NoError(err) - mock.ExpectQuery("SELECT(.+)"). - WillReturnRows( - sqlmock.NewRows([]string{"id", "parent_id"}). - AddRow(1, 0). - AddRow(2, 0). - AddRow(3, 0), - ) - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 2, 3). - WillReturnRows( - sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}). - AddRow(4, "1.txt", "1.txt", 602, 1), - ) - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 602, 2)) - // 两次查询软连接 - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - mock.ExpectQuery("SELECT(.+)files(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) - // 查询上传策略 - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(602, "local")) - // 删除文件记录 - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectCommit() - // 删除对应分享 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)shares"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - // 删除目录 - mock.ExpectBegin() - mock.ExpectExec("DELETE(.+)"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - // 删除对应分享 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)shares"). - WillReturnResult(sqlmock.NewResult(0, 3)) - mock.ExpectCommit() - - fs.FileTarget = []model.File{} - fs.DirTarget = []model.Folder{} - err = fs.Delete(ctx, []uint{1}, []uint{1}, false, false) - asserts.NoError(err) - } - -} - -func TestFileSystem_Copy(t *testing.T) { - asserts := assert.New(t) - cache.Set("pack_size_1", uint64(0), 0) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Storage: 3, - Group: model.Group{MaxStorage: 3}, - }} - ctx := context.Background() - - // 目录不存在 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows( - sqlmock.NewRows([]string{"name"}), - ) - mock.ExpectQuery("SELECT(.+)").WillReturnRows( - sqlmock.NewRows([]string{"name"}), - ) - err := fs.Copy(ctx, []uint{}, []uint{}, "/src", "/dst") - asserts.Equal(ErrPathNotExist, err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 复制目录出错 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "dst"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "src"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - - err := fs.Copy(ctx, []uint{1}, []uint{}, "/src", "/dst") - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } - -} - -func TestFileSystem_Move(t *testing.T) { - asserts := assert.New(t) - cache.Set("pack_size_1", uint64(0), 0) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - Storage: 3, - Group: model.Group{MaxStorage: 3}, - }} - ctx := context.Background() - - // 目录不存在 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows( - sqlmock.NewRows([]string{"name"}), - ) - err := fs.Move(ctx, []uint{}, []uint{}, "/src", "/dst") - asserts.Equal(ErrPathNotExist, err) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 移动目录出错 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "dst"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "src"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - err := fs.Move(ctx, []uint{1}, []uint{}, "/src", "/dst") - asserts.Error(err) - asserts.NoError(mock.ExpectationsWereMet()) - } -} - -func TestFileSystem_Rename(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{}, - } - ctx := context.Background() - - // 重命名文件 成功 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old.text")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)"). - WithArgs(sqlmock.AnyArg(), "new.txt", sqlmock.AnyArg(), 10). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 重命名文件 不存在 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrPathNotExist, err) - } - - // 重命名文件 失败 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old.text")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)"). - WithArgs(sqlmock.AnyArg(), "new.txt", sqlmock.AnyArg(), 10). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrFileExisted, err) - } - - // 重命名目录 成功 - { - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)"). - WithArgs("new", 10). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - err := fs.Rename(ctx, []uint{10}, []uint{}, "new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - } - - // 重命名目录 不存在 - { - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.Rename(ctx, []uint{10}, []uint{}, "new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrPathNotExist, err) - } - - // 重命名目录 失败 - { - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old")) - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)"). - WithArgs("new", 10). - WillReturnError(errors.New("error")) - mock.ExpectRollback() - err := fs.Rename(ctx, []uint{10}, []uint{}, "new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrFileExisted, err) - } - - // 未选中任何对象 - { - err := fs.Rename(ctx, []uint{}, []uint{}, "new") - asserts.Error(err) - asserts.Equal(ErrPathNotExist, err) - } - - // 新名字是目录,不合法 - { - err := fs.Rename(ctx, []uint{10}, []uint{}, "ne/w") - asserts.Error(err) - asserts.Equal(ErrIllegalObjectName, err) - } - - // 新名字是文件,不合法 - { - err := fs.Rename(ctx, []uint{}, []uint{10}, "ne/w") - asserts.Error(err) - asserts.Equal(ErrIllegalObjectName, err) - } - - // 新名字是文件,扩展名不合法 - { - fs.Policy.OptionsSerialized.FileType = []string{"txt"} - err := fs.Rename(ctx, []uint{}, []uint{10}, "1.jpg") - asserts.Error(err) - asserts.Equal(ErrIllegalObjectName, err) - } - - // 新名字是目录,不应该检测扩展名 - { - fs.Policy.OptionsSerialized.FileType = []string{"txt"} - mock.ExpectQuery("SELECT(.+)folders(.+)"). - WithArgs(10, 1). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - err := fs.Rename(ctx, []uint{10}, []uint{}, "new") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Equal(ErrPathNotExist, err) - } -} - -func TestFileSystem_SaveTo(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - ctx := context.Background() - - // 单文件 失败 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - fs.SetTargetFile(&[]model.File{{Name: "test.txt"}}) - err := fs.SaveTo(ctx, "/") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - // 目录 成功 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - fs.SetTargetDir(&[]model.Folder{{Name: "folder"}}) - err := fs.SaveTo(ctx, "/") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } - // 父目录不存在 - { - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - fs.SetTargetDir(&[]model.Folder{{Name: "folder"}}) - err := fs.SaveTo(ctx, "/") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - } -} diff --git a/pkg/filesystem/path.go b/pkg/filesystem/path.go index b0637aa..056c0c8 100644 --- a/pkg/filesystem/path.go +++ b/pkg/filesystem/path.go @@ -15,13 +15,20 @@ import ( // IsPathExist 返回给定目录是否存在 // 如果存在就返回目录 func (fs *FileSystem) IsPathExist(path string) (bool, *model.Folder) { + tracedEnd, currentFolder := fs.getClosedParent(path) + if tracedEnd { + return true, currentFolder + } + return false, nil +} + +func (fs *FileSystem) getClosedParent(path string) (bool, *model.Folder) { pathList := util.SplitPath(path) if len(pathList) == 0 { return false, nil } // 递归步入目录 - // TODO:测试新增 var currentFolder *model.Folder // 如果已设定跟目录对象,则从给定目录向下遍历 @@ -42,10 +49,12 @@ func (fs *FileSystem) IsPathExist(path string) (bool, *model.Folder) { return false, nil } } else { - currentFolder, err = currentFolder.GetChild(folderName) + nextFolder, err := currentFolder.GetChild(folderName) if err != nil { - return false, nil + return false, currentFolder } + + currentFolder = nextFolder } } diff --git a/pkg/filesystem/path_test.go b/pkg/filesystem/path_test.go deleted file mode 100644 index e4065a4..0000000 --- a/pkg/filesystem/path_test.go +++ /dev/null @@ -1,172 +0,0 @@ -package filesystem - -import ( - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestFileSystem_IsFileExist(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - - // 存在 - { - path := "/1.txt" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, "1.txt").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt")) - exist, file := fs.IsFileExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal(uint(1), file.ID) - } - - // 文件不存在 - { - path := "/1.txt" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)").WithArgs(1, "1.txt").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) - exist, _ := fs.IsFileExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(exist) - } - - // 父目录不存在 - { - path := "/1.txt" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - exist, _ := fs.IsFileExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(exist) - } -} - -func TestFileSystem_IsPathExist(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - - // 查询根目录 - { - path := "/" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - exist, folder := fs.IsPathExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal(uint(1), folder.ID) - } - - // 深层路径 - { - path := "/1/2/3" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "1"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - // 2 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1, "2"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1)) - // 3 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(3, 1, "3"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(4, 1)) - exist, folder := fs.IsPathExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal(uint(4), folder.ID) - } - - // 深层路径 重设根目录为/1 - { - path := "/2/3" - fs.Root = &model.Folder{Name: "1", Model: gorm.Model{ID: 2}, OwnerID: 1} - // 2 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1, "2"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1)) - // 3 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(3, 1, "3"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(4, 1)) - exist, folder := fs.IsPathExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal(uint(4), folder.ID) - fs.Root = nil - } - - // 深层 不存在 - { - path := "/1/2/3" - // 根目录 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - // 1 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, 1, "1"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1)) - // 2 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(2, 1, "2"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1)) - // 3 - mock.ExpectQuery("SELECT(.+)"). - WithArgs(3, 1, "3"). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"})) - exist, folder := fs.IsPathExist(path) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(exist) - asserts.Nil(folder) - } - -} - -func TestFileSystem_IsChildFileExist(t *testing.T) { - asserts := assert.New(t) - fs := &FileSystem{User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }} - folder := model.Folder{ - Model: gorm.Model{ID: 1}, - Name: "123", - Position: "/", - } - - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1, "321"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(2, "321")) - exist, childFile := fs.IsChildFileExist(&folder, "321") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.True(exist) - asserts.Equal("/123", childFile.Position) -} diff --git a/pkg/filesystem/relocate.go b/pkg/filesystem/relocate.go new file mode 100755 index 0000000..673c61b --- /dev/null +++ b/pkg/filesystem/relocate.go @@ -0,0 +1,102 @@ +package filesystem + +import ( + "context" + + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +/* ================ + 存储策略迁移 + ================ +*/ + +// Relocate 将目标文件转移到当前存储策略下 +func (fs *FileSystem) Relocate(ctx context.Context, files []model.File, policy *model.Policy) error { + // 重设存储策略为要转移的目的策略 + fs.Policy = policy + if err := fs.DispatchHandler(); err != nil { + return err + } + + // 将目前文件根据存储策略分组 + fileGroup := fs.GroupFileByPolicy(ctx, files) + + // 按照存储策略分组处理每个文件 + for _, fileList := range fileGroup { + // 如果存储策略一样,则跳过 + if fileList[0].GetPolicy().ID == fs.Policy.ID { + util.Log().Debug("Skip relocating %d file(s), since they are already in desired policy.", + len(fileList)) + continue + } + + // 获取当前存储策略的处理器 + currentPolicy, _ := model.GetPolicyByID(fileList[0].PolicyID) + currentHandler, err := getNewPolicyHandler(¤tPolicy) + if err != nil { + return err + } + + // 记录转移完毕需要删除的文件 + toBeDeleted := make([]model.File, 0, len(fileList)) + + // 循环处理每一个文件 + // for id, r := 0, len(fileList); id < r; id++ { + for id, _ := range fileList { + // 验证文件是否符合新存储策略的规定 + if err := HookValidateFile(ctx, fs, fileList[id]); err != nil { + util.Log().Debug("File %q failed to pass validators in new policy %q, skipping.", + fileList[id].Name, err) + continue + } + + // 为文件生成新存储策略下的物理路径 + savePath := fs.GenerateSavePath(ctx, fileList[id]) + + // 获取原始文件 + src, err := currentHandler.Get(ctx, fileList[id].SourceName) + if err != nil { + util.Log().Debug("Failed to get file %q: %s, skipping.", + fileList[id].Name, err) + continue + } + + // 转存到新的存储策略 + if err := fs.Handler.Put(ctx, &fsctx.FileStream{ + File: src, + SavePath: savePath, + Size: fileList[id].Size, + }); err != nil { + util.Log().Debug("Failed to migrate file %q: %s, skipping.", + fileList[id].Name, err) + continue + } + + toBeDeleted = append(toBeDeleted, *fileList[id]) + + // 更新文件信息 + fileList[id].Relocate(savePath, fs.Policy.ID) + } + + // 排除带有软链接的文件 + toBeDeletedClean, err := model.RemoveFilesWithSoftLinks(toBeDeleted) + if err != nil { + util.Log().Warning("Failed to check soft links: %s", err) + } + + deleteSourceNames := make([]string, 0, len(toBeDeleted)) + for i := 0; i < len(toBeDeletedClean); i++ { + deleteSourceNames = append(deleteSourceNames, toBeDeletedClean[i].SourceName) + } + + // 删除原始策略中的文件 + if _, err := currentHandler.Delete(ctx, deleteSourceNames); err != nil { + util.Log().Warning("Cannot delete files in origin policy after relocating: %s", err) + } + } + + return nil +} diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go index 08dde53..bb8844d 100644 --- a/pkg/filesystem/upload.go +++ b/pkg/filesystem/upload.go @@ -194,17 +194,15 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS // UploadFromStream 从文件流上传文件 func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream, resetPolicy bool) error { + // 给文件系统分配钩子 + fs.Lock.Lock() if resetPolicy { - // 重设存储策略 - fs.Policy = &fs.User.Policy - err := fs.DispatchHandler() + err := fs.SetPolicyFromPath(file.VirtualPath) if err != nil { return err } } - // 给文件系统分配钩子 - fs.Lock.Lock() if fs.Hooks == nil { fs.Use("BeforeUpload", HookValidateFile) fs.Use("BeforeUpload", HookValidateCapacity) diff --git a/pkg/filesystem/upload_test.go b/pkg/filesystem/upload_test.go deleted file mode 100644 index 61dad9f..0000000 --- a/pkg/filesystem/upload_test.go +++ /dev/null @@ -1,263 +0,0 @@ -package filesystem - -import ( - "context" - "errors" - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" - "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" -) - -type FileHeaderMock struct { - testMock.Mock -} - -func (m FileHeaderMock) Put(ctx context.Context, file fsctx.FileHeader) error { - args := m.Called(ctx, file) - return args.Error(0) -} - -func (m FileHeaderMock) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) { - args := m.Called(ctx, ttl, uploadSession, file) - return args.Get(0).(*serializer.UploadCredential), args.Error(1) -} - -func (m FileHeaderMock) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error { - args := m.Called(ctx, uploadSession) - return args.Error(0) -} - -func (m FileHeaderMock) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) { - args := m.Called(ctx, path, recursive) - return args.Get(0).([]response.Object), args.Error(1) -} - -func (m FileHeaderMock) Get(ctx context.Context, path string) (response.RSCloser, error) { - args := m.Called(ctx, path) - return args.Get(0).(response.RSCloser), args.Error(1) -} - -func (m FileHeaderMock) Delete(ctx context.Context, files []string) ([]string, error) { - args := m.Called(ctx, files) - return args.Get(0).([]string), args.Error(1) -} - -func (m FileHeaderMock) Thumb(ctx context.Context, files *model.File) (*response.ContentResponse, error) { - args := m.Called(ctx, files) - return args.Get(0).(*response.ContentResponse), args.Error(1) -} - -func (m FileHeaderMock) Source(ctx context.Context, path string, expires int64, isDownload bool, speed int) (string, error) { - args := m.Called(ctx, path, expires, isDownload, speed) - return args.Get(0).(string), args.Error(1) -} - -func TestFileSystem_Upload(t *testing.T) { - asserts := assert.New(t) - - // 正常 - testHandler := new(FileHeaderMock) - testHandler.On("Put", testMock.Anything, testMock.Anything, testMock.Anything).Return(nil) - fs := &FileSystem{ - Handler: testHandler, - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{ - AutoRename: false, - DirNameRule: "{path}", - }, - } - ctx, cancel := context.WithCancel(context.Background()) - c, _ := gin.CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - ctx = context.WithValue(ctx, fsctx.GinCtx, c) - cancel() - file := &fsctx.FileStream{ - Size: 5, - VirtualPath: "/", - Name: "1.txt", - } - err := fs.Upload(ctx, file) - asserts.NoError(err) - - // 正常,上下文已指定源文件 - testHandler = new(FileHeaderMock) - testHandler.On("Put", testMock.Anything, testMock.Anything).Return(nil) - fs = &FileSystem{ - Handler: testHandler, - User: &model.User{ - Model: gorm.Model{ - ID: 1, - }, - }, - Policy: &model.Policy{ - AutoRename: false, - DirNameRule: "{path}", - }, - } - ctx, cancel = context.WithCancel(context.Background()) - c, _ = gin.CreateTestContext(httptest.NewRecorder()) - c.Request, _ = http.NewRequest("POST", "/", nil) - ctx = context.WithValue(ctx, fsctx.GinCtx, c) - ctx = context.WithValue(ctx, fsctx.FileModelCtx, model.File{SourceName: "123/123.txt"}) - cancel() - file = &fsctx.FileStream{ - Size: 5, - VirtualPath: "/", - Name: "1.txt", - File: ioutil.NopCloser(strings.NewReader("")), - } - err = fs.Upload(ctx, file) - asserts.NoError(err) - - // BeforeUpload 返回错误 - fs.Use("BeforeUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - return errors.New("error") - }) - err = fs.Upload(ctx, file) - asserts.Error(err) - fs.Hooks["BeforeUpload"] = nil - testHandler.AssertExpectations(t) - - // 上传文件失败 - testHandler2 := new(FileHeaderMock) - testHandler2.On("Put", testMock.Anything, testMock.Anything).Return(errors.New("error")) - fs.Handler = testHandler2 - err = fs.Upload(ctx, file) - asserts.Error(err) - testHandler2.AssertExpectations(t) - - // AfterUpload失败 - testHandler3 := new(FileHeaderMock) - testHandler3.On("Put", testMock.Anything, testMock.Anything).Return(nil) - fs.Handler = testHandler3 - fs.Use("AfterUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - return errors.New("error") - }) - fs.Use("AfterValidateFailed", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error { - return errors.New("error") - }) - err = fs.Upload(ctx, file) - asserts.Error(err) - testHandler2.AssertExpectations(t) - -} - -func TestFileSystem_GetUploadToken(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{Model: gorm.Model{ID: 1}}, - Policy: &model.Policy{}, - } - ctx := context.Background() - - // 成功 - { - cache.SetSettings(map[string]string{ - "upload_session_timeout": "10", - }, "setting_") - testHandler := new(FileHeaderMock) - testHandler.On("Token", testMock.Anything, int64(10), testMock.Anything, testMock.Anything).Return(&serializer.UploadCredential{Credential: "test"}, nil) - fs.Handler = testHandler - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found")) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - res, err := fs.CreateUploadSession(ctx, &fsctx.FileStream{ - Size: 0, - Name: "file", - VirtualPath: "/", - }) - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - asserts.NoError(err) - asserts.Equal("test", res.Credential) - } - - // 无法获取上传凭证 - { - cache.SetSettings(map[string]string{ - "upload_credential_timeout": "10", - "upload_session_timeout": "10", - }, "setting_") - testHandler := new(FileHeaderMock) - testHandler.On("Token", testMock.Anything, int64(10), testMock.Anything, testMock.Anything).Return(&serializer.UploadCredential{}, errors.New("error")) - fs.Handler = testHandler - mock.ExpectQuery("SELECT(.+)"). - WithArgs(1). - WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1)) - mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found")) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - _, err := fs.CreateUploadSession(ctx, &fsctx.FileStream{ - Size: 0, - Name: "file", - VirtualPath: "/", - }) - asserts.NoError(mock.ExpectationsWereMet()) - testHandler.AssertExpectations(t) - asserts.Error(err) - } -} - -func TestFileSystem_UploadFromStream(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - Policy: model.Policy{Type: "mock"}, - }, - Policy: &model.Policy{Type: "mock"}, - } - ctx := context.Background() - - err := fs.UploadFromStream(ctx, &fsctx.FileStream{ - File: ioutil.NopCloser(strings.NewReader("123")), - }, true) - asserts.Error(err) -} - -func TestFileSystem_UploadFromPath(t *testing.T) { - asserts := assert.New(t) - fs := FileSystem{ - User: &model.User{ - Model: gorm.Model{ID: 1}, - Policy: model.Policy{Type: "mock"}, - }, - Policy: &model.Policy{Type: "mock"}, - } - ctx := context.Background() - - // 文件不存在 - { - err := fs.UploadFromPath(ctx, "test/not_exist", "/", fsctx.Overwrite) - asserts.Error(err) - } - - // 文存在,上传失败 - { - err := fs.UploadFromPath(ctx, "tests/test.zip", "/", fsctx.Overwrite) - asserts.Error(err) - } -} diff --git a/pkg/filesystem/validator_test.go b/pkg/filesystem/validator_test.go deleted file mode 100644 index 8f685f2..0000000 --- a/pkg/filesystem/validator_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package filesystem - -import ( - "context" - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestFileSystem_ValidateLegalName(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{} - asserts.True(fs.ValidateLegalName(ctx, "1.txt")) - asserts.True(fs.ValidateLegalName(ctx, "1-1.txt")) - asserts.True(fs.ValidateLegalName(ctx, "1?1.txt")) - asserts.False(fs.ValidateLegalName(ctx, "1:1.txt")) - asserts.False(fs.ValidateLegalName(ctx, "../11.txt")) - asserts.False(fs.ValidateLegalName(ctx, "/11.txt")) - asserts.False(fs.ValidateLegalName(ctx, "\\11.txt")) - asserts.False(fs.ValidateLegalName(ctx, "")) - asserts.False(fs.ValidateLegalName(ctx, "1.tx t ")) - asserts.True(fs.ValidateLegalName(ctx, "1.tx t")) -} - -func TestFileSystem_ValidateCapacity(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - cache.Set("pack_size_0", uint64(0), 0) - fs := FileSystem{ - User: &model.User{ - Storage: 10, - Group: model.Group{ - MaxStorage: 11, - }, - }, - } - - asserts.True(fs.ValidateCapacity(ctx, 1)) - asserts.Equal(uint64(11), fs.User.Storage) - - fs.User.Storage = 5 - asserts.False(fs.ValidateCapacity(ctx, 10)) - asserts.Equal(uint64(5), fs.User.Storage) -} - -func TestFileSystem_ValidateFileSize(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{}, - Policy: &model.Policy{ - MaxSize: 10, - }, - } - - asserts.True(fs.ValidateFileSize(ctx, 5)) - asserts.True(fs.ValidateFileSize(ctx, 10)) - asserts.False(fs.ValidateFileSize(ctx, 11)) - - // 无限制 - fs.Policy.MaxSize = 0 - asserts.True(fs.ValidateFileSize(ctx, 11)) -} - -func TestFileSystem_ValidateExtension(t *testing.T) { - asserts := assert.New(t) - ctx := context.Background() - fs := FileSystem{ - User: &model.User{}, - Policy: &model.Policy{ - OptionsSerialized: model.PolicyOption{ - FileType: nil, - }, - }, - } - - asserts.True(fs.ValidateExtension(ctx, "1")) - asserts.True(fs.ValidateExtension(ctx, "1.txt")) - - fs.Policy.OptionsSerialized.FileType = []string{} - asserts.True(fs.ValidateExtension(ctx, "1")) - asserts.True(fs.ValidateExtension(ctx, "1.txt")) - - fs.Policy.OptionsSerialized.FileType = []string{"txt", "jpg"} - asserts.False(fs.ValidateExtension(ctx, "1")) - asserts.False(fs.ValidateExtension(ctx, "1.jpg.png")) - asserts.True(fs.ValidateExtension(ctx, "1.txt")) - asserts.True(fs.ValidateExtension(ctx, "1.png.jpg")) - asserts.True(fs.ValidateExtension(ctx, "1.png.jpG")) - asserts.False(fs.ValidateExtension(ctx, "1.png")) -} diff --git a/pkg/hashid/hash_test.go b/pkg/hashid/hash_test.go deleted file mode 100644 index 5471d9e..0000000 --- a/pkg/hashid/hash_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package hashid - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestHashEncode(t *testing.T) { - asserts := assert.New(t) - - { - res, err := HashEncode([]int{1, 2, 3}) - asserts.NoError(err) - asserts.NotEmpty(res) - } - - { - res, err := HashEncode([]int{}) - asserts.Error(err) - asserts.Empty(res) - } - -} - -func TestHashID(t *testing.T) { - asserts := assert.New(t) - - { - res := HashID(1, ShareID) - asserts.NotEmpty(res) - } -} - -func TestHashDecode(t *testing.T) { - asserts := assert.New(t) - - // 正常 - { - res, _ := HashEncode([]int{1, 2, 3}) - decodeRes, err := HashDecode(res) - asserts.NoError(err) - asserts.Equal([]int{1, 2, 3}, decodeRes) - } - - // 出错 - { - decodeRes, err := HashDecode("233") - asserts.Error(err) - asserts.Len(decodeRes, 0) - } -} - -func TestDecodeHashID(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - uid, err := DecodeHashID(HashID(1, ShareID), ShareID) - asserts.NoError(err) - asserts.EqualValues(1, uid) - } - - // 类型不匹配 - { - uid, err := DecodeHashID(HashID(1, ShareID), UserID) - asserts.Error(err) - asserts.EqualValues(0, uid) - } -} diff --git a/pkg/mocks/remoteclientmock/mock.go b/pkg/mocks/remoteclientmock/mock.go index 303b673..f036541 100644 --- a/pkg/mocks/remoteclientmock/mock.go +++ b/pkg/mocks/remoteclientmock/mock.go @@ -2,6 +2,7 @@ package remoteclientmock import ( "context" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/stretchr/testify/mock" diff --git a/pkg/mq/mq_test.go b/pkg/mq/mq_test.go deleted file mode 100644 index 9acdd3f..0000000 --- a/pkg/mq/mq_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package mq - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/stretchr/testify/assert" - "sync" - "testing" - "time" -) - -func TestPublishAndSubscribe(t *testing.T) { - t.Parallel() - asserts := assert.New(t) - mq := NewMQ() - - // No subscriber - { - asserts.NotPanics(func() { - mq.Publish("No subscriber", Message{}) - }) - } - - // One channel subscriber - { - topic := "One channel subscriber" - msg := Message{TriggeredBy: "Tester"} - notifier := mq.Subscribe(topic, 0) - mq.Publish(topic, msg) - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - msgRecv := <-notifier - asserts.Equal(msg, msgRecv) - }() - wg.Wait() - } - - // two channel subscriber - { - topic := "two channel subscriber" - msg := Message{TriggeredBy: "Tester"} - notifier := mq.Subscribe(topic, 0) - notifier2 := mq.Subscribe(topic, 0) - mq.Publish(topic, msg) - wg := sync.WaitGroup{} - wg.Add(2) - go func() { - wg.Done() - msgRecv := <-notifier - asserts.Equal(msg, msgRecv) - }() - go func() { - wg.Done() - msgRecv := <-notifier2 - asserts.Equal(msg, msgRecv) - }() - wg.Wait() - } - - // two channel subscriber, one timeout - { - topic := "two channel subscriber, one timeout" - msg := Message{TriggeredBy: "Tester"} - mq.Subscribe(topic, 0) - notifier2 := mq.Subscribe(topic, 0) - mq.Publish(topic, msg) - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - msgRecv := <-notifier2 - asserts.Equal(msg, msgRecv) - }() - wg.Wait() - } - - // two channel subscriber, one unsubscribe - { - topic := "two channel subscriber, one unsubscribe" - msg := Message{TriggeredBy: "Tester"} - mq.Subscribe(topic, 0) - notifier2 := mq.Subscribe(topic, 0) - notifier := mq.Subscribe(topic, 0) - mq.Unsubscribe(topic, notifier) - mq.Publish(topic, msg) - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - msgRecv := <-notifier2 - asserts.Equal(msg, msgRecv) - }() - wg.Wait() - - select { - case <-notifier: - t.Error() - default: - } - } -} - -func TestAria2Interface(t *testing.T) { - t.Parallel() - asserts := assert.New(t) - mq := NewMQ() - var ( - OnDownloadStart int - OnDownloadPause int - OnDownloadStop int - OnDownloadComplete int - OnDownloadError int - ) - l := sync.Mutex{} - - mq.SubscribeCallback("TestAria2Interface", func(message Message) { - asserts.Equal("TestAria2Interface", message.TriggeredBy) - l.Lock() - defer l.Unlock() - switch message.Event { - case "1": - OnDownloadStart++ - case "2": - OnDownloadPause++ - case "5": - OnDownloadStop++ - case "4": - OnDownloadComplete++ - case "3": - OnDownloadError++ - } - }) - - mq.OnDownloadStart([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnDownloadPause([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnDownloadStop([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnDownloadError([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - mq.OnBtDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}}) - - time.Sleep(time.Duration(500) * time.Millisecond) - - asserts.Equal(2, OnDownloadStart) - asserts.Equal(2, OnDownloadPause) - asserts.Equal(2, OnDownloadStop) - asserts.Equal(4, OnDownloadComplete) - asserts.Equal(2, OnDownloadError) -} diff --git a/pkg/payment/alipay.go b/pkg/payment/alipay.go new file mode 100755 index 0000000..a08f45e --- /dev/null +++ b/pkg/payment/alipay.go @@ -0,0 +1,43 @@ +package payment + +import ( + "fmt" + "net/url" + + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + alipay "github.com/smartwalle/alipay/v3" +) + +// Alipay 支付宝当面付支付处理 +type Alipay struct { + Client *alipay.Client +} + +// Create 创建订单 +func (pay *Alipay) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) { + gateway, _ := url.Parse("/api/v3/callback/alipay") + var p = alipay.TradePreCreate{ + Trade: alipay.Trade{ + NotifyURL: model.GetSiteURL().ResolveReference(gateway).String(), + Subject: order.Name, + OutTradeNo: order.OrderNo, + TotalAmount: fmt.Sprintf("%.2f", float64(order.Price*order.Num)/100), + }, + } + + if _, err := order.Create(); err != nil { + return nil, ErrInsertOrder.WithError(err) + } + + res, err := pay.Client.TradePreCreate(p) + if err != nil { + return nil, ErrIssueOrder.WithError(err) + } + + return &OrderCreateRes{ + Payment: true, + QRCode: res.QRCode, + ID: order.OrderNo, + }, nil +} diff --git a/pkg/payment/custom.go b/pkg/payment/custom.go new file mode 100755 index 0000000..23c3748 --- /dev/null +++ b/pkg/payment/custom.go @@ -0,0 +1,93 @@ +package payment + +import ( + "encoding/json" + "errors" + "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/gofrs/uuid" + "github.com/qiniu/go-sdk/v7/sms/bytes" + "net/http" + "net/url" +) + +// Custom payment client +type Custom struct { + client request.Client + endpoint string + authClient auth.Auth +} + +const ( + paymentTTL = 3600 * 24 // 24h + CallbackSessionPrefix = "custom_payment_callback_" +) + +func newCustomClient(endpoint, secret string) *Custom { + authClient := auth.HMACAuth{ + SecretKey: []byte(secret), + } + return &Custom{ + endpoint: endpoint, + authClient: auth.General, + client: request.NewClient( + request.WithCredential(authClient, paymentTTL), + request.WithMasterMeta(), + ), + } +} + +// Request body from Cloudreve to create a new payment +type NewCustomOrderRequest struct { + Name string `json:"name"` // Order name + OrderNo string `json:"order_no"` // Order number + NotifyURL string `json:"notify_url"` // Payment callback url + Amount int64 `json:"amount"` // Order total amount +} + +// Create a new payment +func (pay *Custom) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) { + callbackID := uuid.Must(uuid.NewV4()) + gateway, _ := url.Parse(fmt.Sprintf("/api/v3/callback/custom/%s/%s", order.OrderNo, callbackID)) + callback, err := auth.SignURI(pay.authClient, model.GetSiteURL().ResolveReference(gateway).String(), paymentTTL) + if err != nil { + return nil, fmt.Errorf("failed to sign callback url: %w", err) + } + + cache.Set(CallbackSessionPrefix+callbackID.String(), order.OrderNo, paymentTTL) + + body := &NewCustomOrderRequest{ + Name: order.Name, + OrderNo: order.OrderNo, + NotifyURL: callback.String(), + Amount: int64(order.Price * order.Num), + } + bodyJson, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to encode body: %w", err) + } + + res, err := pay.client.Request("POST", pay.endpoint, bytes.NewReader(bodyJson)). + CheckHTTPResponse(http.StatusOK).DecodeResponse() + if err != nil { + return nil, fmt.Errorf("failed to request payment gateway: %w", err) + } + + if res.Code != 0 { + return nil, errors.New(res.Error) + } + + if _, err := order.Create(); err != nil { + return nil, ErrInsertOrder.WithError(err) + } + + return &OrderCreateRes{ + Payment: true, + QRCode: res.Data.(string), + ID: order.OrderNo, + }, nil +} diff --git a/pkg/payment/order.go b/pkg/payment/order.go new file mode 100755 index 0000000..c161411 --- /dev/null +++ b/pkg/payment/order.go @@ -0,0 +1,171 @@ +package payment + +import ( + "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/iGoogle-ink/gopay/wechat/v3" + "github.com/qingwg/payjs" + "github.com/smartwalle/alipay/v3" + "math/rand" + "net/url" + "time" +) + +var ( + // ErrUnknownPaymentMethod 未知支付方式 + ErrUnknownPaymentMethod = serializer.NewError(serializer.CodeInternalSetting, "Unknown payment method", nil) + // ErrUnsupportedPaymentMethod 未知支付方式 + ErrUnsupportedPaymentMethod = serializer.NewError(serializer.CodeInternalSetting, "This order cannot be paid with this method", nil) + // ErrInsertOrder 无法插入订单记录 + ErrInsertOrder = serializer.NewError(serializer.CodeDBError, "Failed to insert order record", nil) + // ErrScoreNotEnough 积分不足 + ErrScoreNotEnough = serializer.NewError(serializer.CodeInsufficientCredit, "", nil) + // ErrCreateStoragePack 无法创建容量包 + ErrCreateStoragePack = serializer.NewError(serializer.CodeDBError, "Failed to create storage pack record", nil) + // ErrGroupConflict 用户组冲突 + ErrGroupConflict = serializer.NewError(serializer.CodeGroupConflict, "", nil) + // ErrGroupInvalid 用户组冲突 + ErrGroupInvalid = serializer.NewError(serializer.CodeGroupInvalid, "", nil) + // ErrAdminFulfillGroup 管理员无法购买用户组 + ErrAdminFulfillGroup = serializer.NewError(serializer.CodeFulfillAdminGroup, "", nil) + // ErrUpgradeGroup 用户组冲突 + ErrUpgradeGroup = serializer.NewError(serializer.CodeDBError, "Failed to update user's group", nil) + // ErrUInitPayment 无法初始化支付实例 + ErrUInitPayment = serializer.NewError(serializer.CodeInternalSetting, "Failed to initialize payment client", nil) + // ErrIssueOrder 订单接口请求失败 + ErrIssueOrder = serializer.NewError(serializer.CodeInternalSetting, "Failed to create order", nil) + // ErrOrderNotFound 订单不存在 + ErrOrderNotFound = serializer.NewError(serializer.CodeNotFound, "", nil) +) + +// Pay 支付处理接口 +type Pay interface { + Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) +} + +// OrderCreateRes 订单创建结果 +type OrderCreateRes struct { + Payment bool `json:"payment"` // 是否需要支付 + ID string `json:"id,omitempty"` // 订单号 + QRCode string `json:"qr_code,omitempty"` // 支付二维码指向的地址 +} + +// NewPaymentInstance 获取新的支付实例 +func NewPaymentInstance(method string) (Pay, error) { + switch method { + case "score": + return &ScorePayment{}, nil + case "alipay": + options := model.GetSettingByNames("alipay_enabled", "appid", "appkey", "shopid") + if options["alipay_enabled"] != "1" { + return nil, ErrUnknownPaymentMethod + } + + // 初始化支付宝客户端 + var client, err = alipay.New(options["appid"], options["appkey"], true) + if err != nil { + return nil, ErrUInitPayment.WithError(err) + } + + // 加载支付宝公钥 + err = client.LoadAliPayPublicKey(options["shopid"]) + if err != nil { + return nil, ErrUInitPayment.WithError(err) + } + + return &Alipay{Client: client}, nil + case "payjs": + options := model.GetSettingByNames("payjs_enabled", "payjs_secret", "payjs_id") + if options["payjs_enabled"] != "1" { + return nil, ErrUnknownPaymentMethod + } + + callback, _ := url.Parse("/api/v3/callback/payjs") + payjsConfig := &payjs.Config{ + Key: options["payjs_secret"], + MchID: options["payjs_id"], + NotifyUrl: model.GetSiteURL().ResolveReference(callback).String(), + } + + return &PayJSClient{Client: payjs.New(payjsConfig)}, nil + case "wechat": + options := model.GetSettingByNames("wechat_enabled", "wechat_appid", "wechat_mchid", "wechat_serial_no", "wechat_api_key", "wechat_pk_content") + if options["wechat_enabled"] != "1" { + return nil, ErrUnknownPaymentMethod + } + client, err := wechat.NewClientV3(options["wechat_appid"], options["wechat_mchid"], options["wechat_serial_no"], options["wechat_api_key"], options["wechat_pk_content"]) + if err != nil { + return nil, ErrUInitPayment.WithError(err) + } + + return &Wechat{Client: client, ApiV3Key: options["wechat_api_key"]}, nil + case "custom": + options := model.GetSettingByNames("custom_payment_enabled", "custom_payment_endpoint", "custom_payment_secret") + if !model.IsTrueVal(options["custom_payment_enabled"]) { + return nil, ErrUnknownPaymentMethod + } + + return newCustomClient(options["custom_payment_endpoint"], options["custom_payment_secret"]), nil + default: + return nil, ErrUnknownPaymentMethod + } +} + +// NewOrder 创建新订单 +func NewOrder(pack *serializer.PackProduct, group *serializer.GroupProducts, num int, method string, user *model.User) (*OrderCreateRes, error) { + // 获取支付实例 + pay, err := NewPaymentInstance(method) + if err != nil { + return nil, err + } + + var ( + orderType int + productID int64 + title string + price int + ) + if pack != nil { + orderType = model.PackOrderType + productID = pack.ID + title = pack.Name + price = pack.Price + } else if group != nil { + if err := checkGroupUpgrade(user, group); err != nil { + return nil, err + } + + orderType = model.GroupOrderType + productID = group.ID + title = group.Name + price = group.Price + } else { + orderType = model.ScoreOrderType + productID = 0 + title = fmt.Sprintf("%d 积分", num) + price = model.GetIntSetting("score_price", 1) + } + + // 创建订单记录 + order := &model.Order{ + UserID: user.ID, + OrderNo: orderID(), + Type: orderType, + Method: method, + ProductID: productID, + Num: num, + Name: fmt.Sprintf("%s - %s", model.GetSettingByName("siteName"), title), + Price: price, + Status: model.OrderUnpaid, + } + + return pay.Create(order, pack, group, user) +} + +func orderID() string { + return fmt.Sprintf("%s%d", + time.Now().Format("20060102150405"), + 100000+rand.Intn(900000), + ) +} diff --git a/pkg/payment/payjs.go b/pkg/payment/payjs.go new file mode 100755 index 0000000..a82ce7b --- /dev/null +++ b/pkg/payment/payjs.go @@ -0,0 +1,31 @@ +package payment + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/qingwg/payjs" +) + +// PayJSClient PayJS支付处理 +type PayJSClient struct { + Client *payjs.PayJS +} + +// Create 创建订单 +func (pay *PayJSClient) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) { + if _, err := order.Create(); err != nil { + return nil, ErrInsertOrder.WithError(err) + } + + PayNative := pay.Client.GetNative() + res, err := PayNative.Create(int64(order.Price*order.Num), order.Name, order.OrderNo, "", "") + if err != nil { + return nil, ErrIssueOrder.WithError(err) + } + + return &OrderCreateRes{ + Payment: true, + QRCode: res.CodeUrl, + ID: order.OrderNo, + }, nil +} diff --git a/pkg/payment/purchase.go b/pkg/payment/purchase.go new file mode 100755 index 0000000..1c558a6 --- /dev/null +++ b/pkg/payment/purchase.go @@ -0,0 +1,137 @@ +package payment + +import ( + "encoding/json" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "strconv" + "time" +) + +// GivePack 创建容量包 +func GivePack(user *model.User, packInfo *serializer.PackProduct, num int) error { + timeNow := time.Now() + expires := timeNow.Add(time.Duration(packInfo.Time*int64(num)) * time.Second) + pack := model.StoragePack{ + Name: packInfo.Name, + UserID: user.ID, + ActiveTime: &timeNow, + ExpiredTime: &expires, + Size: packInfo.Size, + } + if _, err := pack.Create(); err != nil { + return ErrCreateStoragePack.WithError(err) + } + cache.Deletes([]string{strconv.FormatUint(uint64(user.ID), 10)}, "pack_size_") + return nil +} + +func checkGroupUpgrade(user *model.User, groupInfo *serializer.GroupProducts) error { + if user.Group.ID == 1 { + return ErrAdminFulfillGroup + } + + // 检查用户是否已有未过期用户 + if user.PreviousGroupID != 0 && user.GroupID != groupInfo.GroupID { + return ErrGroupConflict + } + + // 用户组不能相同 + if user.GroupID == groupInfo.GroupID && user.PreviousGroupID == 0 { + return ErrGroupInvalid + } + + return nil +} + +// GiveGroup 升级用户组 +func GiveGroup(user *model.User, groupInfo *serializer.GroupProducts, num int) error { + if err := checkGroupUpgrade(user, groupInfo); err != nil { + return err + } + + timeNow := time.Now() + expires := timeNow.Add(time.Duration(groupInfo.Time*int64(num)) * time.Second) + if user.PreviousGroupID != 0 { + expires = user.GroupExpires.Add(time.Duration(groupInfo.Time*int64(num)) * time.Second) + } + + if err := user.UpgradeGroup(groupInfo.GroupID, &expires); err != nil { + return ErrUpgradeGroup.WithError(err) + } + + return nil +} + +// GiveScore 积分充值 +func GiveScore(user *model.User, num int) error { + user.AddScore(num) + return nil +} + +// GiveProduct “发货” +func GiveProduct(user *model.User, pack *serializer.PackProduct, group *serializer.GroupProducts, num int) error { + if pack != nil { + return GivePack(user, pack, num) + } else if group != nil { + return GiveGroup(user, group, num) + } else { + return GiveScore(user, num) + } +} + +// OrderPaid 订单已支付处理 +func OrderPaid(orderNo string) error { + order, err := model.GetOrderByNo(orderNo) + if err != nil || order.Status == model.OrderPaid { + return ErrOrderNotFound.WithError(err) + } + + // 更新订单状态为 已支付 + order.UpdateStatus(model.OrderPaid) + + user, err := model.GetActiveUserByID(order.UserID) + if err != nil { + return serializer.NewError(serializer.CodeUserNotFound, "", err) + } + + // 查询商品 + options := model.GetSettingByNames("pack_data", "group_sell_data") + + var ( + packs []serializer.PackProduct + groups []serializer.GroupProducts + ) + if err := json.Unmarshal([]byte(options["pack_data"]), &packs); err != nil { + return err + } + if err := json.Unmarshal([]byte(options["group_sell_data"]), &groups); err != nil { + return err + } + + // 查找要购买的商品 + var ( + pack *serializer.PackProduct + group *serializer.GroupProducts + ) + if order.Type == model.GroupOrderType { + for _, v := range groups { + if v.ID == order.ProductID { + group = &v + break + } + } + } else if order.Type == model.PackOrderType { + for _, v := range packs { + if v.ID == order.ProductID { + pack = &v + break + } + } + } + + // "发货" + return GiveProduct(&user, pack, group, order.Num) + +} diff --git a/pkg/payment/score.go b/pkg/payment/score.go new file mode 100755 index 0000000..351c848 --- /dev/null +++ b/pkg/payment/score.go @@ -0,0 +1,45 @@ +package payment + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" +) + +// ScorePayment 积分支付处理 +type ScorePayment struct { +} + +// Create 创建新订单 +func (pay *ScorePayment) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) { + if pack != nil { + order.Price = pack.Score + } else { + order.Price = group.Score + } + + // 检查此订单是否可用积分支付 + if order.Price == 0 { + return nil, ErrUnsupportedPaymentMethod + } + + // 扣除用户积分 + if !user.PayScore(order.Price * order.Num) { + return nil, ErrScoreNotEnough + } + + // 商品“发货” + if err := GiveProduct(user, pack, group, order.Num); err != nil { + user.AddScore(order.Price * order.Num) + return nil, err + } + + // 创建订单记录 + order.Status = model.OrderPaid + if _, err := order.Create(); err != nil { + return nil, ErrInsertOrder.WithError(err) + } + + return &OrderCreateRes{ + Payment: false, + }, nil +} diff --git a/pkg/payment/wechat.go b/pkg/payment/wechat.go new file mode 100755 index 0000000..c9025c5 --- /dev/null +++ b/pkg/payment/wechat.go @@ -0,0 +1,88 @@ +package payment + +import ( + "errors" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/iGoogle-ink/gopay" + "github.com/iGoogle-ink/gopay/wechat/v3" + "net/url" + "time" +) + +// Wechat 微信扫码支付接口 +type Wechat struct { + Client *wechat.ClientV3 + ApiV3Key string +} + +// Create 创建订单 +func (pay *Wechat) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) { + gateway, _ := url.Parse("/api/v3/callback/wechat") + bm := make(gopay.BodyMap) + bm. + Set("description", order.Name). + Set("out_trade_no", order.OrderNo). + Set("notify_url", model.GetSiteURL().ResolveReference(gateway).String()). + SetBodyMap("amount", func(bm gopay.BodyMap) { + bm.Set("total", int64(order.Price*order.Num)). + Set("currency", "CNY") + }) + + wxRsp, err := pay.Client.V3TransactionNative(bm) + if err != nil { + return nil, ErrIssueOrder.WithError(err) + } + + if wxRsp.Code == wechat.Success { + if _, err := order.Create(); err != nil { + return nil, ErrInsertOrder.WithError(err) + } + + return &OrderCreateRes{ + Payment: true, + QRCode: wxRsp.Response.CodeUrl, + ID: order.OrderNo, + }, nil + } + + return nil, ErrIssueOrder.WithError(errors.New(wxRsp.Error)) +} + +// GetPlatformCert 获取微信平台证书 +func (pay *Wechat) GetPlatformCert() string { + if cert, ok := cache.Get("wechat_platform_cert"); ok { + return cert.(string) + } + + res, err := pay.Client.GetPlatformCerts() + if err == nil { + // 使用反馈证书中启用时间较晚的 + var ( + currentLatest *time.Time + currentCert string + ) + for _, cert := range res.Certs { + effectiveTime, err := time.Parse("2006-01-02T15:04:05-0700", cert.EffectiveTime) + if err != nil { + if currentLatest == nil { + currentLatest = &effectiveTime + currentCert = cert.PublicKey + continue + } + if currentLatest.Before(effectiveTime) { + currentLatest = &effectiveTime + currentCert = cert.PublicKey + } + } + } + + cache.Set("wechat_platform_cert", currentCert, 3600*10) + return currentCert + } + + util.Log().Debug("Failed to get Wechat Pay platform certificate: %s", err) + return "" +} diff --git a/pkg/qq/connect.go b/pkg/qq/connect.go new file mode 100755 index 0000000..b8cb72f --- /dev/null +++ b/pkg/qq/connect.go @@ -0,0 +1,211 @@ +package qq + +import ( + "crypto/md5" + "encoding/json" + "errors" + "fmt" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/gofrs/uuid" + "net/url" + "strings" +) + +// LoginPage 登陆页面描述 +type LoginPage struct { + URL string + SecretKey string +} + +// UserCredentials 登陆成功后的凭证 +type UserCredentials struct { + OpenID string + AccessToken string +} + +// UserInfo 用户信息 +type UserInfo struct { + Nick string + Avatar string +} + +var ( + // ErrNotEnabled 未开启登录功能 + ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "QQ Login is not enabled", nil) + // ErrObtainAccessToken 无法获取AccessToken + ErrObtainAccessToken = serializer.NewError(serializer.CodeNotSet, "Cannot obtain AccessToken", nil) + // ErrObtainOpenID 无法获取OpenID + ErrObtainOpenID = serializer.NewError(serializer.CodeNotSet, "Cannot obtain OpenID", nil) + //ErrDecodeResponse 无法解析服务端响应 + ErrDecodeResponse = serializer.NewError(serializer.CodeNotSet, "Cannot parse serverside response", nil) +) + +// NewLoginRequest 新建登录会话 +func NewLoginRequest() (*LoginPage, error) { + // 获取相关设定 + options := model.GetSettingByNames("qq_login", "qq_login_id") + if options["qq_login"] == "0" { + return nil, ErrNotEnabled + } + + // 生成唯一ID + u2, err := uuid.NewV4() + if err != nil { + return nil, err + } + secret := fmt.Sprintf("%x", md5.Sum(u2.Bytes())) + + // 生成登录地址 + loginURL, _ := url.Parse("https://graph.qq.com/oauth2.0/authorize?response_type=code") + queries := loginURL.Query() + queries.Add("client_id", options["qq_login_id"]) + queries.Add("redirect_uri", getCallbackURL()) + queries.Add("state", secret) + loginURL.RawQuery = queries.Encode() + + return &LoginPage{ + URL: loginURL.String(), + SecretKey: secret, + }, nil +} + +func getCallbackURL() string { + //return "https://drive.aoaoao.me/Callback/QQ" + // 生成回调地址 + gateway, _ := url.Parse("/login/qq") + callback := model.GetSiteURL().ResolveReference(gateway).String() + + return callback +} + +func getAccessTokenURL(code string) string { + // 获取相关设定 + options := model.GetSettingByNames("qq_login_id", "qq_login_key") + + api, _ := url.Parse("https://graph.qq.com/oauth2.0/token?grant_type=authorization_code") + queries := api.Query() + queries.Add("client_id", options["qq_login_id"]) + queries.Add("redirect_uri", getCallbackURL()) + queries.Add("client_secret", options["qq_login_key"]) + queries.Add("code", code) + api.RawQuery = queries.Encode() + + return api.String() +} + +func getUserInfoURL(openid, ak string) string { + // 获取相关设定 + options := model.GetSettingByNames("qq_login_id", "qq_login_key") + + api, _ := url.Parse("https://graph.qq.com/user/get_user_info") + queries := api.Query() + queries.Add("oauth_consumer_key", options["qq_login_id"]) + queries.Add("openid", openid) + queries.Add("access_token", ak) + api.RawQuery = queries.Encode() + + return api.String() +} + +func getResponse(body string) (map[string]interface{}, error) { + var res map[string]interface{} + + if !strings.Contains(body, "callback") { + return res, nil + } + + body = strings.TrimPrefix(body, "callback(") + body = strings.TrimSuffix(body, ");\n") + + err := json.Unmarshal([]byte(body), &res) + + return res, err +} + +// Callback 处理回调,返回openid和access key +func Callback(code string) (*UserCredentials, error) { + // 获取相关设定 + options := model.GetSettingByNames("qq_login") + if options["qq_login"] == "0" { + return nil, ErrNotEnabled + } + + api := getAccessTokenURL(code) + + // 获取AccessToken + client := request.NewClient() + res := client.Request("GET", api, nil) + resp, err := res.GetResponse() + if err != nil { + return nil, ErrObtainAccessToken.WithError(err) + } + + // 如果服务端返回错误 + errResp, err := getResponse(resp) + if msg, ok := errResp["error_description"]; err == nil && ok { + return nil, ErrObtainAccessToken.WithError(errors.New(msg.(string))) + } + + // 获取AccessToken + vals, err := url.ParseQuery(resp) + if err != nil { + return nil, ErrDecodeResponse.WithError(err) + } + accessToken := vals.Get("access_token") + + // 用 AccessToken 换取OpenID + res = client.Request("GET", "https://graph.qq.com/oauth2.0/me?access_token="+accessToken, nil) + resp, err = res.GetResponse() + if err != nil { + return nil, ErrObtainOpenID.WithError(err) + } + + // 解析服务端响应 + errResp, err = getResponse(resp) + if msg, ok := errResp["error_description"]; err == nil && ok { + return nil, ErrObtainOpenID.WithError(errors.New(msg.(string))) + } + + if openid, ok := errResp["openid"]; ok { + return &UserCredentials{ + OpenID: openid.(string), + AccessToken: accessToken, + }, nil + } + + return nil, ErrDecodeResponse +} + +// GetUserInfo 使用凭证获取用户信息 +func GetUserInfo(credential *UserCredentials) (*UserInfo, error) { + api := getUserInfoURL(credential.OpenID, credential.AccessToken) + + // 获取用户信息 + client := request.NewClient() + res := client.Request("GET", api, nil) + resp, err := res.GetResponse() + if err != nil { + return nil, ErrObtainAccessToken.WithError(err) + } + + var resSerialized map[string]interface{} + if err := json.Unmarshal([]byte(resp), &resSerialized); err != nil { + return nil, ErrDecodeResponse.WithError(err) + } + + // 如果服务端返回错误 + if msg, ok := resSerialized["msg"]; ok && msg.(string) != "" { + return nil, ErrObtainAccessToken.WithError(errors.New(msg.(string))) + } + + if avatar, ok := resSerialized["figureurl_qq_2"]; ok { + return &UserInfo{ + Nick: resSerialized["nickname"].(string), + Avatar: avatar.(string), + }, nil + } + + return nil, ErrDecodeResponse +} diff --git a/pkg/recaptcha/recaptcha.go b/pkg/recaptcha/recaptcha.go index 75354bd..e360898 100644 --- a/pkg/recaptcha/recaptcha.go +++ b/pkg/recaptcha/recaptcha.go @@ -67,7 +67,8 @@ type ReCAPTCHA struct { } // NewReCAPTCHA new ReCAPTCHA instance if version is set to V2 uses recatpcha v2 API, get your secret from https://www.google.com/recaptcha/admin -// if version is set to V2 uses recatpcha v2 API, get your secret from https://g.co/recaptcha/v3 +// +// if version is set to V2 uses recatpcha v2 API, get your secret from https://g.co/recaptcha/v3 func NewReCAPTCHA(ReCAPTCHASecret string, version VERSION, timeout time.Duration) (ReCAPTCHA, error) { if ReCAPTCHASecret == "" { return ReCAPTCHA{}, fmt.Errorf("recaptcha secret cannot be blank") diff --git a/pkg/request/request_test.go b/pkg/request/request_test.go deleted file mode 100644 index e54831e..0000000 --- a/pkg/request/request_test.go +++ /dev/null @@ -1,278 +0,0 @@ -package request - -import ( - "context" - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "io" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" - - "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -type ClientMock struct { - testMock.Mock -} - -func (m ClientMock) Request(method, target string, body io.Reader, opts ...Option) *Response { - args := m.Called(method, target, body, opts) - return args.Get(0).(*Response) -} - -func TestWithTimeout(t *testing.T) { - asserts := assert.New(t) - options := newDefaultOption() - WithTimeout(time.Duration(5) * time.Second).apply(options) - asserts.Equal(time.Duration(5)*time.Second, options.timeout) -} - -func TestWithHeader(t *testing.T) { - asserts := assert.New(t) - options := newDefaultOption() - WithHeader(map[string][]string{"Origin": []string{"123"}}).apply(options) - asserts.Equal(http.Header{"Origin": []string{"123"}}, options.header) -} - -func TestWithContentLength(t *testing.T) { - asserts := assert.New(t) - options := newDefaultOption() - WithContentLength(10).apply(options) - asserts.EqualValues(10, options.contentLength) -} - -func TestWithContext(t *testing.T) { - asserts := assert.New(t) - options := newDefaultOption() - WithContext(context.Background()).apply(options) - asserts.NotNil(options.ctx) -} - -func TestHTTPClient_Request(t *testing.T) { - asserts := assert.New(t) - client := NewClient(WithSlaveMeta("test")) - - // 正常 - { - resp := client.Request( - "POST", - "/test", - strings.NewReader(""), - WithContentLength(0), - WithEndpoint("http://cloudreveisnotexist.com"), - WithTimeout(time.Duration(1)*time.Microsecond), - WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10), - WithoutHeader([]string{"origin", "origin"}), - ) - asserts.Error(resp.Err) - asserts.Nil(resp.Response) - } - - // 正常 带有ctx - { - resp := client.Request( - "GET", - "http://cloudreveisnotexist.com", - strings.NewReader(""), - WithTimeout(time.Duration(1)*time.Microsecond), - WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10), - WithContext(context.Background()), - WithoutHeader([]string{"s s", "s s"}), - WithMasterMeta(), - ) - asserts.Error(resp.Err) - asserts.Nil(resp.Response) - } - -} - -func TestResponse_GetResponse(t *testing.T) { - asserts := assert.New(t) - - // 直接返回错误 - { - resp := Response{ - Err: errors.New("error"), - } - content, err := resp.GetResponse() - asserts.Empty(content) - asserts.Error(err) - } - - // 正常 - { - resp := Response{ - Response: &http.Response{Body: ioutil.NopCloser(strings.NewReader("123"))}, - } - content, err := resp.GetResponse() - asserts.Equal("123", content) - asserts.NoError(err) - } -} - -func TestResponse_CheckHTTPResponse(t *testing.T) { - asserts := assert.New(t) - - // 直接返回错误 - { - resp := Response{ - Err: errors.New("error"), - } - res := resp.CheckHTTPResponse(200) - asserts.Error(res.Err) - } - - // 404错误 - { - resp := Response{ - Response: &http.Response{StatusCode: 404}, - } - res := resp.CheckHTTPResponse(200) - asserts.Error(res.Err) - } - - // 通过 - { - resp := Response{ - Response: &http.Response{StatusCode: 200}, - } - res := resp.CheckHTTPResponse(200) - asserts.NoError(res.Err) - } -} - -func TestResponse_GetRSCloser(t *testing.T) { - asserts := assert.New(t) - - // 直接返回错误 - { - resp := Response{ - Err: errors.New("error"), - } - res, err := resp.GetRSCloser() - asserts.Error(err) - asserts.Nil(res) - } - - // 正常 - { - resp := Response{ - Response: &http.Response{ContentLength: 3, Body: ioutil.NopCloser(strings.NewReader("123"))}, - } - res, err := resp.GetRSCloser() - asserts.NoError(err) - content, err := ioutil.ReadAll(res) - asserts.NoError(err) - asserts.Equal("123", string(content)) - offset, err := res.Seek(0, 0) - asserts.NoError(err) - asserts.Equal(int64(0), offset) - offset, err = res.Seek(0, 2) - asserts.NoError(err) - asserts.Equal(int64(3), offset) - _, err = res.Seek(1, 2) - asserts.Error(err) - asserts.NoError(res.Close()) - } - -} - -func TestResponse_DecodeResponse(t *testing.T) { - asserts := assert.New(t) - - // 直接返回错误 - { - resp := Response{Err: errors.New("error")} - response, err := resp.DecodeResponse() - asserts.Error(err) - asserts.Nil(response) - } - - // 无法解析响应 - { - resp := Response{ - Response: &http.Response{ - Body: ioutil.NopCloser(strings.NewReader("test")), - }, - } - response, err := resp.DecodeResponse() - asserts.Error(err) - asserts.Nil(response) - } - - // 成功 - { - resp := Response{ - Response: &http.Response{ - Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")), - }, - } - response, err := resp.DecodeResponse() - asserts.NoError(err) - asserts.NotNil(response) - asserts.Equal(0, response.Code) - } -} - -func TestNopRSCloser_SetFirstFakeChunk(t *testing.T) { - asserts := assert.New(t) - rsc := NopRSCloser{ - status: &rscStatus{}, - } - rsc.SetFirstFakeChunk() - asserts.True(rsc.status.IgnoreFirst) - - rsc.SetContentLength(20) - asserts.EqualValues(20, rsc.status.Size) -} - -func TestBlackHole(t *testing.T) { - a := assert.New(t) - cache.Set("setting_reset_after_upload_failed", "true", 0) - a.NotPanics(func() { - BlackHole(strings.NewReader("TestBlackHole")) - }) -} - -func TestHTTPClient_TPSLimit(t *testing.T) { - a := assert.New(t) - client := NewClient() - - finished := make(chan struct{}) - go func() { - client.Request( - "POST", - "/test", - strings.NewReader(""), - WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1), - ) - close(finished) - }() - select { - case <-finished: - case <-time.After(10 * time.Second): - a.Fail("Request should be finished instantly.") - } - - finished = make(chan struct{}) - go func() { - client.Request( - "POST", - "/test", - strings.NewReader(""), - WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1), - ) - close(finished) - }() - select { - case <-finished: - case <-time.After(2 * time.Second): - a.Fail("Request should be finished in 1 second.") - } - -} diff --git a/pkg/request/tpslimiter_test.go b/pkg/request/tpslimiter_test.go deleted file mode 100644 index daec236..0000000 --- a/pkg/request/tpslimiter_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package request - -import ( - "context" - "github.com/stretchr/testify/assert" - "testing" - "time" -) - -func TestLimit(t *testing.T) { - a := assert.New(t) - l := NewTPSLimiter() - finished := make(chan struct{}) - go func() { - l.Limit(context.Background(), "token", 1, 1) - close(finished) - }() - select { - case <-finished: - case <-time.After(10 * time.Second): - a.Fail("Limit should be finished instantly.") - } - - finished = make(chan struct{}) - go func() { - l.Limit(context.Background(), "token", 1, 1) - close(finished) - }() - select { - case <-finished: - case <-time.After(2 * time.Second): - a.Fail("Limit should be finished in 1 second.") - } - - finished = make(chan struct{}) - go func() { - l.Limit(context.Background(), "token", 10, 1) - close(finished) - }() - select { - case <-finished: - case <-time.After(1 * time.Second): - a.Fail("Limit should be finished instantly.") - } - -} diff --git a/pkg/serializer/aria2_test.go b/pkg/serializer/aria2_test.go deleted file mode 100644 index 1f3ca61..0000000 --- a/pkg/serializer/aria2_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package serializer - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestBuildFinishedListResponse(t *testing.T) { - asserts := assert.New(t) - tasks := []model.Download{ - { - StatusInfo: rpc.StatusInfo{ - Files: []rpc.FileInfo{ - { - Path: "/file/name.txt", - }, - }, - }, - Task: &model.Task{ - Model: gorm.Model{}, - Error: "error", - }, - }, - { - StatusInfo: rpc.StatusInfo{ - Files: []rpc.FileInfo{ - { - Path: "/file/name1.txt", - }, - { - Path: "/file/name2.txt", - }, - }, - }, - }, - } - tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt" - res := BuildFinishedListResponse(tasks).Data.([]FinishedListResponse) - asserts.Len(res, 2) - asserts.Equal("name.txt", res[1].Name) - asserts.Equal("name.txt", res[0].Name) - asserts.Equal("name.txt", res[0].Files[0].Path) - asserts.Equal("name1.txt", res[1].Files[0].Path) - asserts.Equal("name2.txt", res[1].Files[1].Path) - asserts.EqualValues(0, res[0].TaskStatus) - asserts.Equal("error", res[0].TaskError) -} - -func TestBuildDownloadingResponse(t *testing.T) { - asserts := assert.New(t) - cache.Set("setting_aria2_interval", "10", 0) - tasks := []model.Download{ - { - StatusInfo: rpc.StatusInfo{ - Files: []rpc.FileInfo{ - { - Path: "/file/name.txt", - }, - }, - }, - Task: &model.Task{ - Model: gorm.Model{}, - Error: "error", - }, - }, - { - StatusInfo: rpc.StatusInfo{ - Files: []rpc.FileInfo{ - { - Path: "/file/name1.txt", - }, - { - Path: "/file/name2.txt", - }, - }, - }, - }, - } - tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt" - tasks[1].ID = 1 - - res := BuildDownloadingResponse(tasks, map[uint]int{1: 5}).Data.([]DownloadListResponse) - asserts.Len(res, 2) - asserts.Equal("name1.txt", res[1].Name) - asserts.Equal(5, res[1].UpdateInterval) - asserts.Equal("name.txt", res[0].Name) - asserts.Equal("name.txt", res[0].Info.Files[0].Path) - asserts.Equal("name1.txt", res[1].Info.Files[0].Path) - asserts.Equal("name2.txt", res[1].Info.Files[1].Path) -} diff --git a/pkg/serializer/auth_test.go b/pkg/serializer/auth_test.go deleted file mode 100644 index 96b6b9b..0000000 --- a/pkg/serializer/auth_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package serializer - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestNewRequestSignString(t *testing.T) { - asserts := assert.New(t) - - sign := NewRequestSignString("1", "2", "3") - asserts.NotEmpty(sign) -} diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go index 326c0d8..7ddbc59 100644 --- a/pkg/serializer/error.go +++ b/pkg/serializer/error.go @@ -78,6 +78,8 @@ const ( CodeInvalidChunkIndex = 40012 // CodeInvalidContentLength 无效的正文长度 CodeInvalidContentLength = 40013 + // CodePhoneRequired 未绑定手机 + CodePhoneRequired = 40010 // CodeBatchSourceSize 超出批量获取外链限制 CodeBatchSourceSize = 40014 // CodeBatchAria2Size 超出最大 Aria2 任务数量限制 @@ -112,6 +114,8 @@ const ( CodeInvalidTempLink = 40029 // CodeTempLinkExpired 临时链接过期 CodeTempLinkExpired = 40030 + // CodeEmailProviderBaned 邮箱后缀被禁用 + CodeEmailProviderBaned = 40031 // CodeEmailExisted 邮箱已被使用 CodeEmailExisted = 40032 // CodeEmailSent 邮箱已重新发送 @@ -192,6 +196,8 @@ const ( CodeDisabledSharePreview = 40070 // 签名无效 CodeInvalidSign = 40071 + // 管理员无法购买用户组 + CodeFulfillAdminGroup = 40072 // CodeDBError 数据库操作失败 CodeDBError = 50001 // CodeEncryptError 加密失败 diff --git a/pkg/serializer/error_test.go b/pkg/serializer/error_test.go deleted file mode 100644 index d02fd5d..0000000 --- a/pkg/serializer/error_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package serializer - -import ( - "errors" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestNewError(t *testing.T) { - a := assert.New(t) - err := NewError(400, "Bad Request", errors.New("error")) - a.Error(err) - a.EqualValues(400, err.Code) - - err.WithError(errors.New("error2")) - a.Equal("error2", err.RawError.Error()) - a.Equal("Bad Request", err.Error()) - - resp := &Response{ - Code: 400, - Msg: "Bad Request", - Error: "error", - } - err = NewErrorFromResponse(resp) - a.Error(err) -} - -func TestDBErr(t *testing.T) { - a := assert.New(t) - resp := DBErr("", nil) - a.NotEmpty(resp.Msg) - - resp = ParamErr("", nil) - a.NotEmpty(resp.Msg) -} - -func TestErr(t *testing.T) { - a := assert.New(t) - err := NewError(400, "Bad Request", errors.New("error")) - resp := Err(400, "", err) - a.Equal("Bad Request", resp.Msg) -} diff --git a/pkg/serializer/explorer_test.go b/pkg/serializer/explorer_test.go deleted file mode 100644 index 00c9efc..0000000 --- a/pkg/serializer/explorer_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package serializer - -import ( - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestBuildObjectList(t *testing.T) { - a := assert.New(t) - res := BuildObjectList(1, []Object{{}, {}}, &model.Policy{}) - a.NotEmpty(res.Parent) - a.NotNil(res.Policy) - a.Len(res.Objects, 2) -} diff --git a/pkg/serializer/response_test.go b/pkg/serializer/response_test.go deleted file mode 100644 index 70c8899..0000000 --- a/pkg/serializer/response_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package serializer - -import ( - "encoding/json" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestNewResponseWithGobData(t *testing.T) { - a := assert.New(t) - type args struct { - data interface{} - } - - res := NewResponseWithGobData(args{}) - a.Equal(CodeInternalSetting, res.Code) - - res = NewResponseWithGobData("TestNewResponseWithGobData") - a.Equal(0, res.Code) - a.NotEmpty(res.Data) -} - -func TestResponse_GobDecode(t *testing.T) { - a := assert.New(t) - res := NewResponseWithGobData("TestResponse_GobDecode") - jsonContent, err := json.Marshal(res) - a.NoError(err) - resDecoded := &Response{} - a.NoError(json.Unmarshal(jsonContent, resDecoded)) - var target string - resDecoded.GobDecode(&target) - a.Equal("TestResponse_GobDecode", target) -} diff --git a/pkg/serializer/setting.go b/pkg/serializer/setting.go index 7e4ce00..df6fc2c 100644 --- a/pkg/serializer/setting.go +++ b/pkg/serializer/setting.go @@ -1,8 +1,9 @@ package serializer import ( - model "github.com/cloudreve/Cloudreve/v3/models" "time" + + model "github.com/cloudreve/Cloudreve/v3/models" ) // SiteConfig 站点全局设置序列 @@ -12,18 +13,25 @@ type SiteConfig struct { RegCaptcha bool `json:"regCaptcha"` ForgetCaptcha bool `json:"forgetCaptcha"` EmailActive bool `json:"emailActive"` + QQLogin bool `json:"QQLogin"` Themes string `json:"themes"` DefaultTheme string `json:"defaultTheme"` + ScoreEnabled bool `json:"score_enabled"` + ShareScoreRate string `json:"share_score_rate"` HomepageViewMethod string `json:"home_view_method"` ShareViewMethod string `json:"share_view_method"` Authn bool `json:"authn"` User User `json:"user"` ReCaptchaKey string `json:"captcha_ReCaptchaKey"` + SiteNotice string `json:"site_notice"` CaptchaType string `json:"captcha_type"` TCaptchaCaptchaAppId string `json:"tcaptcha_captcha_app_id"` RegisterEnabled bool `json:"registerEnabled"` + ReportEnabled bool `json:"report_enabled"` AppPromotion bool `json:"app_promotion"` WopiExts []string `json:"wopi_exts"` + AppFeedbackLink string `json:"app_feedback"` + AppForumLink string `json:"app_forum"` } type task struct { @@ -75,18 +83,31 @@ func BuildSiteConfig(settings map[string]string, user *model.User, wopiExts []st RegCaptcha: model.IsTrueVal(checkSettingValue(settings, "reg_captcha")), ForgetCaptcha: model.IsTrueVal(checkSettingValue(settings, "forget_captcha")), EmailActive: model.IsTrueVal(checkSettingValue(settings, "email_active")), + QQLogin: model.IsTrueVal(checkSettingValue(settings, "qq_login")), Themes: checkSettingValue(settings, "themes"), DefaultTheme: checkSettingValue(settings, "defaultTheme"), + ScoreEnabled: model.IsTrueVal(checkSettingValue(settings, "score_enabled")), + ShareScoreRate: checkSettingValue(settings, "share_score_rate"), HomepageViewMethod: checkSettingValue(settings, "home_view_method"), ShareViewMethod: checkSettingValue(settings, "share_view_method"), Authn: model.IsTrueVal(checkSettingValue(settings, "authn_enabled")), User: userRes, + SiteNotice: checkSettingValue(settings, "siteNotice"), ReCaptchaKey: checkSettingValue(settings, "captcha_ReCaptchaKey"), CaptchaType: checkSettingValue(settings, "captcha_type"), TCaptchaCaptchaAppId: checkSettingValue(settings, "captcha_TCaptcha_CaptchaAppId"), RegisterEnabled: model.IsTrueVal(checkSettingValue(settings, "register_enabled")), + ReportEnabled: model.IsTrueVal(checkSettingValue(settings, "report_enabled")), AppPromotion: model.IsTrueVal(checkSettingValue(settings, "show_app_promotion")), + AppFeedbackLink: checkSettingValue(settings, "app_feedback_link"), + AppForumLink: checkSettingValue(settings, "app_forum_link"), WopiExts: wopiExts, }} return res } + +// VolResponse VOL query response +type VolResponse struct { + Signature string `json:"signature"` + Content string `json:"content"` +} diff --git a/pkg/serializer/setting_test.go b/pkg/serializer/setting_test.go deleted file mode 100644 index 680edb6..0000000 --- a/pkg/serializer/setting_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package serializer - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestCheckSettingValue(t *testing.T) { - asserts := assert.New(t) - - asserts.Equal("", checkSettingValue(map[string]string{}, "key")) - asserts.Equal("123", checkSettingValue(map[string]string{"key": "123"}, "key")) -} - -func TestBuildSiteConfig(t *testing.T) { - asserts := assert.New(t) - - res := BuildSiteConfig(map[string]string{"not exist": ""}, &model.User{}, nil) - asserts.Equal("", res.Data.(SiteConfig).SiteName) - - res = BuildSiteConfig(map[string]string{"siteName": "123"}, &model.User{}, nil) - asserts.Equal("123", res.Data.(SiteConfig).SiteName) - - // 非空用户 - res = BuildSiteConfig(map[string]string{"qq_login": "1"}, &model.User{ - Model: gorm.Model{ - ID: 5, - }, - }, nil) - asserts.Len(res.Data.(SiteConfig).User.ID, 4) -} - -func TestBuildTaskList(t *testing.T) { - asserts := assert.New(t) - tasks := []model.Task{{}} - - res := BuildTaskList(tasks, 1) - asserts.NotNil(res) -} diff --git a/pkg/serializer/share.go b/pkg/serializer/share.go index 94da4c6..ade604d 100644 --- a/pkg/serializer/share.go +++ b/pkg/serializer/share.go @@ -12,6 +12,7 @@ type Share struct { Key string `json:"key"` Locked bool `json:"locked"` IsDir bool `json:"is_dir"` + Score int `json:"score"` CreateDate time.Time `json:"create_date,omitempty"` Downloads int `json:"downloads"` Views int `json:"views"` @@ -36,6 +37,7 @@ type shareSource struct { type myShareItem struct { Key string `json:"key"` IsDir bool `json:"is_dir"` + Score int `json:"score"` Password string `json:"password"` CreateDate time.Time `json:"create_date,omitempty"` Downloads int `json:"downloads"` @@ -54,6 +56,7 @@ func BuildShareList(shares []model.Share, total int) Response { item := myShareItem{ Key: hashid.HashID(shares[i].ID, hashid.ShareID), IsDir: shares[i].IsDir, + Score: shares[i].Score, Password: shares[i].Password, CreateDate: shares[i].CreatedAt, Downloads: shares[i].Downloads, @@ -99,6 +102,7 @@ func BuildShareResponse(share *model.Share, unlocked bool) Share { Nick: creator.Nick, GroupName: creator.Group.Name, }, + Score: share.Score, CreateDate: share.CreatedAt, } diff --git a/pkg/serializer/share_test.go b/pkg/serializer/share_test.go deleted file mode 100644 index 72feb0c..0000000 --- a/pkg/serializer/share_test.go +++ /dev/null @@ -1,85 +0,0 @@ -package serializer - -import ( - "testing" - "time" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestBuildShareList(t *testing.T) { - asserts := assert.New(t) - timeNow := time.Now() - - shares := []model.Share{ - { - Expires: &timeNow, - File: model.File{ - Model: gorm.Model{ID: 1}, - }, - }, - { - Folder: model.Folder{ - Model: gorm.Model{ID: 1}, - }, - }, - } - - res := BuildShareList(shares, 2) - asserts.Equal(0, res.Code) -} - -func TestBuildShareResponse(t *testing.T) { - asserts := assert.New(t) - - // 未解锁 - { - share := &model.Share{ - User: model.User{Model: gorm.Model{ID: 1}}, - Downloads: 1, - } - res := BuildShareResponse(share, false) - asserts.EqualValues(0, res.Downloads) - asserts.True(res.Locked) - asserts.NotNil(res.Creator) - } - - // 已解锁,非目录 - { - expires := time.Now().Add(time.Duration(10) * time.Second) - share := &model.Share{ - User: model.User{Model: gorm.Model{ID: 1}}, - Downloads: 1, - Expires: &expires, - File: model.File{ - Model: gorm.Model{ID: 1}, - }, - } - res := BuildShareResponse(share, true) - asserts.EqualValues(1, res.Downloads) - asserts.False(res.Locked) - asserts.NotEmpty(res.Expire) - asserts.NotNil(res.Creator) - } - - // 已解锁,是目录 - { - expires := time.Now().Add(time.Duration(10) * time.Second) - share := &model.Share{ - User: model.User{Model: gorm.Model{ID: 1}}, - Downloads: 1, - Expires: &expires, - Folder: model.Folder{ - Model: gorm.Model{ID: 1}, - }, - IsDir: true, - } - res := BuildShareResponse(share, true) - asserts.EqualValues(1, res.Downloads) - asserts.False(res.Locked) - asserts.NotEmpty(res.Expire) - asserts.NotNil(res.Creator) - } -} diff --git a/pkg/serializer/slave_test.go b/pkg/serializer/slave_test.go deleted file mode 100644 index 46b5d2d..0000000 --- a/pkg/serializer/slave_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package serializer - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" -) - -func TestSlaveTransferReq_Hash(t *testing.T) { - a := assert.New(t) - s1 := &SlaveTransferReq{ - Src: "1", - Policy: &model.Policy{}, - } - s2 := &SlaveTransferReq{ - Src: "2", - Policy: &model.Policy{}, - } - a.NotEqual(s1.Hash("1"), s2.Hash("1")) -} diff --git a/pkg/serializer/user.go b/pkg/serializer/user.go index 142f424..e5f67d8 100644 --- a/pkg/serializer/user.go +++ b/pkg/serializer/user.go @@ -2,11 +2,11 @@ package serializer import ( "fmt" + "time" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/duo-labs/webauthn/webauthn" - "time" ) // CheckLogin 检查登录 @@ -17,6 +17,14 @@ func CheckLogin() Response { } } +// PhoneRequired 需要绑定手机 +func PhoneRequired() Response { + return Response{ + Code: CodePhoneRequired, + Msg: "此功能需要绑定手机后使用", + } +} + // User 用户序列化器 type User struct { ID string `json:"id"` @@ -26,6 +34,7 @@ type User struct { Avatar string `json:"avatar"` CreatedAt time.Time `json:"created_at"` PreferredTheme string `json:"preferred_theme"` + Score int `json:"score"` Anonymous bool `json:"anonymous"` Group group `json:"group"` Tags []tag `json:"tags"` @@ -37,10 +46,13 @@ type group struct { AllowShare bool `json:"allowShare"` AllowRemoteDownload bool `json:"allowRemoteDownload"` AllowArchiveDownload bool `json:"allowArchiveDownload"` + ShareFreeEnabled bool `json:"shareFree"` ShareDownload bool `json:"shareDownload"` CompressEnabled bool `json:"compress"` WebDAVEnabled bool `json:"webdav"` + RelocateEnabled bool `json:"relocate"` SourceBatchSize int `json:"sourceBatch"` + SelectNode bool `json:"selectNode"` AdvanceDelete bool `json:"advanceDelete"` AllowWebDAVProxy bool `json:"allowWebDAVProxy"` } @@ -91,6 +103,7 @@ func BuildUser(user model.User) User { Avatar: user.Avatar, CreatedAt: user.CreatedAt, PreferredTheme: user.OptionsSerialized.PreferredTheme, + Score: user.Score, Anonymous: user.IsAnonymous(), Group: group{ ID: user.GroupID, @@ -98,11 +111,14 @@ func BuildUser(user model.User) User { AllowShare: user.Group.ShareEnabled, AllowRemoteDownload: user.Group.OptionsSerialized.Aria2, AllowArchiveDownload: user.Group.OptionsSerialized.ArchiveDownload, + ShareFreeEnabled: user.Group.OptionsSerialized.ShareFree, ShareDownload: user.Group.OptionsSerialized.ShareDownload, CompressEnabled: user.Group.OptionsSerialized.ArchiveTask, WebDAVEnabled: user.Group.WebDAVEnabled, AllowWebDAVProxy: user.Group.OptionsSerialized.WebDAVProxy, + RelocateEnabled: user.Group.OptionsSerialized.Relocate, SourceBatchSize: user.Group.OptionsSerialized.SourceBatchSize, + SelectNode: user.Group.OptionsSerialized.SelectNode, AdvanceDelete: user.Group.OptionsSerialized.AdvanceDelete, }, Tags: buildTagRes(tags), @@ -118,7 +134,7 @@ func BuildUserResponse(user model.User) Response { // BuildUserStorageResponse 序列化用户存储概况响应 func BuildUserStorageResponse(user model.User) Response { - total := user.Group.MaxStorage + total := user.Group.MaxStorage + user.GetAvailablePackSize() storageResp := storage{ Used: user.Storage, Free: total - user.Storage, diff --git a/pkg/serializer/user_test.go b/pkg/serializer/user_test.go deleted file mode 100644 index 2942186..0000000 --- a/pkg/serializer/user_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package serializer - -import ( - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/duo-labs/webauthn/webauthn" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestBuildUser(t *testing.T) { - asserts := assert.New(t) - user := model.User{} - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) - res := BuildUser(user) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(res) - -} - -func TestBuildUserResponse(t *testing.T) { - asserts := assert.New(t) - user := model.User{} - res := BuildUserResponse(user) - asserts.NotNil(res) -} - -func TestBuildUserStorageResponse(t *testing.T) { - asserts := assert.New(t) - cache.Set("pack_size_0", uint64(0), 0) - - { - user := model.User{ - Storage: 0, - Group: model.Group{MaxStorage: 10}, - } - res := BuildUserStorageResponse(user) - asserts.Equal(uint64(0), res.Data.(storage).Used) - asserts.Equal(uint64(10), res.Data.(storage).Total) - asserts.Equal(uint64(10), res.Data.(storage).Free) - } - { - user := model.User{ - Storage: 6, - Group: model.Group{MaxStorage: 10}, - } - res := BuildUserStorageResponse(user) - asserts.Equal(uint64(6), res.Data.(storage).Used) - asserts.Equal(uint64(10), res.Data.(storage).Total) - asserts.Equal(uint64(4), res.Data.(storage).Free) - } - { - user := model.User{ - Storage: 20, - Group: model.Group{MaxStorage: 10}, - } - res := BuildUserStorageResponse(user) - asserts.Equal(uint64(20), res.Data.(storage).Used) - asserts.Equal(uint64(10), res.Data.(storage).Total) - asserts.Equal(uint64(0), res.Data.(storage).Free) - } - { - user := model.User{ - Storage: 6, - Group: model.Group{MaxStorage: 10}, - } - res := BuildUserStorageResponse(user) - asserts.Equal(uint64(6), res.Data.(storage).Used) - asserts.Equal(uint64(10), res.Data.(storage).Total) - asserts.Equal(uint64(4), res.Data.(storage).Free) - } -} - -func TestBuildTagRes(t *testing.T) { - asserts := assert.New(t) - tags := []model.Tag{ - { - Type: 0, - Expression: "exp", - }, - { - Type: 1, - Expression: "exp", - }, - } - res := buildTagRes(tags) - asserts.Len(res, 2) - asserts.Equal("", res[0].Expression) - asserts.Equal("exp", res[1].Expression) -} - -func TestBuildWebAuthnList(t *testing.T) { - asserts := assert.New(t) - credentials := []webauthn.Credential{{}} - res := BuildWebAuthnList(credentials) - asserts.Len(res, 1) -} diff --git a/pkg/serializer/vas.go b/pkg/serializer/vas.go new file mode 100755 index 0000000..2da604a --- /dev/null +++ b/pkg/serializer/vas.go @@ -0,0 +1,158 @@ +package serializer + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/hashid" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "time" +) + +type quota struct { + Base uint64 `json:"base"` + Pack uint64 `json:"pack"` + Used uint64 `json:"used"` + Total uint64 `json:"total"` + Packs []storagePacks `json:"packs"` +} + +type storagePacks struct { + Name string `json:"name"` + Size uint64 `json:"size"` + ActivateDate time.Time `json:"activate_date"` + Expiration int `json:"expiration"` + ExpirationDate time.Time `json:"expiration_date"` +} + +// MountedFolders 已挂载的目录 +type MountedFolders struct { + ID string `json:"id"` + Name string `json:"name"` + PolicyName string `json:"policy_name"` +} + +type policyOptions struct { + Name string `json:"name"` + ID string `json:"id"` +} + +type nodeOptions struct { + Name string `json:"name"` + ID uint `json:"id"` +} + +// BuildPolicySettingRes 构建存储策略选项选择 +func BuildPolicySettingRes(policies []model.Policy) Response { + options := make([]policyOptions, 0, len(policies)) + for _, policy := range policies { + options = append(options, policyOptions{ + Name: policy.Name, + ID: hashid.HashID(policy.ID, hashid.PolicyID), + }) + } + + return Response{ + Data: options, + } +} + +// BuildMountedFolderRes 构建已挂载目录响应,list为当前用户可用存储策略ID +func BuildMountedFolderRes(folders []model.Folder, list []uint) []MountedFolders { + res := make([]MountedFolders, 0, len(folders)) + for _, folder := range folders { + single := MountedFolders{ + ID: hashid.HashID(folder.ID, hashid.FolderID), + Name: folder.Name, + PolicyName: "[Invalid Policy]", + } + if policy, err := model.GetPolicyByID(folder.PolicyID); err == nil && util.ContainsUint(list, policy.ID) { + single.PolicyName = policy.Name + } + + res = append(res, single) + } + + return res +} + +// BuildUserQuotaResponse 序列化用户存储配额概况响应 +func BuildUserQuotaResponse(user *model.User, packs []model.StoragePack) Response { + packSize := user.GetAvailablePackSize() + res := quota{ + Base: user.Group.MaxStorage, + Pack: packSize, + Used: user.Storage, + Total: packSize + user.Group.MaxStorage, + Packs: make([]storagePacks, 0, len(packs)), + } + for _, pack := range packs { + res.Packs = append(res.Packs, storagePacks{ + Name: pack.Name, + Size: pack.Size, + ActivateDate: *pack.ActiveTime, + Expiration: int(pack.ExpiredTime.Sub(*pack.ActiveTime).Seconds()), + ExpirationDate: *pack.ExpiredTime, + }) + } + + return Response{ + Data: res, + } +} + +// PackProduct 容量包商品 +type PackProduct struct { + ID int64 `json:"id"` + Name string `json:"name"` + Size uint64 `json:"size"` + Time int64 `json:"time"` + Price int `json:"price"` + Score int `json:"score"` +} + +// GroupProducts 用户组商品 +type GroupProducts struct { + ID int64 `json:"id"` + Name string `json:"name"` + GroupID uint `json:"group_id"` + Time int64 `json:"time"` + Price int `json:"price"` + Score int `json:"score"` + Des []string `json:"des"` + Highlight bool `json:"highlight"` +} + +// BuildProductResponse 构建增值服务商品响应 +func BuildProductResponse(groups []GroupProducts, packs []PackProduct, + wechat, alipay, payjs, custom bool, customName string, scorePrice int) Response { + // 隐藏响应中的用户组ID + for i := 0; i < len(groups); i++ { + groups[i].GroupID = 0 + } + return Response{ + Data: map[string]interface{}{ + "packs": packs, + "groups": groups, + "alipay": alipay, + "wechat": wechat, + "payjs": payjs, + "custom": custom, + "custom_name": customName, + "score_price": scorePrice, + }, + } +} + +// BuildNodeOptionRes 构建可用节点列表响应 +func BuildNodeOptionRes(nodes []*model.Node) Response { + options := make([]nodeOptions, 0, len(nodes)) + for _, node := range nodes { + options = append(options, nodeOptions{ + Name: node.Name, + ID: node.ID, + }) + } + + return Response{ + Data: options, + } +} diff --git a/pkg/task/compress.go b/pkg/task/compress.go index 5e20a36..7f5025f 100644 --- a/pkg/task/compress.go +++ b/pkg/task/compress.go @@ -122,7 +122,7 @@ func (job *CompressTask) Do() { job.zipPath = zipFilePath zipFile.Close() - util.Log().Debug("Compressed file saved to %q, start uploading it...", zipFilePath) + util.Log().Debug("Compressed file saved to %q, start uploading it...", zipFile) job.TaskModel.SetProgress(TransferringProgress) // 上传文件 @@ -155,7 +155,7 @@ func NewCompressTask(user *model.User, dst string, dirs, files []uint) (Job, err return newTask, nil } -// NewCompressTaskFromModel 从数据库记录中恢复压缩任务 +// NewRelocateTaskFromModel 从数据库记录中恢复迁移任务 func NewCompressTaskFromModel(task *model.Task) (Job, error) { user, err := model.GetActiveUserByID(task.UserID) if err != nil { diff --git a/pkg/task/compress_test.go b/pkg/task/compress_test.go deleted file mode 100644 index 34b282d..0000000 --- a/pkg/task/compress_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestCompressTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &CompressTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(CompressTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestCompressTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &CompressTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestCompressTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &CompressTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - zipPath: "test/TestCompressTask_SetError", - } - zipFile, _ := util.CreatNestedFile("test/TestCompressTask_SetError") - zipFile.Close() - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.False(util.Exists("test/TestCompressTask_SetError")) - asserts.Equal("error", task.GetError().Msg) -} - -func TestCompressTask_Do(t *testing.T) { - asserts := assert.New(t) - task := &CompressTask{ - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - // 无法创建文件系统 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "unknown", - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 压缩出错 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - }, - } - task.TaskProps.Dirs = []uint{1} - // 更新进度 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - // 查找目录 - mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error")) - // 更新错误 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 上传出错 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - MaxSize: 1, - }, - } - task.TaskProps.Dirs = []uint{1} - cache.Set("setting_temp_path", "test", 0) - // 更新进度 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - // 查找目录 - mock.ExpectQuery("SELECT(.+)folders"). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - // 查找文件 - mock.ExpectQuery("SELECT(.+)files"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 查找子文件 - mock.ExpectQuery("SELECT(.+)files"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 查找子目录 - mock.ExpectQuery("SELECT(.+)folders"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 更新错误 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - asserts.True(util.IsEmpty(util.RelativePath("test/compress"))) - } -} - -func TestNewCompressTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewCompressTask(&model.User{}, "/", []uint{12}, []uint{}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewCompressTask(&model.User{}, "/", []uint{12}, []uint{}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewCompressTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewCompressTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewCompressTaskFromModel(&model.Task{Props: ""}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/decompress_test.go b/pkg/task/decompress_test.go deleted file mode 100644 index 75b7cfe..0000000 --- a/pkg/task/decompress_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestDecompressTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &DecompressTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(DecompressTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestDecompressTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &DecompressTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestDecompressTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &DecompressTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - zipPath: "test/TestCompressTask_SetError", - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error", nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("error", task.GetError().Msg) -} - -func TestDecompressTask_Do(t *testing.T) { - asserts := assert.New(t) - task := &DecompressTask{ - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - // 无法创建文件系统 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "unknown", - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 压缩文件不存在 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - }, - } - task.TaskProps.Src = "test" - task.Do() - asserts.NotEmpty(task.GetError().Msg) - } -} - -func TestNewDecompressTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewDecompressTask(&model.User{}, "/", "/", "utf-8") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewDecompressTask(&model.User{}, "/", "/", "utf-8") - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewDecompressTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewDecompressTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewDecompressTaskFromModel(&model.Task{Props: ""}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/import.go b/pkg/task/import.go index 607b4d1..2b5134b 100644 --- a/pkg/task/import.go +++ b/pkg/task/import.go @@ -86,7 +86,6 @@ func (job *ImportTask) Do() { } // 创建文件系统 - job.User.Policy = policy fs, err := filesystem.NewFileSystem(job.User) if err != nil { job.SetErrorMsg(err.Error(), nil) diff --git a/pkg/task/import_test.go b/pkg/task/import_test.go deleted file mode 100644 index a17123d..0000000 --- a/pkg/task/import_test.go +++ /dev/null @@ -1,246 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/util" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestImportTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &ImportTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(ImportTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestImportTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &ImportTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestImportTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &ImportTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error", nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("error", task.GetError().Msg) -} - -func TestImportTask_Do(t *testing.T) { - asserts := assert.New(t) - task := &ImportTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - TaskProps: ImportProps{ - PolicyID: 63, - Src: "", - Recursive: false, - Dst: "", - }, - } - - // 存储策略不存在 - { - cache.Deletes([]string{"63"}, "policy_") - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 设定失败状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.Err.Error) - task.Err = nil - } - - // 无法分配 Filesystem - { - cache.Deletes([]string{"63"}, "policy_") - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "unknown")) - // 设定失败状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.Err.Msg) - task.Err = nil - } - - // 成功列取,但是文件为空 - { - cache.Deletes([]string{"63"}, "policy_") - task.TaskProps.Src = "TestImportTask_Do/empty" - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "local")) - // 设定listing状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 设定inserting状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(task.Err) - task.Err = nil - } - - // 创建测试文件 - f, _ := util.CreatNestedFile(util.RelativePath("tests/TestImportTask_Do/test.txt")) - f.Close() - - // 成功列取,包含一个文件一个目录,父目录创建失败 - { - cache.Deletes([]string{"63"}, "policy_") - task.TaskProps.Src = "tests" - task.TaskProps.Dst = "/" - task.TaskProps.Recursive = true - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "local")) - // 设定listing状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 设定inserting状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 查找父目录,但是不存在 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 仍然不存在 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 创建文件时查找父目录,仍然不存在 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - - task.Do() - - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(task.Err) - task.Err = nil - } - - // 成功列取,包含一个文件一个目录, 全部操作成功 - { - cache.Deletes([]string{"63"}, "policy_") - task.TaskProps.Src = "tests" - task.TaskProps.Dst = "/" - task.TaskProps.Recursive = true - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "local")) - // 设定listing状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 设定inserting状态 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - // 查找父目录,存在 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - // 查找同名文件,不存在 - mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id"})) - // 创建目录 - mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"})) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)folders(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - // 插入文件记录 - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectExec("UPDATE(.+)users(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) - mock.ExpectCommit() - - task.Do() - - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(task.Err) - task.Err = nil - } -} - -func TestNewImportTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewImportTask(1, 1, "/", "/", false) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewImportTask(1, 1, "/", "/", false) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewImportTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewImportTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewImportTaskFromModel(&model.Task{Props: "?"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/job.go b/pkg/task/job.go index d480492..ad77c6b 100644 --- a/pkg/task/job.go +++ b/pkg/task/job.go @@ -15,6 +15,8 @@ const ( TransferTaskType // ImportTaskType 导入任务 ImportTaskType + // RelocateTaskType 存储策略迁移任务 + RelocateTaskType // RecycleTaskType 回收任务 RecycleTaskType ) @@ -115,6 +117,8 @@ func GetJobFromModel(task *model.Task) (Job, error) { return NewTransferTaskFromModel(task) case ImportTaskType: return NewImportTaskFromModel(task) + case RelocateTaskType: + return NewRelocateTaskFromModel(task) case RecycleTaskType: return NewRecycleTaskFromModel(task) default: diff --git a/pkg/task/job_test.go b/pkg/task/job_test.go deleted file mode 100644 index 737f5b7..0000000 --- a/pkg/task/job_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" -) - -func TestRecord(t *testing.T) { - asserts := assert.New(t) - job := &TransferTask{ - User: &model.User{Policy: model.Policy{Type: "unknown"}}, - } - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - _, err := Record(job) - asserts.NoError(err) -} - -type taskPoolMock struct { - testMock.Mock -} - -func (t taskPoolMock) Add(num int) { - t.Called(num) -} - -func (t taskPoolMock) Submit(job Job) { - t.Called(job) -} - -func TestResume(t *testing.T) { - asserts := assert.New(t) - mockPool := taskPoolMock{} - - // 没有任务 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type"})) - Resume(mockPool) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 有任务, 类型未知 - { - mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(233)) - Resume(mockPool) - asserts.NoError(mock.ExpectationsWereMet()) - } - - // 有任务 - { - mockPool.On("Submit", testMock.Anything) - mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type", "props"}).AddRow(CompressTaskType, "{}")) - mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - Resume(mockPool) - asserts.NoError(mock.ExpectationsWereMet()) - mockPool.AssertExpectations(t) - } -} - -func TestGetJobFromModel(t *testing.T) { - asserts := assert.New(t) - - // CompressTaskType - { - task := &model.Task{ - Status: 0, - Type: CompressTaskType, - } - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - job, err := GetJobFromModel(task) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } - // DecompressTaskType - { - task := &model.Task{ - Status: 0, - Type: DecompressTaskType, - } - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - job, err := GetJobFromModel(task) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } - // TransferTaskType - { - task := &model.Task{ - Status: 0, - Type: TransferTaskType, - } - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - job, err := GetJobFromModel(task) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } - // RecycleTaskType - { - task := &model.Task{ - Status: 0, - Type: RecycleTaskType, - } - mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) - job, err := GetJobFromModel(task) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} diff --git a/pkg/task/pool_test.go b/pkg/task/pool_test.go deleted file mode 100644 index fbe4134..0000000 --- a/pkg/task/pool_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package task - -import ( - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestInit(t *testing.T) { - asserts := assert.New(t) - cache.Set("setting_max_worker_num", "10", 0) - mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(-1)) - Init() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Len(TaskPoll.(*AsyncPool).idleWorker, 10) -} - -func TestPool_Submit(t *testing.T) { - asserts := assert.New(t) - pool := &AsyncPool{ - idleWorker: make(chan int, 1), - } - pool.Add(1) - job := &MockJob{ - DoFunc: func() { - - }, - } - asserts.NotPanics(func() { - pool.Submit(job) - }) -} diff --git a/pkg/task/recycle_test.go b/pkg/task/recycle_test.go deleted file mode 100644 index 0092a30..0000000 --- a/pkg/task/recycle_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestRecycleTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &RecycleTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(RecycleTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestRecycleTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &RecycleTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestRecycleTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &RecycleTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error", nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("error", task.GetError().Msg) -} - -func TestNewRecycleTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewRecycleTask(&model.Download{ - Model: gorm.Model{ID: 1}, - GID: "test_g_id", - Parent: "/", - UserID: 1, - NodeID: 1, - }) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewRecycleTask(&model.Download{ - Model: gorm.Model{ID: 1}, - GID: "test_g_id", - Parent: "test/not_exist", - UserID: 1, - NodeID: 1, - }) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewRecycleTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewRecycleTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewRecycleTaskFromModel(&model.Task{Props: "?"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/relocate.go b/pkg/task/relocate.go new file mode 100755 index 0000000..65666ed --- /dev/null +++ b/pkg/task/relocate.go @@ -0,0 +1,176 @@ +package task + +import ( + "context" + "encoding/json" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/util" +) + +// RelocateTask 存储策略迁移任务 +type RelocateTask struct { + User *model.User + TaskModel *model.Task + TaskProps RelocateProps + Err *JobError +} + +// RelocateProps 存储策略迁移任务属性 +type RelocateProps struct { + Dirs []uint `json:"dirs"` + Files []uint `json:"files"` + DstPolicyID uint `json:"dst_policy_id"` +} + +// Props 获取任务属性 +func (job *RelocateTask) Props() string { + res, _ := json.Marshal(job.TaskProps) + return string(res) +} + +// Type 获取任务状态 +func (job *RelocateTask) Type() int { + return RelocateTaskType +} + +// Creator 获取创建者ID +func (job *RelocateTask) Creator() uint { + return job.User.ID +} + +// Model 获取任务的数据库模型 +func (job *RelocateTask) Model() *model.Task { + return job.TaskModel +} + +// SetStatus 设定状态 +func (job *RelocateTask) SetStatus(status int) { + job.TaskModel.SetStatus(status) +} + +// SetError 设定任务失败信息 +func (job *RelocateTask) SetError(err *JobError) { + job.Err = err + res, _ := json.Marshal(job.Err) + job.TaskModel.SetError(string(res)) +} + +// SetErrorMsg 设定任务失败信息 +func (job *RelocateTask) SetErrorMsg(msg string) { + job.SetError(&JobError{Msg: msg}) +} + +// GetError 返回任务失败信息 +func (job *RelocateTask) GetError() *JobError { + return job.Err +} + +// Do 开始执行任务 +func (job *RelocateTask) Do() { + // 创建文件系统 + fs, err := filesystem.NewFileSystem(job.User) + if err != nil { + job.SetErrorMsg(err.Error()) + return + } + + job.TaskModel.SetProgress(ListingProgress) + util.Log().Debug("Start migration task.") + + // ---------------------------- + // 索引出所有待迁移的文件 + // ---------------------------- + targetFiles := make([]model.File, 0, len(job.TaskProps.Files)) + + // 索引用户选择的单独的文件 + outerFiles, err := model.GetFilesByIDs(job.TaskProps.Files, job.User.ID) + if err != nil { + job.SetError(&JobError{ + Msg: "Failed to index files.", + Error: err.Error(), + }) + return + } + targetFiles = append(targetFiles, outerFiles...) + + // 索引用户选择目录下的所有递归子文件 + subFolders, err := model.GetRecursiveChildFolder(job.TaskProps.Dirs, job.User.ID, true) + if err != nil { + job.SetError(&JobError{ + Msg: "Failed to index child folders.", + Error: err.Error(), + }) + return + } + + subFiles, err := model.GetChildFilesOfFolders(&subFolders) + if err != nil { + job.SetError(&JobError{ + Msg: "Failed to index child files.", + Error: err.Error(), + }) + return + } + targetFiles = append(targetFiles, subFiles...) + + // 查找目标存储策略 + policy, err := model.GetPolicyByID(job.TaskProps.DstPolicyID) + if err != nil { + job.SetError(&JobError{ + Msg: "Invalid policy.", + Error: err.Error(), + }) + return + } + + // 开始转移文件 + job.TaskModel.SetProgress(TransferringProgress) + ctx := context.Background() + err = fs.Relocate(ctx, targetFiles, &policy) + if err != nil { + job.SetErrorMsg(err.Error()) + return + } + + return +} + +// NewRelocateTask 新建转移任务 +func NewRelocateTask(user *model.User, dstPolicyID uint, dirs, files []uint) (Job, error) { + newTask := &RelocateTask{ + User: user, + TaskProps: RelocateProps{ + Dirs: dirs, + Files: files, + DstPolicyID: dstPolicyID, + }, + } + + record, err := Record(newTask) + if err != nil { + return nil, err + } + newTask.TaskModel = record + + return newTask, nil +} + +// NewCompressTaskFromModel 从数据库记录中恢复压缩任务 +func NewRelocateTaskFromModel(task *model.Task) (Job, error) { + user, err := model.GetActiveUserByID(task.UserID) + if err != nil { + return nil, err + } + newTask := &RelocateTask{ + User: &user, + TaskModel: task, + } + + err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) + if err != nil { + return nil, err + } + + return newTask, nil +} diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index 54bba47..135c809 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -93,6 +93,7 @@ func (job *TransferTask) Do() { job.SetErrorMsg(err.Error(), nil) return } + defer fs.Recycle() successCount := 0 errorList := make([]string, 0, len(job.TaskProps.Src)) @@ -115,6 +116,7 @@ func (job *TransferTask) Do() { } // 切换为从机节点处理上传 + fs.SetPolicyFromPath(path.Dir(dst)) fs.SwitchToSlaveHandler(node) err = fs.UploadFromStream(context.Background(), &fsctx.FileStream{ File: nil, diff --git a/pkg/task/transfer_test.go b/pkg/task/transfer_test.go deleted file mode 100644 index 612a453..0000000 --- a/pkg/task/transfer_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package task - -import ( - "errors" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestTransferTask_Props(t *testing.T) { - asserts := assert.New(t) - task := &TransferTask{ - User: &model.User{}, - } - asserts.NotEmpty(task.Props()) - asserts.Equal(TransferTaskType, task.Type()) - asserts.EqualValues(0, task.Creator()) - asserts.Nil(task.Model()) -} - -func TestTransferTask_SetStatus(t *testing.T) { - asserts := assert.New(t) - task := &TransferTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - task.SetStatus(3) - asserts.NoError(mock.ExpectationsWereMet()) -} - -func TestTransferTask_SetError(t *testing.T) { - asserts := assert.New(t) - task := &TransferTask{ - User: &model.User{}, - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - - task.SetErrorMsg("error", nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Equal("error", task.GetError().Msg) -} - -func TestTransferTask_Do(t *testing.T) { - asserts := assert.New(t) - task := &TransferTask{ - TaskModel: &model.Task{ - Model: gorm.Model{ID: 1}, - }, - } - - // 无法创建文件系统 - { - task.TaskProps.Parent = "test/not_exist" - task.User = &model.User{ - Policy: model.Policy{ - Type: "unknown", - }, - } - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 上传出错 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - }, - } - task.TaskProps.Src = []string{"test/not_exist"} - task.TaskProps.Parent = "test/not_exist" - // 更新错误 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } - - // 替换目录前缀 - { - task.User = &model.User{ - Policy: model.Policy{ - Type: "mock", - }, - } - task.TaskProps.Src = []string{"test/not_exist"} - task.TaskProps.Parent = "test/not_exist" - task.TaskProps.TrimPath = true - // 更新错误 - mock.ExpectBegin() - mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, - 1)) - mock.ExpectCommit() - task.Do() - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotEmpty(task.GetError().Msg) - } -} - -func TestNewTransferTask(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectCommit() - job, err := NewTransferTask(1, []string{}, "/", "/", false, 0, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NotNil(job) - asserts.NoError(err) - } - - // 失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - mock.ExpectBegin() - mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) - mock.ExpectRollback() - job, err := NewTransferTask(1, []string{}, "/", "/", false, 0, nil) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Nil(job) - asserts.Error(err) - } -} - -func TestNewTransferTaskFromModel(t *testing.T) { - asserts := assert.New(t) - - // 成功 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewTransferTaskFromModel(&model.Task{Props: "{}"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.NoError(err) - asserts.NotNil(job) - } - - // JSON解析失败 - { - mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) - job, err := NewTransferTaskFromModel(&model.Task{Props: "?"}) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(job) - } -} diff --git a/pkg/task/worker_test.go b/pkg/task/worker_test.go deleted file mode 100644 index 64c6551..0000000 --- a/pkg/task/worker_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package task - -import ( - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/stretchr/testify/assert" -) - -type MockJob struct { - Err *JobError - Status int - DoFunc func() -} - -func (job *MockJob) Type() int { - panic("implement me") -} - -func (job *MockJob) Creator() uint { - panic("implement me") -} - -func (job *MockJob) Props() string { - panic("implement me") -} - -func (job *MockJob) Model() *model.Task { - panic("implement me") -} - -func (job *MockJob) SetStatus(status int) { - job.Status = status -} - -func (job *MockJob) Do() { - job.DoFunc() -} - -func (job *MockJob) SetError(*JobError) { -} - -func (job *MockJob) GetError() *JobError { - return job.Err -} - -func TestGeneralWorker_Do(t *testing.T) { - asserts := assert.New(t) - worker := &GeneralWorker{} - job := &MockJob{} - - // 正常 - { - job.DoFunc = func() { - } - worker.Do(job) - asserts.Equal(Complete, job.Status) - } - - // 有错误 - { - job.DoFunc = func() { - } - job.Status = Queued - job.Err = &JobError{Msg: "error"} - worker.Do(job) - asserts.Equal(Error, job.Status) - } - - // 有致命错误 - { - job.DoFunc = func() { - panic("mock fatal error") - } - job.Status = Queued - job.Err = nil - worker.Do(job) - asserts.Equal(Error, job.Status) - } - -} diff --git a/pkg/thumb/builtin.go b/pkg/thumb/builtin.go index 206d046..fadac6e 100644 --- a/pkg/thumb/builtin.go +++ b/pkg/thumb/builtin.go @@ -14,6 +14,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gofrs/uuid" + //"github.com/nfnt/resize" "golang.org/x/image/draw" ) diff --git a/pkg/util/common_test.go b/pkg/util/common_test.go deleted file mode 100644 index ae4c47f..0000000 --- a/pkg/util/common_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package util - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestRandStringRunes(t *testing.T) { - asserts := assert.New(t) - - // 0 长度字符 - randStr := RandStringRunes(0) - asserts.Len(randStr, 0) - - // 16 长度字符 - randStr = RandStringRunes(16) - asserts.Len(randStr, 16) - - // 32 长度字符 - randStr = RandStringRunes(32) - asserts.Len(randStr, 32) - - //相同长度字符 - sameLenStr1 := RandStringRunes(32) - sameLenStr2 := RandStringRunes(32) - asserts.NotEqual(sameLenStr1, sameLenStr2) -} - -func TestContainsUint(t *testing.T) { - asserts := assert.New(t) - asserts.True(ContainsUint([]uint{0, 2, 3, 65, 4}, 65)) - asserts.True(ContainsUint([]uint{65}, 65)) - asserts.False(ContainsUint([]uint{65}, 6)) -} - -func TestContainsString(t *testing.T) { - asserts := assert.New(t) - asserts.True(ContainsString([]string{"", "1"}, "")) - asserts.True(ContainsString([]string{"", "1"}, "1")) - asserts.False(ContainsString([]string{"", "1"}, " ")) -} - -func TestReplace(t *testing.T) { - asserts := assert.New(t) - - asserts.Equal("origin", Replace(map[string]string{ - "123": "321", - }, "origin")) - - asserts.Equal("321origin321", Replace(map[string]string{ - "123": "321", - }, "123origin123")) - asserts.Equal("321new321", Replace(map[string]string{ - "123": "321", - "origin": "new", - }, "123origin123")) -} - -func TestBuildRegexp(t *testing.T) { - asserts := assert.New(t) - - asserts.Equal("^/dir/", BuildRegexp([]string{"/dir"}, "^", "/", "|")) - asserts.Equal("^/dir/|^/dir/di\\*r/", BuildRegexp([]string{"/dir", "/dir/di*r"}, "^", "/", "|")) -} - -func TestBuildConcat(t *testing.T) { - asserts := assert.New(t) - asserts.Equal("CONCAT(1,2)", BuildConcat("1", "2", "mysql")) - asserts.Equal("1||2", BuildConcat("1", "2", "sqlite")) -} - -func TestSliceDifference(t *testing.T) { - asserts := assert.New(t) - - { - s1 := []string{"1", "2", "3", "4"} - s2 := []string{"2", "4"} - asserts.Equal([]string{"1", "3"}, SliceDifference(s1, s2)) - } - - { - s2 := []string{"1", "2", "3", "4"} - s1 := []string{"2", "4"} - asserts.Equal([]string{}, SliceDifference(s1, s2)) - } - - { - s1 := []string{"1", "2", "3", "4"} - s2 := []string{"1", "2", "3", "4"} - asserts.Equal([]string{}, SliceDifference(s1, s2)) - } - - { - s1 := []string{"1", "2", "3", "4"} - s2 := []string{} - asserts.Equal([]string{"1", "2", "3", "4"}, SliceDifference(s1, s2)) - } -} diff --git a/pkg/util/io_test.go b/pkg/util/io_test.go deleted file mode 100644 index 755d203..0000000 --- a/pkg/util/io_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package util - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestExists(t *testing.T) { - asserts := assert.New(t) - asserts.True(Exists("io_test.go")) - asserts.False(Exists("io_test.js")) -} - -func TestCreatNestedFile(t *testing.T) { - asserts := assert.New(t) - - // 父目录不存在 - { - file, err := CreatNestedFile("test/nest.txt") - asserts.NoError(err) - asserts.NoError(file.Close()) - asserts.FileExists("test/nest.txt") - } - - // 父目录存在 - { - file, err := CreatNestedFile("test/direct.txt") - asserts.NoError(err) - asserts.NoError(file.Close()) - asserts.FileExists("test/direct.txt") - } -} - -func TestIsEmpty(t *testing.T) { - asserts := assert.New(t) - - asserts.False(IsEmpty("")) - asserts.False(IsEmpty("not_exist")) -} diff --git a/pkg/util/logger_test.go b/pkg/util/logger_test.go deleted file mode 100644 index 5f1352d..0000000 --- a/pkg/util/logger_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// +build !race - -package util - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestBuildLogger(t *testing.T) { - asserts := assert.New(t) - asserts.NotPanics(func() { - BuildLogger("error") - }) - asserts.NotPanics(func() { - BuildLogger("warning") - }) - asserts.NotPanics(func() { - BuildLogger("info") - }) - asserts.NotPanics(func() { - BuildLogger("?") - }) - asserts.NotPanics(func() { - BuildLogger("debug") - }) -} - -func TestLog(t *testing.T) { - asserts := assert.New(t) - asserts.NotNil(Log()) - GloablLogger = nil - asserts.NotNil(Log()) -} - -func TestLogger_Debug(t *testing.T) { - asserts := assert.New(t) - l := Logger{ - level: LevelDebug, - } - asserts.NotPanics(func() { - l.Debug("123") - }) - l.level = LevelError - asserts.NotPanics(func() { - l.Debug("123") - }) -} - -func TestLogger_Info(t *testing.T) { - asserts := assert.New(t) - l := Logger{ - level: LevelDebug, - } - asserts.NotPanics(func() { - l.Info("123") - }) - l.level = LevelError - asserts.NotPanics(func() { - l.Info("123") - }) -} -func TestLogger_Warning(t *testing.T) { - asserts := assert.New(t) - l := Logger{ - level: LevelDebug, - } - asserts.NotPanics(func() { - l.Warning("123") - }) - l.level = LevelError - asserts.NotPanics(func() { - l.Warning("123") - }) -} - -func TestLogger_Error(t *testing.T) { - asserts := assert.New(t) - l := Logger{ - level: LevelDebug, - } - asserts.NotPanics(func() { - l.Error("123") - }) - l.level = -1 - asserts.NotPanics(func() { - l.Error("123") - }) -} - -func TestLogger_Panic(t *testing.T) { - asserts := assert.New(t) - l := Logger{ - level: LevelDebug, - } - asserts.Panics(func() { - l.Panic("123") - }) - l.level = -1 - asserts.NotPanics(func() { - l.Error("123") - }) -} diff --git a/pkg/util/path.go b/pkg/util/path.go index 2dd8aef..ff51d57 100644 --- a/pkg/util/path.go +++ b/pkg/util/path.go @@ -56,4 +56,3 @@ func RelativePath(name string) string { e, _ := os.Executable() return filepath.Join(filepath.Dir(e), name) } - diff --git a/pkg/util/path_test.go b/pkg/util/path_test.go deleted file mode 100644 index a417a9c..0000000 --- a/pkg/util/path_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package util - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestDotPathToStandardPath(t *testing.T) { - asserts := assert.New(t) - - asserts.Equal("/", DotPathToStandardPath("")) - asserts.Equal("/目录", DotPathToStandardPath("目录")) - asserts.Equal("/目录/目录2", DotPathToStandardPath("目录,目录2")) -} - -func TestFillSlash(t *testing.T) { - asserts := assert.New(t) - asserts.Equal("/", FillSlash("/")) - asserts.Equal("/", FillSlash("")) - asserts.Equal("/123/", FillSlash("/123")) -} - -func TestRemoveSlash(t *testing.T) { - asserts := assert.New(t) - asserts.Equal("/", RemoveSlash("/")) - asserts.Equal("/123/1236", RemoveSlash("/123/1236")) - asserts.Equal("/123/1236", RemoveSlash("/123/1236/")) -} - -func TestSplitPath(t *testing.T) { - asserts := assert.New(t) - asserts.Equal([]string{}, SplitPath("")) - asserts.Equal([]string{}, SplitPath("1")) - asserts.Equal([]string{"/"}, SplitPath("/")) - asserts.Equal([]string{"/", "123", "321"}, SplitPath("/123/321")) -} diff --git a/pkg/util/ztool.go b/pkg/util/ztool.go new file mode 100644 index 0000000..9733b6f --- /dev/null +++ b/pkg/util/ztool.go @@ -0,0 +1,35 @@ +package util + +import "strings" + +type ( + // 可取长度类型 + LenAble interface{ string | []any | chan any } +) + +// 计算切片元素总长度 +/* + 传入字符串切片, 返回其所有元素长度之和 + e.g. LenArray({`ele3`,`ele2`,`ele1`}) => 12 +*/ +func LenArray[T LenAble](a []T) int { + var o int + for i, r := 0, len(a); i < r; i++ { + o += len(a[i]) + } + return o +} + +// 字符串快速拼接 +/* + 传入多个字符串参数, 返回拼接后的结果 + e.g: StrConcat("str1", "str2", "str3") => "str1str2str3" +*/ +func StrConcat(a ...string) string { + var b strings.Builder + b.Grow(LenArray(a)) + for i, r := 0, len(a); i < r; i++ { + b.WriteString(a[i]) + } + return b.String() +} diff --git a/pkg/vol/vol.go b/pkg/vol/vol.go new file mode 100755 index 0000000..58dae92 --- /dev/null +++ b/pkg/vol/vol.go @@ -0,0 +1,39 @@ +package vol + +import ( + "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "net/http" +) + +var ClientSecret = "" + +const CRMSite = "https://pro.cloudreve.org/crm/api/vol/" + +type Client interface { + // Sync VOL from CRM, return content (base64 encoded) and signature. + Sync() (string, string, error) +} + +type VolClient struct { + secret string + client request.Client +} + +func New(secret string) Client { + return &VolClient{secret: secret, client: request.NewClient()} +} + +func (c *VolClient) Sync() (string, string, error) { + res, err := c.client.Request("GET", CRMSite+c.secret, nil).CheckHTTPResponse(http.StatusOK).DecodeResponse() + if err != nil { + return "", "", fmt.Errorf("failed to get VOL from CRM: %w", err) + } + + if res.Code != 0 { + return "", "", fmt.Errorf("CRM return error: %s", res.Msg) + } + + vol := res.Data.(map[string]interface{}) + return vol["content"].(string), vol["signature"].(string), nil +} diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go index 5ab9906..78cc9f4 100644 --- a/pkg/webdav/webdav.go +++ b/pkg/webdav/webdav.go @@ -414,6 +414,9 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst ctx = context.WithValue(ctx, fsctx.FileModelCtx, *originFile) fileData.Mode |= fsctx.Overwrite } else { + // 尝试获取并重设存储策略 + fs.SetPolicyFromPath(filePath) + // 给文件系统分配钩子 fs.Use("BeforeUpload", filesystem.HookValidateFile) fs.Use("BeforeUpload", filesystem.HookValidateCapacity) diff --git a/pkg/wopi/discovery_test.go b/pkg/wopi/discovery_test.go deleted file mode 100644 index 8092384..0000000 --- a/pkg/wopi/discovery_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package wopi - -import ( - "errors" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "io" - "net/http" - "net/url" - "strings" - "testing" -) - -func TestClient_AvailableExts(t *testing.T) { - a := assert.New(t) - endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery") - client := &client{ - cache: cache.NewMemoStore(), - config: config{ - discoveryEndpoint: endpoint, - }, - } - - // Discovery failed - { - expectedErr := errors.New("error") - mockHttp := &requestmock.RequestMock{} - client.http = mockHttp - mockHttp.On( - "Request", - "GET", - endpoint.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: expectedErr, - }) - res := client.AvailableExts() - a.Empty(res) - mockHttp.AssertExpectations(t) - } - - // pass - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{}, - }, - ".ppt": {}, - ".xls": { - "not_supported": Action{}, - }, - } - res := client.AvailableExts() - a.Len(res, 1) - a.Equal("doc", res[0]) - } -} - -func TestClient_RefreshDiscovery(t *testing.T) { - a := assert.New(t) - endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery") - client := &client{ - cache: cache.NewMemoStore(), - config: config{ - discoveryEndpoint: endpoint, - }, - } - - // cache hit - { - client.cache.Set(DiscoverResponseCacheKey, WopiDiscovery{Text: "test"}, 0) - a.NoError(client.checkDiscovery()) - a.Equal("test", client.discovery.Text) - client.discovery = &WopiDiscovery{} - client.cache.Delete([]string{DiscoverResponseCacheKey}, "") - } - - // malformed xml - { - mockHttp := &requestmock.RequestMock{} - client.http = mockHttp - mockHttp.On( - "Request", - "GET", - endpoint.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`{"code":203}`)), - }, - }) - res := client.refreshDiscovery() - a.ErrorContains(res, "failed to parse") - mockHttp.AssertExpectations(t) - } - - // all pass - { - testResponse := ` -` - mockHttp := &requestmock.RequestMock{} - client.http = mockHttp - mockHttp.On( - "Request", - "GET", - endpoint.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Response: &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(testResponse)), - }, - }) - res := client.refreshDiscovery() - a.NoError(res, res) - a.NotEmpty(client.actions[".docx"]) - a.NotEmpty(client.actions[".docx"][string(ActionPreview)]) - a.NotEmpty(client.actions[".docx"][string(ActionEdit)]) - mockHttp.AssertExpectations(t) - } -} diff --git a/pkg/wopi/wopi.go b/pkg/wopi/wopi.go index 2938de0..7a7b296 100644 --- a/pkg/wopi/wopi.go +++ b/pkg/wopi/wopi.go @@ -3,17 +3,18 @@ package wopi import ( "errors" "fmt" + "net/url" + "path" + "strings" + "sync" + "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gofrs/uuid" - "net/url" - "path" - "strings" - "sync" - "time" ) type Client interface { diff --git a/pkg/wopi/wopi_test.go b/pkg/wopi/wopi_test.go deleted file mode 100644 index 78c4bcc..0000000 --- a/pkg/wopi/wopi_test.go +++ /dev/null @@ -1,184 +0,0 @@ -package wopi - -import ( - "database/sql" - "errors" - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/cache" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/cachemock" - "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock" - "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" - testMock "github.com/stretchr/testify/mock" - "net/url" - "testing" -) - -var mock sqlmock.Sqlmock - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - model.DB, _ = gorm.Open("mysql", db) - defer db.Close() - m.Run() -} - -func TestNewSession(t *testing.T) { - a := assert.New(t) - endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery") - client := &client{ - cache: cache.NewMemoStore(), - config: config{ - discoveryEndpoint: endpoint, - }, - } - - // Discovery failed - { - expectedErr := errors.New("error") - mockHttp := &requestmock.RequestMock{} - client.http = mockHttp - mockHttp.On( - "Request", - "GET", - endpoint.String(), - testMock.Anything, - testMock.Anything, - ).Return(&request.Response{ - Err: expectedErr, - }) - res, err := client.NewSession(0, &model.File{}, ActionPreview) - a.Nil(res) - a.ErrorIs(err, expectedErr) - mockHttp.AssertExpectations(t) - } - - // not supported ext - { - client.discovery = &WopiDiscovery{} - client.actions = make(map[string]map[string]Action) - res, err := client.NewSession(0, &model.File{}, ActionPreview) - a.Nil(res) - a.ErrorIs(err, ErrActionNotSupported) - } - - // preferred action not supported - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": {}, - } - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionPreview) - a.Nil(res) - a.ErrorIs(err, ErrActionNotSupported) - } - - // src url cannot be parsed - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{ - Urlsrc: string([]byte{0x7f}), - }, - }, - } - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit) - a.Nil(res) - a.ErrorContains(err, "invalid control character in URL") - } - - // all pass - default placeholder - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{ - Urlsrc: "https://doc.com/doc", - }, - }, - } - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit) - a.NotNil(res) - a.NoError(err) - resUrl := res.ActionURL.String() - a.Contains(resUrl, wopiSrcParamDefault) - } - - // all pass - with placeholders - { - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{ - Urlsrc: "https://doc.com/doc?origin=preserved&", - }, - }, - } - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit) - a.NotNil(res) - a.NoError(err) - resUrl := res.ActionURL.String() - a.Contains(resUrl, "origin=preserved") - a.Contains(resUrl, "dc=lng") - a.Contains(resUrl, "src=") - a.NotContains(resUrl, "notsuported") - } - - // cache operation failed - { - mockCache := &cachemock.CacheClientMock{} - expectedErr := errors.New("error") - client.cache = mockCache - client.discovery = &WopiDiscovery{} - client.actions = map[string]map[string]Action{ - ".doc": { - string(ActionPreviewFallback): Action{ - Urlsrc: "https://doc.com/doc", - }, - }, - } - mockCache.On("Set", testMock.Anything, testMock.Anything, testMock.Anything).Return(expectedErr) - res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit) - a.Nil(res) - a.ErrorIs(err, expectedErr) - } -} - -func TestInit(t *testing.T) { - a := assert.New(t) - - // not enabled - { - a.Nil(Default) - Default = &client{} - Init() - a.Nil(Default) - } - - // throw error - { - a.Nil(Default) - cache.Set("setting_wopi_enabled", "1", 0) - cache.Set("setting_wopi_endpoint", string([]byte{0x7f}), 0) - Init() - a.Nil(Default) - } - - // all pass - { - a.Nil(Default) - cache.Set("setting_wopi_enabled", "1", 0) - cache.Set("setting_wopi_endpoint", "", 0) - Init() - a.NotNil(Default) - } -} diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go index 26d917e..e9f4dac 100644 --- a/routers/controllers/admin.go +++ b/routers/controllers/admin.go @@ -1,14 +1,15 @@ package controllers import ( - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" - "github.com/cloudreve/Cloudreve/v3/pkg/mq" - "io" + // "io" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/email" - "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" + + // "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/wopi" "github.com/cloudreve/Cloudreve/v3/service/admin" @@ -27,17 +28,17 @@ func AdminSummary(c *gin.Context) { } // AdminNews 获取社区新闻 -func AdminNews(c *gin.Context) { - tag := "announcements" - if c.Query("tag") != "" { - tag = c.Query("tag") - } - r := request.NewClient() - res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3A"+tag+"&sort=-startTime&page%5Blimit%5D=10", nil) - if res.Err == nil { - io.Copy(c.Writer, res.Response.Body) - } -} +// func AdminNews(c *gin.Context) { +// tag := "announcements" +// if c.Query("tag") != "" { +// tag = c.Query("tag") +// } +// r := request.NewClient() +// res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3A"+tag+"&sort=-startTime&page%5Blimit%5D=10", nil) +// if res.Err == nil { +// io.Copy(c.Writer, res.Response.Body) +// } +// } // AdminChangeSetting 获取站点设定项 func AdminChangeSetting(c *gin.Context) { @@ -109,6 +110,39 @@ func AdminTestThumbGenerator(c *gin.Context) { } } +// AdminListRedeems 列出激活码 +func AdminListRedeems(c *gin.Context) { + var service admin.AdminListService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Redeems() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminGenerateRedeems 生成激活码 +func AdminGenerateRedeems(c *gin.Context) { + var service admin.GenerateRedeemsService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Generate() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminDeleteRedeem 删除激活码 +func AdminDeleteRedeem(c *gin.Context) { + var service admin.SingleIDService + if err := c.ShouldBindUri(&service); err == nil { + res := service.DeleteRedeem() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // AdminTestAria2 测试aria2连接 func AdminTestAria2(c *gin.Context) { var service admin.Aria2TestService @@ -389,6 +423,28 @@ func AdminDeleteShare(c *gin.Context) { } } +// AdminListOrder 列出订单 +func AdminListOrder(c *gin.Context) { + var service admin.AdminListService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Orders() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminDeleteOrder 批量删除订单 +func AdminDeleteOrder(c *gin.Context) { + var service admin.OrderBatchService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Delete(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // AdminListDownload 列出离线下载任务 func AdminListDownload(c *gin.Context) { var service admin.AdminListService @@ -455,6 +511,28 @@ func AdminListFolders(c *gin.Context) { } } +// AdminListReport 列出未处理举报 +func AdminListReport(c *gin.Context) { + var service admin.AdminListService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Reports() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AdminDeleteTask 批量删除举报 +func AdminDeleteReport(c *gin.Context) { + var service admin.ReportBatchService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Delete() + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // AdminListNodes 列出从机节点 func AdminListNodes(c *gin.Context) { var service admin.AdminListService @@ -509,3 +587,10 @@ func AdminGetNode(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// AdminSyncVol 同步VOL授权 +func AdminSyncVol(c *gin.Context) { + var service admin.VolService + res := service.Sync() + c.JSON(200, res) +} diff --git a/routers/controllers/file.go b/routers/controllers/file.go index 0e7c206..e32061b 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -3,10 +3,10 @@ package controllers import ( "context" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "net/http" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/service/explorer" @@ -51,6 +51,17 @@ func Compress(c *gin.Context) { } } +// Relocate 创建文件转移任务 +func Relocate(c *gin.Context) { + var service explorer.ItemRelocateService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.CreateRelocateTask(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // Decompress 创建文件解压缩任务 func Decompress(c *gin.Context) { var service explorer.ItemDecompressService @@ -135,6 +146,7 @@ func AnonymousPermLink(c *gin.Context) { } +// GetSource 获取文件的外链地址 func GetSource(c *gin.Context) { // 创建上下文 ctx, cancel := context.WithCancel(context.Background()) @@ -304,52 +316,6 @@ func FileUpload(c *gin.Context) { } else { c.JSON(200, ErrorResponse(err)) } - - //fileData := fsctx.FileStream{ - // MIMEType: c.Request.Header.Get("Content-Type"), - // File: c.Request.Body, - // Size: fileSize, - // Name: fileName, - // VirtualPath: filePath, - // Mode: fsctx.Create, - //} - // - //// 创建文件系统 - //fs, err := filesystem.NewFileSystemFromContext(c) - //if err != nil { - // c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)) - // return - //} - // - //// 非可用策略时拒绝上传 - //if !fs.Policy.IsTransitUpload(fileSize) { - // request.BlackHole(c.Request.Body) - // c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, "当前存储策略无法使用", nil)) - // return - //} - // - //// 给文件系统分配钩子 - //fs.Use("BeforeUpload", filesystem.HookValidateFile) - //fs.Use("BeforeUpload", filesystem.HookValidateCapacity) - //fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile) - //fs.Use("AfterUploadCanceled", filesystem.HookGiveBackCapacity) - //fs.Use("AfterUpload", filesystem.GenericAfterUpload) - //fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile) - //fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity) - //fs.Use("AfterUploadFailed", filesystem.HookGiveBackCapacity) - // - //// 执行上传 - //ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{}) - //uploadCtx := context.WithValue(ctx, fsctx.GinCtx, c) - //err = fs.Upload(uploadCtx, &fileData) - //if err != nil { - // c.JSON(200, serializer.Err(serializer.CodeUploadFailed, err.Error(), err)) - // return - //} - // - //c.JSON(200, serializer.Response{ - // Code: 0, - //}) } // DeleteUploadSession 删除上传会话 diff --git a/routers/controllers/share.go b/routers/controllers/share.go index 8795c1e..b9bcb0a 100644 --- a/routers/controllers/share.go +++ b/routers/controllers/share.go @@ -173,6 +173,17 @@ func GetShareDocPreview(c *gin.Context) { } } +// SaveShare 转存他人分享 +func SaveShare(c *gin.Context) { + var service share.Service + if err := c.ShouldBindJSON(&service); err == nil { + res := service.SaveToMyFile(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // ListSharedFolder 列出分享的目录下的对象 func ListSharedFolder(c *gin.Context) { var service share.Service @@ -235,3 +246,14 @@ func GetUserShare(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// ReportShare 举报分享 +func ReportShare(c *gin.Context) { + var service share.ShareReportService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Report(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/controllers/site.go b/routers/controllers/site.go index c4a3508..141deed 100644 --- a/routers/controllers/site.go +++ b/routers/controllers/site.go @@ -14,13 +14,17 @@ import ( func SiteConfig(c *gin.Context) { siteConfig := model.GetSettingByNames( "siteName", + "siteNotice", "login_captcha", + "qq_login", "reg_captcha", "email_active", "forget_captcha", - "email_active", + // "email_active", "themes", "defaultTheme", + "score_enabled", + "share_score_rate", "home_view_method", "share_view_method", "authn_enabled", @@ -28,7 +32,10 @@ func SiteConfig(c *gin.Context) { "captcha_type", "captcha_TCaptcha_CaptchaAppId", "register_enabled", + "report_enabled", "show_app_promotion", + "app_forum_link", + "app_feedback_link", ) var wopiExts []string @@ -49,8 +56,8 @@ func SiteConfig(c *gin.Context) { // Ping 状态检查页面 func Ping(c *gin.Context) { version := conf.BackendVersion - if conf.IsPro == "true" { - version += "-pro" + if conf.IsPlus == "true" { + version += "-plus" } c.JSON(200, serializer.Response{ @@ -139,3 +146,22 @@ func Manifest(c *gin.Context) { "background_color": options["pwa_background_color"], }) } + +// GetVolSecret 获取 VOL 密钥 +func GetVolSecret(c *gin.Context) { + vol := model.GetSettingByNames("vol_content", "vol_signature") + if vol["vol_signature"] == "" { + c.JSON(200, serializer.Response{ + Code: serializer.CodeNotFound, + }) + + return + } + + c.JSON(200, serializer.Response{ + Data: serializer.VolResponse{ + Signature: vol["vol_signature"], + Content: vol["vol_content"], + }, + }) +} diff --git a/routers/controllers/user.go b/routers/controllers/user.go index 5d6301e..648fd1b 100644 --- a/routers/controllers/user.go +++ b/routers/controllers/user.go @@ -6,6 +6,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/authn" + "github.com/cloudreve/Cloudreve/v3/pkg/qq" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/thumb" @@ -213,6 +214,24 @@ func UserActivate(c *gin.Context) { } } +// UserQQLogin 初始化QQ登录 +func UserQQLogin(c *gin.Context) { + // 新建绑定 + res, err := qq.NewLoginRequest() + if err != nil { + c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法使用QQ登录", err)) + return + } + + // 设定QQ登录会话Secret + util.SetSession(c, map[string]interface{}{"qq_login_secret": res.SecretKey}) + + c.JSON(200, serializer.Response{ + Data: res.URL, + }) + +} + // UserSignOut 用户退出登录 func UserSignOut(c *gin.Context) { util.DeleteSession(c, "user_id") @@ -233,6 +252,28 @@ func UserStorage(c *gin.Context) { c.JSON(200, res) } +// UserAvailablePolicies 用户存储策略设置 +func UserAvailablePolicies(c *gin.Context) { + var service user.SettingService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Policy(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// UserAvailableNodes 用户可选节点 +func UserAvailableNodes(c *gin.Context) { + var service user.SettingService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Nodes(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // UserTasks 获取任务队列 func UserTasks(c *gin.Context) { var service user.SettingListService @@ -339,6 +380,12 @@ func UpdateOption(c *gin.Context) { switch service.Option { case "nick": subService = &user.ChangerNick{} + case "vip": + subService = &user.VIPUnsubscribe{} + case "qq": + subService = &user.QQBind{} + case "policy": + subService = &user.PolicyChange{} case "homepage": subService = &user.HomePage{} case "password": diff --git a/routers/controllers/vas.go b/routers/controllers/vas.go new file mode 100755 index 0000000..9177941 --- /dev/null +++ b/routers/controllers/vas.go @@ -0,0 +1,214 @@ +package controllers + +import ( + "github.com/cloudreve/Cloudreve/v3/pkg/cache" + "github.com/cloudreve/Cloudreve/v3/pkg/payment" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/cloudreve/Cloudreve/v3/service/vas" + "github.com/gin-gonic/gin" + "github.com/iGoogle-ink/gopay" + "github.com/iGoogle-ink/gopay/wechat/v3" + "github.com/qingwg/payjs/notify" + "github.com/smartwalle/alipay/v3" + "net/http" +) + +// GetQuota 获取容量配额信息 +func GetQuota(c *gin.Context) { + var service vas.GeneralVASService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Quota(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// GetProduct 获取商品信息 +func GetProduct(c *gin.Context) { + var service vas.GeneralVASService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Products(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// NewOrder 新建支付订单 +func NewOrder(c *gin.Context) { + var service vas.CreateOrderService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Create(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// OrderStatus 查询订单状态 +func OrderStatus(c *gin.Context) { + var service vas.OrderService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Status(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// GetRedeemInfo 获取兑换码信息 +func GetRedeemInfo(c *gin.Context) { + var service vas.RedeemService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Query(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// DoRedeem 获取兑换码信息 +func DoRedeem(c *gin.Context) { + var service vas.RedeemService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Redeem(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// AlipayCallback 支付宝回调 +func AlipayCallback(c *gin.Context) { + pay, err := payment.NewPaymentInstance("alipay") + if err != nil { + util.Log().Debug("[Alipay callback] Failed to create alipay client, %s", err) + c.Status(400) + return + } + + res, err := pay.(*payment.Alipay).Client.GetTradeNotification(c.Request) + if err != nil { + util.Log().Debug("[Alipay callback] Failed to validate callback request, %s", err) + c.Status(403) + return + } + + if res != nil && res.TradeStatus == "TRADE_SUCCESS" { + // 支付成功 + if err := payment.OrderPaid(res.OutTradeNo); err != nil { + util.Log().Debug("[Alipay callback] Failed to process payment, %s", err) + } + } + + // 确认收到通知消息 + alipay.AckNotification(c.Writer) +} + +// WechatCallback 微信扫码支付回调 +func WechatCallback(c *gin.Context) { + pay, err := payment.NewPaymentInstance("wechat") + if err != nil { + util.Log().Debug("[Wechat pay callback] Failed to create alipay client, %s", err) + c.JSON(500, &wechat.V3NotifyRsp{Code: gopay.FAIL, Message: "Failed to create alipay client"}) + return + } + + notifyReq, err := wechat.V3ParseNotify(c.Request) + if err != nil { + util.Log().Debug("[Wechat pay callback] Failed to parse callback content, %s", err) + c.JSON(500, &wechat.V3NotifyRsp{Code: gopay.FAIL, Message: "Failed to parse callback content"}) + return + } + + err = notifyReq.VerifySign(pay.(*payment.Wechat).GetPlatformCert()) + if err != nil { + util.Log().Debug("[Wechat pay callback] Failed to verify callback signature, %s", err) + c.JSON(403, &wechat.V3NotifyRsp{Code: gopay.FAIL, Message: "Failed to verify callback signature"}) + return + } + + // 解密回调正文 + result, err := notifyReq.DecryptCipherText(pay.(*payment.Wechat).ApiV3Key) + if result != nil && result.TradeState == "SUCCESS" { + // 支付成功 + if err := payment.OrderPaid(result.OutTradeNo); err != nil { + util.Log().Debug("[Wechat pay callback] Failed to process payment, %s", err) + } + } + + // 确认收到通知消息 + c.JSON(http.StatusOK, &wechat.V3NotifyRsp{Code: gopay.SUCCESS, Message: "Success"}) +} + +// PayJSCallback PayJS回调 +func PayJSCallback(c *gin.Context) { + pay, err := payment.NewPaymentInstance("payjs") + if err != nil { + util.Log().Debug("[PayJS callback] Failed to initialize payment client, %s", err) + c.Status(400) + return + } + + payNotify := pay.(*payment.PayJSClient).Client.GetNotify(c.Request, c.Writer) + + //设置接收消息的处理方法 + payNotify.SetMessageHandler(func(msg notify.Message) { + if err := payment.OrderPaid(msg.OutTradeNo); err != nil { + util.Log().Debug("[PayJS callback] Failed to process payment, %s", err) + } + }) + + //处理消息接收以及回复 + err = payNotify.Serve() + if err != nil { + util.Log().Debug("[PayJS callback] Failed to process payment, %s", err) + return + } + + //发送回复的消息 + payNotify.SendResponseMsg() + +} + +// QQCallback QQ互联回调 +func QQCallback(c *gin.Context) { + var service vas.QQCallbackService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Callback(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// CustomCallback PayJS回调 +func CustomCallback(c *gin.Context) { + orderNo := c.Param("orderno") + sessionID := c.Param("id") + sessionRaw, exist := cache.Get(payment.CallbackSessionPrefix + sessionID) + if !exist { + util.Log().Debug("[Custom callback] Failed to process payment, session not found") + c.JSON(200, serializer.Err(serializer.CodeNotFound, "session not found", nil)) + return + } + + expectedID := sessionRaw.(string) + if expectedID != orderNo { + util.Log().Debug("[Custom callback] Failed to process payment, session mismatch") + c.JSON(200, serializer.Err(serializer.CodeInternalSetting, "session mismatch", nil)) + return + } + + cache.Deletes([]string{sessionID}, payment.CallbackSessionPrefix) + + if err := payment.OrderPaid(orderNo); err != nil { + c.JSON(200, serializer.Err(serializer.CodeInternalSetting, "failed to fulfill payment", err)) + util.Log().Debug("[Custom callback] Failed to process payment, %s", err) + return + } + + c.JSON(200, serializer.Response{}) +} diff --git a/routers/controllers/webdav.go b/routers/controllers/webdav.go index 0453ada..e3baf91 100644 --- a/routers/controllers/webdav.go +++ b/routers/controllers/webdav.go @@ -2,6 +2,9 @@ package controllers import ( "context" + "net/http" + "sync" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" @@ -9,8 +12,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/webdav" "github.com/cloudreve/Cloudreve/v3/service/setting" "github.com/gin-gonic/gin" - "net/http" - "sync" ) var handler *webdav.Handler @@ -92,6 +93,28 @@ func UpdateWebDAVAccounts(c *gin.Context) { } } +// DeleteWebDAVMounts 删除WebDAV挂载 +func DeleteWebDAVMounts(c *gin.Context) { + var service setting.WebDAVListService + if err := c.ShouldBindUri(&service); err == nil { + res := service.Unmount(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + +// UpdateWebDAVAccountsReadonly 更改WebDAV账户只读性 +func UpdateWebDAVAccountsReadonly(c *gin.Context) { + var service setting.WebDAVAccountUpdateReadonlyService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Update(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} + // CreateWebDAVAccounts 创建WebDAV账户 func CreateWebDAVAccounts(c *gin.Context) { var service setting.WebDAVAccountCreateService @@ -102,3 +125,14 @@ func CreateWebDAVAccounts(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// CreateWebDAVMounts 创建WebDAV目录挂载 +func CreateWebDAVMounts(c *gin.Context) { + var service setting.WebDAVMountCreateService + if err := c.ShouldBindJSON(&service); err == nil { + res := service.Create(c, CurrentUser(c)) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/main_test.go b/routers/main_test.go deleted file mode 100644 index 83664bd..0000000 --- a/routers/main_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package routers - -import ( - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/gin-gonic/gin" - "github.com/jinzhu/gorm" -) - -var mock sqlmock.Sqlmock -var memDB *gorm.DB -var mockDB *gorm.DB - -// TestMain 初始化数据库Mock -func TestMain(m *testing.M) { - // 设置gin为测试模式 - gin.SetMode(gin.TestMode) - - // 初始化sqlmock - var db *sql.DB - var err error - db, mock, err = sqlmock.New() - if err != nil { - panic("An error was not expected when opening a stub database connection") - } - - // 初始话内存数据库 - model.Init() - memDB = model.DB - - mockDB, _ = gorm.Open("mysql", db) - model.DB = memDB - defer db.Close() - - m.Run() -} - -func switchToMemDB() { - model.DB = memDB -} - -func switchToMockDB() { - model.DB = mockDB -} diff --git a/routers/router.go b/routers/router.go index aa8e903..f2a187b 100644 --- a/routers/router.go +++ b/routers/router.go @@ -1,6 +1,8 @@ package routers import ( + // "github.com/abslant/gzip" + "github.com/cloudreve/Cloudreve/v3/bootstrap" "github.com/cloudreve/Cloudreve/v3/middleware" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" @@ -117,11 +119,14 @@ func InitCORS(router *gin.Engine) { // InitMasterRouter 初始化主机模式路由 func InitMasterRouter() *gin.Engine { r := gin.Default() + bootstrap.InitCustomRoute(r.Group("/api/v3")) + // bootstrap.InitCustomRoute(r.Group("/custom")) /* 静态资源 */ r.Use(gzip.Gzip(gzip.DefaultCompression, gzip.WithExcludedPaths([]string{"/api/"}))) + // r.Use(gzip.GzipHandler()) r.Use(middleware.FrontendFileHandler()) r.GET("manifest.json", controllers.Manifest) @@ -165,6 +170,8 @@ func InitMasterRouter() *gin.Engine { site.GET("captcha", controllers.Captcha) // 站点全局配置 site.GET("config", middleware.CSRFInit(), controllers.SiteConfig) + // VOL 密钥 + site.GET("vol", controllers.GetVolSecret) } // 用户相关路由 @@ -190,6 +197,8 @@ func InitMasterRouter() *gin.Engine { middleware.HashID(hashid.UserID), controllers.UserActivate, ) + // 初始化QQ登录 + user.POST("qq", controllers.UserQQLogin) // WebAuthn登陆初始化 user.GET("authn/:username", middleware.IsFunctionEnabled("authn_enabled"), @@ -267,6 +276,32 @@ func InitMasterRouter() *gin.Engine { // 回调接口 callback := v3.Group("callback") { + // QQ互联回调 + callback.POST( + "qq", + controllers.QQCallback, + ) + // PAYJS回调 + callback.POST( + "payjs", + controllers.PayJSCallback, + ) + // 支付宝回调 + callback.POST( + "alipay", + controllers.AlipayCallback, + ) + // 微信扫码支付回调 + callback.POST( + "wechat", + controllers.WechatCallback, + ) + // Custom payment callback + callback.GET( + "custom/:orderno/:id", + middleware.SignRequired(auth.General), + controllers.CustomCallback, + ) // 远程策略上传回调 callback.POST( "remote/:sessionID/:key", @@ -392,8 +427,14 @@ func InitMasterRouter() *gin.Engine { middleware.ShareCanPreview(), controllers.ShareThumb, ) + // 举报分享 + share.POST("report/:id", + middleware.IsFunctionEnabled("report_enabled"), + middleware.CheckShareUnlocked(), + controllers.ReportShare, + ) // 搜索公共分享 - v3.Group("share").GET("search", controllers.SearchShare) + v3.Group("share").GET("search", middleware.AuthRequired(), controllers.SearchShare) } wopi := v3.Group( @@ -422,7 +463,7 @@ func InitMasterRouter() *gin.Engine { // 获取站点概况 admin.GET("summary", controllers.AdminSummary) // 获取社区新闻 - admin.GET("news", controllers.AdminNews) + // admin.GET("news", controllers.AdminNews) // 更改设置 admin.PATCH("setting", controllers.AdminChangeSetting) // 获取设置 @@ -447,6 +488,22 @@ func InitMasterRouter() *gin.Engine { aria2.POST("test", controllers.AdminTestAria2) } + vol := admin.Group("vol") + { + vol.GET("sync", controllers.AdminSyncVol) + } + + // 兑换码相关 + redeem := admin.Group("redeem") + { + // 列出激活码 + redeem.POST("list", controllers.AdminListRedeems) + // 生成激活码 + redeem.POST("", controllers.AdminGenerateRedeems) + // 删除激活码 + redeem.DELETE(":id", controllers.AdminDeleteRedeem) + } + // 存储策略管理 policy := admin.Group("policy") { @@ -525,6 +582,14 @@ func InitMasterRouter() *gin.Engine { share.POST("delete", controllers.AdminDeleteShare) } + order := admin.Group("order") + { + // 列出订单 + order.POST("list", controllers.AdminListOrder) + // 删除 + order.POST("delete", controllers.AdminDeleteOrder) + } + download := admin.Group("download") { // 列出任务 @@ -543,6 +608,14 @@ func InitMasterRouter() *gin.Engine { task.POST("import", controllers.AdminCreateImportTask) } + report := admin.Group("report") + { + // 列出未处理举报 + report.POST("list", controllers.AdminListReport) + // 删除 + report.POST("delete", controllers.AdminDeleteReport) + } + node := admin.Group("node") { // 列出从机节点 @@ -585,6 +658,10 @@ func InitMasterRouter() *gin.Engine { // 用户设置 setting := user.Group("setting") { + // 获取用户可选存储策略 + setting.GET("policies", controllers.UserAvailablePolicies) + // 获取用户可选节点 + setting.GET("nodes", controllers.UserAvailableNodes) // 任务队列 setting.GET("tasks", controllers.UserTasks) // 获取当前用户设定 @@ -601,7 +678,7 @@ func InitMasterRouter() *gin.Engine { } // 文件 - file := auth.Group("file", middleware.HashID(hashid.FileID)) + file := auth.Group("file", middleware.PhoneRequired(), middleware.HashID(hashid.FileID)) { // 上传 upload := file.Group("upload") @@ -637,12 +714,14 @@ func InitMasterRouter() *gin.Engine { file.POST("compress", controllers.Compress) // 创建文件解压缩任务 file.POST("decompress", controllers.Decompress) - // 创建文件解压缩任务 + // 创建文件转移任务 + file.POST("relocate", controllers.Relocate) + // 搜索文件 file.GET("search/:type/:keywords", controllers.SearchFile) } // 离线下载任务 - aria2 := auth.Group("aria2") + aria2 := auth.Group("aria2", middleware.PhoneRequired()) { // 创建URL下载任务 aria2.POST("url", controllers.AddAria2URL) @@ -659,7 +738,7 @@ func InitMasterRouter() *gin.Engine { } // 目录 - directory := auth.Group("directory") + directory := auth.Group("directory", middleware.PhoneRequired()) { // 创建目录 directory.PUT("", controllers.CreateDirectory) @@ -668,7 +747,7 @@ func InitMasterRouter() *gin.Engine { } // 对象,文件和目录的抽象 - object := auth.Group("object") + object := auth.Group("object", middleware.PhoneRequired()) { // 删除对象 object.DELETE("", controllers.Delete) @@ -683,12 +762,19 @@ func InitMasterRouter() *gin.Engine { } // 分享 - share := auth.Group("share") + share := auth.Group("share", middleware.PhoneRequired()) { // 创建新分享 share.POST("", controllers.CreateShare) // 列出我的分享 share.GET("", controllers.ListShare) + // 转存他人分享 + share.POST("save/:id", + middleware.ShareAvailable(), + middleware.CheckShareUnlocked(), + middleware.BeforeShareDownload(), + controllers.SaveShare, + ) // 更新分享属性 share.PATCH(":id", middleware.ShareAvailable(), @@ -712,8 +798,25 @@ func InitMasterRouter() *gin.Engine { tag.DELETE(":id", middleware.HashID(hashid.TagID), controllers.DeleteTag) } + // 增值服务相关 + vas := auth.Group("vas", middleware.PhoneRequired()) + { + // 获取容量包及配额信息 + vas.GET("pack", controllers.GetQuota) + // 获取商品信息,同时返回支付信息 + vas.GET("product", controllers.GetProduct) + // 新建支付订单 + vas.POST("order", controllers.NewOrder) + // 查询订单状态 + vas.GET("order/:id", controllers.OrderStatus) + // 获取兑换码信息 + vas.GET("redeem/:code", controllers.GetRedeemInfo) + // 执行兑换 + vas.POST("redeem/:code", controllers.DoRedeem) + } + // WebDAV管理相关 - webdav := auth.Group("webdav") + webdav := auth.Group("webdav", middleware.PhoneRequired()) { // 获取账号信息 webdav.GET("accounts", controllers.GetWebDAVAccounts) @@ -721,6 +824,13 @@ func InitMasterRouter() *gin.Engine { webdav.POST("accounts", controllers.CreateWebDAVAccounts) // 删除账号 webdav.DELETE("accounts/:id", controllers.DeleteWebDAVAccounts) + // 删除目录挂载 + webdav.DELETE("mount/:id", + middleware.HashID(hashid.FolderID), + controllers.DeleteWebDAVMounts, + ) + // 创建目录挂载 + webdav.POST("mount", controllers.CreateWebDAVMounts) // 更新账号可读性和是否使用代理服务 webdav.PATCH("accounts", controllers.UpdateWebDAVAccounts) } diff --git a/routers/router_test.go b/routers/router_test.go deleted file mode 100644 index 2476de6..0000000 --- a/routers/router_test.go +++ /dev/null @@ -1,251 +0,0 @@ -package routers - -import ( - "github.com/cloudreve/Cloudreve/v3/pkg/conf" - "net/http" - "net/http/httptest" - "testing" - - model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/jinzhu/gorm" - "github.com/stretchr/testify/assert" -) - -func TestPing(t *testing.T) { - asserts := assert.New(t) - router := InitMasterRouter() - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/api/v3/site/ping", nil) - router.ServeHTTP(w, req) - - assert.Equal(t, 200, w.Code) - asserts.Contains(w.Body.String(), conf.BackendVersion) -} - -func TestCaptcha(t *testing.T) { - asserts := assert.New(t) - router := InitMasterRouter() - w := httptest.NewRecorder() - - req, _ := http.NewRequest( - "GET", - "/api/v3/site/captcha", - nil, - ) - - router.ServeHTTP(w, req) - - asserts.Equal(200, w.Code) - asserts.Contains(w.Body.String(), "base64") -} - -//func TestUserSession(t *testing.T) { -// mutex.Lock() -// defer mutex.Unlock() -// switchToMockDB() -// asserts := assert.New(t) -// router := InitMasterRouter() -// w := httptest.NewRecorder() -// -// // 创建测试用验证码 -// var configD = base64Captcha.ConfigDigit{ -// Height: 80, -// Width: 240, -// MaxSkew: 0.7, -// DotCount: 80, -// CaptchaLen: 1, -// } -// idKeyD, _ := base64Captcha.GenerateCaptcha("", configD) -// middleware.ContextMock = map[string]interface{}{ -// "captchaID": idKeyD, -// } -// -// testCases := []struct { -// settingRows *sqlmock.Rows -// userRows *sqlmock.Rows -// policyRows *sqlmock.Rows -// reqBody string -// expected interface{} -// }{ -// // 登录信息正确,不需要验证码 -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "0", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"), -// expected: serializer.BuildUserResponse(model.User{ -// Email: "admin@cloudreve.org", -// Nick: "admin", -// Policy: model.Policy{ -// Type: "local", -// OptionsSerialized: model.PolicyOption{FileType: []string{}}, -// }, -// }), -// }, -// // 登录信息正确,需要验证码,验证码错误 -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "1", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"), -// expected: serializer.ParamErr("验证码错误", nil), -// }, -// // 邮箱正确密码错误 -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "0", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"), -// expected: serializer.Err(401, "用户邮箱或密码错误", nil), -// }, -// //邮箱格式不正确 -// { -// reqBody: `{"userName":"admin@cloudreve","captchaCode":"captchaCode","Password":"admin123"}`, -// expected: serializer.Err(40001, "邮箱格式不正确", errors.New("Key: 'UserLoginService.UserName' Error:Field validation for 'UserName' failed on the 'email' tag")), -// }, -// // 用户被Ban -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "0", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options", "status"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}", model.Baned), -// expected: serializer.Err(403, "该账号已被封禁", nil), -// }, -// // 用户未激活 -// { -// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}). -// AddRow("login_captcha", "0", "login"), -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options", "status"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}", model.NotActivicated), -// expected: serializer.Err(403, "该账号未激活", nil), -// }, -// } -// -// for k, testCase := range testCases { -// if testCase.settingRows != nil { -// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.settingRows) -// } -// if testCase.userRows != nil { -// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.userRows) -// } -// if testCase.policyRows != nil { -// mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND \\(\\(`policies`.`id` = 1\\)\\)(.+)$").WillReturnRows(testCase.policyRows) -// } -// req, _ := http.NewRequest( -// "POST", -// "/api/v3/user/session", -// bytes.NewReader([]byte(testCase.reqBody)), -// ) -// router.ServeHTTP(w, req) -// -// asserts.Equal(200, w.Code) -// expectedJSON, _ := json.Marshal(testCase.expected) -// asserts.JSONEq(string(expectedJSON), w.Body.String(), "测试用例:%d", k) -// -// w.Body.Reset() -// asserts.NoError(mock.ExpectationsWereMet()) -// model.ClearCache() -// } -// -//} -// -//func TestSessionAuthCheck(t *testing.T) { -// mutex.Lock() -// defer mutex.Unlock() -// switchToMockDB() -// asserts := assert.New(t) -// router := InitMasterRouter() -// w := httptest.NewRecorder() -// -// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}")) -// expectedUser, _ := model.GetUserByID(1) -// -// testCases := []struct { -// userRows *sqlmock.Rows -// sessionMock map[string]interface{} -// contextMock map[string]interface{} -// expected interface{} -// }{ -// // 未登录 -// { -// expected: serializer.CheckLogin(), -// }, -// // 登录正常 -// { -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}). -// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"), -// sessionMock: map[string]interface{}{"user_id": 1}, -// expected: serializer.BuildUserResponse(expectedUser), -// }, -// // UID不存在 -// { -// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}), -// sessionMock: map[string]interface{}{"user_id": -1}, -// expected: serializer.CheckLogin(), -// }, -// } -// -// for _, testCase := range testCases { -// req, _ := http.NewRequest( -// "GET", -// "/api/v3/user/me", -// nil, -// ) -// if testCase.userRows != nil { -// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.userRows) -// } -// middleware.ContextMock = testCase.contextMock -// middleware.SessionMock = testCase.sessionMock -// router.ServeHTTP(w, req) -// expectedJSON, _ := json.Marshal(testCase.expected) -// -// asserts.Equal(200, w.Code) -// asserts.JSONEq(string(expectedJSON), w.Body.String()) -// asserts.NoError(mock.ExpectationsWereMet()) -// -// w.Body.Reset() -// } -// -//} - -func TestSiteConfigRoute(t *testing.T) { - switchToMemDB() - asserts := assert.New(t) - router := InitMasterRouter() - w := httptest.NewRecorder() - - req, _ := http.NewRequest( - "GET", - "/api/v3/site/config", - nil, - ) - router.ServeHTTP(w, req) - asserts.Equal(200, w.Code) - asserts.Contains(w.Body.String(), "Cloudreve") - - w.Body.Reset() - - // 消除无效值 - model.DB.Model(&model.Setting{ - Model: gorm.Model{ - ID: 2, - }, - }).UpdateColumn("name", "siteName_b") - - req, _ = http.NewRequest( - "GET", - "/api/v3/site/config", - nil, - ) - router.ServeHTTP(w, req) - asserts.Equal(200, w.Code) - asserts.Contains(w.Body.String(), "\"title\"") - - model.DB.Model(&model.Setting{ - Model: gorm.Model{ - ID: 2, - }, - }).UpdateColumn("name", "siteName") -} diff --git a/service/admin/aria2.go b/service/admin/aria2.go index 6a2b77d..0bdd641 100644 --- a/service/admin/aria2.go +++ b/service/admin/aria2.go @@ -3,10 +3,10 @@ package admin import ( "bytes" "encoding/json" - model "github.com/cloudreve/Cloudreve/v3/models" "net/url" "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/request" diff --git a/service/admin/order.go b/service/admin/order.go new file mode 100755 index 0000000..4556c0f --- /dev/null +++ b/service/admin/order.go @@ -0,0 +1,75 @@ +package admin + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/gin-gonic/gin" + "strings" +) + +// OrderBatchService 订单批量操作服务 +type OrderBatchService struct { + ID []uint `json:"id" binding:"min=1"` +} + +// Delete 删除订单 +func (service *OrderBatchService) Delete(c *gin.Context) serializer.Response { + if err := model.DB.Where("id in (?)", service.ID).Delete(&model.Order{}).Error; err != nil { + return serializer.DBErr("Failed to delete order records.", err) + } + return serializer.Response{} +} + +// Orders 列出订单 +func (service *AdminListService) Orders() serializer.Response { + var res []model.Order + total := 0 + + tx := model.DB.Model(&model.Order{}) + if service.OrderBy != "" { + tx = tx.Order(service.OrderBy) + } + + for k, v := range service.Conditions { + tx = tx.Where(k+" = ?", v) + } + + if len(service.Searches) > 0 { + search := "" + for k, v := range service.Searches { + search += k + " like '%" + v + "%' OR " + } + search = strings.TrimSuffix(search, " OR ") + tx = tx.Where(search) + } + + // 计算总数用于分页 + tx.Count(&total) + + // 查询记录 + tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + + // 查询对应用户,同时计算HashID + users := make(map[uint]model.User) + for _, file := range res { + users[file.UserID] = model.User{} + } + + userIDs := make([]uint, 0, len(users)) + for k := range users { + userIDs = append(userIDs, k) + } + + var userList []model.User + model.DB.Where("id in (?)", userIDs).Find(&userList) + + for _, v := range userList { + users[v.ID] = v + } + + return serializer.Response{Data: map[string]interface{}{ + "total": total, + "items": res, + "users": users, + }} +} diff --git a/service/admin/policy.go b/service/admin/policy.go index 478203a..d0b01e0 100644 --- a/service/admin/policy.go +++ b/service/admin/policy.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" "net/http" "net/url" "os" @@ -14,6 +13,8 @@ import ( "strings" "time" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" @@ -80,7 +81,10 @@ func (service *PolicyService) Delete() serializer.Response { // 检查用户组使用 var groups []model.Group model.DB.Model(&model.Group{}).Where( - "policies like ?", + "policies like ? OR policies like ? OR policies like ? OR policies like ?", + fmt.Sprintf("[%d,%%", service.ID), + fmt.Sprintf("%%,%d]", service.ID), + fmt.Sprintf("%%,%d,%%", service.ID), fmt.Sprintf("%%[%d]%%", service.ID), ).Find(&groups) @@ -185,7 +189,6 @@ func (service *PolicyService) AddCORS() serializer.Response { }, }), } - if err := handler.CORS(); err != nil { return serializer.Err(serializer.CodeAddCORS, "", err) } @@ -227,8 +230,8 @@ func (service *SlavePingService) Test() serializer.Response { } version := conf.BackendVersion - if conf.IsPro == "true" { - version += "-pro" + if conf.IsPlus == "true" { + version += "-plus" } if res.Data.(string) != version { return serializer.Err(serializer.CodeVersionMismatch, "Master: "+res.Data.(string)+", Slave: "+version, nil) diff --git a/service/admin/report.go b/service/admin/report.go new file mode 100755 index 0000000..b30d7af --- /dev/null +++ b/service/admin/report.go @@ -0,0 +1,72 @@ +package admin + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/hashid" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" +) + +// ReportBatchService 任务批量操作服务 +type ReportBatchService struct { + ID []uint `json:"id" binding:"min=1"` +} + +// Reports 批量删除举报 +func (service *ReportBatchService) Delete() serializer.Response { + if err := model.DB.Where("id in (?)", service.ID).Delete(&model.Report{}).Error; err != nil { + return serializer.DBErr("Failed to change report status", err) + } + return serializer.Response{} +} + +// Reports 列出待处理举报 +func (service *AdminListService) Reports() serializer.Response { + var res []model.Report + total := 0 + + tx := model.DB.Model(&model.Report{}) + if service.OrderBy != "" { + tx = tx.Order(service.OrderBy) + } + + for k, v := range service.Conditions { + tx = tx.Where(k+" = ?", v) + } + + // 计算总数用于分页 + tx.Count(&total) + + // 查询记录 + tx.Set("gorm:auto_preload", true).Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + + // 计算分享的 HashID + hashIDs := make(map[uint]string, len(res)) + for _, report := range res { + hashIDs[report.Share.ID] = hashid.HashID(report.Share.ID, hashid.ShareID) + } + + // 查询对应用户 + users := make(map[uint]model.User) + for _, report := range res { + users[report.Share.UserID] = model.User{} + } + + userIDs := make([]uint, 0, len(users)) + for k := range users { + userIDs = append(userIDs, k) + } + + var userList []model.User + model.DB.Where("id in (?)", userIDs).Find(&userList) + + for _, v := range userList { + users[v.ID] = v + } + + return serializer.Response{Data: map[string]interface{}{ + "total": total, + "items": res, + "users": users, + "ids": hashIDs, + }} +} diff --git a/service/admin/site.go b/service/admin/site.go index 69aa2d8..f4e83ad 100644 --- a/service/admin/site.go +++ b/service/admin/site.go @@ -10,6 +10,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/email" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/thumb" + "github.com/cloudreve/Cloudreve/v3/pkg/vol" "github.com/gin-gonic/gin" ) @@ -89,7 +90,7 @@ func (service *NoParamService) Summary() serializer.Response { "backend": conf.BackendVersion, "db": conf.RequiredDBVersion, "commit": conf.LastCommit, - "is_pro": conf.IsPro, + "is_plus": conf.IsPlus, } if res, ok := cache.Get("admin_summary"); ok { @@ -163,3 +164,36 @@ func (s *ThumbGeneratorTestService) Test(c *gin.Context) serializer.Response { Data: version, } } + +// VOL 授权管理服务 +type VolService struct { +} + +// Sync 同步 VOL 授权 +func (s *VolService) Sync() serializer.Response { + volClient := vol.New(vol.ClientSecret) + content, signature, err := volClient.Sync() + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, err.Error(), err) + } + + subService := &BatchSettingChangeService{ + Options: []SettingChangeService{ + { + Key: "vol_content", + Value: content, + }, + { + Key: "vol_signature", + Value: signature, + }, + }, + } + + res := subService.Change() + if res.Code != 0 { + return res + } + + return serializer.Response{Data: content} +} diff --git a/service/admin/user.go b/service/admin/user.go index eb76ac9..984a799 100644 --- a/service/admin/user.go +++ b/service/admin/user.go @@ -72,6 +72,12 @@ func (service *UserBatchService) Delete() serializer.Response { model.DB.Where("user_id = ?", uid).Delete(&model.Download{}) model.DB.Where("user_id = ?", uid).Delete(&model.Task{}) + // 删除订单记录 + model.DB.Where("user_id = ?", uid).Delete(&model.Order{}) + + // 删除容量包 + model.DB.Where("user_id = ?", uid).Delete(&model.StoragePack{}) + // 删除标签 model.DB.Where("user_id = ?", uid).Delete(&model.Tag{}) @@ -109,6 +115,7 @@ func (service *AddUserService) Add() serializer.Response { user.Email = service.User.Email user.GroupID = service.User.GroupID user.Status = service.User.Status + user.Score = service.User.Score user.TwoFactor = service.User.TwoFactor // 检查愚蠢操作 diff --git a/service/admin/vas.go b/service/admin/vas.go new file mode 100755 index 0000000..20e3a12 --- /dev/null +++ b/service/admin/vas.go @@ -0,0 +1,98 @@ +package admin + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/gofrs/uuid" +) + +// GenerateRedeemsService 兑换码生成服务 +type GenerateRedeemsService struct { + Num int `json:"num" binding:"required,min=1,max=100"` + ID int64 `json:"id"` + Time int `json:"time" binding:"required,min=1"` + Type int `json:"type" binding:"min=0,max=2"` +} + +// SingleIDService 单ID服务 +type SingleIDService struct { + ID uint `uri:"id" binding:"required"` +} + +// DeleteRedeem 删除兑换码 +func (service *SingleIDService) DeleteRedeem() serializer.Response { + if err := model.DB.Where("id = ?", service.ID).Delete(&model.Redeem{}).Error; err != nil { + return serializer.DBErr("Failed to delete gift code record.", err) + } + + return serializer.Response{} +} + +// Generate 生成兑换码 +func (service *GenerateRedeemsService) Generate() serializer.Response { + res := make([]string, service.Num) + redeem := model.Redeem{} + + // 开始事务 + tx := model.DB.Begin() + if err := tx.Error; err != nil { + return serializer.DBErr("Cannot start transaction", err) + } + + // 创建每个兑换码 + for i := 0; i < service.Num; i++ { + redeem.Model.ID = 0 + redeem.Num = service.Time + redeem.Type = service.Type + redeem.ProductID = service.ID + redeem.Used = false + + // 生成唯一兑换码 + u2, err := uuid.NewV4() + if err != nil { + tx.Rollback() + return serializer.Err(serializer.CodeInternalSetting, "Failed to generate UUID", err) + } + + redeem.Code = u2.String() + if err := tx.Create(&redeem).Error; err != nil { + tx.Rollback() + return serializer.DBErr("Failed to insert gift code record", err) + } + + res[i] = redeem.Code + } + + if err := tx.Commit().Error; err != nil { + return serializer.DBErr("Failed to insert gift code record", err) + } + + return serializer.Response{Data: res} + +} + +// Redeems 列出激活码 +func (service *AdminListService) Redeems() serializer.Response { + var res []model.Redeem + total := 0 + + tx := model.DB.Model(&model.Redeem{}) + if service.OrderBy != "" { + tx = tx.Order(service.OrderBy) + } + + for k, v := range service.Conditions { + tx = tx.Where("? = ?", k, v) + } + + // 计算总数用于分页 + tx.Count(&total) + + // 查询记录 + tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) + + return serializer.Response{Data: map[string]interface{}{ + "total": total, + "items": res, + }} +} diff --git a/service/aria2/add.go b/service/aria2/add.go index 816c57b..0a47dd0 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -60,8 +60,9 @@ func (service *BatchAddURLService) Add(c *gin.Context, taskType int) serializer. // AddURLService 添加URL离线下载服务 type AddURLService struct { - URL string `json:"url" binding:"required"` - Dst string `json:"dst" binding:"required,min=1"` + URL string `json:"url" binding:"required"` + Dst string `json:"dst" binding:"required,min=1"` + PreferredNode uint `json:"preferred_node"` } // Add 主机创建新的链接离线下载任务 @@ -92,6 +93,10 @@ func (service *AddURLService) Add(c *gin.Context, fs *filesystem.FileSystem, tas return serializer.Err(serializer.CodeBatchAria2Size, "", nil) } + if service.PreferredNode > 0 && !fs.User.Group.OptionsSerialized.SelectNode { + return serializer.Err(serializer.CodeGroupNotAllowed, "not allowed to select nodes", nil) + } + // 创建任务 task := &model.Download{ Status: common.Ready, @@ -105,7 +110,8 @@ func (service *AddURLService) Add(c *gin.Context, fs *filesystem.FileSystem, tas lb := aria2.GetLoadBalancer() // 获取 Aria2 实例 - err, node := cluster.Default.BalanceNodeByFeature("aria2", lb) + err, node := cluster.Default.BalanceNodeByFeature("aria2", lb, fs.User.Group.OptionsSerialized.AvailableNodes, + service.PreferredNode) if err != nil { return serializer.Err(serializer.CodeInternalSetting, "Failed to get Aria2 instance", err) } diff --git a/service/callback/upload.go b/service/callback/upload.go index 0dd7924..c91979b 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -239,7 +239,7 @@ func (service *S3Callback) PreProcess(c *gin.Context) serializer.Response { return ProcessCallback(service, c) } -// PreProcess 对从机客户端回调进行预处理验证 +// PreProcess 对OneDrive客户端回调进行预处理验证 func (service *UploadCallbackService) PreProcess(c *gin.Context) serializer.Response { // 创建文件系统 fs, err := filesystem.NewFileSystemFromCallback(c) diff --git a/service/explorer/directory.go b/service/explorer/directory.go index cd03999..0116fdb 100644 --- a/service/explorer/directory.go +++ b/service/explorer/directory.go @@ -37,6 +37,11 @@ func (service *DirectoryService) ListDirectory(c *gin.Context) serializer.Respon parentID = fs.DirTarget[0].ID } + // 获取目录的存储策略 + if err := fs.SetPolicyFromPath(service.Path); err != nil { + return serializer.Err(serializer.CodePolicyNotExist, "", err) + } + return serializer.Response{ Code: 0, Data: serializer.BuildObjectList(parentID, objects, fs.Policy), diff --git a/service/explorer/file.go b/service/explorer/file.go index 1c9d870..44aa202 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/util" "io/ioutil" "net/http" "net/url" @@ -18,6 +17,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/wopi" "github.com/gin-gonic/gin" ) @@ -56,6 +56,9 @@ func (service *SingleFileService) Create(c *gin.Context) serializer.Response { } defer fs.Recycle() + baseDir := path.Dir(service.Path) + fs.SetPolicyFromPath(baseDir) + // 上下文 ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -68,7 +71,7 @@ func (service *SingleFileService) Create(c *gin.Context) serializer.Response { err = fs.Upload(ctx, &fsctx.FileStream{ File: ioutil.NopCloser(strings.NewReader("")), Size: 0, - VirtualPath: path.Dir(service.Path), + VirtualPath: baseDir, Name: path.Base(service.Path), }) if err != nil { @@ -197,7 +200,7 @@ func (service *FileIDService) CreateDocPreviewSession(ctx context.Context, c *gi // 创建文件系统 fs, err := filesystem.NewFileSystemFromContext(c) if err != nil { - return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + return serializer.Err(serializer.CodeCreateFSError, "", err) } defer fs.Recycle() diff --git a/service/explorer/objects.go b/service/explorer/objects.go index 1c3c45a..fb4db07 100644 --- a/service/explorer/objects.go +++ b/service/explorer/objects.go @@ -55,6 +55,12 @@ type ItemCompressService struct { Name string `json:"name" binding:"required,min=1,max=255"` } +// ItemRelocateService 文件转移任务服务 +type ItemRelocateService struct { + Src ItemIDService `json:"src"` + DstPolicyID string `json:"dst_policy_id" binding:"required"` +} + // ItemDecompressService 文件解压缩任务服务 type ItemDecompressService struct { Src string `json:"src"` @@ -234,6 +240,57 @@ func (service *ItemCompressService) CreateCompressTask(c *gin.Context) serialize } +// CreateRelocateTask 创建文件转移任务 +func (service *ItemRelocateService) CreateRelocateTask(c *gin.Context) serializer.Response { + userCtx, _ := c.Get("user") + user := userCtx.(*model.User) + + // 取得存储策略的ID + rawID, err := hashid.DecodeHashID(service.DstPolicyID, hashid.PolicyID) + if err != nil { + return serializer.Err(serializer.CodePolicyNotExist, "", err) + } + + // 检查用户组权限 + if !user.Group.OptionsSerialized.Relocate { + return serializer.Err(serializer.CodeGroupNotAllowed, "", nil) + } + + // 用户是否可以使用目的存储策略 + if !util.ContainsUint(user.Group.PolicyList, rawID) { + return serializer.ParamErr("Storage policy is not available", nil) + } + + // 查找存储策略 + if _, err := model.GetPolicyByID(rawID); err != nil { + return serializer.ParamErr("Storage policy is not available", nil) + } + + // 查找是否有正在进行中的转存任务 + var tasks []model.Task + model.DB.Where("status in (?) and user_id = ? and type = ?", + []int{task.Queued, task.Processing}, user.ID, + task.RelocateTaskType).Find(&tasks) + if len(tasks) > 0 { + return serializer.Response{ + Code: serializer.CodeConflict, + Msg: "There's ongoing relocate task, please wait for the previous task to finish", + } + } + + IDRaw := service.Src.Raw() + + // 创建任务 + job, err := task.NewRelocateTask(user, rawID, IDRaw.Dirs, + IDRaw.Items) + if err != nil { + return serializer.Err(serializer.CodeCreateTaskError, "", err) + } + task.TaskPoll.Submit(job) + + return serializer.Response{} +} + // Archive 创建归档 func (service *ItemIDService) Archive(ctx context.Context, c *gin.Context) serializer.Response { // 创建文件系统 @@ -270,7 +327,7 @@ func (service *ItemIDService) Delete(ctx context.Context, c *gin.Context) serial // 创建文件系统 fs, err := filesystem.NewFileSystemFromContext(c) if err != nil { - return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + return serializer.Err(serializer.CodeCreateFSError, "", err) } defer fs.Recycle() @@ -351,7 +408,7 @@ func (service *ItemRenameService) Rename(ctx context.Context, c *gin.Context) se // 创建文件系统 fs, err := filesystem.NewFileSystemFromContext(c) if err != nil { - return serializer.Err(serializer.CodeCreateFSError, "", err) + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) } defer fs.Recycle() @@ -415,6 +472,10 @@ func (service *ItemPropertyService) GetProperty(ctx context.Context, c *gin.Cont return serializer.DBErr("Failed to query folder records", err) } + policy := user.GetPolicyID(&folder[0]) + if folder[0].PolicyID > 0 { + props.Policy = policy.Name + } props.CreatedAt = folder[0].CreatedAt props.UpdatedAt = folder[0].UpdatedAt @@ -423,6 +484,7 @@ func (service *ItemPropertyService) GetProperty(ctx context.Context, c *gin.Cont res := cacheRes.(serializer.ObjectProps) res.CreatedAt = props.CreatedAt res.UpdatedAt = props.UpdatedAt + res.Policy = props.Policy return serializer.Response{Data: res} } diff --git a/service/explorer/upload.go b/service/explorer/upload.go index 0c26c26..470691a 100644 --- a/service/explorer/upload.go +++ b/service/explorer/upload.go @@ -3,6 +3,11 @@ package explorer import ( "context" "fmt" + "io/ioutil" + "strconv" + "strings" + "time" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" @@ -13,10 +18,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" - "io/ioutil" - "strconv" - "strings" - "time" ) // CreateUploadSessionService 获取上传凭证服务 @@ -43,8 +44,13 @@ func (service *CreateUploadSessionService) Create(ctx context.Context, c *gin.Co return serializer.Err(serializer.CodePolicyNotExist, "", err) } + // 分配并检查存储策略 + if err := fs.SetPolicyFromPreference(rawID); err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, "", err) + } + if fs.Policy.ID != rawID { - return serializer.Err(serializer.CodePolicyNotAllowed, "存储策略发生变化,请刷新文件列表并重新添加此任务", nil) + return serializer.Err(serializer.CodePolicyChanged, "", nil) } file := &fsctx.FileStream{ diff --git a/service/node/fabric.go b/service/node/fabric.go index deb2184..d35a0a9 100644 --- a/service/node/fabric.go +++ b/service/node/fabric.go @@ -2,6 +2,7 @@ package node import ( "encoding/gob" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/conf" diff --git a/service/setting/webdav.go b/service/setting/webdav.go index 1f81751..78bdc4a 100644 --- a/service/setting/webdav.go +++ b/service/setting/webdav.go @@ -2,6 +2,8 @@ package setting import ( model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" @@ -29,12 +31,71 @@ type WebDAVAccountUpdateService struct { UseProxy *bool `json:"use_proxy" binding:"required_without=Readonly"` } +// WebDAVAccountUpdateReadonlyService WebDAV 修改只读性服务 +type WebDAVAccountUpdateReadonlyService struct { + ID uint `json:"id" binding:"required,min=1"` + Readonly bool `json:"readonly"` +} + // WebDAVMountCreateService WebDAV 挂载创建服务 type WebDAVMountCreateService struct { Path string `json:"path" binding:"required,min=1,max=65535"` Policy string `json:"policy" binding:"required,min=1"` } +// Create 创建目录挂载 +func (service *WebDAVMountCreateService) Create(c *gin.Context, user *model.User) serializer.Response { + // 创建文件系统 + fs, err := filesystem.NewFileSystem(user) + if err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) + } + defer fs.Recycle() + + // 检索要挂载的目录 + exist, folder := fs.IsPathExist(service.Path) + if !exist { + return serializer.Err(serializer.CodeParentNotExist, "", err) + } + + // 检索要挂载的存储策略 + policyID, err := hashid.DecodeHashID(service.Policy, hashid.PolicyID) + if err != nil { + return serializer.Err(serializer.CodePolicyNotExist, "", err) + } + + // 检查存储策略是否可用 + if policy, err := model.GetPolicyByID(policyID); err != nil || !util.ContainsUint(user.Group.PolicyList, policy.ID) { + return serializer.Err(serializer.CodePolicyNotAllowed, "", err) + } + + // 挂载 + if err := folder.Mount(policyID); err != nil { + return serializer.Err(serializer.CodeDBError, "Failed to update folder record", err) + } + + return serializer.Response{ + Data: map[string]interface{}{ + "id": hashid.HashID(folder.ID, hashid.FolderID), + }, + } +} + +// Unmount 取消目录挂载 +func (service *WebDAVListService) Unmount(c *gin.Context, user *model.User) serializer.Response { + folderID, _ := c.Get("object_id") + folder, err := model.GetFoldersByIDs([]uint{folderID.(uint)}, user.ID) + if err != nil || len(folder) == 0 { + return serializer.Err(serializer.CodeParentNotExist, "", err) + } + + if err := folder[0].Mount(0); err != nil { + return serializer.DBErr("Failed to update folder record", err) + } + + return serializer.Response{} +} + // Create 创建WebDAV账户 func (service *WebDAVAccountCreateService) Create(c *gin.Context, user *model.User) serializer.Response { account := model.Webdav{ @@ -45,7 +106,7 @@ func (service *WebDAVAccountCreateService) Create(c *gin.Context, user *model.Us } if _, err := account.Create(); err != nil { - return serializer.Err(serializer.CodeDBError, "创建失败", err) + return serializer.DBErr("Failed to create account record", err) } return serializer.Response{ @@ -76,11 +137,23 @@ func (service *WebDAVAccountUpdateService) Update(c *gin.Context, user *model.Us return serializer.Response{Data: updates} } +// Update 修改WebDAV账户的只读性 +func (service *WebDAVAccountUpdateReadonlyService) Update(c *gin.Context, user *model.User) serializer.Response { + model.UpdateWebDAVAccountReadonlyByID(service.ID, user.ID, service.Readonly) + return serializer.Response{Data: map[string]bool{ + "readonly": service.Readonly, + }} +} + // Accounts 列出WebDAV账号 func (service *WebDAVListService) Accounts(c *gin.Context, user *model.User) serializer.Response { accounts := model.ListWebDAVAccounts(user.ID) + // 查找挂载了存储策略的目录 + folders := model.GetMountedFolders(user.ID) + return serializer.Response{Data: map[string]interface{}{ "accounts": accounts, + "folders": serializer.BuildMountedFolderRes(folders, user.Group.PolicyList), }} } diff --git a/service/share/manage.go b/service/share/manage.go index 9daccdb..af13695 100644 --- a/service/share/manage.go +++ b/service/share/manage.go @@ -17,6 +17,7 @@ type ShareCreateService struct { Password string `json:"password" binding:"max=255"` RemainDownloads int `json:"downloads"` Expire int `json:"expire"` + Score int `json:"score" binding:"gte=0"` Preview bool `json:"preview"` } @@ -117,6 +118,7 @@ func (service *ShareCreateService) Create(c *gin.Context) serializer.Response { IsDir: service.IsDir, UserID: user.ID, SourceID: sourceID, + Score: service.Score, RemainDownloads: -1, PreviewEnabled: service.Preview, SourceName: sourceName, diff --git a/service/share/visit.go b/service/share/visit.go index caad060..aaadc9f 100644 --- a/service/share/visit.go +++ b/service/share/visit.go @@ -48,6 +48,28 @@ type ShareListService struct { Keywords string `form:"keywords"` } +// ShareReportService 举报分享 +type ShareReportService struct { + Reason int `json:"reason" binding:"gte=0,lte=4"` + Des string `json:"des"` +} + +// Get 获取给定用户的分享 +func (service *ShareReportService) Report(c *gin.Context) serializer.Response { + // 取得分享ID + shareID, _ := c.Get("share") + + report := &model.Report{ + ShareID: shareID.(*model.Share).ID, + Reason: service.Reason, + Description: service.Des, + } + if err := report.Create(); err != nil { + return serializer.DBErr("Failed to create report record", err) + } + return serializer.Response{} +} + // Get 获取给定用户的分享 func (service *ShareUserGetService) Get(c *gin.Context) serializer.Response { // 取得用户 @@ -118,6 +140,8 @@ func (service *ShareListService) List(c *gin.Context, user *model.User) serializ func (service *ShareGetService) Get(c *gin.Context) serializer.Response { shareCtx, _ := c.Get("share") share := shareCtx.(*model.Share) + userCtx, _ := c.Get("user") + user := userCtx.(*model.User) // 是否已解锁 unlocked := true @@ -137,6 +161,11 @@ func (service *ShareGetService) Get(c *gin.Context) serializer.Response { share.Viewed() } + // 如果已经下载过或者是自己的分享,不需要付积分 + if share.UserID == user.ID || share.WasDownloadedBy(user, c) { + share.Score = 0 + } + return serializer.Response{ Code: 0, Data: serializer.BuildShareResponse(share, unlocked), @@ -224,6 +253,39 @@ func (service *Service) CreateDocPreviewSession(c *gin.Context) serializer.Respo return subService.CreateDocPreviewSession(ctx, c, false) } +// SaveToMyFile 将此分享转存到自己的网盘 +func (service *Service) SaveToMyFile(c *gin.Context) serializer.Response { + shareCtx, _ := c.Get("share") + share := shareCtx.(*model.Share) + userCtx, _ := c.Get("user") + user := userCtx.(*model.User) + + // 不能转存自己的文件 + if share.UserID == user.ID { + return serializer.Err(serializer.CodeSaveOwnShare, "", nil) + } + + // 创建文件系统 + fs, err := filesystem.NewFileSystem(user) + if err != nil { + return serializer.Err(serializer.CodeCreateFSError, "", err) + } + defer fs.Recycle() + + // 重设文件系统处理目标为源文件 + err = fs.SetTargetByInterface(share.Source()) + if err != nil { + return serializer.Err(serializer.CodeFileNotFound, "", err) + } + + err = fs.SaveTo(context.Background(), service.Path) + if err != nil { + return serializer.Err(serializer.CodeNotSet, err.Error(), err) + } + + return serializer.Response{} +} + // List 列出分享的目录下的对象 func (service *Service) List(c *gin.Context) serializer.Response { shareCtx, _ := c.Get("share") @@ -378,11 +440,11 @@ func (service *SearchService) Search(c *gin.Context) serializer.Response { share := shareCtx.(*model.Share) if !share.IsDir { - return serializer.ParamErr("此分享无法列目录", nil) + return serializer.ParamErr("This is not a shared folder", nil) } if service.Path != "" && !path.IsAbs(service.Path) { - return serializer.ParamErr("路径无效", nil) + return serializer.ParamErr("Invalid path", nil) } // 创建文件系统 @@ -402,7 +464,7 @@ func (service *SearchService) Search(c *gin.Context) serializer.Response { if service.Path != "" { ok, parent := fs.IsPathExist(service.Path) if !ok { - return serializer.Err(serializer.CodeParentNotExist, "Cannot find parent folder", nil) + return serializer.Err(serializer.CodeParentNotExist, "", nil) } fs.Root = parent diff --git a/service/user/login.go b/service/user/login.go index 22649dd..a22ce23 100644 --- a/service/user/login.go +++ b/service/user/login.go @@ -2,17 +2,19 @@ package user import ( "fmt" - model "github.com/cloudreve/Cloudreve/v3/models" + "net/url" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/gofrs/uuid" + + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/email" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" - "github.com/gofrs/uuid" "github.com/pquerna/otp/totp" - "net/url" ) // UserLoginService 管理用户登录的服务 diff --git a/service/user/register.go b/service/user/register.go index 35e8253..6dccb7b 100644 --- a/service/user/register.go +++ b/service/user/register.go @@ -9,6 +9,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/email" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" ) @@ -22,7 +23,22 @@ type UserRegisterService struct { // Register 新用户注册 func (service *UserRegisterService) Register(c *gin.Context) serializer.Response { // 相关设定 - options := model.GetSettingByNames("email_active") + options := model.GetSettingByNames("email_active", "reg_captcha", "mail_domain_filter", "mail_domain_filter_list") + + // 检查是否在邮件域黑名单里 + if options["mail_domain_filter"] != "0" { + filterList := strings.Split(options["mail_domain_filter_list"], ",") + emailSplit := strings.Split(service.UserName, "@") + emailDomain := emailSplit[len(emailSplit)-1] + inList := util.ContainsString(filterList, emailDomain) + domainErr := serializer.Err(serializer.CodeEmailProviderBaned, "Email provider banned", nil) + if options["mail_domain_filter"] == "1" && !inList { + return domainErr + } + if options["mail_domain_filter"] == "2" && inList { + return domainErr + } + } // 相关设定 isEmailRequired := model.IsTrueVal(options["email_active"]) diff --git a/service/user/setting.go b/service/user/setting.go index 8d7f619..1135697 100644 --- a/service/user/setting.go +++ b/service/user/setting.go @@ -8,12 +8,17 @@ import ( "os" "path/filepath" "strings" + "time" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/hashid" + "github.com/cloudreve/Cloudreve/v3/pkg/qq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" + "github.com/samber/lo" ) // SettingService 通用设置服务 @@ -45,6 +50,14 @@ type ChangerNick struct { Nick string `json:"nick" binding:"required,min=1,max=255"` } +// VIPUnsubscribe 用户组解约服务 +type VIPUnsubscribe struct { +} + +// QQBind QQ互联服务 +type QQBind struct { +} + // PolicyChange 更改存储策略 type PolicyChange struct { ID string `json:"id" binding:"required"` @@ -163,6 +176,77 @@ func (service *HomePage) Update(c *gin.Context, user *model.User) serializer.Res return serializer.Response{} } +// Update 更改用户偏好的存储策略 +func (service *PolicyChange) Update(c *gin.Context, user *model.User) serializer.Response { + // 取得存储策略的ID + rawID, err := hashid.DecodeHashID(service.ID, hashid.PolicyID) + if err != nil { + return serializer.Err(serializer.CodePolicyNotExist, "", err) + } + + // 用户是否可以切换到此存储策略 + if !util.ContainsUint(user.Group.PolicyList, rawID) { + return serializer.Err(serializer.CodePolicyNotAllowed, "", nil) + } + + // 查找存储策略 + if _, err := model.GetPolicyByID(rawID); err != nil { + return serializer.Err(serializer.CodePolicyNotAllowed, "", nil) + } + + // 切换存储策略 + user.OptionsSerialized.PreferredPolicy = rawID + if err := user.UpdateOptions(); err != nil { + return serializer.DBErr("Failed to update user preferences", err) + } + + return serializer.Response{} +} + +// Update 绑定或解绑QQ +func (service *QQBind) Update(c *gin.Context, user *model.User) serializer.Response { + // 解除绑定 + if user.OpenID != "" { + // 只通过QQ登录的用户无法解除绑定 + if strings.HasSuffix(user.Email, "@login.qq.com") { + return serializer.Err(serializer.CodeNoPermissionErr, "This user cannot be unlinked", nil) + } + + if err := user.Update(map[string]interface{}{"open_id": ""}); err != nil { + return serializer.DBErr("Failed to update user open id", err) + } + return serializer.Response{ + Data: "", + } + } + + // 新建绑定 + res, err := qq.NewLoginRequest() + if err != nil { + return serializer.Err(serializer.CodeNotSet, "Failed to start QQ login request", err) + } + + // 设定QQ登录会话Secret + util.SetSession(c, map[string]interface{}{"qq_login_secret": res.SecretKey}) + + return serializer.Response{ + Data: res.URL, + } +} + +// Update 用户组解约 +func (service *VIPUnsubscribe) Update(c *gin.Context, user *model.User) serializer.Response { + if user.GroupExpires != nil { + timeNow := time.Now() + if time.Now().Before(*user.GroupExpires) { + if err := user.Update(map[string]interface{}{"group_expires": &timeNow}); err != nil { + return serializer.DBErr("Failed to update user", err) + } + } + } + return serializer.Response{} +} + // Update 更改昵称 func (service *ChangerNick) Update(c *gin.Context, user *model.User) serializer.Response { if err := user.Update(map[string]interface{}{"nick": service.Nick}); err != nil { @@ -241,16 +325,72 @@ func (service *SettingListService) ListTasks(c *gin.Context, user *model.User) s return serializer.BuildTaskList(tasks, total) } +// Policy 获取用户存储策略设置 +func (service *SettingService) Policy(c *gin.Context, user *model.User) serializer.Response { + // 取得用户可用存储策略 + available := make([]model.Policy, 0, len(user.Group.PolicyList)) + for _, id := range user.Group.PolicyList { + if policy, err := model.GetPolicyByID(id); err == nil { + available = append(available, policy) + } + } + + return serializer.BuildPolicySettingRes(available) +} + +// Nodes 获取用户可选节点 +func (service *SettingService) Nodes(c *gin.Context, user *model.User) serializer.Response { + if !user.Group.OptionsSerialized.SelectNode { + return serializer.Err(serializer.CodeGroupNotAllowed, "", nil) + } + + availableNodesID := user.Group.OptionsSerialized.AvailableNodes + + // All nodes available + if len(availableNodesID) == 0 { + nodes, err := model.GetNodesByStatus(model.NodeActive) + if err != nil { + return serializer.DBErr("Failed to list nodes", err) + } + + availableNodesID = lo.Map[model.Node, uint](nodes, func(node model.Node, index int) uint { + return node.ID + }) + } + + // 取得用户可用存储策略 + available := lo.FilterMap[uint, *model.Node](availableNodesID, + func(id uint, index int) (*model.Node, bool) { + if node := cluster.Default.GetNodeByID(id); node != nil { + return node.DBModel(), node.IsActive() && node.IsFeatureEnabled("aria2") + } + + return nil, false + }) + + return serializer.BuildNodeOptionRes(available) +} + // Settings 获取用户设定 func (service *SettingService) Settings(c *gin.Context, user *model.User) serializer.Response { + // 用户组有效期 + var groupExpires *time.Time + if user.GroupExpires != nil { + if expires := user.GroupExpires.Unix() - time.Now().Unix(); expires > 0 { + groupExpires = user.GroupExpires + } + } + return serializer.Response{ Data: map[string]interface{}{ - "uid": user.ID, - "homepage": !user.OptionsSerialized.ProfileOff, - "two_factor": user.TwoFactor != "", - "prefer_theme": user.OptionsSerialized.PreferredTheme, - "themes": model.GetSettingByName("themes"), - "authn": serializer.BuildWebAuthnList(user.WebAuthnCredentials()), + "uid": user.ID, + "qq": user.OpenID != "", + "homepage": !user.OptionsSerialized.ProfileOff, + "two_factor": user.TwoFactor != "", + "prefer_theme": user.OptionsSerialized.PreferredTheme, + "themes": model.GetSettingByName("themes"), + "group_expires": groupExpires, + "authn": serializer.BuildWebAuthnList(user.WebAuthnCredentials()), }, } } diff --git a/service/vas/purchase.go b/service/vas/purchase.go new file mode 100755 index 0000000..3755b13 --- /dev/null +++ b/service/vas/purchase.go @@ -0,0 +1,235 @@ +package vas + +import ( + "encoding/json" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/payment" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/gin-gonic/gin" +) + +// CreateOrderService 创建订单服务 +type CreateOrderService struct { + Action string `json:"action" binding:"required,eq=group|eq=pack|eq=score"` + Method string `json:"method" binding:"required,eq=alipay|eq=score|eq=payjs|eq=wechat|eq=custom"` + ID int64 `json:"id" binding:"required"` + Num int `json:"num" binding:"required,min=1"` +} + +// RedeemService 兑换服务 +type RedeemService struct { + Code string `uri:"code" binding:"required,max=64"` +} + +// OrderService 订单查询 +type OrderService struct { + ID string `uri:"id" binding:"required"` +} + +// Status 查询订单状态 +func (service *OrderService) Status(c *gin.Context, user *model.User) serializer.Response { + order, _ := model.GetOrderByNo(service.ID) + if order == nil || order.UserID != user.ID { + return serializer.Err(serializer.CodeNotFound, "", nil) + } + + return serializer.Response{Data: order.Status} +} + +// Redeem 开始兑换 +func (service *RedeemService) Redeem(c *gin.Context, user *model.User) serializer.Response { + redeem, err := model.GetAvailableRedeem(service.Code) + if err != nil { + return serializer.Err(serializer.CodeInvalidGiftCode, "", err) + } + + // 取得当前商品信息 + packs, groups, err := decodeProductInfo() + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to parse product settings", err) + } + + // 查找要购买的商品 + var ( + pack *serializer.PackProduct + group *serializer.GroupProducts + ) + if redeem.Type == model.GroupOrderType { + for _, v := range groups { + if v.ID == redeem.ProductID { + group = &v + break + } + } + + if group == nil { + return serializer.Err(serializer.CodeNotFound, "", err) + } + + } else if redeem.Type == model.PackOrderType { + for _, v := range packs { + if v.ID == redeem.ProductID { + pack = &v + break + } + } + + if pack == nil { + return serializer.Err(serializer.CodeNotFound, "", err) + } + + } + + err = payment.GiveProduct(user, pack, group, redeem.Num) + if err != nil { + return serializer.Err(serializer.CodeNotSet, "Redeem failed", err) + } + + redeem.Use() + + return serializer.Response{} + +} + +// Query 检查兑换码信息 +func (service *RedeemService) Query(c *gin.Context) serializer.Response { + redeem, err := model.GetAvailableRedeem(service.Code) + if err != nil { + return serializer.Err(serializer.CodeInvalidGiftCode, "", err) + } + + var ( + name = "积分" + productTime int64 + ) + if redeem.Type != model.ScoreOrderType { + packs, groups, err := decodeProductInfo() + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to parse product settings", err) + } + if redeem.Type == model.GroupOrderType { + for _, v := range groups { + if v.ID == redeem.ProductID { + name = v.Name + productTime = v.Time + break + } + } + } else { + for _, v := range packs { + if v.ID == redeem.ProductID { + name = v.Name + productTime = v.Time + break + } + } + } + + if name == "积分" { + return serializer.Err(serializer.CodeNotFound, "", err) + } + + } + + return serializer.Response{ + Data: struct { + Name string `json:"name"` + Type int `json:"type"` + Num int `json:"num"` + Time int64 `json:"time"` + }{ + name, redeem.Type, redeem.Num, productTime, + }, + } +} + +// Create 创建新订单 +func (service *CreateOrderService) Create(c *gin.Context, user *model.User) serializer.Response { + // 取得当前商品信息 + packs, groups, err := decodeProductInfo() + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to parse product list", err) + } + + // 查找要购买的商品 + var ( + pack *serializer.PackProduct + group *serializer.GroupProducts + ) + if service.Action == "group" { + for _, v := range groups { + if v.ID == service.ID { + group = &v + break + } + } + } else if service.Action == "pack" { + for _, v := range packs { + if v.ID == service.ID { + pack = &v + break + } + } + } + + // 购买积分 + if pack == nil && group == nil { + if service.Method == "score" { + return serializer.ParamErr("Payment method not supported", nil) + } + } + + // 创建订单 + res, err := payment.NewOrder(pack, group, service.Num, service.Method, user) + if err != nil { + return serializer.Err(serializer.CodeNotSet, err.Error(), err) + } + + return serializer.Response{Data: res} + +} + +// Products 获取商品信息 +func (service *GeneralVASService) Products(c *gin.Context, user *model.User) serializer.Response { + options := model.GetSettingByNames( + "wechat_enabled", + "alipay_enabled", + "payjs_enabled", + "payjs_enabled", + "custom_payment_enabled", + "custom_payment_name", + ) + scorePrice := model.GetIntSetting("score_price", 0) + packs, groups, err := decodeProductInfo() + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to parse product list", err) + } + + return serializer.BuildProductResponse( + groups, + packs, + model.IsTrueVal(options["wechat_enabled"]), + model.IsTrueVal(options["alipay_enabled"]), + model.IsTrueVal(options["payjs_enabled"]), + model.IsTrueVal(options["custom_payment_enabled"]), + options["custom_payment_name"], + scorePrice, + ) +} + +func decodeProductInfo() ([]serializer.PackProduct, []serializer.GroupProducts, error) { + options := model.GetSettingByNames("pack_data", "group_sell_data", "alipay_enabled", "payjs_enabled") + + var ( + packs []serializer.PackProduct + groups []serializer.GroupProducts + ) + if err := json.Unmarshal([]byte(options["pack_data"]), &packs); err != nil { + return nil, nil, err + } + if err := json.Unmarshal([]byte(options["group_sell_data"]), &groups); err != nil { + return nil, nil, err + } + + return packs, groups, nil +} diff --git a/service/vas/qq.go b/service/vas/qq.go new file mode 100755 index 0000000..8d8795e --- /dev/null +++ b/service/vas/qq.go @@ -0,0 +1,113 @@ +package vas + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/qq" + "github.com/cloudreve/Cloudreve/v3/pkg/request" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/thumb" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/gin-gonic/gin" +) + +// QQCallbackService QQ互联回调处理服务 +type QQCallbackService struct { + Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` +} + +// Callback 处理QQ互联回调 +func (service *QQCallbackService) Callback(c *gin.Context, user *model.User) serializer.Response { + + state := util.GetSession(c, "qq_login_secret") + if stateStr, ok := state.(string); !ok || stateStr != service.State { + return serializer.Err(serializer.CodeSignExpired, "", nil) + } + util.DeleteSession(c, "qq_login_secret") + + // 获取OpenID + credential, err := qq.Callback(service.Code) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to get session status", err) + } + + // 如果已登录,则绑定已有用户 + if user != nil { + + if user.OpenID != "" { + return serializer.Err(serializer.CodeQQBindConflict, "", nil) + } + + // OpenID 是否重复 + if _, err := model.GetActiveUserByOpenID(credential.OpenID); err == nil { + return serializer.Err(serializer.CodeQQBindOtherAccount, "", nil) + } + + if err := user.Update(map[string]interface{}{"open_id": credential.OpenID}); err != nil { + return serializer.DBErr("Failed to update user open id", err) + } + return serializer.Response{ + Data: "/setting", + } + + } + + // 未登录,尝试查找用户 + if expectedUser, err := model.GetActiveUserByOpenID(credential.OpenID); err == nil { + // 用户绑定了此QQ,设定为登录状态 + util.SetSession(c, map[string]interface{}{ + "user_id": expectedUser.ID, + }) + res := serializer.BuildUserResponse(expectedUser) + res.Code = 203 + return res + + } + + // 无匹配用户,创建新用户 + if !model.IsTrueVal(model.GetSettingByName("qq_direct_login")) { + return serializer.Err(serializer.CodeQQNotLinked, "", nil) + } + + // 获取用户信息 + userInfo, err := qq.GetUserInfo(credential) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "Failed to fetch user info", err) + } + + // 生成邮箱地址 + fakeEmail := util.RandStringRunes(16) + "@login.qq.com" + + // 创建用户 + defaultGroup := model.GetIntSetting("default_group", 2) + + newUser := model.NewUser() + newUser.Email = fakeEmail + newUser.Nick = userInfo.Nick + newUser.SetPassword("") + newUser.Status = model.Active + newUser.GroupID = uint(defaultGroup) + newUser.OpenID = credential.OpenID + newUser.Avatar = "file" + + // 创建用户 + if err := model.DB.Create(&newUser).Error; err != nil { + return serializer.Err(serializer.CodeEmailExisted, "", err) + } + + // 下载头像 + r := request.NewClient() + rawAvatar := r.Request("GET", userInfo.Avatar, nil) + if avatar, err := thumb.NewThumbFromFile(rawAvatar.Response.Body, "avatar.jpg"); err == nil { + avatar.CreateAvatar(newUser.ID) + } + + // 登录 + util.SetSession(c, map[string]interface{}{"user_id": newUser.ID}) + + newUser, _ = model.GetActiveUserByID(newUser.ID) + + res := serializer.BuildUserResponse(newUser) + res.Code = 203 + return res +} diff --git a/service/vas/quota.go b/service/vas/quota.go new file mode 100755 index 0000000..23c1019 --- /dev/null +++ b/service/vas/quota.go @@ -0,0 +1,17 @@ +package vas + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/gin-gonic/gin" +) + +// GeneralVASService 通用增值服务 +type GeneralVASService struct { +} + +// Quota 获取容量配额信息 +func (service *GeneralVASService) Quota(c *gin.Context, user *model.User) serializer.Response { + packs := user.GetAvailableStoragePacks() + return serializer.BuildUserQuotaResponse(user, packs) +}