diff --git a/.gitmodules b/.gitmodules
index 1d5acb2..2bf2e0a 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,3 @@
[submodule "assets"]
path = assets
- url = https://github.com/cloudreve/frontend.git
+ url = https://github.com/Cloudreamr/frontend.git
diff --git a/assets.zip b/assets.zip
index 15cb0ec..7ca2a76 100644
Binary files a/assets.zip and b/assets.zip differ
diff --git a/bootstrap/app.go b/bootstrap/app.go
index 2906526..7c0b737 100644
--- a/bootstrap/app.go
+++ b/bootstrap/app.go
@@ -1,15 +1,27 @@
package bootstrap
import (
- "encoding/json"
+ // "bytes"
+ // "crypto/aes"
+ // "crypto/cipher"
+ // "encoding/gob"
+ // "encoding/json"
"fmt"
+ // "io/ioutil"
+ // "os"
+ // "strconv"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/hashicorp/go-version"
+ "github.com/cloudreve/Cloudreve/v3/pkg/vol"
+
+ // "github.com/cloudreve/Cloudreve/v3/pkg/request"
+ // "github.com/cloudreve/Cloudreve/v3/pkg/util"
+ "github.com/gin-gonic/gin"
)
+var matrix []byte
+var APPID string
+
// InitApplication 初始化应用常量
func InitApplication() {
fmt.Print(`
@@ -19,40 +31,95 @@ func InitApplication() {
/ /___| | (_) | |_| | (_| | | | __/\ V / __/
\____/|_|\___/ \__,_|\__,_|_| \___| \_/ \___|
- V` + conf.BackendVersion + ` Commit #` + conf.LastCommit + ` Pro=` + conf.IsPro + `
+ V` + conf.BackendVersion + ` Commit #` + conf.LastCommit + ` Plus=` + conf.IsPlus + `
================================================
`)
- go CheckUpdate()
+ // data, err := ioutil.ReadFile(util.RelativePath(string([]byte{107, 101, 121, 46, 98, 105, 110})))
+ // if err != nil {
+ // util.Log().Panic("%s", err)
+ // }
+
+ //table := deSign(data)
+ //constant.HashIDTable = table["table"].([]int)
+ //APPID = table["id"].(string)
+ //matrix = table["pic"].([]byte)
+ APPID = `1145141919810`
+ matrix = []byte{1, 1, 4, 5, 1, 4, 1, 9, 1, 9, 8, 1, 0}
+ vol.ClientSecret = APPID
}
-type GitHubRelease struct {
- URL string `json:"html_url"`
- Name string `json:"name"`
- Tag string `json:"tag_name"`
+// InitCustomRoute 初始化自定义路由
+func InitCustomRoute(group *gin.RouterGroup) {
+ group.GET(string([]byte{98, 103}), func(c *gin.Context) {
+ c.Header("content-type", "image/png")
+ c.Writer.Write(matrix)
+ })
+ group.GET("id", func(c *gin.Context) {
+ c.String(200, APPID)
+ })
}
-// CheckUpdate 检查更新
-func CheckUpdate() {
- client := request.NewClient()
- res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse()
- if err != nil {
- util.Log().Warning("更新检查失败, %s", err)
- return
- }
+// func deSign(data []byte) map[string]interface{} {
+// res := decode(data, seed())
+// dec := gob.NewDecoder(bytes.NewReader(res))
+// obj := map[string]interface{}{}
+// err := dec.Decode(&obj)
+// if err != nil {
+// util.Log().Panic("您仍在使用旧版的授权文件,请前往 https://pro.cloudreve.org/ 登录下载最新的授权文件")
+// os.Exit(-1)
+// }
+// return checkKeyUpdate(obj)
+// }
- var list []GitHubRelease
- if err := json.Unmarshal([]byte(res), &list); err != nil {
- util.Log().Warning("更新检查失败, %s", err)
- return
- }
+// func checkKeyUpdate(table map[string]interface{}) map[string]interface{} {
+// if table["version"].(string) != conf.KeyVersion {
+// util.Log().Info("正在自动更新授权文件...")
+// reqBody := map[string]string{
+// "secret": table["secret"].(string),
+// "id": table["id"].(string),
+// }
+// reqBodyString, _ := json.Marshal(reqBody)
+// client := request.NewClient()
+// resp := client.Request("POST", "https://pro.cloudreve.org/Api/UpdateKey",
+// bytes.NewReader(reqBodyString)).CheckHTTPResponse(200)
+// if resp.Err != nil {
+// util.Log().Panic("授权文件更新失败, %s", resp.Err)
+// }
+// keyContent, _ := ioutil.ReadAll(resp.Response.Body)
+// ioutil.WriteFile(util.RelativePath(string([]byte{107, 101, 121, 46, 98, 105, 110})), keyContent, os.ModePerm)
- if len(list) > 0 {
- present, err1 := version.NewVersion(conf.BackendVersion)
- latest, err2 := version.NewVersion(list[0].Tag)
- if err1 == nil && err2 == nil && latest.GreaterThan(present) {
- util.Log().Info("有新的版本 [%s] 可用,下载:%s", list[0].Name, list[0].URL)
- }
- }
+// return deSign(keyContent)
+// }
-}
+// return table
+// }
+
+// func seed() []byte {
+// res := []int{8}
+// s := "20210323"
+// m := 1 << 20
+// a := 9
+// b := 7
+// for i := 1; i < 23; i++ {
+// res = append(res, (a*res[i-1]+b)%m)
+// s += strconv.Itoa(res[i])
+// }
+// return []byte(s)
+// }
+
+// func decode(cryted []byte, key []byte) []byte {
+// block, _ := aes.NewCipher(key[:32])
+// blockSize := block.BlockSize()
+// blockMode := cipher.NewCBCDecrypter(block, key[:blockSize])
+// orig := make([]byte, len(cryted))
+// blockMode.CryptBlocks(orig, cryted)
+// orig = pKCS7UnPadding(orig)
+// return orig
+// }
+
+// func pKCS7UnPadding(origData []byte) []byte {
+// length := len(origData)
+// unpadding := int(origData[length-1])
+// return origData[:(length - unpadding)]
+// }
diff --git a/bootstrap/constant/constant.go b/bootstrap/constant/constant.go
new file mode 100755
index 0000000..0d78500
--- /dev/null
+++ b/bootstrap/constant/constant.go
@@ -0,0 +1,3 @@
+package constant
+
+// var HashIDTable = []int{0, 1, 2, 3, 4, 5}
diff --git a/bootstrap/init.go b/bootstrap/init.go
index e5f2800..2718ccb 100644
--- a/bootstrap/init.go
+++ b/bootstrap/init.go
@@ -1,6 +1,9 @@
package bootstrap
import (
+ "io/fs"
+ "path/filepath"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/models/scripts"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
@@ -14,8 +17,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/gin-gonic/gin"
- "io/fs"
- "path/filepath"
)
// Init 初始化启动
diff --git a/bootstrap/static.go b/bootstrap/static.go
index 233e22a..4989b97 100644
--- a/bootstrap/static.go
+++ b/bootstrap/static.go
@@ -79,8 +79,8 @@ func InitStatic(statics fs.FS) {
}
staticName := "cloudreve-frontend"
- if conf.IsPro == "true" {
- staticName += "-pro"
+ if conf.IsPlus == "true" {
+ staticName += "-plus"
}
if v.Name != staticName {
@@ -102,8 +102,8 @@ func Eject(statics fs.FS) {
util.Log().Panic("Failed to initialize static resources: %s", err)
}
- var walk func(relPath string, d fs.DirEntry, err error) error
- walk = func(relPath string, d fs.DirEntry, err error) error {
+ // var walk func(relPath string, d fs.DirEntry, err error) error
+ walk := func(relPath string, d fs.DirEntry, err error) error {
if err != nil {
return errors.Errorf("Failed to read info of %q: %s, skipping...", relPath, err)
}
@@ -111,11 +111,11 @@ func Eject(statics fs.FS) {
if !d.IsDir() {
// 写入文件
out, err := util.CreatNestedFile(filepath.Join(util.RelativePath(""), StaticFolder, relPath))
- defer out.Close()
if err != nil {
return errors.Errorf("Failed to create file %q: %s, skipping...", relPath, err)
}
+ defer out.Close()
util.Log().Info("Ejecting %q...", relPath)
obj, _ := embedFS.Open(relPath)
diff --git a/go.mod b/go.mod
index d32a7c7..ece85eb 100644
--- a/go.mod
+++ b/go.mod
@@ -3,11 +3,10 @@ module github.com/cloudreve/Cloudreve/v3
go 1.18
require (
- github.com/DATA-DOG/go-sqlmock v1.3.3
github.com/HFO4/aliyun-oss-go-sdk v2.2.3+incompatible
github.com/aws/aws-sdk-go v1.31.5
github.com/duo-labs/webauthn v0.0.0-20220330035159-03696f3d4499
- github.com/fatih/color v1.9.0
+ github.com/fatih/color v1.16.0
github.com/gin-contrib/cors v1.3.0
github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8
github.com/gin-contrib/sessions v0.0.5
@@ -20,22 +19,25 @@ require (
github.com/gofrs/uuid v4.0.0+incompatible
github.com/gomodule/redigo v2.0.0+incompatible
github.com/google/go-querystring v1.0.0
+ github.com/google/uuid v1.3.0
github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.2.1
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/go-version v1.3.0
+ github.com/iGoogle-ink/gopay v1.5.36
github.com/jinzhu/gorm v1.9.11
github.com/juju/ratelimit v1.0.1
github.com/mholt/archiver/v4 v4.0.0-alpha.6
github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.2.0
+ github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371
github.com/qiniu/go-sdk/v7 v7.11.1
- github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1
github.com/robfig/cron/v3 v3.0.1
github.com/samber/lo v1.38.1
+ github.com/smartwalle/alipay/v3 v3.2.20
github.com/speps/go-hashids v2.0.0+incompatible
- github.com/stretchr/testify v1.7.2
+ github.com/stretchr/testify v1.8.3
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.393
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf v1.0.393
@@ -70,10 +72,10 @@ require (
github.com/fullstorydev/grpcurl v1.8.1 // indirect
github.com/fxamacker/cbor/v2 v2.4.0 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
- github.com/go-playground/locales v0.14.0 // indirect
- github.com/go-playground/universal-translator v0.18.0 // indirect
+ github.com/go-playground/locales v0.14.1 // indirect
+ github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
- github.com/goccy/go-json v0.9.8 // indirect
+ github.com/goccy/go-json v0.10.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.1.0 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
@@ -84,7 +86,6 @@ require (
github.com/google/btree v1.0.1 // indirect
github.com/google/certificate-transparency-go v1.1.2-0.20210511102531-373a877eec92 // indirect
github.com/google/go-cmp v0.5.9 // indirect
- github.com/google/uuid v1.3.0 // indirect
github.com/googleapis/gax-go/v2 v2.0.5 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
@@ -98,10 +99,10 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.15.1 // indirect
github.com/klauspost/pgzip v1.2.5 // indirect
- github.com/leodido/go-urn v1.2.1 // indirect
- github.com/lib/pq v1.10.3 // indirect
- github.com/mattn/go-colorable v0.1.4 // indirect
- github.com/mattn/go-isatty v0.0.17 // indirect
+ github.com/leodido/go-urn v1.2.4 // indirect
+ github.com/lib/pq v1.10.9 // indirect
+ github.com/mattn/go-colorable v0.1.13 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.12 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect
@@ -110,7 +111,7 @@ require (
github.com/mozillazg/go-httpheader v0.2.1 // indirect
github.com/nwaples/rardecode/v2 v2.0.0-beta.2 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
- github.com/pelletier/go-toml/v2 v2.0.2 // indirect
+ github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.10.0 // indirect
@@ -122,13 +123,16 @@ require (
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/satori/go.uuid v1.2.0 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
+ github.com/smartwalle/ncrypto v1.0.4 // indirect
+ github.com/smartwalle/ngx v1.0.9 // indirect
+ github.com/smartwalle/nsign v1.0.9 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spf13/cobra v1.1.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect
- github.com/stretchr/objx v0.2.0 // indirect
+ github.com/stretchr/objx v0.5.0 // indirect
github.com/therootcompany/xz v1.0.1 // indirect
github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect
- github.com/ugorji/go/codec v1.2.7 // indirect
+ github.com/ugorji/go/codec v1.2.11 // indirect
github.com/ulikunitz/xz v0.5.10 // indirect
github.com/urfave/cli v1.22.5 // indirect
github.com/x448/float16 v0.8.4 // indirect
@@ -147,20 +151,19 @@ require (
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.7.0 // indirect
go.uber.org/zap v1.16.0 // indirect
- golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
+ golang.org/x/crypto v0.9.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
- golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 // indirect
- golang.org/x/net v0.0.0-20220630215102-69896b714898 // indirect
+ golang.org/x/mod v0.8.0 // indirect
+ golang.org/x/net v0.10.0 // indirect
golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c // indirect
- golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
- golang.org/x/sys v0.4.0 // indirect
- golang.org/x/text v0.3.7 // indirect
- golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 // indirect
- golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
+ golang.org/x/sync v0.1.0 // indirect
+ golang.org/x/sys v0.17.0 // indirect
+ golang.org/x/text v0.9.0 // indirect
+ golang.org/x/tools v0.6.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20210510173355-fb37daa5cd7a // indirect
google.golang.org/grpc v1.37.0 // indirect
- google.golang.org/protobuf v1.28.0 // indirect
+ google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/cheggaaa/pb.v1 v1.0.28 // indirect
gopkg.in/mail.v2 v2.3.1 // indirect
diff --git a/go.sum b/go.sum
index 64a345d..ca78810 100644
--- a/go.sum
+++ b/go.sum
@@ -62,8 +62,6 @@ github.com/Azure/go-autorest v12.0.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSW
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
-github.com/DATA-DOG/go-sqlmock v1.3.3 h1:CWUqKXe0s8A2z6qCgkP4Kru7wC11YoAnoupUKFDnH08=
-github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0=
github.com/GeertJohan/go.rice v1.0.2/go.mod h1:af5vUNlDNkCjOZeSGFgIJxDje9qdjsO6hshx0gTmZt4=
github.com/GoogleCloudPlatform/cloudsql-proxy v0.0.0-20191009163259-e802c2cb94ae/go.mod h1:mjwGPas4yKduTyubHvD1Atl9r1rUq8DfVy+gkVvZ+oo=
@@ -236,6 +234,8 @@ github.com/etcd-io/gofail v0.0.0-20190801230047-ad7f989257ca/go.mod h1:49H/RkXP8
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
+github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM=
+github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/form3tech-oss/jwt-go v3.2.3+incompatible h1:7ZaBxOI7TMoYBfyA3cQHErNNyAWIKUMIwqxEtgHOs5c=
@@ -289,12 +289,14 @@ github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBY
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
-github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
+github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
+github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
-github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho=
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
+github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
+github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.8.0/go.mod h1:9JhgTzTaE31GZDpH/HSvHiRJrJ3iKAgqqH0Bl/Ocjdk=
github.com/go-playground/validator/v10 v10.11.0 h1:0W+xRM511GY47Yy3bZUbJVitCNg2BOGlCyvTqsp/xIw=
github.com/go-playground/validator/v10 v10.11.0/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU=
@@ -305,8 +307,8 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
-github.com/goccy/go-json v0.9.8 h1:DxXB6MLd6yyel7CLph8EwNIonUtVZd3Ue5iRcL4DQCE=
-github.com/goccy/go-json v0.9.8/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
+github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
+github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
@@ -488,6 +490,8 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO
github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo=
github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4=
github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg=
+github.com/iGoogle-ink/gopay v1.5.36 h1:RctuoiEdTbiXOmzQ9i1388opwAOjheUDIFoHl1EeNr8=
+github.com/iGoogle-ink/gopay v1.5.36/go.mod h1:JADVzrfz9kzGMCgV7OzJ954pqwMU7PotYMAjP84YKIE=
github.com/iancoleman/strcase v0.0.0-20180726023541-3605ed457bf7/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
@@ -566,14 +570,15 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/go-gypsy v1.0.0/go.mod h1:chkXM0zjdpXOiqkCW1XcCHDfjfk14PH2KKkQWxfJUcU=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
-github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
+github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
+github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/letsencrypt/pkcs11key/v4 v4.0.0/go.mod h1:EFUvBDay26dErnNb70Nd0/VW3tJiIbETBPTl9ATXQag=
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
-github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg=
-github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
+github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/lyft/protoc-gen-star v0.5.1/go.mod h1:9toiA3cC7z5uVbODF7kEQ91Xn7XNFkVUl+SrEe+ZORU=
@@ -585,6 +590,8 @@ github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcncea
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
+github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
+github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-ieproxy v0.0.0-20190610004146-91bb50d98149/go.mod h1:31jz6HNzdxOmlERGGEc4v/dMssOfmp2p5bT/okiKFFc=
github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
@@ -593,8 +600,11 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
-github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng=
-github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
+github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
@@ -686,8 +696,8 @@ github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FI
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
-github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw=
-github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI=
+github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
+github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac=
github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
@@ -744,12 +754,12 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/pseudomuto/protoc-gen-doc v1.4.1/go.mod h1:exDTOVwqpp30eV/EDPFLZy3Pwr2sn6hBC1WIYH/UbIg=
github.com/pseudomuto/protokit v0.2.0/go.mod h1:2PdH30hxVHsup8KpBTOXTBeMVhJZVio3Q8ViKSAXT0Q=
+github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371 h1:8VWtyY2IwjEQZSNT4Kyyct9zv9hoegD5GQhFr+TMdCI=
+github.com/qingwg/payjs v0.0.0-20190928033402-c53dbe16b371/go.mod h1:9UFrQveqNm3ELF6HSvMtDR3KYpJ7Ib9s0WVmYhaUBlU=
github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdkk=
github.com/qiniu/go-sdk/v7 v7.11.1 h1:/LZ9rvFS4p6SnszhGv11FNB1+n4OZvBCwFg7opH5Ovs=
github.com/qiniu/go-sdk/v7 v7.11.1/go.mod h1:btsaOc8CA3hdVloULfFdDgDc+g4f3TDZEFsDY0BLE+w=
github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs=
-github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 h1:leEwA4MD1ew0lNgzz6Q4G76G3AEfeci+TMggN6WuFRs=
-github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1/go.mod h1:JaY6n2sDr+z2WTsXkOmNRUfDy6FN0L6Nk7x06ndm4tY=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQi3hZP0fzvnsLmzXZdQGc4bEcgu24cp+d4M=
@@ -790,6 +800,14 @@ github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrf
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
+github.com/smartwalle/alipay/v3 v3.2.20 h1:IjpG3YYgUgzCfS0z/EHlUbbr0OlrmOBHUst/3FzToYE=
+github.com/smartwalle/alipay/v3 v3.2.20/go.mod h1:KWg91KsY+eIOf26ZfZeH7bed1bWulGpGrL1ErHF3jWo=
+github.com/smartwalle/ncrypto v1.0.4 h1:P2rqQxDepJwgeO5ShoC+wGcK2wNJDmcdBOWAksuIgx8=
+github.com/smartwalle/ncrypto v1.0.4/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk=
+github.com/smartwalle/ngx v1.0.9 h1:pUXDvWRZJIHVrCKA1uZ15YwNti+5P4GuJGbpJ4WvpMw=
+github.com/smartwalle/ngx v1.0.9/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0=
+github.com/smartwalle/nsign v1.0.9 h1:8poAgG7zBd8HkZy9RQDwasC6XZvJpDGQWSjzL2FZL6E=
+github.com/smartwalle/nsign v1.0.9/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/assertions v1.0.0 h1:UVQPSSmc3qtTi+zPPkCXvZX9VvW/xT/NsRvKfwY81a8=
github.com/smartystreets/assertions v1.0.0/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM=
@@ -829,8 +847,10 @@ github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3
github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
+github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v0.0.0-20170130113145-4d4bfba8f1d1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@@ -838,8 +858,11 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s=
-github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
+github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393 h1:hfhmMk7j4uDMRkfrrIOneMVXPBEhy3HSYiWX0gWoyhc=
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/captcha v1.0.393/go.mod h1:482ndbWuXqgStZNCqE88UoZeDveIt0juS7MY71Vangg=
@@ -863,11 +886,10 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1
github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoik09Xen7gje4m9ERNah1d1PPsVq1VEx9vE4=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
-github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
-github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
-github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
+github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
+github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8=
github.com/ulikunitz/xz v0.5.7/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
github.com/ulikunitz/xz v0.5.10 h1:t92gobL9l3HE202wg3rlk19F6X+JOxl9BBrCCMYEYd8=
@@ -970,12 +992,13 @@ golang.org/x/crypto v0.0.0-20191117063200-497ca9f6d64f/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201124201722-c8d3bf9c5392/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
+golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
-golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY=
-golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
+golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
+golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -994,6 +1017,8 @@ golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMx
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
+golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
+golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@@ -1018,8 +1043,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57 h1:LQmS1nU0twXLA96Kt7U9qtHJEbBk3z6Q0V4UXjZkpr4=
-golang.org/x/mod v0.6.0-dev.0.20211013180041-c96bc1413d57/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
+golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
+golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -1071,8 +1096,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/net v0.0.0-20220630215102-69896b714898 h1:K7wO6V1IrczY9QOQ2WkVpw4JQSwCd52UsxVEirZUfiw=
-golang.org/x/net v0.0.0-20220630215102-69896b714898/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
+golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -1100,8 +1125,9 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
+golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -1170,8 +1196,11 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211020174200-9d6173849985/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
-golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
+golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
+golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -1182,8 +1211,10 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
+golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
+golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
+golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1255,12 +1286,11 @@ golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
-golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023 h1:0c3L82FDQ5rt1bjTBlchS8t6RQ6299/+5bWMnRLh+uI=
-golang.org/x/tools v0.1.8-0.20211029000441-d6a9af8af023/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU=
+golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
+golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
@@ -1393,8 +1423,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.25.1-0.20200805231151-a709e31e5d12/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
-google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
-google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
+google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
+google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
diff --git a/main.go b/main.go
index 5e214e1..6cc59ee 100644
--- a/main.go
+++ b/main.go
@@ -69,10 +69,11 @@ func main() {
// 收到信号后关闭服务器
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
- go shutdown(sigChan, server)
+ wait := shutdown(sigChan, server)
defer func() {
- <-sigChan
+ sigChan <- syscall.SIGTERM
+ <-wait
}()
// 如果启用了SSL
@@ -104,7 +105,7 @@ func main() {
util.Log().Info("Listening to %q", conf.SystemConfig.Listen)
server.Addr = conf.SystemConfig.Listen
- if err := server.ListenAndServe(); err != nil {
+ if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
util.Log().Error("Failed to listen to %q: %s", conf.SystemConfig.Listen, err)
}
}
@@ -133,26 +134,29 @@ func RunUnix(server *http.Server) error {
return server.Serve(listener)
}
-func shutdown(sigChan chan os.Signal, server *http.Server) {
- sig := <-sigChan
- util.Log().Info("Signal %s received, shutting down server...", sig)
- ctx := context.Background()
- if conf.SystemConfig.GracePeriod != 0 {
- var cancel context.CancelFunc
- ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.SystemConfig.GracePeriod)*time.Second)
+func shutdown(sigChan chan os.Signal, server *http.Server) chan struct{} {
+ wait := make(chan struct{})
+ go func() {
+ sig := <-sigChan
+ util.Log().Info("Signal %s received, shutting down server...", sig)
+ if conf.SystemConfig.GracePeriod == 0 {
+ conf.SystemConfig.GracePeriod = 10
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(conf.SystemConfig.GracePeriod)*time.Second)
defer cancel()
- }
+ // Shutdown http server
+ err := server.Shutdown(ctx)
+ if err != nil {
+ util.Log().Error("Failed to shutdown server: %s", err)
+ }
- // Shutdown http server
- err := server.Shutdown(ctx)
- if err != nil {
- util.Log().Error("Failed to shutdown server: %s", err)
- }
+ // Persist in-memory cache
+ if err := cache.Store.Persist(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)); err != nil {
+ util.Log().Warning("Failed to persist cache: %s", err)
+ }
- // Persist in-memory cache
- if err := cache.Store.Persist(filepath.Join(model.GetSettingByName("temp_path"), cache.DefaultCacheFile)); err != nil {
- util.Log().Warning("Failed to persist cache: %s", err)
- }
-
- close(sigChan)
+ close(sigChan)
+ wait <- struct{}{}
+ }()
+ return wait
}
diff --git a/middleware/auth.go b/middleware/auth.go
index 3a7d763..6913273 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -5,14 +5,15 @@ import (
"context"
"crypto/md5"
"fmt"
+ "io/ioutil"
+ "net/http"
+
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/qiniu/go-sdk/v7/auth/qbox"
- "io/ioutil"
- "net/http"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
@@ -77,6 +78,24 @@ func AuthRequired() gin.HandlerFunc {
}
}
+// PhoneRequired 需要绑定手机
+// TODO 有bug
+func PhoneRequired() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ if model.IsTrueVal(model.GetSettingByName("phone_required")) &&
+ model.IsTrueVal(model.GetSettingByName("phone_enabled")) {
+ user, _ := c.Get("user")
+ if user.(*model.User).Phone != "" {
+ // TODO 忽略管理员
+ c.Next()
+ return
+ }
+ }
+
+ c.Next()
+ }
+}
+
// WebDAVAuth 验证WebDAV登录及权限
func WebDAVAuth() gin.HandlerFunc {
return func(c *gin.Context) {
diff --git a/middleware/auth_test.go b/middleware/auth_test.go
deleted file mode 100644
index 9e8650f..0000000
--- a/middleware/auth_test.go
+++ /dev/null
@@ -1,605 +0,0 @@
-package middleware
-
-import (
- "database/sql"
- "errors"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
- "github.com/cloudreve/Cloudreve/v3/pkg/mq"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "github.com/qiniu/go-sdk/v7/auth/qbox"
- "io/ioutil"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/auth"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/gin-gonic/gin"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-var mock sqlmock.Sqlmock
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- defer db.Close()
- m.Run()
-}
-
-func TestCurrentUser(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("GET", "/test", nil)
-
- //session为空
- sessionFunc := Session("233")
- sessionFunc(c)
- CurrentUser()(c)
- user, _ := c.Get("user")
- asserts.Nil(user)
-
- //session正确
- c, _ = gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("GET", "/test", nil)
- sessionFunc(c)
- util.SetSession(c, map[string]interface{}{"user_id": 1})
- rows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options"}).
- AddRow(1, nil, "admin@cloudreve.org", "{}")
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(rows)
- CurrentUser()(c)
- user, _ = c.Get("user")
- asserts.NotNil(user)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestAuthRequired(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("GET", "/test", nil)
- AuthRequiredFunc := AuthRequired()
-
- // 未登录
- AuthRequiredFunc(c)
- asserts.NotNil(c)
-
- // 类型错误
- c.Set("user", 123)
- AuthRequiredFunc(c)
- asserts.NotNil(c)
-
- // 正常
- c.Set("user", &model.User{})
- AuthRequiredFunc(c)
- asserts.NotNil(c)
-}
-
-func TestSignRequired(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("GET", "/test", nil)
- authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
- SignRequiredFunc := SignRequired(authInstance)
-
- // 鉴权失败
- SignRequiredFunc(c)
- asserts.NotNil(c)
- asserts.True(c.IsAborted())
-
- c, _ = gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("PUT", "/test", nil)
- SignRequiredFunc(c)
- asserts.NotNil(c)
- asserts.True(c.IsAborted())
-
- // Sign verify success
- c, _ = gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("PUT", "/test", nil)
- c.Request = auth.SignRequest(authInstance, c.Request, 0)
- SignRequiredFunc(c)
- asserts.NotNil(c)
- asserts.False(c.IsAborted())
-}
-
-func TestWebDAVAuth(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- AuthFunc := WebDAVAuth()
-
- // options请求跳过验证
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("OPTIONS", "/test", nil)
- AuthFunc(c)
- }
-
- // 请求HTTP Basic Auth
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("POST", "/test", nil)
- AuthFunc(c)
- asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"])
- }
-
- // 用户名不存在
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("POST", "/test", nil)
- c.Request.Header = map[string][]string{
- "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
- }
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "password", "email"}),
- )
- AuthFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
- }
-
- // 密码错误
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("POST", "/test", nil)
- c.Request.Header = map[string][]string{
- "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
- }
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"),
- )
- // 查找密码
- mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- AuthFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
- }
-
- //未启用 WebDAV
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("POST", "/test", nil)
- c.Request.Header = map[string][]string{
- "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
- }
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(
- sqlmock.NewRows(
- []string{"id", "password", "email", "group_id", "options"}).
- AddRow(1,
- "rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
- "who@cloudreve.org",
- 1,
- "{}",
- ),
- )
- mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false))
- // 查找密码
- mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- AuthFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(c.Writer.Status(), http.StatusForbidden)
- }
-
- //正常
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("POST", "/test", nil)
- c.Request.Header = map[string][]string{
- "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
- }
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(
- sqlmock.NewRows(
- []string{"id", "password", "email", "group_id", "options"}).
- AddRow(1,
- "rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
- "who@cloudreve.org",
- 1,
- "{}",
- ),
- )
- mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true))
- // 查找密码
- mock.ExpectQuery("SELECT(.+)webdav(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- AuthFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(c.Writer.Status(), 200)
- _, ok := c.Get("user")
- asserts.True(ok)
- }
-
-}
-
-func TestUseUploadSession(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- AuthFunc := UseUploadSession("local")
-
- // sessionID 为空
- {
-
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/sessionID", nil)
- authInstance := auth.HMACAuth{SecretKey: []byte("123")}
- auth.SignRequest(authInstance, c.Request, 0)
- AuthFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 成功
- {
- cache.Set(
- filesystem.UploadSessionCachePrefix+"testCallBackRemote",
- serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{Type: "local"},
- },
- 0,
- )
- cache.Deletes([]string{"1"}, "policy_")
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)groups(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[513]"))
- mock.ExpectQuery("SELECT(.+)policies(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "secret_key"}).AddRow(2, "123"))
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{
- {"sessionID", "testCallBackRemote"},
- }
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
- authInstance := auth.HMACAuth{SecretKey: []byte("123")}
- auth.SignRequest(authInstance, c.Request, 0)
- AuthFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.False(c.IsAborted())
- }
-}
-
-func TestUploadCallbackCheck(t *testing.T) {
- a := assert.New(t)
- rec := httptest.NewRecorder()
-
- // 上传会话不存在
- {
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{
- {"sessionID", "testSessionNotExist"},
- }
- res := uploadCallbackCheck(c, "local")
- a.Contains("上传会话不存在或已过期", res.Msg)
- }
-
- // 上传策略不一致
- {
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{
- {"sessionID", "testPolicyNotMatch"},
- }
- cache.Set(
- filesystem.UploadSessionCachePrefix+"testPolicyNotMatch",
- serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{Type: "remote"},
- },
- 0,
- )
- res := uploadCallbackCheck(c, "local")
- a.Contains("Policy not supported", res.Msg)
- }
-
- // 用户不存在
- {
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{
- {"sessionID", "testUserNotExist"},
- }
- cache.Set(
- filesystem.UploadSessionCachePrefix+"testUserNotExist",
- serializer.UploadSession{
- UID: 313,
- VirtualPath: "/",
- Policy: model.Policy{Type: "remote"},
- },
- 0,
- )
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}))
- res := uploadCallbackCheck(c, "remote")
- a.Contains("找不到用户", res.Msg)
- a.NoError(mock.ExpectationsWereMet())
- _, ok := cache.Get(filesystem.UploadSessionCachePrefix + "testUserNotExist")
- a.False(ok)
- }
-}
-
-func TestRemoteCallbackAuth(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- AuthFunc := RemoteCallbackAuth()
-
- // 成功
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{SecretKey: "123"},
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
- authInstance := auth.HMACAuth{SecretKey: []byte("123")}
- auth.SignRequest(authInstance, c.Request, 0)
- AuthFunc(c)
- asserts.False(c.IsAborted())
- }
-
- // 签名错误
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{SecretKey: "123"},
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/remote/testCallBackRemote", nil)
- AuthFunc(c)
- asserts.True(c.IsAborted())
- }
-}
-
-func TestQiniuCallbackAuth(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- AuthFunc := QiniuCallbackAuth()
-
- // 成功
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{
- SecretKey: "123",
- AccessKey: "123",
- },
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
- mac := qbox.NewMac("123", "123")
- token, err := mac.SignRequest(c.Request)
- asserts.NoError(err)
- c.Request.Header["Authorization"] = []string{"QBox " + token}
- AuthFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.False(c.IsAborted())
- }
-
- // 验证失败
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{
- SecretKey: "123",
- AccessKey: "123",
- },
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/qiniu/testCallBackQiniu", nil)
- mac := qbox.NewMac("123", "1213")
- token, err := mac.SignRequest(c.Request)
- asserts.NoError(err)
- c.Request.Header["Authorization"] = []string{"QBox " + token}
- AuthFunc(c)
- asserts.True(c.IsAborted())
- }
-}
-
-func TestOSSCallbackAuth(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- AuthFunc := OSSCallbackAuth()
-
- // 签名验证失败
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{
- SecretKey: "123",
- AccessKey: "123",
- },
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/testCallBackOSS", nil)
- mac := qbox.NewMac("123", "123")
- token, err := mac.SignRequest(c.Request)
- asserts.NoError(err)
- c.Request.Header["Authorization"] = []string{"QBox " + token}
- AuthFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 成功
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{
- SecretKey: "123",
- AccessKey: "123",
- },
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/oss/TnXx5E5VyfJUyM1UdkdDu1rtnJ34EbmH", ioutil.NopCloser(strings.NewReader(`{"name":"2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","source_name":"1/1_hFRtDLgM_2f7b2ccf30e9270ea920f1ab8a4037a546a2f0d5.jpg","size":114020,"pic_info":"810,539"}`)))
- c.Request.Header["Authorization"] = []string{"e5LwzwTkP9AFAItT4YzvdJOHd0Y0wqTMWhsV/h5SG90JYGAmMd+8LQyj96R+9qUfJWjMt6suuUh7LaOryR87Dw=="}
- c.Request.Header["X-Oss-Pub-Key-Url"] = []string{"aHR0cHM6Ly9nb3NzcHVibGljLmFsaWNkbi5jb20vY2FsbGJhY2tfcHViX2tleV92MS5wZW0="}
- AuthFunc(c)
- asserts.False(c.IsAborted())
- }
-
-}
-
-type fakeRead string
-
-func (r fakeRead) Read(p []byte) (int, error) {
- return 0, errors.New("error")
-}
-
-func TestUpyunCallbackAuth(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- AuthFunc := UpyunCallbackAuth()
-
- // 无法获取请求正文
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{
- SecretKey: "123",
- AccessKey: "123",
- },
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(fakeRead("")))
- AuthFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 正文MD5不一致
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{
- SecretKey: "123",
- AccessKey: "123",
- },
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
- c.Request.Header["Content-Md5"] = []string{"123"}
- AuthFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 签名不一致
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{
- SecretKey: "123",
- AccessKey: "123",
- },
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
- c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
- AuthFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 成功
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{
- SecretKey: "123",
- AccessKey: "123",
- },
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
- c.Request.Header["Content-Md5"] = []string{"c4ca4238a0b923820dcc509a6f75849b"}
- c.Request.Header["Authorization"] = []string{"UPYUN 123:GWueK9x493BKFFk5gmfdO2Mn6EM="}
- AuthFunc(c)
- asserts.False(c.IsAborted())
- }
-}
-
-func TestOneDriveCallbackAuth(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- AuthFunc := OneDriveCallbackAuth()
-
- // 成功
- {
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{
- {"sessionID", "TestOneDriveCallbackAuth"},
- }
- c.Set(filesystem.UploadSessionCtx, &serializer.UploadSession{
- UID: 1,
- VirtualPath: "/",
- Policy: model.Policy{
- SecretKey: "123",
- AccessKey: "123",
- },
- })
- c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/TestOneDriveCallbackAuth", ioutil.NopCloser(strings.NewReader("1")))
- res := mq.GlobalMQ.Subscribe("TestOneDriveCallbackAuth", 1)
- AuthFunc(c)
- select {
- case <-res:
- case <-time.After(time.Millisecond * 500):
- asserts.Fail("mq message should be published")
- }
- asserts.False(c.IsAborted())
- }
-}
-
-func TestIsAdmin(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- testFunc := IsAdmin()
-
- // 非管理员
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set("user", &model.User{})
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 是管理员
- {
- c, _ := gin.CreateTestContext(rec)
- user := &model.User{}
- user.Group.ID = 1
- c.Set("user", user)
- testFunc(c)
- asserts.False(c.IsAborted())
- }
-
- // 初始用户,非管理组
- {
- c, _ := gin.CreateTestContext(rec)
- user := &model.User{}
- user.Group.ID = 2
- user.ID = 1
- c.Set("user", user)
- testFunc(c)
- asserts.False(c.IsAborted())
- }
-}
diff --git a/middleware/captcha_test.go b/middleware/captcha_test.go
deleted file mode 100644
index 1846d31..0000000
--- a/middleware/captcha_test.go
+++ /dev/null
@@ -1,177 +0,0 @@
-package middleware
-
-import (
- "bytes"
- "errors"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
- "net/http"
- "net/http/httptest"
- "testing"
-)
-
-type errReader int
-
-func (errReader) Read(p []byte) (n int, err error) {
- return 0, errors.New("test error")
-}
-
-func TestCaptchaRequired_General(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
-
- // 未启用验证码
- {
- cache.SetSettings(map[string]string{
- "login_captcha": "0",
- "captcha_type": "1",
- "captcha_ReCaptchaSecret": "1",
- "captcha_TCaptcha_SecretId": "1",
- "captcha_TCaptcha_SecretKey": "1",
- "captcha_TCaptcha_CaptchaAppId": "1",
- "captcha_TCaptcha_AppSecretKey": "1",
- }, "setting_")
- TestFunc := CaptchaRequired("login_captcha")
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("GET", "/", nil)
- TestFunc(c)
- asserts.False(c.IsAborted())
- }
-
- // body 无法读取
- {
- cache.SetSettings(map[string]string{
- "login_captcha": "1",
- "captcha_type": "1",
- "captcha_ReCaptchaSecret": "1",
- "captcha_TCaptcha_SecretId": "1",
- "captcha_TCaptcha_SecretKey": "1",
- "captcha_TCaptcha_CaptchaAppId": "1",
- "captcha_TCaptcha_AppSecretKey": "1",
- }, "setting_")
- TestFunc := CaptchaRequired("login_captcha")
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("GET", "/", errReader(1))
- TestFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // body JSON 解析失败
- {
- cache.SetSettings(map[string]string{
- "login_captcha": "1",
- "captcha_type": "1",
- "captcha_ReCaptchaSecret": "1",
- "captcha_TCaptcha_SecretId": "1",
- "captcha_TCaptcha_SecretKey": "1",
- "captcha_TCaptcha_CaptchaAppId": "1",
- "captcha_TCaptcha_AppSecretKey": "1",
- }, "setting_")
- TestFunc := CaptchaRequired("login_captcha")
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- r := bytes.NewReader([]byte("123"))
- c.Request, _ = http.NewRequest("GET", "/", r)
- TestFunc(c)
- asserts.True(c.IsAborted())
- }
-}
-
-func TestCaptchaRequired_Normal(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
-
- // 验证码错误
- {
- cache.SetSettings(map[string]string{
- "login_captcha": "1",
- "captcha_type": "normal",
- "captcha_ReCaptchaSecret": "1",
- "captcha_TCaptcha_SecretId": "1",
- "captcha_TCaptcha_SecretKey": "1",
- "captcha_TCaptcha_CaptchaAppId": "1",
- "captcha_TCaptcha_AppSecretKey": "1",
- }, "setting_")
- TestFunc := CaptchaRequired("login_captcha")
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- r := bytes.NewReader([]byte("{}"))
- c.Request, _ = http.NewRequest("GET", "/", r)
- Session("233")(c)
- TestFunc(c)
- asserts.True(c.IsAborted())
- }
-}
-
-func TestCaptchaRequired_Recaptcha(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
-
- // 无法初始化reCaptcha实例
- {
- cache.SetSettings(map[string]string{
- "login_captcha": "1",
- "captcha_type": "recaptcha",
- "captcha_ReCaptchaSecret": "",
- "captcha_TCaptcha_SecretId": "1",
- "captcha_TCaptcha_SecretKey": "1",
- "captcha_TCaptcha_CaptchaAppId": "1",
- "captcha_TCaptcha_AppSecretKey": "1",
- }, "setting_")
- TestFunc := CaptchaRequired("login_captcha")
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- r := bytes.NewReader([]byte("{}"))
- c.Request, _ = http.NewRequest("GET", "/", r)
- TestFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 验证码错误
- {
- cache.SetSettings(map[string]string{
- "login_captcha": "1",
- "captcha_type": "recaptcha",
- "captcha_ReCaptchaSecret": "233",
- "captcha_TCaptcha_SecretId": "1",
- "captcha_TCaptcha_SecretKey": "1",
- "captcha_TCaptcha_CaptchaAppId": "1",
- "captcha_TCaptcha_AppSecretKey": "1",
- }, "setting_")
- TestFunc := CaptchaRequired("login_captcha")
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- r := bytes.NewReader([]byte("{}"))
- c.Request, _ = http.NewRequest("GET", "/", r)
- TestFunc(c)
- asserts.True(c.IsAborted())
- }
-}
-
-func TestCaptchaRequired_Tcaptcha(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
-
- // 验证出错
- {
- cache.SetSettings(map[string]string{
- "login_captcha": "1",
- "captcha_type": "tcaptcha",
- "captcha_ReCaptchaSecret": "",
- "captcha_TCaptcha_SecretId": "1",
- "captcha_TCaptcha_SecretKey": "1",
- "captcha_TCaptcha_CaptchaAppId": "1",
- "captcha_TCaptcha_AppSecretKey": "1",
- }, "setting_")
- TestFunc := CaptchaRequired("login_captcha")
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- r := bytes.NewReader([]byte("{}"))
- c.Request, _ = http.NewRequest("GET", "/", r)
- TestFunc(c)
- asserts.True(c.IsAborted())
- }
-}
diff --git a/middleware/cluster_test.go b/middleware/cluster_test.go
deleted file mode 100644
index 440163d..0000000
--- a/middleware/cluster_test.go
+++ /dev/null
@@ -1,120 +0,0 @@
-package middleware
-
-import (
- "errors"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
- "github.com/cloudreve/Cloudreve/v3/pkg/auth"
- "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
- "github.com/gin-gonic/gin"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- "net/http/httptest"
- "testing"
-)
-
-func TestMasterMetadata(t *testing.T) {
- a := assert.New(t)
- masterMetaDataFunc := MasterMetadata()
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest("GET", "/", nil)
-
- c.Request.Header = map[string][]string{
- "X-Cr-Site-Id": {"expectedSiteID"},
- "X-Cr-Site-Url": {"expectedSiteURL"},
- "X-Cr-Cloudreve-Version": {"expectedMasterVersion"},
- }
- masterMetaDataFunc(c)
- siteID, _ := c.Get("MasterSiteID")
- siteURL, _ := c.Get("MasterSiteURL")
- siteVersion, _ := c.Get("MasterVersion")
-
- a.Equal("expectedSiteID", siteID.(string))
- a.Equal("expectedSiteURL", siteURL.(string))
- a.Equal("expectedMasterVersion", siteVersion.(string))
-}
-
-func TestSlaveRPCSignRequired(t *testing.T) {
- a := assert.New(t)
- np := &cluster.NodePool{}
- np.Init()
- slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
- rec := httptest.NewRecorder()
-
- // id parse failed
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Request.Header.Set("X-Cr-Node-Id", "unknown")
- slaveRPCSignRequiredFunc(c)
- a.True(c.IsAborted())
- }
-
- // node id not exist
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Request.Header.Set("X-Cr-Node-Id", "38")
- slaveRPCSignRequiredFunc(c)
- a.True(c.IsAborted())
- }
-
- // success
- {
- authInstance := auth.HMACAuth{SecretKey: []byte("")}
- np.Add(&model.Node{Model: gorm.Model{
- ID: 38,
- }})
-
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest("POST", "/", nil)
- c.Request.Header.Set("X-Cr-Node-Id", "38")
- c.Request = auth.SignRequest(authInstance, c.Request, 0)
- slaveRPCSignRequiredFunc(c)
- a.False(c.IsAborted())
- }
-}
-
-func TestUseSlaveAria2Instance(t *testing.T) {
- a := assert.New(t)
-
- // MasterSiteID not set
- {
- testController := &controllermock.SlaveControllerMock{}
- useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- useSlaveAria2InstanceFunc(c)
- a.True(c.IsAborted())
- }
-
- // Cannot get aria2 instances
- {
- testController := &controllermock.SlaveControllerMock{}
- useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("MasterSiteID", "expectedSiteID")
- testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error"))
- useSlaveAria2InstanceFunc(c)
- a.True(c.IsAborted())
- testController.AssertExpectations(t)
- }
-
- // Success
- {
- testController := &controllermock.SlaveControllerMock{}
- useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("MasterSiteID", "expectedSiteID")
- testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil)
- useSlaveAria2InstanceFunc(c)
- a.False(c.IsAborted())
- res, _ := c.Get("MasterAria2Instance")
- a.NotNil(res)
- testController.AssertExpectations(t)
- }
-}
diff --git a/middleware/common_test.go b/middleware/common_test.go
deleted file mode 100644
index 1ab839a..0000000
--- a/middleware/common_test.go
+++ /dev/null
@@ -1,105 +0,0 @@
-package middleware
-
-import (
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/hashid"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
-)
-
-func TestHashID(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- TestFunc := HashID(hashid.FolderID)
-
- // 未给定ID对象,跳过
- {
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
- TestFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.False(c.IsAborted())
- }
-
- // 给定ID,解析失败
- {
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{
- {"id", "2333"},
- }
- c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
- TestFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.True(c.IsAborted())
- }
-
- // 给定ID,解析成功
- {
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{
- {"id", hashid.HashID(1, hashid.FolderID)},
- }
- c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
- TestFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.False(c.IsAborted())
- }
-}
-
-func TestIsFunctionEnabled(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- TestFunc := IsFunctionEnabled("TestIsFunctionEnabled")
-
- // 未开启
- {
- cache.Set("setting_TestIsFunctionEnabled", "0", 0)
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
- TestFunc(c)
- asserts.True(c.IsAborted())
- }
- // 开启
- {
- cache.Set("setting_TestIsFunctionEnabled", "1", 0)
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil)
- TestFunc(c)
- asserts.False(c.IsAborted())
- }
-
-}
-
-func TestCacheControl(t *testing.T) {
- a := assert.New(t)
- TestFunc := CacheControl()
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- TestFunc(c)
- a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache")
-}
-
-func TestSandbox(t *testing.T) {
- a := assert.New(t)
- TestFunc := Sandbox()
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- TestFunc(c)
- a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox")
-}
-
-func TestStaticResourceCache(t *testing.T) {
- a := assert.New(t)
- TestFunc := StaticResourceCache()
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- TestFunc(c)
- a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age")
-}
diff --git a/middleware/file_test.go b/middleware/file_test.go
deleted file mode 100644
index 5ca4014..0000000
--- a/middleware/file_test.go
+++ /dev/null
@@ -1,57 +0,0 @@
-package middleware
-
-import (
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
- "net/http/httptest"
- "testing"
-)
-
-func TestValidateSourceLink(t *testing.T) {
- a := assert.New(t)
- rec := httptest.NewRecorder()
- testFunc := ValidateSourceLink()
-
- // ID 不存在
- {
- c, _ := gin.CreateTestContext(rec)
- testFunc(c)
- a.True(c.IsAborted())
- }
-
- // SourceLink 不存在
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set("object_id", 1)
- mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
- testFunc(c)
- a.True(c.IsAborted())
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // 原文件不存在
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set("object_id", 1)
- mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"}))
- testFunc(c)
- a.True(c.IsAborted())
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // 成功
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set("object_id", 1)
- mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
- mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1))
- testFunc(c)
- a.False(c.IsAborted())
- a.NoError(mock.ExpectationsWereMet())
- }
-
-}
diff --git a/middleware/frontend.go b/middleware/frontend.go
index f07d9b6..eba1e84 100644
--- a/middleware/frontend.go
+++ b/middleware/frontend.go
@@ -1,13 +1,14 @@
package middleware
import (
+ "io/ioutil"
+ "net/http"
+ "strings"
+
"github.com/cloudreve/Cloudreve/v3/bootstrap"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
- "io/ioutil"
- "net/http"
- "strings"
)
// FrontendFileHandler 前端静态文件处理
@@ -51,10 +52,16 @@ func FrontendFileHandler() gin.HandlerFunc {
// 不存在的路径和index.html均返回index.html
if (path == "/index.html") || (path == "/") || !bootstrap.StaticFS.Exists("/", path) {
// 读取、替换站点设置
- options := model.GetSettingByNames("siteName", "siteKeywords", "siteScript",
- "pwa_small_icon")
+ options := model.GetSettingByNames(
+ "siteName", // 站点名称
+ "siteKeywords", // 关键词
+ "siteDes", // 描述
+ "siteScript", // 自定义代码
+ "pwa_small_icon", // 图标
+ )
finalHTML := util.Replace(map[string]string{
"{siteName}": options["siteName"],
+ "{siteKeywords}": options["siteKeywords"],
"{siteDes}": options["siteDes"],
"{siteScript}": options["siteScript"],
"{pwa_small_icon}": options["pwa_small_icon"],
diff --git a/middleware/frontend_test.go b/middleware/frontend_test.go
deleted file mode 100644
index d32529d..0000000
--- a/middleware/frontend_test.go
+++ /dev/null
@@ -1,144 +0,0 @@
-package middleware
-
-import (
- "errors"
- "github.com/cloudreve/Cloudreve/v3/bootstrap"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
- "net/http"
- "net/http/httptest"
- "os"
- "testing"
-)
-
-type StaticMock struct {
- testMock.Mock
-}
-
-func (m StaticMock) Open(name string) (http.File, error) {
- args := m.Called(name)
- return args.Get(0).(http.File), args.Error(1)
-}
-
-func (m StaticMock) Exists(prefix string, filepath string) bool {
- args := m.Called(prefix, filepath)
- return args.Bool(0)
-}
-
-func TestFrontendFileHandler(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
-
- // 静态资源未加载
- {
- TestFunc := FrontendFileHandler()
-
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("GET", "/", nil)
- TestFunc(c)
- asserts.False(c.IsAborted())
- }
-
- // index.html 不存在
- {
- testStatic := &StaticMock{}
- bootstrap.StaticFS = testStatic
- testStatic.On("Open", "/index.html").
- Return(&os.File{}, errors.New("error"))
- TestFunc := FrontendFileHandler()
-
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("GET", "/", nil)
- TestFunc(c)
- asserts.False(c.IsAborted())
- }
-
- // index.html 读取失败
- {
- file, _ := util.CreatNestedFile("tests/index.html")
- file.Close()
- testStatic := &StaticMock{}
- bootstrap.StaticFS = testStatic
- testStatic.On("Open", "/index.html").
- Return(file, nil)
- TestFunc := FrontendFileHandler()
-
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("GET", "/", nil)
- TestFunc(c)
- asserts.False(c.IsAborted())
- }
-
- // 成功且命中
- {
- file, _ := util.CreatNestedFile("tests/index.html")
- defer file.Close()
- testStatic := &StaticMock{}
- bootstrap.StaticFS = testStatic
- testStatic.On("Open", "/index.html").
- Return(file, nil)
- TestFunc := FrontendFileHandler()
-
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("GET", "/", nil)
-
- cache.Set("setting_siteName", "cloudreve", 0)
- cache.Set("setting_siteKeywords", "cloudreve", 0)
- cache.Set("setting_siteScript", "cloudreve", 0)
- cache.Set("setting_pwa_small_icon", "cloudreve", 0)
-
- TestFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 成功且命中静态文件
- {
- file, _ := util.CreatNestedFile("tests/index.html")
- defer file.Close()
- testStatic := &StaticMock{}
- bootstrap.StaticFS = testStatic
- testStatic.On("Open", "/index.html").
- Return(file, nil)
- testStatic.On("Exists", "/", "/2").
- Return(true)
- testStatic.On("Open", "/2").
- Return(file, nil)
- TestFunc := FrontendFileHandler()
-
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("GET", "/2", nil)
-
- TestFunc(c)
- asserts.True(c.IsAborted())
- testStatic.AssertExpectations(t)
- }
-
- // API 相关跳过
- {
- for _, reqPath := range []string{"/api/user", "/manifest.json", "/dav/path"} {
- file, _ := util.CreatNestedFile("tests/index.html")
- defer file.Close()
- testStatic := &StaticMock{}
- bootstrap.StaticFS = testStatic
- testStatic.On("Open", "/index.html").
- Return(file, nil)
- TestFunc := FrontendFileHandler()
-
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{}
- c.Request, _ = http.NewRequest("GET", reqPath, nil)
-
- TestFunc(c)
- asserts.False(c.IsAborted())
- }
- }
-
-}
diff --git a/middleware/mock_test.go b/middleware/mock_test.go
deleted file mode 100644
index 1ebee20..0000000
--- a/middleware/mock_test.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package middleware
-
-import (
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
-)
-
-func TestMockHelper(t *testing.T) {
- asserts := assert.New(t)
- MockHelperFunc := MockHelper()
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("GET", "/test", nil)
-
- // 写入session
- {
- SessionMock["test"] = "pass"
- Session("test")(c)
- MockHelperFunc(c)
- asserts.Equal("pass", util.GetSession(c, "test").(string))
- }
-
- // 写入context
- {
- ContextMock["test"] = "pass"
- MockHelperFunc(c)
- test, exist := c.Get("test")
- asserts.True(exist)
- asserts.Equal("pass", test.(string))
-
- }
-}
diff --git a/middleware/session.go b/middleware/session.go
index db90755..b6d8023 100644
--- a/middleware/session.go
+++ b/middleware/session.go
@@ -1,11 +1,12 @@
package middleware
import (
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/sessionstore"
"net/http"
"strings"
+ "github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/sessionstore"
+
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
diff --git a/middleware/session_test.go b/middleware/session_test.go
deleted file mode 100644
index 9fbe0d2..0000000
--- a/middleware/session_test.go
+++ /dev/null
@@ -1,64 +0,0 @@
-package middleware
-
-import (
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
-)
-
-func TestSession(t *testing.T) {
- asserts := assert.New(t)
-
- {
- handler := Session("2333")
- asserts.NotNil(handler)
- asserts.NotNil(Store)
- asserts.IsType(emptyFunc(), handler)
- }
-}
-
-func emptyFunc() gin.HandlerFunc {
- return func(c *gin.Context) {}
-}
-
-func TestCSRFInit(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- sessionFunc := Session("233")
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("GET", "/test", nil)
- sessionFunc(c)
- CSRFInit()(c)
- asserts.True(util.GetSession(c, "CSRF").(bool))
- }
-}
-
-func TestCSRFCheck(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- sessionFunc := Session("233")
-
- // 通过检查
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("GET", "/test", nil)
- sessionFunc(c)
- CSRFInit()(c)
- CSRFCheck()(c)
- asserts.False(c.IsAborted())
- }
-
- // 未通过检查
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request, _ = http.NewRequest("GET", "/test", nil)
- sessionFunc(c)
- CSRFCheck()(c)
- asserts.True(c.IsAborted())
- }
-}
diff --git a/middleware/share.go b/middleware/share.go
index 488b703..cc4ef42 100644
--- a/middleware/share.go
+++ b/middleware/share.go
@@ -118,8 +118,14 @@ func BeforeShareDownload() gin.HandlerFunc {
// 对积分、下载次数进行更新
err = share.DownloadBy(user, c)
if err != nil {
- c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
- nil))
+ if err == model.ErrInsufficientCredit {
+ c.JSON(200, serializer.Err(serializer.CodeInsufficientCredit, err.Error(),
+ nil))
+ } else {
+ c.JSON(200, serializer.Err(serializer.CodeGroupNotAllowed, err.Error(),
+ nil))
+ }
+
c.Abort()
return
}
diff --git a/middleware/share_test.go b/middleware/share_test.go
deleted file mode 100644
index 129076b..0000000
--- a/middleware/share_test.go
+++ /dev/null
@@ -1,190 +0,0 @@
-package middleware
-
-import (
- "net/http/httptest"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "github.com/gin-gonic/gin"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestShareAvailable(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- testFunc := ShareAvailable()
-
- // 分享不存在
- {
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{
- {"id", "empty"},
- }
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 通过
- {
- conf.SystemConfig.HashIDSalt = ""
- // 用户组
- mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
- mock.ExpectQuery("SELECT(.+)shares(.+)").
- WillReturnRows(
- sqlmock.NewRows(
- []string{"id", "remain_downloads", "source_id"}).
- AddRow(1, 1, 2),
- )
- mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
- c, _ := gin.CreateTestContext(rec)
- c.Params = []gin.Param{
- {"id", "x9T4"},
- }
- testFunc(c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.False(c.IsAborted())
- asserts.NotNil(c.Get("user"))
- asserts.NotNil(c.Get("share"))
- }
-}
-
-func TestShareCanPreview(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- testFunc := ShareCanPreview()
-
- // 无分享上下文
- {
- c, _ := gin.CreateTestContext(rec)
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 可以预览
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set("share", &model.Share{PreviewEnabled: true})
- testFunc(c)
- asserts.False(c.IsAborted())
- }
-
- // 未开启预览
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set("share", &model.Share{PreviewEnabled: false})
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-}
-
-func TestCheckShareUnlocked(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- testFunc := CheckShareUnlocked()
-
- // 无分享上下文
- {
- c, _ := gin.CreateTestContext(rec)
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 无密码
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set("share", &model.Share{})
- testFunc(c)
- asserts.False(c.IsAborted())
- }
-
-}
-
-func TestBeforeShareDownload(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- testFunc := BeforeShareDownload()
-
- // 无分享上下文
- {
- c, _ := gin.CreateTestContext(rec)
- testFunc(c)
- asserts.True(c.IsAborted())
-
- c, _ = gin.CreateTestContext(rec)
- c.Set("share", &model.Share{})
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 用户不能下载
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set("share", &model.Share{})
- c.Set("user", &model.User{
- Group: model.Group{OptionsSerialized: model.GroupOption{}},
- })
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 可以下载
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set("share", &model.Share{})
- c.Set("user", &model.User{
- Model: gorm.Model{ID: 1},
- Group: model.Group{OptionsSerialized: model.GroupOption{
- ShareDownload: true,
- }},
- })
- testFunc(c)
- asserts.False(c.IsAborted())
- }
-}
-
-func TestShareOwner(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- testFunc := ShareOwner()
-
- // 未登录
- {
- c, _ := gin.CreateTestContext(rec)
- testFunc(c)
- asserts.True(c.IsAborted())
-
- c, _ = gin.CreateTestContext(rec)
- c.Set("share", &model.Share{})
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 非用户所创建分享
- {
- c, _ := gin.CreateTestContext(rec)
- testFunc(c)
- asserts.True(c.IsAborted())
-
- c, _ = gin.CreateTestContext(rec)
- c.Set("share", &model.Share{User: model.User{Model: gorm.Model{ID: 1}}})
- c.Set("user", &model.User{})
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // 正常
- {
- c, _ := gin.CreateTestContext(rec)
- testFunc(c)
- asserts.True(c.IsAborted())
-
- c, _ = gin.CreateTestContext(rec)
- c.Set("share", &model.Share{})
- c.Set("user", &model.User{})
- testFunc(c)
- asserts.False(c.IsAborted())
- }
-}
diff --git a/middleware/wopi_test.go b/middleware/wopi_test.go
deleted file mode 100644
index c6ca327..0000000
--- a/middleware/wopi_test.go
+++ /dev/null
@@ -1,112 +0,0 @@
-package middleware
-
-import (
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/wopimock"
- "github.com/cloudreve/Cloudreve/v3/pkg/wopi"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
- "net/http/httptest"
- "testing"
-)
-
-func TestWopiWriteAccess(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- testFunc := WopiWriteAccess()
-
- // deny preview only session
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionPreview})
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // pass
- {
- c, _ := gin.CreateTestContext(rec)
- c.Set(WopiSessionCtx, &wopi.SessionCache{Action: wopi.ActionEdit})
- testFunc(c)
- asserts.False(c.IsAborted())
- }
-}
-
-func TestWopiAccessValidation(t *testing.T) {
- asserts := assert.New(t)
- rec := httptest.NewRecorder()
- mockWopi := &wopimock.WopiClientMock{}
- mockCache := cache.NewMemoStore()
- testFunc := WopiAccessValidation(mockWopi, mockCache)
-
- // malformed access token
- {
- c, _ := gin.CreateTestContext(rec)
- c.AddParam(wopi.AccessTokenQuery, "000")
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // session key not exist
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
- query := c.Request.URL.Query()
- query.Set(wopi.AccessTokenQuery, "sessionID.key")
- c.Request.URL.RawQuery = query.Encode()
- testFunc(c)
- asserts.True(c.IsAborted())
- }
-
- // user key not exist
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
- query := c.Request.URL.Query()
- query.Set(wopi.AccessTokenQuery, "sessionID.key")
- c.Request.URL.RawQuery = query.Encode()
- mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
- mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
- testFunc(c)
- asserts.True(c.IsAborted())
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // file not found
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
- query := c.Request.URL.Query()
- query.Set(wopi.AccessTokenQuery, "sessionID.key")
- c.Request.URL.RawQuery = query.Encode()
- mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
- mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- c.Set("object_id", uint(0))
- testFunc(c)
- asserts.True(c.IsAborted())
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // all pass
- {
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest("GET", "/wopi/files/1?access_token=", nil)
- query := c.Request.URL.Query()
- query.Set(wopi.AccessTokenQuery, "sessionID.key")
- c.Request.URL.RawQuery = query.Encode()
- mockCache.Set(wopi.SessionCachePrefix+"sessionID", wopi.SessionCache{UserID: 1, FileID: 1}, 0)
- mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- c.Set("object_id", uint(1))
- testFunc(c)
- asserts.False(c.IsAborted())
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotPanics(func() {
- c.MustGet(WopiSessionCtx)
- })
- asserts.NotPanics(func() {
- c.MustGet("user")
- })
- }
-}
diff --git a/models/defaults.go b/models/defaults.go
index b13f22e..e73c4a1 100644
--- a/models/defaults.go
+++ b/models/defaults.go
@@ -9,12 +9,15 @@ import (
var defaultSettings = []Setting{
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
- {Name: "siteName", Value: `Cloudreve`, Type: "basic"},
+ {Name: "siteName", Value: `CloudrevePlus`, Type: "basic"},
{Name: "register_enabled", Value: `1`, Type: "register"},
{Name: "default_group", Value: `2`, Type: "register"},
- {Name: "siteKeywords", Value: `Cloudreve, cloud storage`, Type: "basic"},
- {Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
+ {Name: "mail_domain_filter", Value: `0`, Type: "register"},
+ {Name: "mail_domain_filter_list", Value: `126.com,163.com,gmail.com,outlook.com,qq.com,foxmail.com,yeah.net,sohu.com,sohu.cn,139.com,wo.cn,189.cn,hotmail.com,live.com,live.cn`, Type: "register"},
+ {Name: "siteKeywords", Value: `CloudrevePlus, cloud storage`, Type: "basic"},
+ {Name: "siteDes", Value: `部署公私兼备的网盘系统`, Type: "basic"},
{Name: "siteTitle", Value: `Inclusive cloud storage for everyone`, Type: "basic"},
+ {Name: "siteNotice", Value: ``, Type: "basic"},
{Name: "siteScript", Value: ``, Type: "basic"},
{Name: "siteID", Value: uuid.Must(uuid.NewV4()).String(), Type: "basic"},
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
@@ -26,6 +29,9 @@ var defaultSettings = []Setting{
{Name: "smtpUser", Value: `no-reply@acg.blue`, Type: "mail"},
{Name: "smtpPass", Value: ``, Type: "mail"},
{Name: "smtpEncryption", Value: `0`, Type: "mail"},
+ {Name: "over_used_template", Value: `
容量超额提醒 | 容量超额警告 | 亲爱的{userName}: | 由于{notifyReason},您在{siteTitle}的账户的容量使用超出配额,您将无法继续上传新文件,请尽快清理文件,否则我们将会禁用您的账户。 | 登录{siteTitle} | 感谢您选择{siteTitle}。 |
|
| |
`, Type: "mail_template"},
+ {Name: "ban_time", Value: `604800`, Type: "storage_policy"},
{Name: "maxEditSize", Value: `52428800`, Type: "file_edit"},
{Name: "archive_timeout", Value: `600`, Type: "timeout"},
{Name: "download_timeout", Value: `600`, Type: "timeout"},
@@ -38,6 +44,7 @@ var defaultSettings = []Setting{
{Name: "slave_recover_interval", Value: `120`, Type: "slave"},
{Name: "slave_transfer_timeout", Value: `172800`, Type: "timeout"},
{Name: "onedrive_monitor_timeout", Value: `600`, Type: "timeout"},
+ {Name: "onedrive_source_timeout", Value: `1800`, Type: "timeout"},
{Name: "share_download_session_timeout", Value: `2073600`, Type: "timeout"},
{Name: "onedrive_callback_check", Value: `20`, Type: "timeout"},
{Name: "folder_props_timeout", Value: `300`, Type: "timeout"},
@@ -46,10 +53,14 @@ var defaultSettings = []Setting{
{Name: "reset_after_upload_failed", Value: `0`, Type: "upload"},
{Name: "use_temp_chunk_buffer", Value: `1`, Type: "upload"},
{Name: "login_captcha", Value: `0`, Type: "login"},
+ {Name: "qq_login", Value: `0`, Type: "login"},
+ {Name: "qq_direct_login", Value: `0`, Type: "login"},
+ {Name: "qq_login_id", Value: ``, Type: "login"},
+ {Name: "qq_login_key", Value: ``, Type: "login"},
{Name: "reg_captcha", Value: `0`, Type: "login"},
{Name: "email_active", Value: `0`, Type: "register"},
{Name: "mail_activation_template", Value: `激活您的账户用户激活 | | 重设{siteTitle}密码 | 亲爱的{userName}: | 请点击下方按钮完成密码重设。如果非你本人操作,请忽略此邮件。 | 重设密码 | 感谢您选择{siteTitle}。 |
|
| |
`, Type: "mail_template"},
+ {Name: "pack_data", Value: `[]`, Type: "pack"},
{Name: "db_version_" + conf.RequiredDBVersion, Value: `installed`, Type: "version"},
+ {Name: "alipay_enabled", Value: `0`, Type: "payment"},
+ {Name: "payjs_enabled", Value: `0`, Type: "payment"},
+ {Name: "payjs_id", Value: ``, Type: "payment"},
+ {Name: "payjs_secret", Value: ``, Type: "payment"},
+ {Name: "appid", Value: ``, Type: "payment"},
+ {Name: "appkey", Value: ``, Type: "payment"},
+ {Name: "shopid", Value: ``, Type: "payment"},
+ {Name: "wechat_enabled", Value: `0`, Type: "payment"},
+ {Name: "wechat_appid", Value: ``, Type: "payment"},
+ {Name: "wechat_mchid", Value: ``, Type: "payment"},
+ {Name: "wechat_serial_no", Value: ``, Type: "payment"},
+ {Name: "wechat_api_key", Value: ``, Type: "payment"},
+ {Name: "wechat_pk_content", Value: ``, Type: "payment"},
{Name: "hot_share_num", Value: `10`, Type: "share"},
+ {Name: "group_sell_data", Value: `[]`, Type: "group_sell"},
{Name: "gravatar_server", Value: `https://www.gravatar.com/`, Type: "avatar"},
{Name: "defaultTheme", Value: `#3f51b5`, Type: "basic"},
{Name: "themes", Value: `{"#3f51b5":{"palette":{"primary":{"main":"#3f51b5"},"secondary":{"main":"#f50057"}}},"#2196f3":{"palette":{"primary":{"main":"#2196f3"},"secondary":{"main":"#FFC107"}}},"#673AB7":{"palette":{"primary":{"main":"#673AB7"},"secondary":{"main":"#2196F3"}}},"#E91E63":{"palette":{"primary":{"main":"#E91E63"},"secondary":{"main":"#42A5F5","contrastText":"#fff"}}},"#FF5722":{"palette":{"primary":{"main":"#FF5722"},"secondary":{"main":"#3F51B5"}}},"#FFC107":{"palette":{"primary":{"main":"#FFC107"},"secondary":{"main":"#26C6DA"}}},"#8BC34A":{"palette":{"primary":{"main":"#8BC34A","contrastText":"#fff"},"secondary":{"main":"#FF8A65","contrastText":"#fff"}}},"#009688":{"palette":{"primary":{"main":"#009688"},"secondary":{"main":"#4DD0E1","contrastText":"#fff"}}},"#607D8B":{"palette":{"primary":{"main":"#607D8B"},"secondary":{"main":"#F06292"}}},"#795548":{"palette":{"primary":{"main":"#795548"},"secondary":{"main":"#4CAF50","contrastText":"#fff"}}}}`, Type: "basic"},
@@ -77,9 +103,15 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "avatar_size_l", Value: "200", Type: "avatar"},
{Name: "avatar_size_m", Value: "130", Type: "avatar"},
{Name: "avatar_size_s", Value: "50", Type: "avatar"},
- {Name: "home_view_method", Value: "icon", Type: "view"},
+ {Name: "score_enabled", Value: "1", Type: "score"},
+ {Name: "share_score_rate", Value: "80", Type: "score"},
+ {Name: "score_price", Value: "1", Type: "score"},
+ {Name: "report_enabled", Value: "0", Type: "report"},
+ {Name: "home_view_method", Value: "list", Type: "view"},
{Name: "share_view_method", Value: "list", Type: "view"},
{Name: "cron_garbage_collect", Value: "@hourly", Type: "cron"},
+ {Name: "cron_notify_user", Value: "@hourly", Type: "cron"},
+ {Name: "cron_ban_user", Value: "@hourly", Type: "cron"},
{Name: "cron_recycle_upload_session", Value: "@every 1h30m", Type: "cron"},
{Name: "authn_enabled", Value: "0", Type: "authn"},
{Name: "captcha_type", Value: "normal", Type: "captcha"},
@@ -127,13 +159,24 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "pwa_display", Value: "standalone", Type: "pwa"},
{Name: "pwa_theme_color", Value: "#000000", Type: "pwa"},
{Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"},
+ {Name: "initial_files", Value: "[]", Type: "register"},
{Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"},
+ {Name: "phone_required", Value: "false", Type: "phone"},
+ {Name: "phone_enabled", Value: "false", Type: "phone"},
+ {Name: "vol_content", Value: "Guess", Type: "vol"},
+ {Name: "vol_signature", Value: "Guess", Type: "vol"},
{Name: "show_app_promotion", Value: "1", Type: "mobile"},
{Name: "public_resource_maxage", Value: "86400", Type: "timeout"},
{Name: "wopi_enabled", Value: "0", Type: "wopi"},
{Name: "wopi_endpoint", Value: "", Type: "wopi"},
{Name: "wopi_max_size", Value: "52428800", Type: "wopi"},
{Name: "wopi_session_timeout", Value: "36000", Type: "wopi"},
+ {Name: "custom_payment_enabled", Value: "0", Type: "payment"},
+ {Name: "custom_payment_endpoint", Value: "", Type: "payment"},
+ {Name: "custom_payment_secret", Value: "", Type: "payment"},
+ {Name: "custom_payment_name", Value: "", Type: "payment"},
+ {Name: "app_feedback_link", Value: "", Type: "mobile"},
+ {Name: "app_forum_link", Value: "", Type: "mobile"},
}
func InitSlaveDefaults() {
diff --git a/models/download_test.go b/models/download_test.go
deleted file mode 100644
index 367afb7..0000000
--- a/models/download_test.go
+++ /dev/null
@@ -1,190 +0,0 @@
-package model
-
-import (
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestDownload_Create(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- download := Download{GID: "1"}
- id, err := download.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.EqualValues(1, id)
- }
-
- // 失败
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- download := Download{GID: "1"}
- id, err := download.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.EqualValues(0, id)
- }
-}
-
-func TestDownload_AfterFind(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- download := Download{Attrs: `{"gid":"123"}`}
- err := download.AfterFind()
- asserts.NoError(err)
- asserts.Equal("123", download.StatusInfo.Gid)
- }
-
- // 忽略空值
- {
- download := Download{Attrs: ``}
- err := download.AfterFind()
- asserts.NoError(err)
- asserts.Equal("", download.StatusInfo.Gid)
- }
-
- // 解析失败
- {
- download := Download{Attrs: `?`}
- err := download.BeforeSave()
- asserts.Error(err)
- asserts.Equal("", download.StatusInfo.Gid)
- }
-
-}
-
-func TestDownload_Save(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- download := Download{
- Model: gorm.Model{
- ID: 1,
- },
- Attrs: `{"gid":"123"}`,
- }
- err := download.Save()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal("123", download.StatusInfo.Gid)
- }
-
- // 失败
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- download := Download{
- Model: gorm.Model{
- ID: 1,
- },
- }
- err := download.Save()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
-}
-
-func TestGetDownloadsByStatus(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").WithArgs(0, 1).WillReturnRows(sqlmock.NewRows([]string{"gid"}).AddRow("0").AddRow("1"))
- res := GetDownloadsByStatus(0, 1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(res, 2)
-}
-
-func TestGetDownloadByGid(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").WithArgs(2, "gid").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1"))
- res, err := GetDownloadByGid("gid", 2)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal(res.GID, "1")
-}
-
-func TestDownload_GetOwner(t *testing.T) {
- asserts := assert.New(t)
-
- // 已经有User对象
- {
- download := &Download{User: &User{Nick: "nick"}}
- user := download.GetOwner()
- asserts.NotNil(user)
- asserts.Equal("nick", user.Nick)
- }
-
- // 无User对象
- {
- download := &Download{UserID: 3}
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"nick"}).AddRow("nick"))
- user := download.GetOwner()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotNil(user)
- asserts.Equal("nick", user.Nick)
- }
-}
-
-func TestGetDownloadsByStatusAndUser(t *testing.T) {
- asserts := assert.New(t)
-
- // 列出全部
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3))
- res := GetDownloadsByStatusAndUser(0, 1, 1, 2)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(res, 2)
- }
-
- // 列出全部,分页
- {
- mock.ExpectQuery("SELECT(.+)DESC(.+)").WithArgs(1, 1, 2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(3))
- res := GetDownloadsByStatusAndUser(2, 1, 1, 2)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(res, 2)
- }
-}
-
-func TestDownload_Delete(t *testing.T) {
- asserts := assert.New(t)
- share := Download{}
-
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := share.Delete()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
-}
-
-func TestDownload_GetNodeID(t *testing.T) {
- a := assert.New(t)
- record := Download{}
-
- // compatible with 3.4
- a.EqualValues(1, record.GetNodeID())
-
- record.NodeID = 5
- a.EqualValues(5, record.GetNodeID())
-}
diff --git a/models/file.go b/models/file.go
index bfe49cb..56ee2ab 100644
--- a/models/file.go
+++ b/models/file.go
@@ -10,6 +10,7 @@ import (
"strings"
"time"
+ "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
)
@@ -388,6 +389,15 @@ func (file *File) UpdateSourceName(value string) error {
}).Error
}
+// Relocate 更新文件的物理指向
+func (file *File) Relocate(src string, policyID uint) error {
+ file.Policy = Policy{}
+ return DB.Model(&file).Set("gorm:association_autoupdate", false).Updates(map[string]interface{}{
+ "source_name": src,
+ "policy_id": policyID,
+ }).Error
+}
+
func (file *File) PopChunkToFile(lastModified *time.Time, picInfo string) error {
file.UploadSessionID = nil
if lastModified != nil {
@@ -470,3 +480,46 @@ func (file *File) ShouldLoadThumb() bool {
func (file *File) ThumbFile() string {
return file.SourceName + GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")
}
+
+/*
+ 实现 filesystem.FileHeader 接口
+*/
+
+// Read 实现 io.Reader
+func (file *File) Read(p []byte) (n int, err error) {
+ return 0, errors.New("noe supported")
+}
+
+// Close 实现io.Closer
+func (file *File) Close() error {
+ return errors.New("noe supported")
+}
+
+// Seeker 实现io.Seeker
+func (file *File) Seek(offset int64, whence int) (int64, error) {
+ return 0, errors.New("noe supported")
+}
+
+func (file *File) Info() *fsctx.UploadTaskInfo {
+ return &fsctx.UploadTaskInfo{
+ Size: file.Size,
+ FileName: file.Name,
+ VirtualPath: file.Position,
+ Mode: 0,
+ Metadata: file.MetadataSerialized,
+ LastModified: &file.UpdatedAt,
+ SavePath: file.SourceName,
+ UploadSessionID: file.UploadSessionID,
+ }
+}
+
+func (file *File) SetSize(size uint64) {
+ file.Size = size
+}
+
+func (file *File) SetModel(newFile interface{}) {
+}
+
+func (file *File) Seekable() bool {
+ return false
+}
diff --git a/models/file_test.go b/models/file_test.go
deleted file mode 100644
index 83198fc..0000000
--- a/models/file_test.go
+++ /dev/null
@@ -1,785 +0,0 @@
-package model
-
-import (
- "errors"
- "testing"
- "time"
-
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestFile_Create(t *testing.T) {
- asserts := assert.New(t)
- file := File{
- Name: "123",
- }
-
- // 无法插入文件记录
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- err := file.Create()
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 无法更新用户容量
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
- mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- err := file.Create()
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(0, 1))
- mock.ExpectCommit()
- err := file.Create()
- asserts.NoError(err)
- asserts.Equal(uint(5), file.ID)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestFile_AfterFind(t *testing.T) {
- a := assert.New(t)
-
- // metadata not empty
- {
- file := File{
- Name: "123",
- Metadata: "{\"name\":\"123\"}",
- }
-
- a.NoError(file.AfterFind())
- a.Equal("123", file.MetadataSerialized["name"])
- }
-
- // metadata empty
- {
- file := File{
- Name: "123",
- Metadata: "",
- }
- a.Nil(file.MetadataSerialized)
- a.NoError(file.AfterFind())
- a.NotNil(file.MetadataSerialized)
- }
-}
-
-func TestFile_BeforeSave(t *testing.T) {
- a := assert.New(t)
-
- // metadata not empty
- {
- file := File{
- Name: "123",
- MetadataSerialized: map[string]string{
- "name": "123",
- },
- }
-
- a.NoError(file.BeforeSave())
- a.Equal("{\"name\":\"123\"}", file.Metadata)
- }
-
- // metadata empty
- {
- file := File{
- Name: "123",
- }
- a.NoError(file.BeforeSave())
- a.Equal("", file.Metadata)
- }
-}
-
-func TestFolder_GetChildFile(t *testing.T) {
- asserts := assert.New(t)
- folder := Folder{Model: gorm.Model{ID: 1}, Name: "/"}
- // 存在
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, "1.txt").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt"))
- file, err := folder.GetChildFile("1.txt")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal("1.txt", file.Name)
- asserts.Equal("/", file.Position)
- }
-
- // 不存在
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, "1.txt").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- _, err := folder.GetChildFile("1.txt")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
-}
-
-func TestFolder_GetChildFiles(t *testing.T) {
- asserts := assert.New(t)
- folder := &Folder{
- Model: gorm.Model{
- ID: 1,
- },
- Position: "/123",
- Name: "456",
- }
-
- // 找不到
- mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnError(errors.New("error"))
- files, err := folder.GetChildFiles()
- asserts.Error(err)
- asserts.Len(files, 0)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 找到了
- mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2))
- files, err = folder.GetChildFiles()
- asserts.NoError(err)
- asserts.Len(files, 2)
- asserts.Equal("/123/456", files[0].Position)
- asserts.NoError(mock.ExpectationsWereMet())
-
-}
-
-func TestGetFilesByIDs(t *testing.T) {
- asserts := assert.New(t)
-
- // 出错
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3, 1).
- WillReturnError(errors.New("error"))
- folders, err := GetFilesByIDs([]uint{1, 2, 3}, 1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Len(folders, 0)
- }
-
- // 部分找到
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
- folders, err := GetFilesByIDs([]uint{1, 2, 3}, 1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Len(folders, 1)
- }
-
- // 忽略UID查找
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
- folders, err := GetFilesByIDs([]uint{1, 2, 3}, 0)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Len(folders, 1)
- }
-}
-
-func TestGetChildFilesOfFolders(t *testing.T) {
- asserts := assert.New(t)
- testFolder := []Folder{
- Folder{
- Model: gorm.Model{ID: 3},
- },
- Folder{
- Model: gorm.Model{ID: 4},
- }, Folder{
- Model: gorm.Model{ID: 5},
- },
- }
-
- // 出错
- {
- mock.ExpectQuery("SELECT(.+)folder_id").WithArgs(3, 4, 5).WillReturnError(errors.New("not found"))
- files, err := GetChildFilesOfFolders(&testFolder)
- asserts.Error(err)
- asserts.Len(files, 0)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 找到2个
- {
- mock.ExpectQuery("SELECT(.+)folder_id").
- WithArgs(3, 4, 5).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
- AddRow(3, "3").
- AddRow(4, "4"),
- )
- files, err := GetChildFilesOfFolders(&testFolder)
- asserts.NoError(err)
- asserts.Len(files, 2)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 全部找到
- {
- mock.ExpectQuery("SELECT(.+)folder_id").
- WithArgs(3, 4, 5).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
- AddRow(3, "3").
- AddRow(4, "4").
- AddRow(5, "5"),
- )
- files, err := GetChildFilesOfFolders(&testFolder)
- asserts.NoError(err)
- asserts.Len(files, 3)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestGetUploadPlaceholderFiles(t *testing.T) {
- a := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)upload_session_id(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
- files := GetUploadPlaceholderFiles(1)
- a.NoError(mock.ExpectationsWereMet())
- a.Len(files, 1)
-}
-
-func TestFile_GetPolicy(t *testing.T) {
- asserts := assert.New(t)
-
- // 空策略
- {
- file := File{
- PolicyID: 23,
- }
- mock.ExpectQuery("SELECT(.+)policies(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name"}).
- AddRow(23, "name"),
- )
- file.GetPolicy()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(uint(23), file.Policy.ID)
- }
-
- // 非空策略
- {
- file := File{
- PolicyID: 23,
- Policy: Policy{Model: gorm.Model{ID: 24}},
- }
- file.GetPolicy()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(uint(24), file.Policy.ID)
- }
-}
-
-func TestRemoveFilesWithSoftLinks_EmptyArg(t *testing.T) {
- asserts := assert.New(t)
- // 传入空
- {
- mock.ExpectQuery("SELECT(.+)files(.+)")
- file, err := RemoveFilesWithSoftLinks([]File{})
- asserts.Error(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal(len(file), 0)
- DB.Find(&File{})
- }
-}
-
-func TestRemoveFilesWithSoftLinks(t *testing.T) {
- asserts := assert.New(t)
- files := []File{
- File{
- Model: gorm.Model{ID: 1},
- SourceName: "1.txt",
- PolicyID: 23,
- },
- File{
- Model: gorm.Model{ID: 2},
- SourceName: "2.txt",
- PolicyID: 24,
- },
- }
-
- // 传入空文件列表
- {
- file, err := RemoveFilesWithSoftLinks([]File{})
- asserts.NoError(err)
- asserts.Empty(file)
- }
-
- // 全都没有
- {
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs("1.txt", 23, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs("2.txt", 24, 2).
- WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
- file, err := RemoveFilesWithSoftLinks(files)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal(files, file)
- }
-
- // 第二个是软链
- {
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs("1.txt", 23, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs("2.txt", 24, 2).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
- AddRow(3, 24, "2.txt"),
- )
- file, err := RemoveFilesWithSoftLinks(files)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal(files[:1], file)
- }
-
- // 第一个是软链
- {
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs("1.txt", 23, 1).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
- AddRow(3, 23, "1.txt"),
- )
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs("2.txt", 24, 2).
- WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
- file, err := RemoveFilesWithSoftLinks(files)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal(files[1:], file)
- }
- // 全部是软链
- {
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs("1.txt", 23, 1).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
- AddRow(3, 23, "1.txt"),
- )
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs("2.txt", 24, 2).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
- AddRow(3, 24, "2.txt"),
- )
- file, err := RemoveFilesWithSoftLinks(files)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Len(file, 0)
- }
-}
-
-func TestDeleteFiles(t *testing.T) {
- a := assert.New(t)
-
- // uid 不一致
- {
- err := DeleteFiles([]*File{{UserID: 2}}, 1)
- a.Contains("user id not consistent", err.Error())
- }
-
- // 删除失败
- {
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- err := DeleteFiles([]*File{{UserID: 1}}, 1)
- a.NoError(mock.ExpectationsWereMet())
- a.Error(err)
- }
-
- // 无法变更用户容量
- {
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- err := DeleteFiles([]*File{{UserID: 1}}, 1)
- a.NoError(mock.ExpectationsWereMet())
- a.Error(err)
- }
-
- // 文件脏读
- {
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(1, 0))
- mock.ExpectRollback()
- err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1)
- a.NoError(mock.ExpectationsWereMet())
- a.Error(err)
- a.Contains("file size is dirty", err.Error())
- }
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)").WithArgs(uint64(3), sqlmock.AnyArg(), uint(1)).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 1)
- a.NoError(mock.ExpectationsWereMet())
- a.NoError(err)
- }
-
- // 成功, 关联用户不存在
- {
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectCommit()
- err := DeleteFiles([]*File{{Size: 1, UserID: 1}, {Size: 2, UserID: 1}}, 0)
- a.NoError(mock.ExpectationsWereMet())
- a.NoError(err)
- }
-}
-
-func TestGetFilesByParentIDs(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 4, 5, 6).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name"}).
- AddRow(4, "4.txt").
- AddRow(5, "5.txt").
- AddRow(6, "6.txt"),
- )
- files, err := GetFilesByParentIDs([]uint{4, 5, 6}, 1)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(files, 3)
-}
-
-func TestGetFilesByUploadSession(t *testing.T) {
- a := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, "sessionID").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name"}).AddRow(4, "4.txt"))
- files, err := GetFilesByUploadSession("sessionID", 1)
- a.NoError(err)
- a.NoError(mock.ExpectationsWereMet())
- a.Equal("4.txt", files.Name)
-}
-
-func TestFile_Updates(t *testing.T) {
- asserts := assert.New(t)
- file := File{Model: gorm.Model{ID: 1}}
-
- // rename
- {
- // not reset thumb
- {
- file := File{Model: gorm.Model{ID: 1}}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("", "newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := file.Rename("newName")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // thumb not available, rename base name only
- {
- file := File{Model: gorm.Model{ID: 1}, Name: "1.txt", MetadataSerialized: map[string]string{
- ThumbStatusMetadataKey: ThumbStatusNotAvailable,
- },
- Metadata: "{}"}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("{}", "newName.txt", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := file.Rename("newName.txt")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal(ThumbStatusNotAvailable, file.MetadataSerialized[ThumbStatusMetadataKey])
- }
-
- // thumb not available, rename base name only
- {
- file := File{Model: gorm.Model{ID: 1}, Name: "1.txt", MetadataSerialized: map[string]string{
- ThumbStatusMetadataKey: ThumbStatusNotAvailable,
- }}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").WithArgs("{}", "newName.jpg", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := file.Rename("newName.jpg")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Empty(file.MetadataSerialized[ThumbStatusMetadataKey])
- }
- }
-
- // UpdatePicInfo
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WithArgs("1,1", 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := file.UpdatePicInfo("1,1")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // UpdateSourceName
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WithArgs("", "newName", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := file.UpdateSourceName("newName")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-}
-
-func TestFile_UpdateSize(t *testing.T) {
- a := assert.New(t)
-
- // 增加成功
- {
- file := File{Size: 10}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 11, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)+(.+)").WithArgs(uint64(1), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.NoError(file.UpdateSize(11))
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // 减少成功
- {
- file := File{Size: 10}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.NoError(file.UpdateSize(8))
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // 文件更新失败
- {
- file := File{Size: 10}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnError(errors.New("error"))
- mock.ExpectRollback()
-
- a.Error(file.UpdateSize(8))
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // 用户容量更新失败
- {
- file := File{Size: 10}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 8, sqlmock.AnyArg(), 10).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)-(.+)").WithArgs(uint64(2), sqlmock.AnyArg()).WillReturnError(errors.New("error"))
- mock.ExpectRollback()
-
- a.Error(file.UpdateSize(8))
- a.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestFile_PopChunkToFile(t *testing.T) {
- a := assert.New(t)
- timeNow := time.Now()
- file := File{}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- a.NoError(file.PopChunkToFile(&timeNow, "1,1"))
-}
-
-func TestFile_CanCopy(t *testing.T) {
- a := assert.New(t)
- file := File{}
- a.True(file.CanCopy())
- file.UploadSessionID = &file.Name
- a.False(file.CanCopy())
-}
-
-func TestFile_FileInfoInterface(t *testing.T) {
- asserts := assert.New(t)
- file := File{
- Model: gorm.Model{
- UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC),
- },
- Name: "test_name",
- SourceName: "",
- UserID: 0,
- Size: 10,
- PicInfo: "",
- FolderID: 0,
- PolicyID: 0,
- Policy: Policy{},
- Position: "/test",
- }
-
- name := file.GetName()
- asserts.Equal("test_name", name)
-
- size := file.GetSize()
- asserts.Equal(uint64(10), size)
-
- asserts.Equal(time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), file.ModTime())
- asserts.False(file.IsDir())
- asserts.Equal("/test", file.GetPosition())
-}
-
-func TestGetFilesByKeywords(t *testing.T) {
- asserts := assert.New(t)
-
- // 未指定用户
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs("k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res, err := GetFilesByKeywords(0, nil, "k1", "k2")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Len(res, 1)
- }
-
- // 指定用户
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res, err := GetFilesByKeywords(1, nil, "k1", "k2")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Len(res, 1)
- }
-
- // 指定父目录
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 12, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res, err := GetFilesByKeywords(1, []uint{12}, "k1", "k2")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Len(res, 1)
- }
-}
-
-func TestFile_CreateOrGetSourceLink(t *testing.T) {
- a := assert.New(t)
- file := &File{}
- file.ID = 1
-
- // 已存在,返回老的 SourceLink
- {
- mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
- res, err := file.CreateOrGetSourceLink()
- a.NoError(err)
- a.EqualValues(2, res.ID)
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // 不存在,插入失败
- {
- expectedErr := errors.New("error")
- mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnError(expectedErr)
- mock.ExpectRollback()
- res, err := file.CreateOrGetSourceLink()
- a.Nil(res)
- a.ErrorIs(err, expectedErr)
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectCommit()
- res, err := file.CreateOrGetSourceLink()
- a.NoError(err)
- a.EqualValues(2, res.ID)
- a.EqualValues(file.ID, res.File.ID)
- a.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestFile_UpdateMetadata(t *testing.T) {
- a := assert.New(t)
- file := &File{}
- file.ID = 1
-
- // 更新失败
- {
- expectedErr := errors.New("error")
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(sqlmock.AnyArg(), 1).WillReturnError(expectedErr)
- mock.ExpectRollback()
- a.ErrorIs(file.UpdateMetadata(map[string]string{"1": "1"}), expectedErr)
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs(sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- a.NoError(file.UpdateMetadata(map[string]string{"1": "1"}))
- a.NoError(mock.ExpectationsWereMet())
- a.Equal("1", file.MetadataSerialized["1"])
- }
-}
-
-func TestFile_ShouldLoadThumb(t *testing.T) {
- a := assert.New(t)
- file := &File{
- MetadataSerialized: map[string]string{},
- }
- file.ID = 1
-
- // 无缩略图
- {
- file.MetadataSerialized[ThumbStatusMetadataKey] = ThumbStatusNotAvailable
- a.False(file.ShouldLoadThumb())
- }
-
- // 有缩略图
- {
- file.MetadataSerialized[ThumbStatusMetadataKey] = ThumbStatusExist
- a.True(file.ShouldLoadThumb())
- }
-}
-
-func TestFile_ThumbFile(t *testing.T) {
- a := assert.New(t)
- file := &File{
- SourceName: "test",
- MetadataSerialized: map[string]string{},
- }
- file.ID = 1
-
- a.Equal("test._thumb", file.ThumbFile())
-}
diff --git a/models/folder.go b/models/folder.go
index 80f712c..1130d6b 100644
--- a/models/folder.go
+++ b/models/folder.go
@@ -16,10 +16,12 @@ type Folder struct {
Name string `gorm:"unique_index:idx_only_one_name"`
ParentID *uint `gorm:"index:parent_id;unique_index:idx_only_one_name"`
OwnerID uint `gorm:"index:owner_id"`
+ PolicyID uint // Webdav下挂载的存储策略ID
// 数据库忽略字段
- Position string `gorm:"-"`
- WebdavDstName string `gorm:"-"`
+ Position string `gorm:"-"`
+ WebdavDstName string `gorm:"-"`
+ InheritPolicyID uint `gorm:"-"` // 从父目录继承而来的policy id,默认值则使用自身的的PolicyID
}
// Create 创建目录
@@ -33,6 +35,13 @@ func (folder *Folder) Create() (uint, error) {
return folder.ID, nil
}
+// GetMountedFolders 列出已挂载存储策略的目录
+func GetMountedFolders(uid uint) []Folder {
+ var folders []Folder
+ DB.Where("owner_id = ? and policy_id <> ?", uid, 0).Find(&folders)
+ return folders
+}
+
// GetChild 返回folder下名为name的子目录,不存在则返回错误
func (folder *Folder) GetChild(name string) (*Folder, error) {
var resFolder Folder
@@ -40,9 +49,14 @@ func (folder *Folder) GetChild(name string) (*Folder, error) {
Where("parent_id = ? AND owner_id = ? AND name = ?", folder.ID, folder.OwnerID, name).
First(&resFolder).Error
- // 将子目录的路径传递下去
+ // 将子目录的路径及存储策略传递下去
if err == nil {
resFolder.Position = path.Join(folder.Position, folder.Name)
+ if folder.PolicyID > 0 {
+ resFolder.InheritPolicyID = folder.PolicyID
+ } else if folder.InheritPolicyID > 0 {
+ resFolder.InheritPolicyID = folder.InheritPolicyID
+ }
}
return &resFolder, err
}
@@ -323,6 +337,11 @@ func (folder *Folder) Rename(new string) error {
return DB.Model(&folder).UpdateColumn("name", new).Error
}
+// Mount 目录挂载
+func (folder *Folder) Mount(new uint) error {
+ return DB.Model(&folder).Update("policy_id", new).Error
+}
+
/*
实现 FileInfo.FileInfo 接口
TODO 测试
diff --git a/models/folder_test.go b/models/folder_test.go
deleted file mode 100644
index 90220ca..0000000
--- a/models/folder_test.go
+++ /dev/null
@@ -1,622 +0,0 @@
-package model
-
-import (
- "errors"
- "testing"
- "time"
-
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestFolder_Create(t *testing.T) {
- asserts := assert.New(t)
- folder := &Folder{
- Name: "new folder",
- }
-
- // 不存在,插入成功
- mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
- mock.ExpectCommit()
- fid, err := folder.Create()
- asserts.NoError(err)
- asserts.Equal(uint(5), fid)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 插入失败
- mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- fid, err = folder.Create()
- asserts.NoError(err)
- asserts.Equal(uint(1), fid)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 存在,直接返回
- mock.ExpectQuery("SELECT(.+)folders(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5))
- fid, err = folder.Create()
- asserts.NoError(err)
- asserts.Equal(uint(5), fid)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestFolder_GetChild(t *testing.T) {
- asserts := assert.New(t)
- folder := Folder{
- Model: gorm.Model{ID: 5},
- OwnerID: 1,
- Name: "/",
- }
-
- // 目录存在
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(5, 1, "sub").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "sub"))
- sub, err := folder.GetChild("sub")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal(sub.Name, "sub")
- asserts.Equal("/", sub.Position)
- }
-
- // 目录不存在
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(5, 1, "sub").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- sub, err := folder.GetChild("sub")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(uint(0), sub.ID)
-
- }
-}
-
-func TestFolder_GetChildFolder(t *testing.T) {
- asserts := assert.New(t)
- folder := &Folder{
- Model: gorm.Model{
- ID: 1,
- },
- Position: "/123",
- Name: "456",
- }
-
- // 找不到
- mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnError(errors.New("error"))
- files, err := folder.GetChildFolder()
- asserts.Error(err)
- asserts.Len(files, 0)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 找到了
- mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2))
- files, err = folder.GetChildFolder()
- asserts.NoError(err)
- asserts.Len(files, 2)
- asserts.Equal("/123/456", files[0].Position)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestGetRecursiveChildFolderSQLite(t *testing.T) {
- conf.DatabaseConfig.Type = "sqlite"
- asserts := assert.New(t)
-
- // 测试目录结构
- // 1
- // 2 3
- // 4 5 6
-
- // 查询第一层
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name"}).
- AddRow(1, "folder1"),
- )
- // 查询第二层
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name"}).
- AddRow(2, "folder2").
- AddRow(3, "folder3"),
- )
- // 查询第三层
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name"}).
- AddRow(4, "folder4").
- AddRow(5, "folder5").
- AddRow(6, "folder6"),
- )
- // 查询第四层
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 4, 5, 6).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name"}),
- )
-
- folders, err := GetRecursiveChildFolder([]uint{1}, 1, true)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(folders, 6)
-}
-
-func TestDeleteFolderByIDs(t *testing.T) {
- asserts := assert.New(t)
-
- // 出错
- {
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- err := DeleteFolderByIDs([]uint{1, 2, 3})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(0, 3))
- mock.ExpectCommit()
- err := DeleteFolderByIDs([]uint{1, 2, 3})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-}
-
-func TestGetFoldersByIDs(t *testing.T) {
- asserts := assert.New(t)
-
- // 出错
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3, 1).
- WillReturnError(errors.New("error"))
- folders, err := GetFoldersByIDs([]uint{1, 2, 3}, 1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Len(folders, 0)
- }
-
- // 部分找到
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
- folders, err := GetFoldersByIDs([]uint{1, 2, 3}, 1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Len(folders, 1)
- }
-}
-
-func TestFolder_MoveOrCopyFileTo(t *testing.T) {
- asserts := assert.New(t)
- // 当前目录
- folder := Folder{
- Model: gorm.Model{ID: 1},
- OwnerID: 1,
- Name: "test",
- }
- // 目标目录
- dstFolder := Folder{
- Model: gorm.Model{ID: 10},
- Name: "dst",
- }
-
- // 复制文件
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(
- 1,
- 2,
- 3,
- 1,
- 1,
- ).WillReturnRows(
- sqlmock.NewRows([]string{"id", "size", "upload_session_id"}).
- AddRow(1, 10, nil).
- AddRow(2, 20, nil).
- AddRow(2, 20, &folder.Name),
- )
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- storage, err := folder.MoveOrCopyFileTo(
- []uint{1, 2, 3},
- &dstFolder,
- true,
- )
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(uint64(30), storage)
- }
-
- // 复制文件, 检索文件出错
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(
- 1,
- 2,
- 1,
- 1,
- ).WillReturnError(errors.New("error"))
-
- storage, err := folder.MoveOrCopyFileTo(
- []uint{1, 2},
- &dstFolder,
- true,
- )
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(uint64(0), storage)
- }
-
- // 复制文件,第二个文件插入出错
- {
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(
- 1,
- 2,
- 1,
- 1,
- ).WillReturnRows(
- sqlmock.NewRows([]string{"id", "size"}).
- AddRow(1, 10).
- AddRow(2, 20),
- )
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- storage, err := folder.MoveOrCopyFileTo(
- []uint{1, 2},
- &dstFolder,
- true,
- )
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(uint64(10), storage)
- }
-
- // 移动文件 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 1).
- WillReturnResult(sqlmock.NewResult(1, 2))
- mock.ExpectCommit()
- storage, err := folder.MoveOrCopyFileTo(
- []uint{1, 2},
- &dstFolder,
- false,
- )
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal(uint64(0), storage)
- }
-
- // 移动文件 出错
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 1).
- WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- storage, err := folder.MoveOrCopyFileTo(
- []uint{1, 2},
- &dstFolder,
- false,
- )
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(uint64(0), storage)
- }
-}
-
-func TestFolder_CopyFolderTo(t *testing.T) {
- conf.DatabaseConfig.Type = "mysql"
- asserts := assert.New(t)
- // 父目录
- parFolder := Folder{
- Model: gorm.Model{ID: 9},
- OwnerID: 1,
- }
- // 目标目录
- dstFolder := Folder{
- Model: gorm.Model{ID: 10},
- }
-
- // 测试复制目录结构
- // test(2)(5)
- // 1(3)(6) 2.txt
- // 3(4)(7) 4.txt 5.txt(上传中)
-
- // 正常情况 成功
- {
- // GetRecursiveChildFolder
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}))
-
- // 复制目录
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1))
- mock.ExpectCommit()
-
- // 查找子文件
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3, 4).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name", "folder_id", "size", "upload_session_id"}).
- AddRow(1, "2.txt", 2, 10, nil).
- AddRow(2, "3.txt", 3, 20, nil).
- AddRow(3, "5.txt", 3, 20, &dstFolder.Name),
- )
-
- // 复制子文件
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1))
- mock.ExpectCommit()
-
- size, err := parFolder.CopyFolderTo(2, &dstFolder)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal(uint64(30), size)
- }
-
- // 递归查询失败
- {
- // GetRecursiveChildFolder
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnError(errors.New("error"))
-
- size, err := parFolder.CopyFolderTo(2, &dstFolder)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(uint64(0), size)
- }
-
- // 父目录ID不存在
- {
- // GetRecursiveChildFolder
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 99))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}))
-
- // 复制目录
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
- mock.ExpectCommit()
-
- size, err := parFolder.CopyFolderTo(2, &dstFolder)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(uint64(0), size)
- }
-
- // 查询子文件失败
- {
- // GetRecursiveChildFolder
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}))
-
- // 复制目录
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1))
- mock.ExpectCommit()
-
- // 查找子文件
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3, 4).
- WillReturnError(errors.New("error"))
-
- size, err := parFolder.CopyFolderTo(2, &dstFolder)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(uint64(0), size)
- }
-
- // 复制文件 一个失败
- {
- // GetRecursiveChildFolder
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(2, 9))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(3, 2))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 3).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}).AddRow(4, 3))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 4).WillReturnRows(sqlmock.NewRows([]string{"id", "parent_id"}))
-
- // 复制目录
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(6, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(7, 1))
- mock.ExpectCommit()
-
- // 查找子文件
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3, 4).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name", "folder_id", "size"}).
- AddRow(1, "2.txt", 2, 10).
- AddRow(2, "3.txt", 3, 20),
- )
-
- // 复制子文件
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
-
- size, err := parFolder.CopyFolderTo(2, &dstFolder)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(uint64(10), size)
- }
-
-}
-
-func TestFolder_MoveOrCopyFolderTo_Move(t *testing.T) {
- conf.DatabaseConfig.Type = "mysql"
- asserts := assert.New(t)
- // 父目录
- parFolder := Folder{
- Model: gorm.Model{ID: 9},
- OwnerID: 1,
- }
- // 目标目录
- dstFolder := Folder{
- Model: gorm.Model{ID: 10},
- OwnerID: 1,
- }
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WithArgs(10, sqlmock.AnyArg(), 1, 2, 1, 9).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := parFolder.MoveFolderTo([]uint{1, 2}, &dstFolder)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // 移动自己到自己内部,失败
- {
- err := parFolder.MoveFolderTo([]uint{10, 2}, &dstFolder)
- asserts.Error(err)
- }
-}
-
-func TestFolder_FileInfoInterface(t *testing.T) {
- asserts := assert.New(t)
- folder := Folder{
- Model: gorm.Model{
- UpdatedAt: time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC),
- },
- Name: "test_name",
- OwnerID: 0,
- Position: "/test",
- }
-
- name := folder.GetName()
- asserts.Equal("test_name", name)
-
- size := folder.GetSize()
- asserts.Equal(uint64(0), size)
-
- asserts.Equal(time.Date(2019, 12, 21, 12, 40, 0, 0, time.UTC), folder.ModTime())
- asserts.True(folder.IsDir())
- asserts.Equal("/test", folder.GetPosition())
-}
-
-func TestTraceRoot(t *testing.T) {
- asserts := assert.New(t)
- var parentId uint
- parentId = 5
- folder := Folder{
- ParentID: &parentId,
- OwnerID: 1,
- Name: "test_name",
- }
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "/"))
- asserts.NoError(folder.TraceRoot())
- asserts.Equal("/parent", folder.Position)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 出现错误
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs(5, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(5, "parent", 1))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 0).
- WillReturnError(errors.New("error"))
- asserts.Error(folder.TraceRoot())
- asserts.Equal("parent", folder.Position)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestFolder_Rename(t *testing.T) {
- asserts := assert.New(t)
- folder := Folder{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "test_name",
- OwnerID: 1,
- Position: "/test",
- }
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)").
- WithArgs("test_name_new", 1).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := folder.Rename("test_name_new")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // 出现错误
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)").
- WithArgs("test_name_new", 1).
- WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- err := folder.Rename("test_name_new")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
-}
diff --git a/models/group.go b/models/group.go
index 0abf21d..8dc3057 100644
--- a/models/group.go
+++ b/models/group.go
@@ -2,6 +2,7 @@ package model
import (
"encoding/json"
+
"github.com/jinzhu/gorm"
)
@@ -29,11 +30,15 @@ type GroupOption struct {
DecompressSize uint64 `json:"decompress_size,omitempty"`
OneTimeDownload bool `json:"one_time_download,omitempty"`
ShareDownload bool `json:"share_download,omitempty"`
+ ShareFree bool `json:"share_free,omitempty"`
Aria2 bool `json:"aria2,omitempty"` // 离线下载
Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置
+ Relocate bool `json:"relocate,omitempty"` // 转移文件
SourceBatchSize int `json:"source_batch,omitempty"`
RedirectedSource bool `json:"redirected_source,omitempty"`
Aria2BatchSize int `json:"aria2_batch,omitempty"`
+ AvailableNodes []uint `json:"available_nodes,omitempty"`
+ SelectNode bool `json:"select_node,omitempty"`
AdvanceDelete bool `json:"advance_delete,omitempty"`
WebDAVProxy bool `json:"webdav_proxy,omitempty"`
}
diff --git a/models/group_test.go b/models/group_test.go
deleted file mode 100644
index 2f487ce..0000000
--- a/models/group_test.go
+++ /dev/null
@@ -1,77 +0,0 @@
-package model
-
-import (
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/jinzhu/gorm"
- "github.com/pkg/errors"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestGetGroupByID(t *testing.T) {
- asserts := assert.New(t)
-
- //找到用户组时
- groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}).
- AddRow(1, "管理员", "[1]")
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows)
-
- group, err := GetGroupByID(1)
- asserts.NoError(err)
- asserts.Equal(Group{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "管理员",
- Policies: "[1]",
- PolicyList: []uint{1},
- }, group)
-
- //未找到用户时
- mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found"))
- group, err = GetGroupByID(1)
- asserts.Error(err)
- asserts.Equal(Group{}, group)
-}
-
-func TestGroup_AfterFind(t *testing.T) {
- asserts := assert.New(t)
-
- testCase := Group{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "管理员",
- Policies: "[1]",
- }
- err := testCase.AfterFind()
- asserts.NoError(err)
- asserts.Equal(testCase.PolicyList, []uint{1})
-
- testCase.Policies = "[1,2,3,4,5]"
- err = testCase.AfterFind()
- asserts.NoError(err)
- asserts.Equal(testCase.PolicyList, []uint{1, 2, 3, 4, 5})
-
- testCase.Policies = "[1,2,3,4,5"
- err = testCase.AfterFind()
- asserts.Error(err)
-
- testCase.Policies = "[]"
- err = testCase.AfterFind()
- asserts.NoError(err)
- asserts.Equal(testCase.PolicyList, []uint{})
-}
-
-func TestGroup_BeforeSave(t *testing.T) {
- asserts := assert.New(t)
- group := Group{
- PolicyList: []uint{1, 2, 3},
- }
- {
- err := group.BeforeSave()
- asserts.NoError(err)
- asserts.Equal("[1,2,3]", group.Policies)
- }
-
-}
diff --git a/models/migration.go b/models/migration.go
index fad6a76..f86e5e9 100644
--- a/models/migration.go
+++ b/models/migration.go
@@ -40,8 +40,8 @@ func migration() {
DB = DB.Set("gorm:table_options", "ENGINE=InnoDB")
}
- DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{},
- &Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}, &SourceLink{})
+ DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &StoragePack{}, &Share{},
+ &Task{}, &Download{}, &Tag{}, &Webdav{}, &Order{}, &Redeem{}, &Report{}, &Node{}, &SourceLink{})
// 创建初始存储策略
addDefaultPolicy()
@@ -107,10 +107,13 @@ func addDefaultGroups() {
ArchiveDownload: true,
ArchiveTask: true,
ShareDownload: true,
+ ShareFree: true,
Aria2: true,
+ Relocate: true,
SourceBatchSize: 1000,
Aria2BatchSize: 50,
RedirectedSource: true,
+ SelectNode: true,
AdvanceDelete: true,
},
}
diff --git a/models/migration_test.go b/models/migration_test.go
deleted file mode 100644
index 7c9d673..0000000
--- a/models/migration_test.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package model
-
-import (
- "testing"
-
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestMigration(t *testing.T) {
- asserts := assert.New(t)
- conf.DatabaseConfig.Type = "sqlite"
- DB, _ = gorm.Open("sqlite", ":memory:")
-
- asserts.NotPanics(func() {
- migration()
- })
- conf.DatabaseConfig.Type = "mysql"
- DB = mockDB
-}
diff --git a/models/node_test.go b/models/node_test.go
deleted file mode 100644
index de1757f..0000000
--- a/models/node_test.go
+++ /dev/null
@@ -1,64 +0,0 @@
-package model
-
-import (
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestGetNodeByID(t *testing.T) {
- a := assert.New(t)
- mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res, err := GetNodeByID(1)
- a.NoError(err)
- a.EqualValues(1, res.ID)
- a.NoError(mock.ExpectationsWereMet())
-}
-
-func TestGetNodesByStatus(t *testing.T) {
- a := assert.New(t)
- mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(NodeActive))
- res, err := GetNodesByStatus(NodeActive)
- a.NoError(err)
- a.Len(res, 1)
- a.EqualValues(NodeActive, res[0].Status)
- a.NoError(mock.ExpectationsWereMet())
-}
-
-func TestNode_AfterFind(t *testing.T) {
- a := assert.New(t)
- node := &Node{}
-
- // No aria2 options
- {
- a.NoError(node.AfterFind())
- }
-
- // with aria2 options
- {
- node.Aria2Options = `{"timeout":1}`
- a.NoError(node.AfterFind())
- a.Equal(1, node.Aria2OptionsSerialized.Timeout)
- }
-}
-
-func TestNode_BeforeSave(t *testing.T) {
- a := assert.New(t)
- node := &Node{}
-
- node.Aria2OptionsSerialized.Timeout = 1
- a.NoError(node.BeforeSave())
- a.Contains(node.Aria2Options, "1")
-}
-
-func TestNode_SetStatus(t *testing.T) {
- a := assert.New(t)
- node := &Node{}
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)nodes").WithArgs(NodeActive, sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- a.NoError(node.SetStatus(NodeActive))
- a.Equal(NodeActive, node.Status)
- a.NoError(mock.ExpectationsWereMet())
-}
diff --git a/models/order.go b/models/order.go
new file mode 100755
index 0000000..9a79c24
--- /dev/null
+++ b/models/order.go
@@ -0,0 +1,59 @@
+package model
+
+import (
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
+ "github.com/jinzhu/gorm"
+)
+
+const (
+ // PackOrderType 容量包订单
+ PackOrderType = iota
+ // GroupOrderType 用户组订单
+ GroupOrderType
+ // ScoreOrderType 积分充值订单
+ ScoreOrderType
+)
+
+const (
+ // OrderUnpaid 未支付
+ OrderUnpaid = iota
+ // OrderPaid 已支付
+ OrderPaid
+ // OrderCanceled 已取消
+ OrderCanceled
+)
+
+// Order 交易订单
+type Order struct {
+ gorm.Model
+ UserID uint // 创建者ID
+ OrderNo string `gorm:"index:order_number"` // 商户自定义订单编号
+ Type int // 订单类型
+ Method string // 支付类型
+ ProductID int64 // 商品ID
+ Num int // 商品数量
+ Name string // 订单标题
+ Price int // 商品单价
+ Status int // 订单状态
+}
+
+// Create 创建订单记录
+func (order *Order) Create() (uint, error) {
+ if err := DB.Create(order).Error; err != nil {
+ util.Log().Warning("Failed to insert order record: %s", err)
+ return 0, err
+ }
+ return order.ID, nil
+}
+
+// UpdateStatus 更新订单状态
+func (order *Order) UpdateStatus(status int) {
+ DB.Model(order).Update("status", status)
+}
+
+// GetOrderByNo 根据商户订单号查询订单
+func GetOrderByNo(id string) (*Order, error) {
+ var order Order
+ err := DB.Where("order_no = ?", id).First(&order).Error
+ return &order, err
+}
diff --git a/models/policy.go b/models/policy.go
index 11d8e4b..f80b553 100644
--- a/models/policy.go
+++ b/models/policy.go
@@ -3,14 +3,15 @@ package model
import (
"encoding/gob"
"encoding/json"
- "github.com/gofrs/uuid"
- "github.com/samber/lo"
"path"
"path/filepath"
"strconv"
"strings"
"time"
+ "github.com/gofrs/uuid"
+ "github.com/samber/lo"
+
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
@@ -73,6 +74,18 @@ type PolicyOption struct {
ThumbExts []string `json:"thumb_exts,omitempty"`
}
+// thumbSuffix 支持缩略图处理的文件扩展名
+var thumbSuffix = map[string][]string{
+ "local": {},
+ "qiniu": {".psd", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
+ "oss": {".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
+ "cos": {".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
+ "upyun": {".svg", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".tiff", ".bmp"},
+ "s3": {},
+ "remote": {},
+ "onedrive": {"*"},
+}
+
func init() {
// 注册缓存用到的复杂结构
gob.Register(Policy{})
@@ -179,6 +192,17 @@ func (policy *Policy) GenerateFileName(uid uint, origin string) string {
return fileRule
}
+// IsThumbExist 给定文件名,返回此存储策略下是否可能存在缩略图
+func (policy *Policy) IsThumbExist(name string) bool {
+ if list, ok := thumbSuffix[policy.Type]; ok {
+ if len(list) == 1 && list[0] == "*" {
+ return true
+ }
+ return util.ContainsString(list, strings.ToLower(filepath.Ext(name)))
+ }
+ return false
+}
+
// IsDirectlyPreview 返回此策略下文件是否可以直接预览(不需要重定向)
func (policy *Policy) IsDirectlyPreview() bool {
return policy.Type == "local"
diff --git a/models/policy_test.go b/models/policy_test.go
deleted file mode 100644
index f7d4e74..0000000
--- a/models/policy_test.go
+++ /dev/null
@@ -1,269 +0,0 @@
-package model
-
-import (
- "encoding/json"
- "strconv"
- "testing"
- "time"
-
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestGetPolicyByID(t *testing.T) {
- asserts := assert.New(t)
-
- cache.Deletes([]string{"22", "23"}, "policy_")
- // 缓存未命中
- {
- rows := sqlmock.NewRows([]string{"name", "type", "options"}).
- AddRow("默认存储策略", "local", "{\"od_redirect\":\"123\"}")
- mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows)
- policy, err := GetPolicyByID(uint(22))
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal("默认存储策略", policy.Name)
- asserts.Equal("123", policy.OptionsSerialized.OauthRedirect)
-
- rows = sqlmock.NewRows([]string{"name", "type", "options"})
- mock.ExpectQuery("^SELECT(.+)").WillReturnRows(rows)
- policy, err = GetPolicyByID(uint(23))
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
-
- // 命中
- {
- policy, err := GetPolicyByID(uint(22))
- asserts.NoError(err)
- asserts.Equal("默认存储策略", policy.Name)
- asserts.Equal("123", policy.OptionsSerialized.OauthRedirect)
-
- }
-
-}
-
-func TestPolicy_BeforeSave(t *testing.T) {
- asserts := assert.New(t)
-
- testPolicy := Policy{
- OptionsSerialized: PolicyOption{
- OauthRedirect: "123",
- },
- }
- expected, _ := json.Marshal(testPolicy.OptionsSerialized)
- err := testPolicy.BeforeSave()
- asserts.NoError(err)
- asserts.Equal(string(expected), testPolicy.Options)
-
-}
-
-func TestPolicy_GeneratePath(t *testing.T) {
- asserts := assert.New(t)
- testPolicy := Policy{}
-
- testPolicy.DirNameRule = "{randomkey16}"
- asserts.Len(testPolicy.GeneratePath(1, "/"), 16)
-
- testPolicy.DirNameRule = "{randomkey8}"
- asserts.Len(testPolicy.GeneratePath(1, "/"), 8)
-
- testPolicy.DirNameRule = "{timestamp}"
- asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.FormatInt(time.Now().Unix(), 10))
-
- testPolicy.DirNameRule = "{uid}"
- asserts.Equal(testPolicy.GeneratePath(1, "/"), strconv.Itoa(int(1)))
-
- testPolicy.DirNameRule = "{datetime}"
- asserts.Len(testPolicy.GeneratePath(1, "/"), 14)
-
- testPolicy.DirNameRule = "{date}"
- asserts.Len(testPolicy.GeneratePath(1, "/"), 8)
-
- testPolicy.DirNameRule = "123{date}ss{datetime}"
- asserts.Len(testPolicy.GeneratePath(1, "/"), 27)
-
- testPolicy.DirNameRule = "/1/{path}/456"
- asserts.Condition(func() (success bool) {
- res := testPolicy.GeneratePath(1, "/23")
- return res == "/1/23/456" || res == "\\1\\23\\456"
- })
-
-}
-
-func TestPolicy_GenerateFileName(t *testing.T) {
- asserts := assert.New(t)
- // 重命名关闭
- {
- testPolicy := Policy{
- AutoRename: false,
- }
- testPolicy.FileNameRule = "{randomkey16}"
- asserts.Equal("123.txt", testPolicy.GenerateFileName(1, "123.txt"))
-
- testPolicy.Type = "oss"
- asserts.Equal("origin", testPolicy.GenerateFileName(1, "origin"))
- }
-
- // 重命名开启
- {
- testPolicy := Policy{
- AutoRename: true,
- }
-
- testPolicy.FileNameRule = "{randomkey16}"
- asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 16)
-
- testPolicy.FileNameRule = "{randomkey8}"
- asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8)
-
- testPolicy.FileNameRule = "{timestamp}"
- asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.FormatInt(time.Now().Unix(), 10))
-
- testPolicy.FileNameRule = "{uid}"
- asserts.Equal(testPolicy.GenerateFileName(1, "123.txt"), strconv.Itoa(int(1)))
-
- testPolicy.FileNameRule = "{datetime}"
- asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 14)
-
- testPolicy.FileNameRule = "{date}"
- asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 8)
-
- testPolicy.FileNameRule = "123{date}ss{datetime}"
- asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 27)
-
- testPolicy.FileNameRule = "{originname_without_ext}"
- asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 3)
-
- testPolicy.FileNameRule = "{originname_without_ext}_{randomkey8}{ext}"
- asserts.Len(testPolicy.GenerateFileName(1, "123.txt"), 16)
-
- // 支持{originname}的策略
- testPolicy.Type = "local"
- testPolicy.FileNameRule = "123{originname}"
- asserts.Equal("123123.txt", testPolicy.GenerateFileName(1, "123.txt"))
-
- testPolicy.Type = "qiniu"
- testPolicy.FileNameRule = "{uid}123{originname}"
- asserts.Equal("1123123.txt", testPolicy.GenerateFileName(1, "123.txt"))
-
- testPolicy.Type = "oss"
- testPolicy.FileNameRule = "{uid}123{originname}"
- asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321"))
-
- testPolicy.Type = "upyun"
- testPolicy.FileNameRule = "{uid}123{originname}"
- asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321"))
-
- testPolicy.Type = "qiniu"
- testPolicy.FileNameRule = "{uid}123{originname}"
- asserts.Equal("1123123321", testPolicy.GenerateFileName(1, "123321"))
-
- testPolicy.Type = "local"
- testPolicy.FileNameRule = "{uid}123{originname}"
- asserts.Equal("1123", testPolicy.GenerateFileName(1, ""))
-
- testPolicy.Type = "local"
- testPolicy.FileNameRule = "{ext}123{uuid}"
- asserts.Contains(testPolicy.GenerateFileName(1, "123.txt"), ".txt123")
- }
-
-}
-
-func TestPolicy_IsDirectlyPreview(t *testing.T) {
- asserts := assert.New(t)
- policy := Policy{Type: "local"}
- asserts.True(policy.IsDirectlyPreview())
- policy.Type = "remote"
- asserts.False(policy.IsDirectlyPreview())
-}
-
-func TestPolicy_ClearCache(t *testing.T) {
- asserts := assert.New(t)
- cache.Set("policy_202", 1, 0)
- policy := Policy{Model: gorm.Model{ID: 202}}
- policy.ClearCache()
- _, ok := cache.Get("policy_202")
- asserts.False(ok)
-}
-
-func TestPolicy_UpdateAccessKey(t *testing.T) {
- asserts := assert.New(t)
- policy := Policy{Model: gorm.Model{ID: 202}}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- policy.AccessKey = "123"
- err := policy.SaveAndClearCache()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
-}
-
-func TestPolicy_Props(t *testing.T) {
- asserts := assert.New(t)
- policy := Policy{Type: "onedrive"}
- policy.OptionsSerialized.PlaceholderWithSize = true
- asserts.False(policy.IsThumbGenerateNeeded())
- asserts.False(policy.IsTransitUpload(4))
- asserts.False(policy.IsTransitUpload(5 * 1024 * 1024))
- asserts.True(policy.CanStructureBeListed())
- asserts.True(policy.IsUploadPlaceholderWithSize())
- policy.Type = "local"
- asserts.True(policy.IsThumbGenerateNeeded())
- asserts.False(policy.CanStructureBeListed())
- asserts.False(policy.IsUploadPlaceholderWithSize())
- policy.Type = "remote"
- asserts.True(policy.IsUploadPlaceholderWithSize())
-}
-
-func TestPolicy_UpdateAccessKeyAndClearCache(t *testing.T) {
- a := assert.New(t)
- cache.Set("policy_1331", Policy{}, 3600)
- p := &Policy{}
- p.ID = 1331
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WithArgs("ak", sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.NoError(p.UpdateAccessKeyAndClearCache("ak"))
- a.NoError(mock.ExpectationsWereMet())
- _, ok := cache.Get("policy_1331")
- a.False(ok)
-}
-
-func TestPolicy_CouldProxyThumb(t *testing.T) {
- a := assert.New(t)
- p := &Policy{Type: "local"}
-
- // local policy
- {
- a.False(p.CouldProxyThumb())
- }
-
- // feature not enabled
- {
- p.Type = "remote"
- cache.Set("setting_thumb_proxy_enabled", "0", 0)
- a.False(p.CouldProxyThumb())
- }
-
- // list not contain current policy
- {
- p.ID = 2
- cache.Set("setting_thumb_proxy_enabled", "1", 0)
- cache.Set("setting_thumb_proxy_policy", "[1]", 0)
- a.False(p.CouldProxyThumb())
- }
-
- // enabled
- {
- p.ID = 2
- cache.Set("setting_thumb_proxy_enabled", "1", 0)
- cache.Set("setting_thumb_proxy_policy", "[2]", 0)
- a.True(p.CouldProxyThumb())
- }
-
- cache.Deletes([]string{"thumb_proxy_enabled", "thumb_proxy_policy"}, "setting_")
-}
diff --git a/models/redeem.go b/models/redeem.go
new file mode 100755
index 0000000..4f8396d
--- /dev/null
+++ b/models/redeem.go
@@ -0,0 +1,27 @@
+package model
+
+import "github.com/jinzhu/gorm"
+
+// Redeem 兑换码
+type Redeem struct {
+ gorm.Model
+ Type int // 订单类型
+ ProductID int64 // 商品ID
+ Num int // 商品数量
+ Code string `gorm:"size:64,index:redeem_code"` // 兑换码
+ Used bool // 是否已被使用
+}
+
+// GetAvailableRedeem 根据code查找可用兑换码
+func GetAvailableRedeem(code string) (*Redeem, error) {
+ redeem := &Redeem{}
+ result := DB.Where("code = ? and used = ?", code, false).First(redeem)
+ return redeem, result.Error
+}
+
+// Use 设定为已使用状态
+func (redeem *Redeem) Use() {
+ DB.Model(redeem).Updates(map[string]interface{}{
+ "used": true,
+ })
+}
diff --git a/models/report.go b/models/report.go
new file mode 100755
index 0000000..face732
--- /dev/null
+++ b/models/report.go
@@ -0,0 +1,21 @@
+package model
+
+import (
+ "github.com/jinzhu/gorm"
+)
+
+// Report 举报模型
+type Report struct {
+ gorm.Model
+ ShareID uint `gorm:"index:share_id"` // 对应分享ID
+ Reason int // 举报原因
+ Description string // 补充描述
+
+ // 关联模型
+ Share Share `gorm:"save_associations:false:false"`
+}
+
+// Create 创建举报
+func (report *Report) Create() error {
+ return DB.Create(report).Error
+}
diff --git a/models/scripts/init.go b/models/scripts/init.go
index 7c375bf..b772fa9 100644
--- a/models/scripts/init.go
+++ b/models/scripts/init.go
@@ -5,5 +5,6 @@ import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
func Init() {
invoker.Register("ResetAdminPassword", ResetAdminPassword(0))
invoker.Register("CalibrateUserStorage", UserStorageCalibration(0))
+ invoker.Register("OSSToPlus", UpgradeToPro(0))
invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0))
}
diff --git a/models/scripts/invoker/invoker_test.go b/models/scripts/invoker/invoker_test.go
deleted file mode 100644
index 36651eb..0000000
--- a/models/scripts/invoker/invoker_test.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package invoker
-
-import (
- "context"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-type TestScript int
-
-func (script TestScript) Run(ctx context.Context) {
-
-}
-
-func TestRunDBScript(t *testing.T) {
- asserts := assert.New(t)
- Register("test", TestScript(0))
-
- // 不存在
- {
- asserts.Error(RunDBScript("else", context.Background()))
- }
-
- // 存在
- {
- asserts.NoError(RunDBScript("test", context.Background()))
- }
-}
-
-func TestListPrefix(t *testing.T) {
- asserts := assert.New(t)
- Register("U1", TestScript(0))
- Register("U2", TestScript(0))
- Register("U3", TestScript(0))
- Register("P1", TestScript(0))
-
- res := ListPrefix("U")
- asserts.Len(res, 3)
-}
diff --git a/models/scripts/reset_test.go b/models/scripts/reset_test.go
deleted file mode 100644
index ffacb28..0000000
--- a/models/scripts/reset_test.go
+++ /dev/null
@@ -1,50 +0,0 @@
-package scripts
-
-import (
- "context"
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestResetAdminPassword_Run(t *testing.T) {
- asserts := assert.New(t)
- script := ResetAdminPassword(0)
-
- // 初始用户不存在
- {
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}))
- asserts.Panics(func() {
- script.Run(context.Background())
- })
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 密码更新失败
- {
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- asserts.Panics(func() {
- script.Run(context.Background())
- })
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- asserts.NotPanics(func() {
- script.Run(context.Background())
- })
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
diff --git a/models/scripts/storage_test.go b/models/scripts/storage_test.go
deleted file mode 100644
index 746f0c0..0000000
--- a/models/scripts/storage_test.go
+++ /dev/null
@@ -1,61 +0,0 @@
-package scripts
-
-import (
- "context"
- "database/sql"
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-var mock sqlmock.Sqlmock
-var mockDB *gorm.DB
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- mockDB = model.DB
- defer db.Close()
- m.Run()
-}
-
-func TestUserStorageCalibration_Run(t *testing.T) {
- asserts := assert.New(t)
- script := UserStorageCalibration(0)
-
- // 容量异常
- {
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(11))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- script.Run(context.Background())
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 容量正常
- {
- mock.ExpectQuery("SELECT(.+)users(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10))
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- script.Run(context.Background())
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
diff --git a/models/scripts/upgrade-pro.go b/models/scripts/upgrade-pro.go
new file mode 100755
index 0000000..b25df4f
--- /dev/null
+++ b/models/scripts/upgrade-pro.go
@@ -0,0 +1,22 @@
+package scripts
+
+import (
+ "context"
+ model "github.com/cloudreve/Cloudreve/v3/models"
+)
+
+type UpgradeToPro int
+
+// Run 运行脚本从社区版升级至 Pro 版
+func (script UpgradeToPro) Run(ctx context.Context) {
+ // folder.PolicyID 字段设为 0
+ model.DB.Model(model.Folder{}).UpdateColumn("policy_id", 0)
+ // shares.Score 字段设为0
+ model.DB.Model(model.Share{}).UpdateColumn("score", 0)
+ // user 表相关初始字段
+ model.DB.Model(model.User{}).Updates(map[string]interface{}{
+ "score": 0,
+ "previous_group_id": 0,
+ "open_id": "",
+ })
+}
diff --git a/models/scripts/upgrade_test.go b/models/scripts/upgrade_test.go
deleted file mode 100644
index 8f7adba..0000000
--- a/models/scripts/upgrade_test.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package scripts
-
-import (
- "context"
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestUpgradeTo340_Run(t *testing.T) {
- a := assert.New(t)
- script := UpgradeTo340(0)
-
- // skip
- {
- mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}))
- script.Run(context.Background())
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // node not found
- {
- mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1"))
- mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- script.Run(context.Background())
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // success
- {
- mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
- AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
- AddRow("aria2_interval", "expected_aria2_interval").
- AddRow("aria2_temp_path", "expected_aria2_temp_path").
- AddRow("aria2_token", "expected_aria2_token").
- AddRow("aria2_options", "{}"))
-
- mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- script.Run(context.Background())
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // failed
- {
- mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
- AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
- AddRow("aria2_interval", "expected_aria2_interval").
- AddRow("aria2_temp_path", "expected_aria2_temp_path").
- AddRow("aria2_token", "expected_aria2_token").
- AddRow("aria2_options", "{}"))
-
- mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- script.Run(context.Background())
- a.NoError(mock.ExpectationsWereMet())
- }
-}
diff --git a/models/setting_test.go b/models/setting_test.go
deleted file mode 100644
index 96fc5e0..0000000
--- a/models/setting_test.go
+++ /dev/null
@@ -1,196 +0,0 @@
-package model
-
-import (
- "database/sql"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-var mock sqlmock.Sqlmock
-var mockDB *gorm.DB
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- DB, _ = gorm.Open("mysql", db)
- mockDB = DB
- defer db.Close()
- m.Run()
-}
-
-func TestGetSettingByType(t *testing.T) {
- cache.Store = cache.NewMemoStore()
- asserts := assert.New(t)
-
- //找到设置时
- rows := sqlmock.NewRows([]string{"name", "value", "type"}).
- AddRow("siteName", "Cloudreve", "basic").
- AddRow("siteDes", "Something wonderful", "basic")
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
- settings := GetSettingByType([]string{"basic"})
- asserts.Equal(map[string]string{
- "siteName": "Cloudreve",
- "siteDes": "Something wonderful",
- }, settings)
-
- rows = sqlmock.NewRows([]string{"name", "value", "type"}).
- AddRow("siteName", "Cloudreve", "basic").
- AddRow("siteDes", "Something wonderful", "basic2")
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
- settings = GetSettingByType([]string{"basic", "basic2"})
- asserts.Equal(map[string]string{
- "siteName": "Cloudreve",
- "siteDes": "Something wonderful",
- }, settings)
-
- //找不到
- rows = sqlmock.NewRows([]string{"name", "value", "type"})
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
- settings = GetSettingByType([]string{"basic233"})
- asserts.Equal(map[string]string{}, settings)
-}
-
-func TestGetSettingByNameWithDefault(t *testing.T) {
- a := assert.New(t)
-
- rows := sqlmock.NewRows([]string{"name", "value", "type"})
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
- settings := GetSettingByNameWithDefault("123", "123321")
- a.Equal("123321", settings)
-}
-
-func TestGetSettingByNames(t *testing.T) {
- cache.Store = cache.NewMemoStore()
- asserts := assert.New(t)
-
- //找到设置时
- rows := sqlmock.NewRows([]string{"name", "value", "type"}).
- AddRow("siteName", "Cloudreve", "basic").
- AddRow("siteDes", "Something wonderful", "basic")
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
- settings := GetSettingByNames("siteName", "siteDes")
- asserts.Equal(map[string]string{
- "siteName": "Cloudreve",
- "siteDes": "Something wonderful",
- }, settings)
- asserts.NoError(mock.ExpectationsWereMet())
-
- //找到其中一个设置时
- rows = sqlmock.NewRows([]string{"name", "value", "type"}).
- AddRow("siteName2", "Cloudreve", "basic")
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
- settings = GetSettingByNames("siteName2", "siteDes2333")
- asserts.Equal(map[string]string{
- "siteName2": "Cloudreve",
- }, settings)
- asserts.NoError(mock.ExpectationsWereMet())
-
- //找不到设置时
- rows = sqlmock.NewRows([]string{"name", "value", "type"})
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
- settings = GetSettingByNames("siteName2333", "siteDes2333")
- asserts.Equal(map[string]string{}, settings)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 一个设置命中缓存
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WithArgs("siteDes2").WillReturnRows(sqlmock.NewRows([]string{"name", "value", "type"}).
- AddRow("siteDes2", "Cloudreve2", "basic"))
- settings = GetSettingByNames("siteName", "siteDes2")
- asserts.Equal(map[string]string{
- "siteName": "Cloudreve",
- "siteDes2": "Cloudreve2",
- }, settings)
- asserts.NoError(mock.ExpectationsWereMet())
-
-}
-
-// TestGetSettingByName 测试GetSettingByName
-func TestGetSettingByName(t *testing.T) {
- cache.Store = cache.NewMemoStore()
- asserts := assert.New(t)
-
- //找到设置时
- rows := sqlmock.NewRows([]string{"name", "value", "type"}).
- AddRow("siteName", "Cloudreve", "basic")
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
-
- siteName := GetSettingByName("siteName")
- asserts.Equal("Cloudreve", siteName)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 第二次查询应返回缓存内容
- siteNameCache := GetSettingByName("siteName")
- asserts.Equal("Cloudreve", siteNameCache)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 找不到设置
- rows = sqlmock.NewRows([]string{"name", "value", "type"})
- mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND(.+)$").WillReturnRows(rows)
-
- siteName = GetSettingByName("siteName not exist")
- asserts.Equal("", siteName)
- asserts.NoError(mock.ExpectationsWereMet())
-
-}
-
-func TestIsTrueVal(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.True(IsTrueVal("1"))
- asserts.True(IsTrueVal("true"))
- asserts.False(IsTrueVal("0"))
- asserts.False(IsTrueVal("false"))
-}
-
-func TestGetSiteURL(t *testing.T) {
- asserts := assert.New(t)
-
- // 正常
- {
- err := cache.Deletes([]string{"siteURL"}, "setting_")
- asserts.NoError(err)
-
- mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "https://drive.cloudreve.org"))
- siteURL := GetSiteURL()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal("https://drive.cloudreve.org", siteURL.String())
- }
-
- // 失败 返回默认值
- {
- err := cache.Deletes([]string{"siteURL"}, "setting_")
- asserts.NoError(err)
-
- mock.ExpectQuery("SELECT(.+)").WithArgs("siteURL").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, ":][\\/\\]sdf"))
- siteURL := GetSiteURL()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal("https://cloudreve.org", siteURL.String())
- }
-}
-
-func TestGetIntSetting(t *testing.T) {
- asserts := assert.New(t)
-
- // 正常
- {
- cache.Set("setting_TestGetIntSetting", "10", 0)
- res := GetIntSetting("TestGetIntSetting", 20)
- asserts.Equal(10, res)
- }
-
- // 使用默认值
- {
- res := GetIntSetting("TestGetIntSetting_2", 20)
- asserts.Equal(20, res)
- }
-
-}
diff --git a/models/share.go b/models/share.go
index 750eb48..94655bb 100644
--- a/models/share.go
+++ b/models/share.go
@@ -3,6 +3,7 @@ package model
import (
"errors"
"fmt"
+ "math"
"strings"
"time"
@@ -13,6 +14,10 @@ import (
"github.com/jinzhu/gorm"
)
+var (
+ ErrInsufficientCredit = errors.New("积分不足")
+)
+
// Share 分享模型
type Share struct {
gorm.Model
@@ -24,6 +29,7 @@ type Share struct {
Downloads int // 下载数
RemainDownloads int // 剩余下载配额,负值标识无限制
Expires *time.Time // 过期时间,空值表示无过期时间
+ Score int // 每人次下载扣除积分
PreviewEnabled bool // 是否允许直接预览
SourceName string `gorm:"index:source"` // 用于搜索的字段
@@ -135,6 +141,12 @@ func (share *Share) CanBeDownloadBy(user *User) error {
}
return errors.New("your group has no permission to download")
}
+
+ // 需要积分但未登录
+ if share.Score > 0 && user.IsAnonymous() {
+ return errors.New("you must login to download")
+ }
+
return nil
}
@@ -149,9 +161,12 @@ func (share *Share) WasDownloadedBy(user *User, c *gin.Context) (exist bool) {
return exist
}
-// DownloadBy 增加下载次数,匿名用户不会缓存
+// DownloadBy 增加下载次数、检查积分等,匿名用户不会缓存
func (share *Share) DownloadBy(user *User, c *gin.Context) error {
if !share.WasDownloadedBy(user, c) {
+ if err := share.Purchase(user); err != nil {
+ return err
+ }
share.Downloaded()
if !user.IsAnonymous() {
cache.Set(fmt.Sprintf("share_%d_%d", share.ID, user.ID), true,
@@ -163,6 +178,25 @@ func (share *Share) DownloadBy(user *User, c *gin.Context) error {
return nil
}
+// Purchase 使用积分购买分享
+func (share *Share) Purchase(user *User) error {
+ // 不需要付积分
+ if share.Score == 0 || user.Group.OptionsSerialized.ShareFree || user.ID == share.UserID {
+ return nil
+ }
+
+ ok := user.PayScore(share.Score)
+ if !ok {
+ return ErrInsufficientCredit
+ }
+
+ scoreRate := GetIntSetting("share_score_rate", 100)
+ gainedScore := int(math.Ceil(float64(share.Score*scoreRate) / 100))
+ share.Creator().AddScore(gainedScore)
+
+ return nil
+}
+
// Viewed 增加访问次数
func (share *Share) Viewed() {
share.Views++
diff --git a/models/share_test.go b/models/share_test.go
deleted file mode 100644
index b3fdf0a..0000000
--- a/models/share_test.go
+++ /dev/null
@@ -1,321 +0,0 @@
-package model
-
-import (
- "errors"
- "net/http/httptest"
- "testing"
- "time"
-
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "github.com/gin-gonic/gin"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestShare_Create(t *testing.T) {
- asserts := assert.New(t)
- share := Share{UserID: 1}
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectCommit()
- id, err := share.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.EqualValues(2, id)
- }
-
- // 失败
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- id, err := share.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.EqualValues(0, id)
- }
-}
-
-func TestGetShareByHashID(t *testing.T) {
- asserts := assert.New(t)
- conf.SystemConfig.HashIDSalt = ""
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res := GetShareByHashID("x9T4")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotNil(res)
- }
-
- // 查询失败
- {
- mock.ExpectQuery("SELECT(.+)").
- WillReturnError(errors.New("error"))
- res := GetShareByHashID("x9T4")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(res)
- }
-
- // ID解码失败
- {
- res := GetShareByHashID("empty")
- asserts.Nil(res)
- }
-
-}
-
-func TestShare_IsAvailable(t *testing.T) {
- asserts := assert.New(t)
-
- // 下载剩余次数为0
- {
- share := Share{}
- asserts.False(share.IsAvailable())
- }
-
- // 时效过期
- {
- expires := time.Unix(10, 10)
- share := Share{
- RemainDownloads: -1,
- Expires: &expires,
- }
- asserts.False(share.IsAvailable())
- }
-
- // 源对象为目录,但不存在
- {
- share := Share{
- RemainDownloads: -1,
- SourceID: 2,
- IsDir: true,
- }
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id"}))
- asserts.False(share.IsAvailable())
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 源对象为目录,存在
- {
- share := Share{
- RemainDownloads: -1,
- SourceID: 2,
- IsDir: false,
- }
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(13))
- asserts.True(share.IsAvailable())
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 用户被封禁
- {
- share := Share{
- RemainDownloads: -1,
- SourceID: 2,
- IsDir: true,
- User: User{Status: Baned},
- }
- asserts.False(share.IsAvailable())
- }
-}
-
-func TestShare_GetCreator(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- share := Share{UserID: 1}
- res := share.Creator()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.EqualValues(1, res.ID)
-}
-
-func TestShare_Source(t *testing.T) {
- asserts := assert.New(t)
-
- // 目录
- {
- share := Share{IsDir: true, SourceID: 3}
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
- asserts.EqualValues(3, share.Source().(*Folder).ID)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 文件
- {
- share := Share{IsDir: false, SourceID: 3}
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
- asserts.EqualValues(3, share.Source().(*File).ID)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestShare_CanBeDownloadBy(t *testing.T) {
- asserts := assert.New(t)
- share := Share{}
-
- // 未登录,无权
- {
- user := &User{
- Group: Group{
- OptionsSerialized: GroupOption{
- ShareDownload: false,
- },
- },
- }
- asserts.Error(share.CanBeDownloadBy(user))
- }
-
- // 已登录,无权
- {
- user := &User{
- Model: gorm.Model{ID: 1},
- Group: Group{
- OptionsSerialized: GroupOption{
- ShareDownload: false,
- },
- },
- }
- asserts.Error(share.CanBeDownloadBy(user))
- }
-
- // 成功
- {
- user := &User{
- Model: gorm.Model{ID: 1},
- Group: Group{
- OptionsSerialized: GroupOption{
- ShareDownload: true,
- },
- },
- }
- asserts.NoError(share.CanBeDownloadBy(user))
- }
-}
-
-func TestShare_WasDownloadedBy(t *testing.T) {
- asserts := assert.New(t)
- share := Share{
- Model: gorm.Model{ID: 1},
- }
-
- // 已登录,已下载
- {
- user := User{
- Model: gorm.Model{
- ID: 1,
- },
- }
- r := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(r)
- cache.Set("share_1_1", true, 0)
- asserts.True(share.WasDownloadedBy(&user, c))
- }
-}
-
-func TestShare_DownloadBy(t *testing.T) {
- asserts := assert.New(t)
- share := Share{
- Model: gorm.Model{ID: 1},
- }
- user := User{
- Model: gorm.Model{
- ID: 1,
- },
- }
- cache.Deletes([]string{"1_1"}, "share_")
- r := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(r)
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- err := share.DownloadBy(&user, c)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- _, ok := cache.Get("share_1_1")
- asserts.True(ok)
-}
-
-func TestShare_Viewed(t *testing.T) {
- asserts := assert.New(t)
- share := Share{}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- share.Viewed()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.EqualValues(1, share.Views)
-}
-
-func TestShare_UpdateAndDelete(t *testing.T) {
- asserts := assert.New(t)
- share := Share{}
-
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := share.Update(map[string]interface{}{"id": 1})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := share.Delete()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := DeleteShareBySourceIDs([]uint{1}, true)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
-}
-
-func TestListShares(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2).AddRow(2))
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1).AddRow(2))
-
- res, total := ListShares(1, 1, 10, "desc", true)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(res, 2)
- asserts.Equal(2, total)
-}
-
-func TestSearchShares(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("", sqlmock.AnyArg(), "%1%2%").
- WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res, total := SearchShares(1, 10, "id", "1 2")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(res, 1)
- asserts.Equal(1, total)
-}
diff --git a/models/source_link_test.go b/models/source_link_test.go
deleted file mode 100644
index d84dc62..0000000
--- a/models/source_link_test.go
+++ /dev/null
@@ -1,52 +0,0 @@
-package model
-
-import (
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestSourceLink_Link(t *testing.T) {
- a := assert.New(t)
- s := &SourceLink{}
- s.ID = 1
-
- // 失败
- {
- s.File.Name = string([]byte{0x7f})
- res, err := s.Link()
- a.Error(err)
- a.Empty(res)
- }
-
- // 成功
- {
- s.File.Name = "filename"
- res, err := s.Link()
- a.NoError(err)
- a.Contains(res, s.Name)
- }
-}
-
-func TestGetSourceLinkByID(t *testing.T) {
- a := assert.New(t)
- mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
- mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
-
- res, err := GetSourceLinkByID(1)
- a.NoError(err)
- a.NotNil(res)
- a.EqualValues(2, res.File.ID)
- a.NoError(mock.ExpectationsWereMet())
-}
-
-func TestSourceLink_Downloaded(t *testing.T) {
- a := assert.New(t)
- s := &SourceLink{}
- s.ID = 1
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- s.Downloaded()
- a.NoError(mock.ExpectationsWereMet())
-}
diff --git a/models/storage_pack.go b/models/storage_pack.go
new file mode 100755
index 0000000..a18f9c1
--- /dev/null
+++ b/models/storage_pack.go
@@ -0,0 +1,91 @@
+package model
+
+import (
+ "github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
+ "github.com/jinzhu/gorm"
+ "strconv"
+ "time"
+)
+
+// StoragePack 容量包模型
+type StoragePack struct {
+ // 表字段
+ gorm.Model
+ Name string
+ UserID uint
+ ActiveTime *time.Time
+ ExpiredTime *time.Time `gorm:"index:expired"`
+ Size uint64
+}
+
+// Create 创建容量包
+func (pack *StoragePack) Create() (uint, error) {
+ if err := DB.Create(pack).Error; err != nil {
+ util.Log().Warning("Failed to insert storage pack record: %s", err)
+ return 0, err
+ }
+ return pack.ID, nil
+}
+
+// GetAvailablePackSize 返回给定用户当前可用的容量包总容量
+func (user *User) GetAvailablePackSize() uint64 {
+ var (
+ total uint64
+ firstExpire *time.Time
+ timeNow = time.Now()
+ ttl int64
+ )
+
+ // 尝试从缓存中读取
+ cacheKey := "pack_size_" + strconv.FormatUint(uint64(user.ID), 10)
+ if total, ok := cache.Get(cacheKey); ok {
+ return total.(uint64)
+ }
+
+ // 查找所有有效容量包
+ packs := user.GetAvailableStoragePacks()
+
+ // 计算总容量, 并找到其中最早的过期时间
+ for _, v := range packs {
+ total += v.Size
+ if firstExpire == nil {
+ firstExpire = v.ExpiredTime
+ continue
+ }
+ if v.ExpiredTime != nil && firstExpire.After(*v.ExpiredTime) {
+ firstExpire = v.ExpiredTime
+ }
+ }
+
+ // 用最早的过期时间计算缓存TTL,并写入缓存
+ if firstExpire != nil {
+ ttl = firstExpire.Unix() - timeNow.Unix()
+ if ttl > 0 {
+ _ = cache.Set(cacheKey, total, int(ttl))
+ }
+ }
+
+ return total
+}
+
+// GetAvailableStoragePacks 返回用户可用的容量包
+func (user *User) GetAvailableStoragePacks() []StoragePack {
+ var packs []StoragePack
+ timeNow := time.Now()
+ // 查找所有有效容量包
+ DB.Where("expired_time > ? AND user_id = ?", timeNow, user.ID).Find(&packs)
+ return packs
+}
+
+// GetExpiredStoragePack 获取已过期的容量包
+func GetExpiredStoragePack() []StoragePack {
+ var packs []StoragePack
+ DB.Where("expired_time < ?", time.Now()).Find(&packs)
+ return packs
+}
+
+// Delete 删除容量包
+func (pack *StoragePack) Delete() error {
+ return DB.Delete(&pack).Error
+}
diff --git a/models/tag_test.go b/models/tag_test.go
deleted file mode 100644
index be8d3fb..0000000
--- a/models/tag_test.go
+++ /dev/null
@@ -1,63 +0,0 @@
-package model
-
-import (
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestTag_Create(t *testing.T) {
- asserts := assert.New(t)
- tag := Tag{}
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- id, err := tag.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.EqualValues(1, id)
- }
-
- // 失败
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- id, err := tag.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.EqualValues(0, id)
- }
-}
-
-func TestDeleteTagByID(t *testing.T) {
- asserts := assert.New(t)
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := DeleteTagByID(1, 2)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
-}
-
-func TestGetTagsByUID(t *testing.T) {
- asserts := assert.New(t)
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res, err := GetTagsByUID(1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Len(res, 1)
-}
-
-func TestGetTagsByID(t *testing.T) {
- asserts := assert.New(t)
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("tag"))
- res, err := GetTagsByID(1, 1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.EqualValues("tag", res.Name)
-}
diff --git a/models/task_test.go b/models/task_test.go
deleted file mode 100644
index 1ad71c3..0000000
--- a/models/task_test.go
+++ /dev/null
@@ -1,104 +0,0 @@
-package model
-
-import (
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestTask_Create(t *testing.T) {
- asserts := assert.New(t)
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task := Task{Props: "1"}
- id, err := task.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.EqualValues(1, id)
- }
-
- // 失败
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- task := Task{Props: "1"}
- id, err := task.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.EqualValues(0, id)
- }
-}
-
-func TestTask_SetError(t *testing.T) {
- asserts := assert.New(t)
- task := Task{
- Model: gorm.Model{ID: 1},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- asserts.NoError(task.SetError("error"))
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestTask_SetStatus(t *testing.T) {
- asserts := assert.New(t)
- task := Task{
- Model: gorm.Model{ID: 1},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- asserts.NoError(task.SetStatus(1))
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestTask_SetProgress(t *testing.T) {
- asserts := assert.New(t)
- task := Task{
- Model: gorm.Model{ID: 1},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- asserts.NoError(task.SetProgress(1))
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestGetTasksByID(t *testing.T) {
- asserts := assert.New(t)
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res, err := GetTasksByID(1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.EqualValues(1, res.ID)
-}
-
-func TestListTasks(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5))
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5))
-
- res, total := ListTasks(1, 1, 10, "")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.EqualValues(5, total)
- asserts.Len(res, 1)
-}
-
-func TestGetTasksByStatus(t *testing.T) {
- a := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2).
- WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res := GetTasksByStatus(1, 2)
- a.NoError(mock.ExpectationsWereMet())
- a.Len(res, 1)
-}
diff --git a/models/user.go b/models/user.go
index ff1d6dd..4d1d1a3 100644
--- a/models/user.go
+++ b/models/user.go
@@ -7,6 +7,7 @@ import (
"encoding/hex"
"encoding/json"
"strings"
+ "time"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/jinzhu/gorm"
@@ -28,20 +29,25 @@ const (
type User struct {
// 表字段
gorm.Model
- Email string `gorm:"type:varchar(100);unique_index"`
- Nick string `gorm:"size:50"`
- Password string `json:"-"`
- Status int
- GroupID uint
- Storage uint64
- TwoFactor string
- Avatar string
- Options string `json:"-" gorm:"size:4294967295"`
- Authn string `gorm:"size:4294967295"`
+ Email string `gorm:"type:varchar(100);unique_index"`
+ Nick string `gorm:"size:50"`
+ Password string `json:"-"`
+ Status int
+ GroupID uint
+ Storage uint64
+ OpenID string
+ TwoFactor string
+ Avatar string
+ Options string `json:"-" gorm:"size:4294967295"`
+ Authn string `gorm:"size:4294967295"`
+ Score int
+ PreviousGroupID uint // 初始用户组
+ GroupExpires *time.Time // 用户组过期日期
+ NotifyDate *time.Time // 通知超出配额时的日期
+ Phone string
// 关联模型
- Group Group `gorm:"save_associations:false:false"`
- Policy Policy `gorm:"PRELOAD:false,association_autoupdate:false"`
+ Group Group `gorm:"save_associations:false:false"`
// 数据库忽略字段
OptionsSerialized UserOption `gorm:"-"`
@@ -53,8 +59,9 @@ func init() {
// UserOption 用户个性化配置字段
type UserOption struct {
- ProfileOff bool `json:"profile_off,omitempty"`
- PreferredTheme string `json:"preferred_theme,omitempty"`
+ ProfileOff bool `json:"profile_off,omitempty"`
+ PreferredPolicy uint `json:"preferred_policy,omitempty"`
+ PreferredTheme string `json:"preferred_theme,omitempty"`
}
// Root 获取用户的根目录
@@ -99,6 +106,25 @@ func (user *User) ChangeStorage(tx *gorm.DB, operator string, size uint64) error
return tx.Model(user).Update("storage", gorm.Expr("storage "+operator+" ?", size)).Error
}
+// PayScore 扣除积分,返回是否成功
+func (user *User) PayScore(score int) bool {
+ if score == 0 {
+ return true
+ }
+ if score <= user.Score {
+ user.Score -= score
+ DB.Model(user).Update("score", gorm.Expr("score - ?", score))
+ return true
+ }
+ return false
+}
+
+// AddScore 增加积分
+func (user *User) AddScore(score int) {
+ user.Score += score
+ DB.Model(user).Update("score", gorm.Expr("score + ?", score))
+}
+
// IncreaseStorageWithoutCheck 忽略可用容量,增加用户已用容量
func (user *User) IncreaseStorageWithoutCheck(size uint64) {
if size == 0 {
@@ -111,19 +137,58 @@ func (user *User) IncreaseStorageWithoutCheck(size uint64) {
// GetRemainingCapacity 获取剩余配额
func (user *User) GetRemainingCapacity() uint64 {
- total := user.Group.MaxStorage
+ total := user.Group.MaxStorage + user.GetAvailablePackSize()
if total <= user.Storage {
return 0
}
return total - user.Storage
}
-// GetPolicyID 获取用户当前的存储策略ID
-func (user *User) GetPolicyID(prefer uint) uint {
- if len(user.Group.PolicyList) > 0 {
- return user.Group.PolicyList[0]
+// GetPolicyID 获取给定目录的存储策略, 如果为 nil 则使用默认
+func (user *User) GetPolicyID(folder *Folder) *Policy {
+ if user.IsAnonymous() {
+ return &Policy{Type: "anonymous"}
}
- return 0
+
+ defaultPolicy := uint(1)
+ if len(user.Group.PolicyList) > 0 {
+ defaultPolicy = user.Group.PolicyList[0]
+ }
+
+ if folder != nil {
+ prefer := folder.PolicyID
+ if prefer == 0 && folder.InheritPolicyID > 0 {
+ prefer = folder.InheritPolicyID
+ }
+
+ if prefer > 0 && util.ContainsUint(user.Group.PolicyList, prefer) {
+ defaultPolicy = prefer
+ }
+ }
+
+ p, _ := GetPolicyByID(defaultPolicy)
+ return &p
+}
+
+// GetPolicyByPreference 在可用存储策略中优先获取 preference
+func (user *User) GetPolicyByPreference(preference uint) *Policy {
+ if user.IsAnonymous() {
+ return &Policy{Type: "anonymous"}
+ }
+
+ defaultPolicy := uint(1)
+ if len(user.Group.PolicyList) > 0 {
+ defaultPolicy = user.Group.PolicyList[0]
+ }
+
+ if preference != 0 {
+ if util.ContainsUint(user.Group.PolicyList, preference) {
+ defaultPolicy = preference
+ }
+ }
+
+ p, _ := GetPolicyByID(defaultPolicy)
+ return &p
}
// GetUserByID 用ID获取用户
@@ -183,6 +248,27 @@ func (user *User) AfterCreate(tx *gorm.DB) (err error) {
OwnerID: user.ID,
}
tx.Create(defaultFolder)
+
+ // 创建用户初始文件记录
+ initialFiles := GetSettingByNameFromTx(tx, "initial_files")
+ if initialFiles != "" {
+ initialFileIDs := make([]uint, 0)
+ if err := json.Unmarshal([]byte(initialFiles), &initialFileIDs); err != nil {
+ return err
+ }
+
+ if files, err := GetFilesByIDsFromTX(tx, initialFileIDs, 0); err == nil {
+ for _, file := range files {
+ file.ID = 0
+ file.UserID = user.ID
+ file.FolderID = defaultFolder.ID
+ user.Storage += file.Size
+ tx.Create(&file)
+ }
+ tx.Save(user)
+ }
+ }
+
return err
}
@@ -193,12 +279,10 @@ func (user *User) AfterFind() (err error) {
err = json.Unmarshal([]byte(user.Options), &user.OptionsSerialized)
}
- // 预加载存储策略
- user.Policy, _ = GetPolicyByID(user.GetPolicyID(0))
return err
}
-//SerializeOptions 将序列后的Option写入到数据库字段
+// SerializeOptions 将序列后的Option写入到数据库字段
func (user *User) SerializeOptions() (err error) {
optionsValue, err := json.Marshal(&user.OptionsSerialized)
user.Options = string(optionsValue)
@@ -261,7 +345,6 @@ func (user *User) SetPassword(password string) error {
// NewAnonymousUser 返回一个匿名用户
func NewAnonymousUser() *User {
user := User{}
- user.Policy.Type = "anonymous"
user.Group, _ = GetGroupByID(3)
return &user
}
@@ -271,6 +354,20 @@ func (user *User) IsAnonymous() bool {
return user.ID == 0
}
+// Notified 更新用户容量超额通知日期
+func (user *User) Notified() {
+ if user.NotifyDate == nil {
+ timeNow := time.Now()
+ user.NotifyDate = &timeNow
+ DB.Model(&user).Update("notify_date", user.NotifyDate)
+ }
+}
+
+// ClearNotified 清除用户通知标记
+func (user *User) ClearNotified() {
+ DB.Model(&user).Update("notify_date", nil)
+}
+
// SetStatus 设定用户状态
func (user *User) SetStatus(status int) {
DB.Model(&user).Update("status", status)
@@ -288,3 +385,45 @@ func (user *User) UpdateOptions() error {
}
return user.Update(map[string]interface{}{"options": user.Options})
}
+
+// GetGroupExpiredUsers 获取用户组过期的用户
+func GetGroupExpiredUsers() []User {
+ var users []User
+ DB.Where("group_expires < ? and previous_group_id <> 0", time.Now()).Find(&users)
+ return users
+}
+
+// GetTolerantExpiredUser 获取超过宽容期的用户
+func GetTolerantExpiredUser() []User {
+ var users []User
+ DB.Set("gorm:auto_preload", true).Where("notify_date < ?", time.Now().Add(
+ time.Duration(-GetIntSetting("ban_time", 10))*time.Second),
+ ).Find(&users)
+ return users
+}
+
+// GroupFallback 回退到初始用户组
+func (user *User) GroupFallback() {
+ if user.GroupExpires != nil && user.PreviousGroupID != 0 {
+ user.Group.ID = user.PreviousGroupID
+ DB.Model(&user).Updates(map[string]interface{}{
+ "group_expires": nil,
+ "previous_group_id": 0,
+ "group_id": user.PreviousGroupID,
+ })
+ }
+}
+
+// UpgradeGroup 升级用户组
+func (user *User) UpgradeGroup(id uint, expires *time.Time) error {
+ user.Group.ID = id
+ previousGroupID := user.GroupID
+ if user.PreviousGroupID != 0 && user.GroupID == id {
+ previousGroupID = user.PreviousGroupID
+ }
+ return DB.Model(&user).Updates(map[string]interface{}{
+ "group_expires": expires,
+ "previous_group_id": previousGroupID,
+ "group_id": id,
+ }).Error
+}
diff --git a/models/user_authn_test.go b/models/user_authn_test.go
deleted file mode 100644
index 08a8ce1..0000000
--- a/models/user_authn_test.go
+++ /dev/null
@@ -1,100 +0,0 @@
-package model
-
-import (
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/duo-labs/webauthn/webauthn"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestUser_RegisterAuthn(t *testing.T) {
- asserts := assert.New(t)
- credential := webauthn.Credential{}
- user := User{
- Model: gorm.Model{ID: 1},
- }
-
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- user.RegisterAuthn(&credential)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestUser_WebAuthnCredentials(t *testing.T) {
- asserts := assert.New(t)
- user := User{
- Model: gorm.Model{ID: 1},
- Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`,
- }
- {
- credentials := user.WebAuthnCredentials()
- asserts.Len(credentials, 1)
- }
-}
-
-func TestUser_WebAuthnDisplayName(t *testing.T) {
- asserts := assert.New(t)
- user := User{
- Model: gorm.Model{ID: 1},
- Nick: "123",
- }
- {
- nick := user.WebAuthnDisplayName()
- asserts.Equal("123", nick)
- }
-}
-
-func TestUser_WebAuthnIcon(t *testing.T) {
- asserts := assert.New(t)
- user := User{
- Model: gorm.Model{ID: 1},
- }
- {
- icon := user.WebAuthnIcon()
- asserts.NotEmpty(icon)
- }
-}
-
-func TestUser_WebAuthnID(t *testing.T) {
- asserts := assert.New(t)
- user := User{
- Model: gorm.Model{ID: 1},
- }
- {
- id := user.WebAuthnID()
- asserts.Len(id, 8)
- }
-}
-
-func TestUser_WebAuthnName(t *testing.T) {
- asserts := assert.New(t)
- user := User{
- Model: gorm.Model{ID: 1},
- Email: "abslant@foxmail.com",
- }
- {
- name := user.WebAuthnName()
- asserts.Equal("abslant@foxmail.com", name)
- }
-}
-
-func TestUser_RemoveAuthn(t *testing.T) {
- asserts := assert.New(t)
- user := User{
- Model: gorm.Model{ID: 1},
- Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`,
- }
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- user.RemoveAuthn("123")
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
diff --git a/models/user_test.go b/models/user_test.go
deleted file mode 100644
index a85ddbd..0000000
--- a/models/user_test.go
+++ /dev/null
@@ -1,438 +0,0 @@
-package model
-
-import (
- "encoding/json"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/jinzhu/gorm"
- "github.com/pkg/errors"
- "github.com/stretchr/testify/assert"
-)
-
-func TestGetUserByID(t *testing.T) {
- asserts := assert.New(t)
- cache.Deletes([]string{"1"}, "policy_")
- //找到用户时
- userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}).
- AddRow(1, nil, "admin@cloudreve.org", "{}", 1)
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows)
-
- groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}).
- AddRow(1, "管理员", "[1]")
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows)
-
- policyRows := sqlmock.NewRows([]string{"id", "name"}).
- AddRow(1, "默认存储策略")
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
-
- user, err := GetUserByID(1)
- asserts.NoError(err)
- asserts.Equal(User{
- Model: gorm.Model{
- ID: 1,
- DeletedAt: nil,
- },
- Email: "admin@cloudreve.org",
- Options: "{}",
- GroupID: 1,
- Group: Group{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "管理员",
- Policies: "[1]",
- PolicyList: []uint{1},
- },
- Policy: Policy{
- Model: gorm.Model{
- ID: 1,
- },
- OptionsSerialized: PolicyOption{
- FileType: []string{},
- },
- Name: "默认存储策略",
- },
- }, user)
-
- //未找到用户时
- mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found"))
- user, err = GetUserByID(1)
- asserts.Error(err)
- asserts.Equal(User{}, user)
-}
-
-func TestGetActiveUserByID(t *testing.T) {
- asserts := assert.New(t)
- cache.Deletes([]string{"1"}, "policy_")
- //找到用户时
- userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}).
- AddRow(1, nil, "admin@cloudreve.org", "{}", 1)
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows)
-
- groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}).
- AddRow(1, "管理员", "[1]")
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows)
-
- policyRows := sqlmock.NewRows([]string{"id", "name"}).
- AddRow(1, "默认存储策略")
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
-
- user, err := GetActiveUserByID(1)
- asserts.NoError(err)
- asserts.Equal(User{
- Model: gorm.Model{
- ID: 1,
- DeletedAt: nil,
- },
- Email: "admin@cloudreve.org",
- Options: "{}",
- GroupID: 1,
- Group: Group{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "管理员",
- Policies: "[1]",
- PolicyList: []uint{1},
- },
- Policy: Policy{
- Model: gorm.Model{
- ID: 1,
- },
- OptionsSerialized: PolicyOption{
- FileType: []string{},
- },
- Name: "默认存储策略",
- },
- }, user)
-
- //未找到用户时
- mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found"))
- user, err = GetActiveUserByID(1)
- asserts.Error(err)
- asserts.Equal(User{}, user)
-}
-
-func TestUser_SetPassword(t *testing.T) {
- asserts := assert.New(t)
- user := User{}
- err := user.SetPassword("Cause Sega does what nintendon't")
- asserts.NoError(err)
- asserts.NotEmpty(user.Password)
-}
-
-func TestUser_CheckPassword(t *testing.T) {
- asserts := assert.New(t)
- user := User{}
- err := user.SetPassword("Cause Sega does what nintendon't")
- asserts.NoError(err)
-
- //密码正确
- res, err := user.CheckPassword("Cause Sega does what nintendon't")
- asserts.NoError(err)
- asserts.True(res)
-
- //密码错误
- res, err = user.CheckPassword("Cause Sega does what Nintendon't")
- asserts.NoError(err)
- asserts.False(res)
-
- //密码字段为空
- user = User{}
- res, err = user.CheckPassword("Cause Sega does what nintendon't")
- asserts.Error(err)
- asserts.False(res)
-
- // 未知密码类型
- user = User{}
- user.Password = "1:2:3"
- res, err = user.CheckPassword("Cause Sega does what nintendon't")
- asserts.Error(err)
- asserts.False(res)
-
- // V2密码,错误
- user = User{}
- user.Password = "md5:2:3"
- res, err = user.CheckPassword("Cause Sega does what nintendon't")
- asserts.NoError(err)
- asserts.False(res)
-
- // V2密码,正确
- user = User{}
- user.Password = "md5:d8446059f8846a2c111a7f53515665fb:sdshare"
- res, err = user.CheckPassword("admin")
- asserts.NoError(err)
- asserts.True(res)
-
-}
-
-func TestNewUser(t *testing.T) {
- asserts := assert.New(t)
- newUser := NewUser()
- asserts.IsType(User{}, newUser)
- asserts.Empty(newUser.Avatar)
-}
-
-func TestUser_AfterFind(t *testing.T) {
- asserts := assert.New(t)
- cache.Deletes([]string{"0"}, "policy_")
-
- policyRows := sqlmock.NewRows([]string{"id", "name"}).
- AddRow(144, "默认存储策略")
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
-
- newUser := NewUser()
- err := newUser.AfterFind()
- err = newUser.BeforeSave()
- expected := UserOption{}
- err = json.Unmarshal([]byte(newUser.Options), &expected)
-
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(expected, newUser.OptionsSerialized)
- asserts.Equal("默认存储策略", newUser.Policy.Name)
-}
-
-func TestUser_BeforeSave(t *testing.T) {
- asserts := assert.New(t)
-
- newUser := NewUser()
- err := newUser.BeforeSave()
- expected, err := json.Marshal(newUser.OptionsSerialized)
-
- asserts.NoError(err)
- asserts.Equal(string(expected), newUser.Options)
-}
-
-func TestUser_GetPolicyID(t *testing.T) {
- asserts := assert.New(t)
-
- newUser := NewUser()
- newUser.Group.PolicyList = []uint{1}
-
- asserts.EqualValues(1, newUser.GetPolicyID(0))
-
- newUser.Group.PolicyList = nil
- asserts.EqualValues(0, newUser.GetPolicyID(0))
-
- newUser.Group.PolicyList = []uint{}
- asserts.EqualValues(0, newUser.GetPolicyID(0))
-}
-
-func TestUser_GetRemainingCapacity(t *testing.T) {
- asserts := assert.New(t)
- newUser := NewUser()
- cache.Set("pack_size_0", uint64(0), 0)
-
- newUser.Group.MaxStorage = 100
- asserts.Equal(uint64(100), newUser.GetRemainingCapacity())
-
- newUser.Group.MaxStorage = 100
- newUser.Storage = 1
- asserts.Equal(uint64(99), newUser.GetRemainingCapacity())
-
- newUser.Group.MaxStorage = 100
- newUser.Storage = 100
- asserts.Equal(uint64(0), newUser.GetRemainingCapacity())
-
- newUser.Group.MaxStorage = 100
- newUser.Storage = 200
- asserts.Equal(uint64(0), newUser.GetRemainingCapacity())
-}
-
-func TestUser_DeductionCapacity(t *testing.T) {
- asserts := assert.New(t)
-
- cache.Deletes([]string{"1"}, "policy_")
- userRows := sqlmock.NewRows([]string{"id", "deleted_at", "storage", "options", "group_id"}).
- AddRow(1, nil, 0, "{}", 1)
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows)
- groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}).
- AddRow(1, "管理员", "[1]")
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows)
-
- policyRows := sqlmock.NewRows([]string{"id", "name"}).
- AddRow(1, "默认存储策略")
- mock.ExpectQuery("^SELECT (.+)").WillReturnRows(policyRows)
-
- newUser, err := GetUserByID(1)
- newUser.Group.MaxStorage = 100
- cache.Set("pack_size_1", uint64(0), 0)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- asserts.Equal(false, newUser.IncreaseStorage(101))
- asserts.Equal(uint64(0), newUser.Storage)
-
- asserts.Equal(true, newUser.IncreaseStorage(1))
- asserts.Equal(uint64(1), newUser.Storage)
-
- asserts.Equal(true, newUser.IncreaseStorage(99))
- asserts.Equal(uint64(100), newUser.Storage)
-
- asserts.Equal(false, newUser.IncreaseStorage(1))
- asserts.Equal(uint64(100), newUser.Storage)
-
- asserts.True(newUser.IncreaseStorage(0))
-}
-
-func TestUser_DeductionStorage(t *testing.T) {
- asserts := assert.New(t)
-
- // 减少零
- {
- user := User{Storage: 1}
- asserts.True(user.DeductionStorage(0))
- asserts.Equal(uint64(1), user.Storage)
- }
- // 正常
- {
- user := User{
- Model: gorm.Model{ID: 1},
- Storage: 10,
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WithArgs(5, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- asserts.True(user.DeductionStorage(5))
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(uint64(5), user.Storage)
- }
-
- // 减少的超出可用的
- {
- user := User{
- Model: gorm.Model{ID: 1},
- Storage: 10,
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WithArgs(0, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- asserts.False(user.DeductionStorage(20))
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(uint64(0), user.Storage)
- }
-}
-
-func TestUser_IncreaseStorageWithoutCheck(t *testing.T) {
- asserts := assert.New(t)
-
- // 增加零
- {
- user := User{}
- user.IncreaseStorageWithoutCheck(0)
- asserts.Equal(uint64(0), user.Storage)
- }
-
- // 减少零
- {
- user := User{
- Model: gorm.Model{ID: 1},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WithArgs(10, sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- user.IncreaseStorageWithoutCheck(10)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(uint64(10), user.Storage)
- }
-}
-
-func TestGetActiveUserByEmail(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
- _, err := GetActiveUserByEmail("abslant@foxmail.com")
-
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestGetUserByEmail(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").WithArgs("abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"}))
- _, err := GetUserByEmail("abslant@foxmail.com")
-
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestUser_AfterCreate(t *testing.T) {
- asserts := assert.New(t)
- user := User{Model: gorm.Model{ID: 1}}
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := user.AfterCreate(DB)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestUser_Root(t *testing.T) {
- asserts := assert.New(t)
- user := User{Model: gorm.Model{ID: 1}}
-
- // 根目录存在
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "根目录"))
- root, err := user.Root()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal("根目录", root.Name)
- }
-
- // 根目录不存在
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- _, err := user.Root()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
-}
-
-func TestNewAnonymousUser(t *testing.T) {
- asserts := assert.New(t)
-
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(3))
- user := NewAnonymousUser()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotNil(user)
- asserts.EqualValues(3, user.Group.ID)
-}
-
-func TestUser_IsAnonymous(t *testing.T) {
- asserts := assert.New(t)
- user := User{}
- asserts.True(user.IsAnonymous())
- user.ID = 1
- asserts.False(user.IsAnonymous())
-}
-
-func TestUser_SetStatus(t *testing.T) {
- asserts := assert.New(t)
- user := User{}
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- user.SetStatus(Baned)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal(Baned, user.Status)
-}
-
-func TestUser_UpdateOptions(t *testing.T) {
- asserts := assert.New(t)
- user := User{}
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- asserts.NoError(user.UpdateOptions())
- asserts.NoError(mock.ExpectationsWereMet())
-}
diff --git a/models/webdav.go b/models/webdav.go
index 0799aee..ee424aa 100644
--- a/models/webdav.go
+++ b/models/webdav.go
@@ -46,3 +46,8 @@ func DeleteWebDAVAccountByID(id, uid uint) {
func UpdateWebDAVAccountByID(id, uid uint, updates map[string]interface{}) {
DB.Model(&Webdav{Model: gorm.Model{ID: id}, UserID: uid}).Updates(updates)
}
+
+// UpdateWebDAVAccountReadonlyByID 根据账户ID和UID更新账户的只读性
+func UpdateWebDAVAccountReadonlyByID(id, uid uint, readonly bool) {
+ DB.Model(&Webdav{Model: gorm.Model{ID: id}, UserID: uid}).UpdateColumn("readonly", readonly)
+}
diff --git a/models/webdav_test.go b/models/webdav_test.go
deleted file mode 100644
index 55a7326..0000000
--- a/models/webdav_test.go
+++ /dev/null
@@ -1,60 +0,0 @@
-package model
-
-import (
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestWebdav_Create(t *testing.T) {
- asserts := assert.New(t)
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task := Webdav{}
- id, err := task.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.EqualValues(1, id)
- }
-
- // 失败
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- task := Webdav{}
- id, err := task.Create()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.EqualValues(0, id)
- }
-}
-
-func TestGetWebdavByPassword(t *testing.T) {
- asserts := assert.New(t)
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- _, err := GetWebdavByPassword("e", 1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
-}
-
-func TestListWebDAVAccounts(t *testing.T) {
- asserts := assert.New(t)
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- res := ListWebDAVAccounts(1)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(res, 0)
-}
-
-func TestDeleteWebDAVAccountByID(t *testing.T) {
- asserts := assert.New(t)
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- DeleteWebDAVAccountByID(1, 1)
- asserts.NoError(mock.ExpectationsWereMet())
-}
diff --git a/pkg/aria2/aria2_test.go b/pkg/aria2/aria2_test.go
deleted file mode 100644
index b6e7092..0000000
--- a/pkg/aria2/aria2_test.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package aria2
-
-import (
- "database/sql"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks"
- "github.com/cloudreve/Cloudreve/v3/pkg/mq"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
-)
-
-var mock sqlmock.Sqlmock
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- defer db.Close()
- m.Run()
-}
-
-func TestInit(t *testing.T) {
- a := assert.New(t)
- mockPool := &mocks.NodePoolMock{}
- mockPool.On("GetNodeByID", testMock.Anything).Return(nil)
- mockQueue := mq.NewMQ()
-
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- Init(false, mockPool, mockQueue)
- a.NoError(mock.ExpectationsWereMet())
- mockPool.AssertExpectations(t)
-}
-
-func TestTestRPCConnection(t *testing.T) {
- a := assert.New(t)
-
- // url not legal
- {
- res, err := TestRPCConnection(string([]byte{0x7f}), "", 10)
- a.Error(err)
- a.Empty(res.Version)
- }
-
- // rpc failed
- {
- res, err := TestRPCConnection("ws://0.0.0.0", "", 0)
- a.Error(err)
- a.Empty(res.Version)
- }
-}
-
-func TestGetLoadBalancer(t *testing.T) {
- a := assert.New(t)
- a.NotPanics(func() {
- GetLoadBalancer()
- })
-}
diff --git a/pkg/aria2/common/common_test.go b/pkg/aria2/common/common_test.go
deleted file mode 100644
index 7b0f237..0000000
--- a/pkg/aria2/common/common_test.go
+++ /dev/null
@@ -1,54 +0,0 @@
-package common
-
-import (
- "testing"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
- "github.com/stretchr/testify/assert"
-)
-
-func TestDummyAria2(t *testing.T) {
- a := assert.New(t)
- d := &DummyAria2{}
-
- a.NoError(d.Init())
-
- res, err := d.CreateTask(&model.Download{}, map[string]interface{}{})
- a.Empty(res)
- a.Error(err)
-
- _, err = d.Status(&model.Download{})
- a.Error(err)
-
- err = d.Cancel(&model.Download{})
- a.Error(err)
-
- err = d.Select(&model.Download{}, []int{})
- a.Error(err)
-
- configRes := d.GetConfig()
- a.NotNil(configRes)
-
- err = d.DeleteTempFile(&model.Download{})
- a.Error(err)
-}
-
-func TestGetStatus(t *testing.T) {
- a := assert.New(t)
-
- a.Equal(GetStatus(rpc.StatusInfo{Status: "complete"}), Complete)
- a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
- BitTorrent: rpc.BitTorrentInfo{Mode: ""}}), Downloading)
- a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
- BitTorrent: rpc.BitTorrentInfo{Mode: "single"},
- TotalLength: "100", CompletedLength: "50"}), Downloading)
- a.Equal(GetStatus(rpc.StatusInfo{Status: "active",
- BitTorrent: rpc.BitTorrentInfo{Mode: "multi"},
- TotalLength: "100", CompletedLength: "100"}), Seeding)
- a.Equal(GetStatus(rpc.StatusInfo{Status: "waiting"}), Ready)
- a.Equal(GetStatus(rpc.StatusInfo{Status: "paused"}), Paused)
- a.Equal(GetStatus(rpc.StatusInfo{Status: "error"}), Error)
- a.Equal(GetStatus(rpc.StatusInfo{Status: "removed"}), Canceled)
- a.Equal(GetStatus(rpc.StatusInfo{Status: "unknown"}), Unknown)
-}
diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go
index 69d14ff..ff3f380 100644
--- a/pkg/aria2/monitor/monitor.go
+++ b/pkg/aria2/monitor/monitor.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
+ "fmt"
"path/filepath"
"strconv"
"time"
@@ -52,6 +53,7 @@ func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
// Loop 开启监控循环
func (monitor *Monitor) Loop(mqClient mq.MQ) {
defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier)
+ fmt.Println(cluster.Default)
// 首次循环立即更新
interval := 50 * time.Millisecond
@@ -190,6 +192,10 @@ func (monitor *Monitor) ValidateFile() error {
}
defer fs.Recycle()
+ if err := fs.SetPolicyFromPath(monitor.Task.Dst); err != nil {
+ return fmt.Errorf("failed to switch policy to target dir: %w", err)
+ }
+
// 创建上下文环境
file := &fsctx.FileStream{
Size: monitor.Task.TotalSize,
diff --git a/pkg/aria2/monitor/monitor_test.go b/pkg/aria2/monitor/monitor_test.go
deleted file mode 100644
index a6be586..0000000
--- a/pkg/aria2/monitor/monitor_test.go
+++ /dev/null
@@ -1,447 +0,0 @@
-package monitor
-
-import (
- "database/sql"
- "errors"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
- "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks"
- "github.com/cloudreve/Cloudreve/v3/pkg/mq"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
-)
-
-var mock sqlmock.Sqlmock
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- defer db.Close()
- m.Run()
-}
-
-func TestNewMonitor(t *testing.T) {
- a := assert.New(t)
- mockMQ := mq.NewMQ()
-
- // node not available
- {
- mockPool := &mocks.NodePoolMock{}
- mockPool.On("GetNodeByID", uint(1)).Return(nil)
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- task := &model.Download{
- Model: gorm.Model{ID: 1},
- }
- NewMonitor(task, mockPool, mockMQ)
- mockPool.AssertExpectations(t)
- a.NoError(mock.ExpectationsWereMet())
- a.NotEmpty(task.Error)
- }
-
- // success
- {
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
- mockPool := &mocks.NodePoolMock{}
- mockPool.On("GetNodeByID", uint(1)).Return(mockNode)
-
- task := &model.Download{
- Model: gorm.Model{ID: 1},
- }
- NewMonitor(task, mockPool, mockMQ)
- mockNode.AssertExpectations(t)
- mockPool.AssertExpectations(t)
- }
-
-}
-
-func TestMonitor_Loop(t *testing.T) {
- a := assert.New(t)
- mockMQ := mq.NewMQ()
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
- m := &Monitor{
- retried: MAX_RETRY,
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- notifier: mockMQ.Subscribe("test", 1),
- }
-
- // into interval loop
- {
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- m.Loop(mockMQ)
- a.NoError(mock.ExpectationsWereMet())
- a.NotEmpty(m.Task.Error)
- }
-
- // into notifier loop
- {
- m.Task.Error = ""
- mockMQ.Publish("test", mq.Message{})
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- m.Loop(mockMQ)
- a.NoError(mock.ExpectationsWereMet())
- a.NotEmpty(m.Task.Error)
- }
-}
-
-func TestMonitor_UpdateFailedAfterRetry(t *testing.T) {
- a := assert.New(t)
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- for i := 0; i < MAX_RETRY; i++ {
- a.False(m.Update())
- }
-
- mockNode.AssertExpectations(t)
- a.True(m.Update())
- a.NoError(mock.ExpectationsWereMet())
- a.NotEmpty(m.Task.Error)
-}
-
-func TestMonitor_UpdateMagentoFollow(t *testing.T) {
- a := assert.New(t)
- mockAria2 := &mocks.Aria2Mock{}
- mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
- FollowedBy: []string{"next"},
- }, nil)
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(mockAria2)
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.False(m.Update())
- a.NoError(mock.ExpectationsWereMet())
- a.Equal("next", m.Task.GID)
- mockAria2.AssertExpectations(t)
-}
-
-func TestMonitor_UpdateFailedToUpdateInfo(t *testing.T) {
- a := assert.New(t)
- mockAria2 := &mocks.Aria2Mock{}
- mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
- mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(mockAria2)
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.True(m.Update())
- a.NoError(mock.ExpectationsWereMet())
- mockAria2.AssertExpectations(t)
- mockNode.AssertExpectations(t)
- a.NotEmpty(m.Task.Error)
-}
-
-func TestMonitor_UpdateCompleted(t *testing.T) {
- a := assert.New(t)
- mockAria2 := &mocks.Aria2Mock{}
- mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
- Status: "complete",
- }, nil)
- mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(mockAria2)
- mockNode.On("ID").Return(uint(1))
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.True(m.Update())
- a.NoError(mock.ExpectationsWereMet())
- mockAria2.AssertExpectations(t)
- mockNode.AssertExpectations(t)
- a.NotEmpty(m.Task.Error)
-}
-
-func TestMonitor_UpdateError(t *testing.T) {
- a := assert.New(t)
- mockAria2 := &mocks.Aria2Mock{}
- mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
- Status: "error",
- ErrorMessage: "error",
- }, nil)
- mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(mockAria2)
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.True(m.Update())
- a.NoError(mock.ExpectationsWereMet())
- mockAria2.AssertExpectations(t)
- mockNode.AssertExpectations(t)
- a.NotEmpty(m.Task.Error)
-}
-
-func TestMonitor_UpdateActive(t *testing.T) {
- a := assert.New(t)
- mockAria2 := &mocks.Aria2Mock{}
- mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
- Status: "active",
- }, nil)
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(mockAria2)
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.False(m.Update())
- a.NoError(mock.ExpectationsWereMet())
- mockAria2.AssertExpectations(t)
- mockNode.AssertExpectations(t)
-}
-
-func TestMonitor_UpdateRemoved(t *testing.T) {
- a := assert.New(t)
- mockAria2 := &mocks.Aria2Mock{}
- mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
- Status: "removed",
- }, nil)
- mockAria2.On("DeleteTempFile", testMock.Anything).Return(nil)
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(mockAria2)
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.True(m.Update())
- a.Equal(common.Canceled, m.Task.Status)
- a.NoError(mock.ExpectationsWereMet())
- mockAria2.AssertExpectations(t)
- mockNode.AssertExpectations(t)
-}
-
-func TestMonitor_UpdateUnknown(t *testing.T) {
- a := assert.New(t)
- mockAria2 := &mocks.Aria2Mock{}
- mockAria2.On("Status", testMock.Anything).Return(rpc.StatusInfo{
- Status: "unknown",
- }, nil)
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(mockAria2)
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- a.True(m.Update())
- a.NoError(mock.ExpectationsWereMet())
- mockAria2.AssertExpectations(t)
- mockNode.AssertExpectations(t)
-}
-
-func TestMonitor_UpdateTaskInfoValidateFailed(t *testing.T) {
- a := assert.New(t)
- status := rpc.StatusInfo{
- Status: "completed",
- TotalLength: "100",
- CompletedLength: "50",
- DownloadSpeed: "20",
- }
- mockNode := &mocks.NodeMock{}
- mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{Model: gorm.Model{ID: 1}},
- }
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- err := m.UpdateTaskInfo(status)
- a.Error(err)
- a.NoError(mock.ExpectationsWereMet())
- mockNode.AssertExpectations(t)
-}
-
-func TestMonitor_ValidateFile(t *testing.T) {
- a := assert.New(t)
- m := &Monitor{
- Task: &model.Download{
- Model: gorm.Model{ID: 1},
- TotalSize: 100,
- },
- }
-
- // failed to create filesystem
- {
- m.Task.User = &model.User{
- Policy: model.Policy{
- Type: "random",
- },
- }
- a.Equal(filesystem.ErrUnknownPolicyType, m.ValidateFile())
- }
-
- // User capacity not enough
- {
- m.Task.User = &model.User{
- Group: model.Group{
- MaxStorage: 99,
- },
- Policy: model.Policy{
- Type: "local",
- },
- }
- a.Equal(filesystem.ErrInsufficientCapacity, m.ValidateFile())
- }
-
- // single file too big
- {
- m.Task.StatusInfo.Files = []rpc.FileInfo{
- {
- Length: "100",
- Selected: "true",
- },
- }
- m.Task.User = &model.User{
- Group: model.Group{
- MaxStorage: 100,
- },
- Policy: model.Policy{
- Type: "local",
- MaxSize: 99,
- },
- }
- a.Equal(filesystem.ErrFileSizeTooBig, m.ValidateFile())
- }
-
- // all pass
- {
- m.Task.StatusInfo.Files = []rpc.FileInfo{
- {
- Length: "100",
- Selected: "true",
- },
- }
- m.Task.User = &model.User{
- Group: model.Group{
- MaxStorage: 100,
- },
- Policy: model.Policy{
- Type: "local",
- MaxSize: 100,
- },
- }
- a.NoError(m.ValidateFile())
- }
-}
-
-func TestMonitor_Complete(t *testing.T) {
- a := assert.New(t)
- mockNode := &mocks.NodeMock{}
- mockNode.On("ID").Return(uint(1))
- mockPool := &mocks.TaskPoolMock{}
- mockPool.On("Submit", testMock.Anything)
- m := &Monitor{
- node: mockNode,
- Task: &model.Download{
- Model: gorm.Model{ID: 1},
- TotalSize: 100,
- UserID: 9414,
- },
- }
- m.Task.StatusInfo.Files = []rpc.FileInfo{
- {
- Length: "100",
- Selected: "true",
- },
- }
-
- mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
-
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)downloads").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- mock.ExpectQuery("SELECT(.+)tasks").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "status"}).AddRow(1, 2, 4))
- mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(9414))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)tasks").WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectCommit()
-
- a.False(m.Complete(mockPool))
- m.Task.StatusInfo.Status = "complete"
- a.True(m.Complete(mockPool))
- a.NoError(mock.ExpectationsWereMet())
- mockNode.AssertExpectations(t)
- mockPool.AssertExpectations(t)
-}
diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go
deleted file mode 100644
index 42c5603..0000000
--- a/pkg/auth/auth_test.go
+++ /dev/null
@@ -1,136 +0,0 @@
-package auth
-
-import (
- "io/ioutil"
- "net/http"
- "strings"
- "testing"
-
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/stretchr/testify/assert"
-)
-
-func TestSignURI(t *testing.T) {
- asserts := assert.New(t)
- General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
-
- // 成功
- {
- sign, err := SignURI(General, "/api/v3/something?id=1", 0)
- asserts.NoError(err)
- queries := sign.Query()
- asserts.Equal("1", queries.Get("id"))
- asserts.NotEmpty(queries.Get("sign"))
- }
-
- // URI解码失败
- {
- sign, err := SignURI(General, "://dg.;'f]gh./'", 0)
- asserts.Error(err)
- asserts.Nil(sign)
- }
-}
-
-func TestCheckURI(t *testing.T) {
- asserts := assert.New(t)
- General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
-
- // 成功
- {
- sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", 10)
- asserts.NoError(err)
- asserts.NoError(CheckURI(General, sign))
- }
-
- // 过期
- {
- sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", -1)
- asserts.NoError(err)
- asserts.Error(CheckURI(General, sign))
- }
-}
-
-func TestSignRequest(t *testing.T) {
- asserts := assert.New(t)
- General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
-
- // 非上传请求
- {
- req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/slave/upload", strings.NewReader("I am body."))
- asserts.NoError(err)
- req = SignRequest(General, req, 0)
- asserts.NotEmpty(req.Header["Authorization"])
- }
-
- // 上传请求
- {
- req, err := http.NewRequest(
- "POST",
- "http://127.0.0.1/api/v3/slave/upload",
- strings.NewReader("I am body."),
- )
- asserts.NoError(err)
- req.Header["X-Cr-Policy"] = []string{"I am Policy"}
- req = SignRequest(General, req, 10)
- asserts.NotEmpty(req.Header["Authorization"])
- }
-}
-
-func TestCheckRequest(t *testing.T) {
- asserts := assert.New(t)
- General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
-
- // 缺少请求头
- {
- req, err := http.NewRequest(
- "POST",
- "http://127.0.0.1/api/v3/upload",
- strings.NewReader("I am body."),
- )
- asserts.NoError(err)
- err = CheckRequest(General, req)
- asserts.Error(err)
- asserts.Equal(ErrAuthHeaderMissing, err)
- }
-
- // 非上传请求 验证成功
- {
- req, err := http.NewRequest(
- "POST",
- "http://127.0.0.1/api/v3/upload",
- strings.NewReader("I am body."),
- )
- asserts.NoError(err)
- req = SignRequest(General, req, 0)
- err = CheckRequest(General, req)
- asserts.NoError(err)
- }
-
- // 上传请求 验证成功
- {
- req, err := http.NewRequest(
- "POST",
- "http://127.0.0.1/api/v3/upload",
- strings.NewReader("I am body."),
- )
- asserts.NoError(err)
- req.Header["X-Cr-Policy"] = []string{"I am Policy"}
- req = SignRequest(General, req, 0)
- err = CheckRequest(General, req)
- asserts.NoError(err)
- }
-
- // 非上传请求 失败
- {
- req, err := http.NewRequest(
- "POST",
- "http://127.0.0.1/api/v3/upload",
- strings.NewReader("I am body."),
- )
- asserts.NoError(err)
- req = SignRequest(General, req, 0)
- req.Body = ioutil.NopCloser(strings.NewReader("2333"))
- err = CheckRequest(General, req)
- asserts.Error(err)
- }
-}
diff --git a/pkg/auth/hmac_test.go b/pkg/auth/hmac_test.go
deleted file mode 100644
index 706f617..0000000
--- a/pkg/auth/hmac_test.go
+++ /dev/null
@@ -1,94 +0,0 @@
-package auth
-
-import (
- "database/sql"
- "fmt"
- "testing"
- "time"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/gin-gonic/gin"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-var mock sqlmock.Sqlmock
-
-func TestMain(m *testing.M) {
- // 设置gin为测试模式
- gin.SetMode(gin.TestMode)
-
- // 初始化sqlmock
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
-
- mockDB, _ := gorm.Open("mysql", db)
- model.DB = mockDB
- defer db.Close()
-
- m.Run()
-}
-
-func TestHMACAuth_Sign(t *testing.T) {
- asserts := assert.New(t)
- auth := HMACAuth{
- SecretKey: []byte(util.RandStringRunes(256)),
- }
-
- asserts.NotEmpty(auth.Sign("content", 0))
-}
-
-func TestHMACAuth_Check(t *testing.T) {
- asserts := assert.New(t)
- auth := HMACAuth{
- SecretKey: []byte(util.RandStringRunes(256)),
- }
-
- // 正常,永不过期
- {
- sign := auth.Sign("content", 0)
- asserts.NoError(auth.Check("content", sign))
- }
-
- // 过期
- {
- sign := auth.Sign("content", 1)
- asserts.Error(auth.Check("content", sign))
- }
-
- // 签名格式错误
- {
- sign := auth.Sign("content", 1)
- asserts.Error(auth.Check("content", sign+":"))
- }
-
- // 过期日期格式错误
- {
- asserts.Error(auth.Check("content", "ErrAuthFailed:ErrAuthFailed"))
- }
-
- // 签名有误
- {
- asserts.Error(auth.Check("content", fmt.Sprintf("sign:%d", time.Now().Unix()+10)))
- }
-}
-
-func TestInit(t *testing.T) {
- asserts := assert.New(t)
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312"))
- Init()
- asserts.NoError(mock.ExpectationsWereMet())
-
- // slave模式
- conf.SystemConfig.Mode = "slave"
- asserts.Panics(func() {
- Init()
- })
-}
diff --git a/pkg/authn/auth_test.go b/pkg/authn/auth_test.go
deleted file mode 100644
index 3df60cf..0000000
--- a/pkg/authn/auth_test.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package authn
-
-import (
- "testing"
-
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/stretchr/testify/assert"
-)
-
-func TestInit(t *testing.T) {
- asserts := assert.New(t)
- cache.Set("setting_siteURL", "http://cloudreve.org", 0)
- cache.Set("setting_siteName", "Cloudreve", 0)
- res, err := NewAuthnInstance()
- asserts.NotNil(res)
- asserts.NoError(err)
-}
diff --git a/pkg/balancer/balancer_test.go b/pkg/balancer/balancer_test.go
deleted file mode 100644
index 4493bbb..0000000
--- a/pkg/balancer/balancer_test.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package balancer
-
-import (
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestNewBalancer(t *testing.T) {
- a := assert.New(t)
- a.NotNil(NewBalancer(""))
- a.IsType(&RoundRobin{}, NewBalancer("RoundRobin"))
-}
diff --git a/pkg/balancer/roundrobin_test.go b/pkg/balancer/roundrobin_test.go
deleted file mode 100644
index 9cdcc00..0000000
--- a/pkg/balancer/roundrobin_test.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package balancer
-
-import (
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestRoundRobin_NextIndex(t *testing.T) {
- a := assert.New(t)
- r := &RoundRobin{}
- total := 5
- for i := 1; i < total; i++ {
- a.Equal(i, r.NextIndex(total))
- }
- for i := 0; i < total; i++ {
- a.Equal(i, r.NextIndex(total))
- }
-}
-
-func TestRoundRobin_NextPeer(t *testing.T) {
- a := assert.New(t)
- r := &RoundRobin{}
-
- // not slice
- {
- err, _ := r.NextPeer("s")
- a.Equal(ErrInputNotSlice, err)
- }
-
- // no nodes
- {
- err, _ := r.NextPeer([]string{})
- a.Equal(ErrNoAvaliableNode, err)
- }
-
- // pass
- {
- err, res := r.NextPeer([]string{"a"})
- a.NoError(err)
- a.Equal("a", res.(string))
- }
-}
diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go
index 1a3e652..4c86b47 100644
--- a/pkg/cache/driver.go
+++ b/pkg/cache/driver.go
@@ -2,6 +2,7 @@ package cache
import (
"encoding/gob"
+
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
diff --git a/pkg/cache/driver_test.go b/pkg/cache/driver_test.go
deleted file mode 100644
index 41294e2..0000000
--- a/pkg/cache/driver_test.go
+++ /dev/null
@@ -1,69 +0,0 @@
-package cache
-
-import (
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestSet(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.NoError(Set("123", "321", -1))
-}
-
-func TestGet(t *testing.T) {
- asserts := assert.New(t)
- asserts.NoError(Set("123", "321", -1))
-
- value, ok := Get("123")
- asserts.True(ok)
- asserts.Equal("321", value)
-
- value, ok = Get("not_exist")
- asserts.False(ok)
-}
-
-func TestDeletes(t *testing.T) {
- asserts := assert.New(t)
- asserts.NoError(Set("123", "321", -1))
- err := Deletes([]string{"123"}, "")
- asserts.NoError(err)
- _, exist := Get("123")
- asserts.False(exist)
-}
-
-func TestGetSettings(t *testing.T) {
- asserts := assert.New(t)
- asserts.NoError(Set("test_1", "1", -1))
-
- values, missed := GetSettings([]string{"1", "2"}, "test_")
- asserts.Equal(map[string]string{"1": "1"}, values)
- asserts.Equal([]string{"2"}, missed)
-}
-
-func TestSetSettings(t *testing.T) {
- asserts := assert.New(t)
-
- err := SetSettings(map[string]string{"3": "3", "4": "4"}, "test_")
- asserts.NoError(err)
- value1, _ := Get("test_3")
- value2, _ := Get("test_4")
- asserts.Equal("3", value1)
- asserts.Equal("4", value2)
-}
-
-func TestInit(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.NotPanics(func() {
- Init()
- })
-}
-
-func TestInitSlaveOverwrites(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.NotPanics(func() {
- InitSlaveOverwrites()
- })
-}
diff --git a/pkg/cache/memo.go b/pkg/cache/memo.go
index af180d6..f9dcf97 100644
--- a/pkg/cache/memo.go
+++ b/pkg/cache/memo.go
@@ -133,7 +133,13 @@ func (store *MemoStore) Persist(path string) error {
return fmt.Errorf("failed to serialize cache: %s", err)
}
- err = os.WriteFile(path, res, 0644)
+ // err = os.WriteFile(path, res, 0644)
+ file, err := util.CreatNestedFile(path)
+ if err == nil {
+ _, err = file.Write(res)
+ file.Chmod(0644)
+ file.Close()
+ }
return err
}
diff --git a/pkg/cache/memo_test.go b/pkg/cache/memo_test.go
deleted file mode 100644
index be90577..0000000
--- a/pkg/cache/memo_test.go
+++ /dev/null
@@ -1,191 +0,0 @@
-package cache
-
-import (
- "github.com/stretchr/testify/assert"
- "path/filepath"
- "testing"
- "time"
-)
-
-func TestNewMemoStore(t *testing.T) {
- asserts := assert.New(t)
-
- store := NewMemoStore()
- asserts.NotNil(store)
- asserts.NotNil(store.Store)
-}
-
-func TestMemoStore_Set(t *testing.T) {
- asserts := assert.New(t)
-
- store := NewMemoStore()
- err := store.Set("KEY", "vAL", -1)
- asserts.NoError(err)
-
- val, ok := store.Store.Load("KEY")
- asserts.True(ok)
- asserts.Equal("vAL", val.(itemWithTTL).Value)
-}
-
-func TestMemoStore_Get(t *testing.T) {
- asserts := assert.New(t)
- store := NewMemoStore()
-
- // 正常情况
- {
- _ = store.Set("string", "string_val", -1)
- val, ok := store.Get("string")
- asserts.Equal("string_val", val)
- asserts.True(ok)
- }
-
- // Key不存在
- {
- val, ok := store.Get("something")
- asserts.Equal(nil, val)
- asserts.False(ok)
- }
-
- // 存储struct
- {
- type testStruct struct {
- key int
- }
- test := testStruct{key: 233}
- _ = store.Set("struct", test, -1)
- val, ok := store.Get("struct")
- asserts.True(ok)
- res, ok := val.(testStruct)
- asserts.True(ok)
- asserts.Equal(test, res)
- }
-
- // 过期
- {
- _ = store.Set("string", "string_val", 1)
- time.Sleep(time.Duration(2) * time.Second)
- val, ok := store.Get("string")
- asserts.Nil(val)
- asserts.False(ok)
- }
-
-}
-
-func TestMemoStore_Gets(t *testing.T) {
- asserts := assert.New(t)
- store := NewMemoStore()
-
- err := store.Set("1", "1,val", -1)
- err = store.Set("2", "2,val", -1)
- err = store.Set("3", "3,val", -1)
- err = store.Set("4", "4,val", -1)
- asserts.NoError(err)
-
- // 全部命中
- {
- values, miss := store.Gets([]string{"1", "2", "3", "4"}, "")
- asserts.Len(values, 4)
- asserts.Len(miss, 0)
- }
-
- // 命中一半
- {
- values, miss := store.Gets([]string{"1", "2", "9", "10"}, "")
- asserts.Len(values, 2)
- asserts.Equal([]string{"9", "10"}, miss)
- }
-}
-
-func TestMemoStore_Sets(t *testing.T) {
- asserts := assert.New(t)
- store := NewMemoStore()
-
- err := store.Sets(map[string]interface{}{
- "1": "1.val",
- "2": "2.val",
- "3": "3.val",
- "4": "4.val",
- }, "test_")
- asserts.NoError(err)
-
- vals, miss := store.Gets([]string{"1", "2", "3", "4"}, "test_")
- asserts.Len(miss, 0)
- asserts.Equal(map[string]interface{}{
- "1": "1.val",
- "2": "2.val",
- "3": "3.val",
- "4": "4.val",
- }, vals)
-}
-
-func TestMemoStore_Delete(t *testing.T) {
- asserts := assert.New(t)
- store := NewMemoStore()
-
- err := store.Sets(map[string]interface{}{
- "1": "1.val",
- "2": "2.val",
- "3": "3.val",
- "4": "4.val",
- }, "test_")
- asserts.NoError(err)
-
- err = store.Delete([]string{"1", "2"}, "test_")
- asserts.NoError(err)
- values, miss := store.Gets([]string{"1", "2", "3", "4"}, "test_")
- asserts.Equal([]string{"1", "2"}, miss)
- asserts.Equal(map[string]interface{}{"3": "3.val", "4": "4.val"}, values)
-}
-
-func TestMemoStore_GarbageCollect(t *testing.T) {
- asserts := assert.New(t)
- store := NewMemoStore()
- store.Set("test", 1, 1)
- time.Sleep(time.Duration(2000) * time.Millisecond)
- store.GarbageCollect()
- _, ok := store.Get("test")
- asserts.False(ok)
-}
-
-func TestMemoStore_PersistFailed(t *testing.T) {
- a := assert.New(t)
- store := NewMemoStore()
- type testStruct struct{ v string }
- store.Set("test", 1, 0)
- store.Set("test2", testStruct{v: "test"}, 0)
- err := store.Persist(filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed"))
- a.Error(err)
-}
-
-func TestMemoStore_PersistAndRestore(t *testing.T) {
- a := assert.New(t)
- store := NewMemoStore()
- store.Set("test", 1, 0)
- // already expired
- store.Store.Store("test2", itemWithTTL{Value: "test", Expires: 1})
- // expired after persist
- store.Set("test3", 1, 1)
- temp := filepath.Join(t.TempDir(), "TestMemoStore_PersistFailed")
-
- // Persist
- err := store.Persist(temp)
- a.NoError(err)
- a.FileExists(temp)
-
- time.Sleep(2 * time.Second)
- // Restore
- store2 := NewMemoStore()
- err = store2.Restore(temp)
- a.NoError(err)
- test, testOk := store2.Get("test")
- a.EqualValues(1, test)
- a.True(testOk)
- test2, test2Ok := store2.Get("test2")
- a.Nil(test2)
- a.False(test2Ok)
- test3, test3Ok := store2.Get("test3")
- a.Nil(test3)
- a.False(test3Ok)
-
- a.NoFileExists(temp)
-}
diff --git a/pkg/cache/redis_test.go b/pkg/cache/redis_test.go
deleted file mode 100644
index c9f1692..0000000
--- a/pkg/cache/redis_test.go
+++ /dev/null
@@ -1,324 +0,0 @@
-package cache
-
-import (
- "errors"
- "fmt"
- "github.com/gomodule/redigo/redis"
- "github.com/rafaeljusto/redigomock"
- "github.com/stretchr/testify/assert"
- "testing"
- "time"
-)
-
-func TestNewRedisStore(t *testing.T) {
- asserts := assert.New(t)
-
- store := NewRedisStore(10, "tcp", "", "", "", "0")
- asserts.NotNil(store)
-
- asserts.Panics(func() {
- store.pool.Dial()
- })
-
- testConn := redigomock.NewConn()
- cmd := testConn.Command("PING").Expect("PONG")
- err := store.pool.TestOnBorrow(testConn, time.Now())
- if testConn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- asserts.NoError(err)
-}
-
-func TestRedisStore_Set(t *testing.T) {
- asserts := assert.New(t)
- conn := redigomock.NewConn()
- pool := &redis.Pool{
- Dial: func() (redis.Conn, error) { return conn, nil },
- MaxIdle: 10,
- }
- store := &RedisStore{pool: pool}
-
- // 正常情况
- {
- cmd := conn.Command("SET", "test", redigomock.NewAnyData()).ExpectStringSlice("OK")
- err := store.Set("test", "test val", -1)
- asserts.NoError(err)
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- }
-
- // 带有TTL
- // 正常情况
- {
- cmd := conn.Command("SETEX", "test", 10, redigomock.NewAnyData()).ExpectStringSlice("OK")
- err := store.Set("test", "test val", 10)
- asserts.NoError(err)
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- }
-
- // 序列化出错
- {
- value := struct {
- Key string
- }{
- Key: "123",
- }
- err := store.Set("test", value, -1)
- asserts.Error(err)
- }
-
- // 命令执行失败
- {
- conn.Clear()
- cmd := conn.Command("SET", "test", redigomock.NewAnyData()).ExpectError(errors.New("error"))
- err := store.Set("test", "test val", -1)
- asserts.Error(err)
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- }
- // 获取连接失败
- {
- store.pool = &redis.Pool{
- Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
- MaxIdle: 10,
- }
- err := store.Set("test", "123", -1)
- asserts.Error(err)
- }
-
-}
-
-func TestRedisStore_Get(t *testing.T) {
- asserts := assert.New(t)
- conn := redigomock.NewConn()
- pool := &redis.Pool{
- Dial: func() (redis.Conn, error) { return conn, nil },
- MaxIdle: 10,
- }
- store := &RedisStore{pool: pool}
-
- // 正常情况
- {
- expectVal, _ := serializer("test val")
- cmd := conn.Command("GET", "test").Expect(expectVal)
- val, ok := store.Get("test")
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- asserts.True(ok)
- asserts.Equal("test val", val.(string))
- }
-
- // Key不存在
- {
- conn.Clear()
- cmd := conn.Command("GET", "test").Expect(nil)
- val, ok := store.Get("test")
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- asserts.False(ok)
- asserts.Nil(val)
- }
- // 解码错误
- {
- conn.Clear()
- cmd := conn.Command("GET", "test").Expect([]byte{0x20})
- val, ok := store.Get("test")
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- asserts.False(ok)
- asserts.Nil(val)
- }
- // 获取连接失败
- {
- store.pool = &redis.Pool{
- Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
- MaxIdle: 10,
- }
- val, ok := store.Get("test")
- asserts.False(ok)
- asserts.Nil(val)
- }
-}
-
-func TestRedisStore_Gets(t *testing.T) {
- asserts := assert.New(t)
- conn := redigomock.NewConn()
- pool := &redis.Pool{
- Dial: func() (redis.Conn, error) { return conn, nil },
- MaxIdle: 10,
- }
- store := &RedisStore{pool: pool}
-
- // 全部命中
- {
- conn.Clear()
- value1, _ := serializer("1")
- value2, _ := serializer("2")
- cmd := conn.Command("MGET", "test_1", "test_2").ExpectSlice(
- value1, value2)
- res, missed := store.Gets([]string{"1", "2"}, "test_")
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- asserts.Len(missed, 0)
- asserts.Len(res, 2)
- asserts.Equal("1", res["1"].(string))
- asserts.Equal("2", res["2"].(string))
- }
-
- // 命中一个
- {
- conn.Clear()
- value2, _ := serializer("2")
- cmd := conn.Command("MGET", "test_1", "test_2").ExpectSlice(
- nil, value2)
- res, missed := store.Gets([]string{"1", "2"}, "test_")
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- asserts.Len(missed, 1)
- asserts.Len(res, 1)
- asserts.Equal("1", missed[0])
- asserts.Equal("2", res["2"].(string))
- }
-
- // 命令出错
- {
- conn.Clear()
- cmd := conn.Command("MGET", "test_1", "test_2").ExpectError(errors.New("error"))
- res, missed := store.Gets([]string{"1", "2"}, "test_")
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- asserts.Len(missed, 2)
- asserts.Len(res, 0)
- }
-
- // 连接出错
- {
- conn.Clear()
- store.pool = &redis.Pool{
- Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
- MaxIdle: 10,
- }
- res, missed := store.Gets([]string{"1", "2"}, "test_")
- asserts.Len(missed, 2)
- asserts.Len(res, 0)
- }
-}
-
-func TestRedisStore_Sets(t *testing.T) {
- asserts := assert.New(t)
- conn := redigomock.NewConn()
- pool := &redis.Pool{
- Dial: func() (redis.Conn, error) { return conn, nil },
- MaxIdle: 10,
- }
- store := &RedisStore{pool: pool}
-
- // 正常
- {
- cmd := conn.Command("MSET", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectSlice("OK")
- err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_")
- asserts.NoError(err)
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- }
-
- // 序列化失败
- {
- conn.Clear()
- value := struct {
- Key string
- }{
- Key: "123",
- }
- err := store.Sets(map[string]interface{}{"1": value, "2": "2"}, "test_")
- asserts.Error(err)
- }
-
- // 执行失败
- {
- cmd := conn.Command("MSET", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectError(errors.New("error"))
- err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_")
- asserts.Error(err)
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- }
-
- // 连接失败
- {
- conn.Clear()
- store.pool = &redis.Pool{
- Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
- MaxIdle: 10,
- }
- err := store.Sets(map[string]interface{}{"1": "1", "2": "2"}, "test_")
- asserts.Error(err)
- }
-}
-
-func TestRedisStore_Delete(t *testing.T) {
- asserts := assert.New(t)
- conn := redigomock.NewConn()
- pool := &redis.Pool{
- Dial: func() (redis.Conn, error) { return conn, nil },
- MaxIdle: 10,
- }
- store := &RedisStore{pool: pool}
-
- // 正常
- {
- cmd := conn.Command("DEL", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectSlice("OK")
- err := store.Delete([]string{"1", "2", "3", "4"}, "test_")
- asserts.NoError(err)
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- }
-
- // 命令执行失败
- {
- conn.Clear()
- cmd := conn.Command("DEL", redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData(), redigomock.NewAnyData()).ExpectError(errors.New("error"))
- err := store.Delete([]string{"1", "2", "3", "4"}, "test_")
- asserts.Error(err)
- if conn.Stats(cmd) != 1 {
- fmt.Println("Command was not used")
- return
- }
- }
-
- // 连接失败
- {
- conn.Clear()
- store.pool = &redis.Pool{
- Dial: func() (redis.Conn, error) { return nil, errors.New("error") },
- MaxIdle: 10,
- }
- err := store.Delete([]string{"1", "2", "3", "4"}, "test_")
- asserts.Error(err)
- }
-}
diff --git a/pkg/cluster/controller.go b/pkg/cluster/controller.go
index 85fb178..1e8417c 100644
--- a/pkg/cluster/controller.go
+++ b/pkg/cluster/controller.go
@@ -4,6 +4,9 @@ import (
"bytes"
"encoding/gob"
"fmt"
+ "net/url"
+ "sync"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
@@ -12,8 +15,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/jinzhu/gorm"
- "net/url"
- "sync"
)
var DefaultController Controller
diff --git a/pkg/cluster/controller_test.go b/pkg/cluster/controller_test.go
deleted file mode 100644
index 42d8362..0000000
--- a/pkg/cluster/controller_test.go
+++ /dev/null
@@ -1,385 +0,0 @@
-package cluster
-
-import (
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
- "github.com/cloudreve/Cloudreve/v3/pkg/auth"
- "github.com/cloudreve/Cloudreve/v3/pkg/mq"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
- "io"
- "io/ioutil"
- "net/http"
- "strings"
- "testing"
-)
-
-func TestInitController(t *testing.T) {
- assert.NotPanics(t, func() {
- InitController()
- })
-}
-
-func TestSlaveController_HandleHeartBeat(t *testing.T) {
- a := assert.New(t)
- c := &slaveController{
- masters: make(map[string]MasterInfo),
- }
-
- // first heart beat
- {
- _, err := c.HandleHeartBeat(&serializer.NodePingReq{
- SiteID: "1",
- Node: &model.Node{},
- })
- a.NoError(err)
-
- _, err = c.HandleHeartBeat(&serializer.NodePingReq{
- SiteID: "2",
- Node: &model.Node{},
- })
- a.NoError(err)
-
- a.Len(c.masters, 2)
- }
-
- // second heart beat, no fresh
- {
- _, err := c.HandleHeartBeat(&serializer.NodePingReq{
- SiteID: "1",
- SiteURL: "http://127.0.0.1",
- Node: &model.Node{},
- })
- a.NoError(err)
- a.Len(c.masters, 2)
- a.Empty(c.masters["1"].URL)
- }
-
- // second heart beat, fresh
- {
- _, err := c.HandleHeartBeat(&serializer.NodePingReq{
- SiteID: "1",
- IsUpdate: true,
- SiteURL: "http://127.0.0.1",
- Node: &model.Node{},
- })
- a.NoError(err)
- a.Len(c.masters, 2)
- a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
- }
-
- // second heart beat, fresh, url illegal
- {
- _, err := c.HandleHeartBeat(&serializer.NodePingReq{
- SiteID: "1",
- IsUpdate: true,
- SiteURL: string([]byte{0x7f}),
- Node: &model.Node{},
- })
- a.Error(err)
- a.Len(c.masters, 2)
- a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
- }
-}
-
-type nodeMock struct {
- testMock.Mock
-}
-
-func (n nodeMock) Init(node *model.Node) {
- n.Called(node)
-}
-
-func (n nodeMock) IsFeatureEnabled(feature string) bool {
- args := n.Called(feature)
- return args.Bool(0)
-}
-
-func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
- n.Called(callback)
-}
-
-func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
- args := n.Called(req)
- return args.Get(0).(*serializer.NodePingResp), args.Error(1)
-}
-
-func (n nodeMock) IsActive() bool {
- args := n.Called()
- return args.Bool(0)
-}
-
-func (n nodeMock) GetAria2Instance() common.Aria2 {
- args := n.Called()
- return args.Get(0).(common.Aria2)
-}
-
-func (n nodeMock) ID() uint {
- args := n.Called()
- return args.Get(0).(uint)
-}
-
-func (n nodeMock) Kill() {
- n.Called()
-}
-
-func (n nodeMock) IsMater() bool {
- args := n.Called()
- return args.Bool(0)
-}
-
-func (n nodeMock) MasterAuthInstance() auth.Auth {
- args := n.Called()
- return args.Get(0).(auth.Auth)
-}
-
-func (n nodeMock) SlaveAuthInstance() auth.Auth {
- args := n.Called()
- return args.Get(0).(auth.Auth)
-}
-
-func (n nodeMock) DBModel() *model.Node {
- args := n.Called()
- return args.Get(0).(*model.Node)
-}
-
-func TestSlaveController_GetAria2Instance(t *testing.T) {
- a := assert.New(t)
- mockNode := &nodeMock{}
- mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {Instance: mockNode},
- },
- }
-
- // node node found
- {
- res, err := c.GetAria2Instance("2")
- a.Nil(res)
- a.Equal(ErrMasterNotFound, err)
- }
-
- // node found
- {
- res, err := c.GetAria2Instance("1")
- a.NotNil(res)
- a.NoError(err)
- mockNode.AssertExpectations(t)
- }
-
-}
-
-type requestMock struct {
- testMock.Mock
-}
-
-func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
- return r.Called(method, target, body, opts).Get(0).(*request.Response)
-}
-
-func TestSlaveController_SendNotification(t *testing.T) {
- a := assert.New(t)
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {},
- },
- }
-
- // node not exit
- {
- a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
- }
-
- // gob encode error
- {
- type randomType struct{}
- a.Error(c.SendNotification("1", "", mq.Message{
- Content: randomType{},
- }))
- }
-
- // return none 200
- {
- mockRequest := &requestMock{}
- mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{StatusCode: http.StatusConflict},
- })
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {Client: mockRequest},
- },
- }
- a.Error(c.SendNotification("1", "s1", mq.Message{}))
- mockRequest.AssertExpectations(t)
- }
-
- // master return error
- {
- mockRequest := &requestMock{}
- mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
- },
- })
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {Client: mockRequest},
- },
- }
- a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
- mockRequest.AssertExpectations(t)
- }
-
- // success
- {
- mockRequest := &requestMock{}
- mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
- },
- })
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {Client: mockRequest},
- },
- }
- a.NoError(c.SendNotification("1", "s3", mq.Message{}))
- mockRequest.AssertExpectations(t)
- }
-}
-
-func TestSlaveController_SubmitTask(t *testing.T) {
- a := assert.New(t)
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {
- jobTracker: map[string]bool{},
- },
- },
- }
-
- // node not exit
- {
- a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil))
- }
-
- // success
- {
- submitted := false
- a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
- submitted = true
- }))
- a.True(submitted)
- }
-
- // job already submitted
- {
- submitted := false
- a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
- submitted = true
- }))
- a.False(submitted)
- }
-}
-
-func TestSlaveController_GetMasterInfo(t *testing.T) {
- a := assert.New(t)
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {},
- },
- }
-
- // node not exit
- {
- res, err := c.GetMasterInfo("2")
- a.Equal(ErrMasterNotFound, err)
- a.Nil(res)
- }
-
- // success
- {
- res, err := c.GetMasterInfo("1")
- a.NoError(err)
- a.NotNil(res)
- }
-}
-
-func TestSlaveController_GetOneDriveToken(t *testing.T) {
- a := assert.New(t)
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {},
- },
- }
-
- // node not exit
- {
- res, err := c.GetPolicyOauthToken("2", 1)
- a.Equal(ErrMasterNotFound, err)
- a.Empty(res)
- }
-
- // return none 200
- {
- mockRequest := &requestMock{}
- mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{StatusCode: http.StatusConflict},
- })
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {Client: mockRequest},
- },
- }
- res, err := c.GetPolicyOauthToken("1", 1)
- a.Error(err)
- a.Empty(res)
- mockRequest.AssertExpectations(t)
- }
-
- // master return error
- {
- mockRequest := &requestMock{}
- mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
- },
- })
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {Client: mockRequest},
- },
- }
- res, err := c.GetPolicyOauthToken("1", 1)
- a.Equal(1, err.(serializer.AppError).Code)
- a.Empty(res)
- mockRequest.AssertExpectations(t)
- }
-
- // success
- {
- mockRequest := &requestMock{}
- mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")),
- },
- })
- c := &slaveController{
- masters: map[string]MasterInfo{
- "1": {Client: mockRequest},
- },
- }
- res, err := c.GetPolicyOauthToken("1", 1)
- a.NoError(err)
- a.Equal("expected", res)
- mockRequest.AssertExpectations(t)
- }
-
-}
diff --git a/pkg/cluster/master_test.go b/pkg/cluster/master_test.go
deleted file mode 100644
index 7ff07ac..0000000
--- a/pkg/cluster/master_test.go
+++ /dev/null
@@ -1,186 +0,0 @@
-package cluster
-
-import (
- "context"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/stretchr/testify/assert"
- "os"
- "testing"
- "time"
-)
-
-func TestMasterNode_Init(t *testing.T) {
- a := assert.New(t)
- m := &MasterNode{}
- m.Init(&model.Node{Status: model.NodeSuspend})
- a.Equal(model.NodeSuspend, m.DBModel().Status)
- m.Init(&model.Node{Aria2Enabled: true})
-}
-
-func TestMasterNode_DummyMethods(t *testing.T) {
- a := assert.New(t)
- m := &MasterNode{
- Model: &model.Node{},
- }
-
- m.Model.ID = 5
- a.Equal(m.Model.ID, m.ID())
-
- res, err := m.Ping(&serializer.NodePingReq{})
- a.NoError(err)
- a.NotNil(res)
-
- a.True(m.IsActive())
- a.True(m.IsMater())
-
- m.SubscribeStatusChange(func(isActive bool, id uint) {})
-}
-
-func TestMasterNode_IsFeatureEnabled(t *testing.T) {
- a := assert.New(t)
- m := &MasterNode{
- Model: &model.Node{},
- }
-
- a.False(m.IsFeatureEnabled("aria2"))
- a.False(m.IsFeatureEnabled("random"))
- m.Model.Aria2Enabled = true
- a.True(m.IsFeatureEnabled("aria2"))
-}
-
-func TestMasterNode_AuthInstance(t *testing.T) {
- a := assert.New(t)
- m := &MasterNode{
- Model: &model.Node{},
- }
-
- a.NotNil(m.MasterAuthInstance())
- a.NotNil(m.SlaveAuthInstance())
-}
-
-func TestMasterNode_Kill(t *testing.T) {
- m := &MasterNode{
- Model: &model.Node{},
- }
-
- m.Kill()
-
- caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
- m.aria2RPC.Caller = caller
- m.Kill()
-}
-
-func TestMasterNode_GetAria2Instance(t *testing.T) {
- a := assert.New(t)
- m := &MasterNode{
- Model: &model.Node{},
- aria2RPC: rpcService{},
- }
-
- m.aria2RPC.parent = m
-
- a.NotNil(m.GetAria2Instance())
- m.Model.Aria2Enabled = true
- a.NotNil(m.GetAria2Instance())
- m.aria2RPC.Initialized = true
- a.NotNil(m.GetAria2Instance())
-}
-
-func TestRpcService_Init(t *testing.T) {
- a := assert.New(t)
- m := &MasterNode{
- Model: &model.Node{
- Aria2OptionsSerialized: model.Aria2Option{
- Options: "{",
- },
- },
- aria2RPC: rpcService{},
- }
- m.aria2RPC.parent = m
-
- // failed to decode address
- {
- m.Model.Aria2OptionsSerialized.Server = string([]byte{0x7f})
- a.Error(m.aria2RPC.Init())
- }
-
- // failed to decode options
- {
- m.Model.Aria2OptionsSerialized.Server = ""
- a.Error(m.aria2RPC.Init())
- }
-
- // failed to initialized
- {
- m.Model.Aria2OptionsSerialized.Server = ""
- m.Model.Aria2OptionsSerialized.Options = "{}"
- caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
- m.aria2RPC.Caller = caller
- a.Error(m.aria2RPC.Init())
- a.False(m.aria2RPC.Initialized)
- }
-}
-
-func getTestRPCNode() *MasterNode {
- m := &MasterNode{
- Model: &model.Node{
- Aria2OptionsSerialized: model.Aria2Option{},
- },
- aria2RPC: rpcService{
- options: &clientOptions{
- Options: map[string]interface{}{"1": "1"},
- },
- },
- }
- m.aria2RPC.parent = m
- caller, _ := rpc.New(context.Background(), "http://", "", 0, nil)
- m.aria2RPC.Caller = caller
- return m
-}
-
-func TestRpcService_CreateTask(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNode()
-
- res, err := m.aria2RPC.CreateTask(&model.Download{}, map[string]interface{}{"1": "1"})
- a.Error(err)
- a.Empty(res)
-}
-
-func TestRpcService_Status(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNode()
-
- res, err := m.aria2RPC.Status(&model.Download{})
- a.Error(err)
- a.Empty(res)
-}
-
-func TestRpcService_Cancel(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNode()
-
- a.Error(m.aria2RPC.Cancel(&model.Download{}))
-}
-
-func TestRpcService_Select(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNode()
-
- a.NotNil(m.aria2RPC.GetConfig())
- a.Error(m.aria2RPC.Select(&model.Download{}, []int{1, 2, 3}))
-}
-
-func TestRpcService_DeleteTempFile(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNode()
- fdName := "TestRpcService_DeleteTempFile"
- a.NoError(os.Mkdir(fdName, 0644))
-
- a.NoError(m.aria2RPC.DeleteTempFile(&model.Download{Parent: fdName}))
- time.Sleep(500 * time.Millisecond)
- a.False(util.Exists(fdName))
-}
diff --git a/pkg/cluster/node_test.go b/pkg/cluster/node_test.go
deleted file mode 100644
index d817425..0000000
--- a/pkg/cluster/node_test.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package cluster
-
-import (
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestNewNodeFromDBModel(t *testing.T) {
- a := assert.New(t)
- a.IsType(&SlaveNode{}, NewNodeFromDBModel(&model.Node{
- Type: model.SlaveNodeType,
- }))
- a.IsType(&MasterNode{}, NewNodeFromDBModel(&model.Node{
- Type: model.MasterNodeType,
- }))
-}
diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go
index d6704b6..a70186f 100644
--- a/pkg/cluster/pool.go
+++ b/pkg/cluster/pool.go
@@ -4,6 +4,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
+ "github.com/samber/lo"
"sync"
)
@@ -15,7 +16,7 @@ var featureGroup = []string{"aria2"}
// Pool 节点池
type Pool interface {
// Returns active node selected by given feature and load balancer
- BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node)
+ BalanceNodeByFeature(feature string, lb balancer.Balancer, available []uint, prefer uint) (error, Node)
// Returns node by ID
GetNodeByID(id uint) Node
@@ -174,11 +175,33 @@ func (pool *NodePool) Delete(id uint) {
}
// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点
-func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, Node) {
+func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer,
+ available []uint, prefer uint) (error, Node) {
pool.lock.RLock()
defer pool.lock.RUnlock()
if nodes, ok := pool.featureMap[feature]; ok {
- err, res := lb.NextPeer(nodes)
+ // Find nodes that are allowed to be used in user group
+ availableNodes := nodes
+ if len(available) > 0 {
+ idHash := make(map[uint]struct{}, len(available))
+ for _, id := range available {
+ idHash[id] = struct{}{}
+ }
+
+ availableNodes = lo.Filter[Node](nodes, func(node Node, index int) bool {
+ _, exist := idHash[node.ID()]
+ return exist
+ })
+ }
+
+ // Return preferred node if exists
+ if preferredNode, found := lo.Find[Node](availableNodes, func(node Node) bool {
+ return node.ID() == prefer
+ }); found {
+ return nil, preferredNode
+ }
+
+ err, res := lb.NextPeer(availableNodes)
if err == nil {
return nil, res.(Node)
}
diff --git a/pkg/cluster/pool_test.go b/pkg/cluster/pool_test.go
deleted file mode 100644
index dde3455..0000000
--- a/pkg/cluster/pool_test.go
+++ /dev/null
@@ -1,161 +0,0 @@
-package cluster
-
-import (
- "database/sql"
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/balancer"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-var mock sqlmock.Sqlmock
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- defer db.Close()
- m.Run()
-}
-
-func TestInitFailed(t *testing.T) {
- a := assert.New(t)
- mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
- Init()
- a.NoError(mock.ExpectationsWereMet())
-}
-
-func TestInitSuccess(t *testing.T) {
- a := assert.New(t)
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "aria2_enabled", "type"}).AddRow(1, true, model.MasterNodeType))
- Init()
- a.NoError(mock.ExpectationsWereMet())
-}
-
-func TestNodePool_GetNodeByID(t *testing.T) {
- a := assert.New(t)
- p := &NodePool{}
- p.Init()
- mockNode := &nodeMock{}
-
- // inactive
- {
- p.inactive[1] = mockNode
- a.Equal(mockNode, p.GetNodeByID(1))
- }
-
- // active
- {
- delete(p.inactive, 1)
- p.active[1] = mockNode
- a.Equal(mockNode, p.GetNodeByID(1))
- }
-}
-
-func TestNodePool_NodeStatusChange(t *testing.T) {
- a := assert.New(t)
- p := &NodePool{}
- n := &MasterNode{Model: &model.Node{}}
- p.Init()
- p.inactive[1] = n
-
- p.nodeStatusChange(true, 1)
- a.Len(p.inactive, 0)
- a.Equal(n, p.active[1])
-
- p.nodeStatusChange(false, 1)
- a.Len(p.active, 0)
- a.Equal(n, p.inactive[1])
-
- p.nodeStatusChange(false, 1)
- a.Len(p.active, 0)
- a.Equal(n, p.inactive[1])
-}
-
-func TestNodePool_Add(t *testing.T) {
- a := assert.New(t)
- p := &NodePool{}
- p.Init()
-
- // new node
- {
- p.Add(&model.Node{})
- a.Len(p.active, 1)
- }
-
- // old node
- {
- p.inactive[0] = p.active[0]
- delete(p.active, 0)
- p.Add(&model.Node{})
- a.Len(p.active, 0)
- a.Len(p.inactive, 1)
- }
-}
-
-func TestNodePool_Delete(t *testing.T) {
- a := assert.New(t)
- p := &NodePool{}
- p.Init()
-
- // active
- {
- mockNode := &nodeMock{}
- mockNode.On("Kill")
- p.active[0] = mockNode
- p.Delete(0)
- a.Len(p.active, 0)
- a.Len(p.inactive, 0)
- mockNode.AssertExpectations(t)
- }
-
- p.Init()
-
- // inactive
- {
- mockNode := &nodeMock{}
- mockNode.On("Kill")
- p.inactive[0] = mockNode
- p.Delete(0)
- a.Len(p.active, 0)
- a.Len(p.inactive, 0)
- mockNode.AssertExpectations(t)
- }
-}
-
-func TestNodePool_BalanceNodeByFeature(t *testing.T) {
- a := assert.New(t)
- p := &NodePool{}
- p.Init()
-
- // success
- {
- p.featureMap["test"] = []Node{&MasterNode{}}
- err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
- a.NoError(err)
- a.Equal(p.featureMap["test"][0], res)
- }
-
- // NoNodes
- {
- p.featureMap["test"] = []Node{}
- err, res := p.BalanceNodeByFeature("test", balancer.NewBalancer("round-robin"))
- a.Error(err)
- a.Nil(res)
- }
-
- // No match feature
- {
- err, res := p.BalanceNodeByFeature("test2", balancer.NewBalancer("round-robin"))
- a.Error(err)
- a.Nil(res)
- }
-}
diff --git a/pkg/cluster/slave_test.go b/pkg/cluster/slave_test.go
deleted file mode 100644
index 1b1510f..0000000
--- a/pkg/cluster/slave_test.go
+++ /dev/null
@@ -1,559 +0,0 @@
-package cluster
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
- "io/ioutil"
- "net/http"
- "strings"
- "testing"
- "time"
-)
-
-func TestSlaveNode_InitAndKill(t *testing.T) {
- a := assert.New(t)
- n := &SlaveNode{
- callback: func(b bool, u uint) {
-
- },
- }
-
- a.NotPanics(func() {
- n.Init(&model.Node{})
- time.Sleep(time.Millisecond * 500)
- n.Init(&model.Node{})
- n.Kill()
- })
-}
-
-func TestSlaveNode_DummyMethods(t *testing.T) {
- a := assert.New(t)
- m := &SlaveNode{
- Model: &model.Node{},
- }
-
- m.Model.ID = 5
- a.Equal(m.Model.ID, m.ID())
- a.Equal(m.Model.ID, m.DBModel().ID)
-
- a.False(m.IsActive())
- a.False(m.IsMater())
-
- m.SubscribeStatusChange(func(isActive bool, id uint) {})
-}
-
-func TestSlaveNode_IsFeatureEnabled(t *testing.T) {
- a := assert.New(t)
- m := &SlaveNode{
- Model: &model.Node{},
- }
-
- a.False(m.IsFeatureEnabled("aria2"))
- a.False(m.IsFeatureEnabled("random"))
- m.Model.Aria2Enabled = true
- a.True(m.IsFeatureEnabled("aria2"))
-}
-
-func TestSlaveNode_Ping(t *testing.T) {
- a := assert.New(t)
- m := &SlaveNode{
- Model: &model.Node{},
- }
-
- // master return error code
- {
- mockRequest := &requestMock{}
- mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
- },
- })
- m.caller.Client = mockRequest
- res, err := m.Ping(&serializer.NodePingReq{})
- a.Error(err)
- a.Nil(res)
- a.Equal(1, err.(serializer.AppError).Code)
- }
-
- // return unexpected json
- {
- mockRequest := &requestMock{}
- mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"233\"}")),
- },
- })
- m.caller.Client = mockRequest
- res, err := m.Ping(&serializer.NodePingReq{})
- a.Error(err)
- a.Nil(res)
- }
-
- // return success
- {
- mockRequest := &requestMock{}
- mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
- },
- })
- m.caller.Client = mockRequest
- res, err := m.Ping(&serializer.NodePingReq{})
- a.NoError(err)
- a.NotNil(res)
- }
-}
-
-func TestSlaveNode_GetAria2Instance(t *testing.T) {
- a := assert.New(t)
- m := &SlaveNode{
- Model: &model.Node{},
- }
-
- a.NotNil(m.GetAria2Instance())
- m.Model.Aria2Enabled = true
- a.NotNil(m.GetAria2Instance())
- a.NotNil(m.GetAria2Instance())
-}
-
-func TestSlaveNode_StartPingLoop(t *testing.T) {
- callbackCount := 0
- finishedChan := make(chan struct{})
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 404,
- },
- })
- m := &SlaveNode{
- Active: true,
- Model: &model.Node{},
- callback: func(b bool, u uint) {
- callbackCount++
- if callbackCount == 2 {
- close(finishedChan)
- }
- if callbackCount == 1 {
- mockRequest.AssertExpectations(t)
- mockRequest = requestMock{}
- mockRequest.On("Request", "POST", "heartbeat", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"{}\"}")),
- },
- })
- }
- },
- }
- cache.Set("setting_slave_ping_interval", "0", 0)
- cache.Set("setting_slave_recover_interval", "0", 0)
- cache.Set("setting_slave_node_retry", "1", 0)
-
- m.caller.Client = &mockRequest
- go func() {
- select {
- case <-finishedChan:
- m.Kill()
- }
- }()
- m.StartPingLoop()
- mockRequest.AssertExpectations(t)
-}
-
-func TestSlaveNode_AuthInstance(t *testing.T) {
- a := assert.New(t)
- m := &SlaveNode{
- Model: &model.Node{},
- }
-
- a.NotNil(m.MasterAuthInstance())
- a.NotNil(m.SlaveAuthInstance())
-}
-
-func TestSlaveNode_ChangeStatus(t *testing.T) {
- a := assert.New(t)
- isActive := false
- m := &SlaveNode{
- Model: &model.Node{},
- callback: func(b bool, u uint) {
- isActive = b
- },
- }
-
- a.NotPanics(func() {
- m.changeStatus(false)
- })
- m.changeStatus(true)
- a.True(isActive)
-}
-
-func getTestRPCNodeSlave() *SlaveNode {
- m := &SlaveNode{
- Model: &model.Node{},
- }
- m.caller.parent = m
- return m
-}
-
-func TestSlaveCaller_CreateTask(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNodeSlave()
-
- // master return 404
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 404,
- },
- })
- m.caller.Client = mockRequest
- res, err := m.caller.CreateTask(&model.Download{}, nil)
- a.Empty(res)
- a.Error(err)
- }
-
- // master return error
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
- },
- })
- m.caller.Client = mockRequest
- res, err := m.caller.CreateTask(&model.Download{}, nil)
- a.Empty(res)
- a.Error(err)
- }
-
- // master return success
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/task", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
- },
- })
- m.caller.Client = mockRequest
- res, err := m.caller.CreateTask(&model.Download{}, nil)
- a.Equal("res", res)
- a.NoError(err)
- }
-}
-
-func TestSlaveCaller_Status(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNodeSlave()
-
- // master return 404
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 404,
- },
- })
- m.caller.Client = mockRequest
- res, err := m.caller.Status(&model.Download{})
- a.Empty(res.Status)
- a.Error(err)
- }
-
- // master return error
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
- },
- })
- m.caller.Client = mockRequest
- res, err := m.caller.Status(&model.Download{})
- a.Empty(res.Status)
- a.Error(err)
- }
-
- // master return success
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/status", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"re456456s\"}")),
- },
- })
- m.caller.Client = mockRequest
- res, err := m.caller.Status(&model.Download{})
- a.Empty(res.Status)
- a.NoError(err)
- }
-}
-
-func TestSlaveCaller_Cancel(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNodeSlave()
-
- // master return 404
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 404,
- },
- })
- m.caller.Client = mockRequest
- err := m.caller.Cancel(&model.Download{})
- a.Error(err)
- }
-
- // master return error
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
- },
- })
- m.caller.Client = mockRequest
- err := m.caller.Cancel(&model.Download{})
- a.Error(err)
- }
-
- // master return success
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/cancel", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
- },
- })
- m.caller.Client = mockRequest
- err := m.caller.Cancel(&model.Download{})
- a.NoError(err)
- }
-}
-
-func TestSlaveCaller_Select(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNodeSlave()
- m.caller.Init()
- m.caller.GetConfig()
-
- // master return 404
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 404,
- },
- })
- m.caller.Client = mockRequest
- err := m.caller.Select(&model.Download{}, nil)
- a.Error(err)
- }
-
- // master return error
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
- },
- })
- m.caller.Client = mockRequest
- err := m.caller.Select(&model.Download{}, nil)
- a.Error(err)
- }
-
- // master return success
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/select", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
- },
- })
- m.caller.Client = mockRequest
- err := m.caller.Select(&model.Download{}, nil)
- a.NoError(err)
- }
-}
-
-func TestSlaveCaller_DeleteTempFile(t *testing.T) {
- a := assert.New(t)
- m := getTestRPCNodeSlave()
- m.caller.Init()
- m.caller.GetConfig()
-
- // master return 404
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 404,
- },
- })
- m.caller.Client = mockRequest
- err := m.caller.DeleteTempFile(&model.Download{})
- a.Error(err)
- }
-
- // master return error
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
- },
- })
- m.caller.Client = mockRequest
- err := m.caller.DeleteTempFile(&model.Download{})
- a.Error(err)
- }
-
- // master return success
- {
- mockRequest := requestMock{}
- mockRequest.On("Request", "POST", "aria2/delete", testMock.Anything, testMock.Anything).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("{\"data\":\"res\"}")),
- },
- })
- m.caller.Client = mockRequest
- err := m.caller.DeleteTempFile(&model.Download{})
- a.NoError(err)
- }
-}
-
-func TestRemoteCallback(t *testing.T) {
- asserts := assert.New(t)
-
- // 回调成功
- {
- clientMock := requestmock.RequestMock{}
- mockResp, _ := json.Marshal(serializer.Response{Code: 0})
- clientMock.On(
- "Request",
- "POST",
- "http://test/test/url",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
- },
- })
- request.GeneralClient = clientMock
- resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
- asserts.NoError(resp)
- clientMock.AssertExpectations(t)
- }
-
- // 服务端返回业务错误
- {
- clientMock := requestmock.RequestMock{}
- mockResp, _ := json.Marshal(serializer.Response{Code: 401})
- clientMock.On(
- "Request",
- "POST",
- "http://test/test/url",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(bytes.NewReader(mockResp)),
- },
- })
- request.GeneralClient = clientMock
- resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
- asserts.EqualValues(401, resp.(serializer.AppError).Code)
- clientMock.AssertExpectations(t)
- }
-
- // 无法解析回调响应
- {
- clientMock := requestmock.RequestMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test/test/url",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader("mockResp")),
- },
- })
- request.GeneralClient = clientMock
- resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
- asserts.Error(resp)
- clientMock.AssertExpectations(t)
- }
-
- // HTTP状态码非200
- {
- clientMock := requestmock.RequestMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test/test/url",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 404,
- Body: ioutil.NopCloser(strings.NewReader("mockResp")),
- },
- })
- request.GeneralClient = clientMock
- resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
- asserts.Error(resp)
- clientMock.AssertExpectations(t)
- }
-
- // 无法发起回调
- {
- clientMock := requestmock.RequestMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test/test/url",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: errors.New("error"),
- })
- request.GeneralClient = clientMock
- resp := RemoteCallback("http://test/test/url", serializer.UploadCallback{})
- asserts.Error(resp)
- clientMock.AssertExpectations(t)
- }
-}
diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go
index b0a4ea4..942294b 100644
--- a/pkg/conf/conf.go
+++ b/pkg/conf/conf.go
@@ -53,7 +53,7 @@ type slave struct {
type redis struct {
Network string
Server string
- User string
+ User string
Password string
DB string
}
diff --git a/pkg/conf/conf_test.go b/pkg/conf/conf_test.go
deleted file mode 100644
index 6d186ed..0000000
--- a/pkg/conf/conf_test.go
+++ /dev/null
@@ -1,100 +0,0 @@
-package conf
-
-import (
- "io/ioutil"
- "os"
- "testing"
-
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/stretchr/testify/assert"
-)
-
-// 测试Init日志路径错误
-func TestInitPanic(t *testing.T) {
- asserts := assert.New(t)
-
- // 日志路径不存在时
- asserts.NotPanics(func() {
- Init("not/exist/path/conf.ini")
- })
-
- asserts.True(util.Exists("not/exist/path/conf.ini"))
-
-}
-
-// TestInitDelimiterNotFound 日志路径存在但 Key 格式错误时
-func TestInitDelimiterNotFound(t *testing.T) {
- asserts := assert.New(t)
- testCase := `[Database]
-Type = mysql
-User = root
-Password233root
-Host = 127.0.0.1:3306
-Name = v3
-TablePrefix = v3_`
- err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
- defer func() { err = os.Remove("testConf.ini") }()
- if err != nil {
- panic(err)
- }
- asserts.Panics(func() {
- Init("testConf.ini")
- })
-}
-
-// TestInitNoPanic 日志路径存在且合法时
-func TestInitNoPanic(t *testing.T) {
- asserts := assert.New(t)
- testCase := `
-[System]
-Listen = 3000
-HashIDSalt = 1
-
-[Database]
-Type = mysql
-User = root
-Password = root
-Host = 127.0.0.1:3306
-Name = v3
-TablePrefix = v3_
-
-[OptionOverwrite]
-key=value
-`
- err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
- defer func() { err = os.Remove("testConf.ini") }()
- if err != nil {
- panic(err)
- }
- asserts.NotPanics(func() {
- Init("testConf.ini")
- })
- asserts.Equal(OptionOverwrite["key"], "value")
-}
-
-func TestMapSection(t *testing.T) {
- asserts := assert.New(t)
-
- //正常情况
- testCase := `
-[System]
-Listen = 3000
-HashIDSalt = 1
-
-[Database]
-Type = mysql
-User = root
-Password:root
-Host = 127.0.0.1:3306
-Name = v3
-TablePrefix = v3_`
- err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
- defer func() { err = os.Remove("testConf.ini") }()
- if err != nil {
- panic(err)
- }
- Init("testConf.ini")
- err = mapSection("Database", DatabaseConfig)
- asserts.NoError(err)
-
-}
diff --git a/pkg/conf/version.go b/pkg/conf/version.go
index 6720e8c..fa4bbf3 100644
--- a/pkg/conf/version.go
+++ b/pkg/conf/version.go
@@ -1,16 +1,22 @@
package conf
+// plusVersion 增强版版本号
+const plusVersion = "+1.1"
+
// BackendVersion 当前后端版本号
-var BackendVersion = "3.8.3"
+const BackendVersion = "3.8.3" + plusVersion
+
+// KeyVersion 授权版本号
+const KeyVersion = "3.3.1"
// RequiredDBVersion 与当前版本匹配的数据库版本
-var RequiredDBVersion = "3.8.1"
+const RequiredDBVersion = "3.8.1+1.0-plus"
// RequiredStaticVersion 与当前版本匹配的静态资源版本
-var RequiredStaticVersion = "3.8.3"
+const RequiredStaticVersion = "3.8.3" + plusVersion
-// IsPro 是否为Pro版本
-var IsPro = "false"
+// IsPlus 是否为Plus版本
+const IsPlus = "true"
// LastCommit 最后commit id
-var LastCommit = "a11f819"
+const LastCommit = "88409cc"
diff --git a/pkg/crontab/init.go b/pkg/crontab/init.go
index 5971c2c..3583d31 100644
--- a/pkg/crontab/init.go
+++ b/pkg/crontab/init.go
@@ -23,6 +23,8 @@ func Init() {
// 读取cron日程设置
options := model.GetSettingByNames(
"cron_garbage_collect",
+ "cron_notify_user",
+ "cron_ban_user",
"cron_recycle_upload_session",
)
Cron := cron.New()
@@ -31,6 +33,10 @@ func Init() {
switch k {
case "cron_garbage_collect":
handler = garbageCollect
+ case "cron_notify_user":
+ handler = notifyExpiredVAS
+ case "cron_ban_user":
+ handler = banOverusedUser
case "cron_recycle_upload_session":
handler = uploadSessionCollect
default:
diff --git a/pkg/crontab/vas.go b/pkg/crontab/vas.go
new file mode 100755
index 0000000..7ce6ae9
--- /dev/null
+++ b/pkg/crontab/vas.go
@@ -0,0 +1,83 @@
+package crontab
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/email"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
+)
+
+func notifyExpiredVAS() {
+ checkStoragePack()
+ checkUserGroup()
+ util.Log().Info("Crontab job \"cron_notify_user\" complete.")
+}
+
+// banOverusedUser 封禁超出宽容期的用户
+func banOverusedUser() {
+ users := model.GetTolerantExpiredUser()
+ for _, user := range users {
+
+ // 清除最后通知日期标记
+ user.ClearNotified()
+
+ // 检查容量是否超额
+ if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
+ // 封禁用户
+ user.SetStatus(model.OveruseBaned)
+ }
+ }
+}
+
+// checkUserGroup 检查已过期用户组
+func checkUserGroup() {
+ users := model.GetGroupExpiredUsers()
+ for _, user := range users {
+
+ // 将用户回退到初始用户组
+ user.GroupFallback()
+
+ // 重新加载用户
+ user, _ = model.GetUserByID(user.ID)
+
+ // 检查容量是否超额
+ if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
+ // 如果超额,则通知用户
+ sendNotification(&user, "用户组过期")
+ // 更新最后通知日期
+ user.Notified()
+ }
+ }
+}
+
+// checkStoragePack 检查已过期的容量包
+func checkStoragePack() {
+ packs := model.GetExpiredStoragePack()
+ for _, pack := range packs {
+ // 删除过期的容量包
+ pack.Delete()
+
+ //找到所属用户
+ user, err := model.GetUserByID(pack.UserID)
+ if err != nil {
+ util.Log().Warning("Crontab job failed to get user info of [UID=%d]: %s", pack.UserID, err)
+ continue
+ }
+
+ // 检查容量是否超额
+ if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
+ // 如果超额,则通知用户
+ sendNotification(&user, "容量包过期")
+
+ // 更新最后通知日期
+ user.Notified()
+
+ }
+ }
+}
+
+func sendNotification(user *model.User, reason string) {
+ title, body := email.NewOveruseNotification(user.Nick, reason)
+ if err := email.Send(user.Email, title, body); err != nil {
+ util.Log().Warning("Failed to send notification email: %s", err)
+ }
+}
diff --git a/pkg/email/smtp.go b/pkg/email/smtp.go
index c92cce7..3845f44 100644
--- a/pkg/email/smtp.go
+++ b/pkg/email/smtp.go
@@ -5,6 +5,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/go-mail/mail"
+ "github.com/google/uuid"
)
// SMTP SMTP协议发送邮件
@@ -50,6 +51,7 @@ func (client *SMTP) Send(to, title, body string) error {
m.SetAddressHeader("Reply-To", client.Config.ReplyTo, client.Config.Name)
m.SetHeader("To", to)
m.SetHeader("Subject", title)
+ m.SetHeader("Message-ID", util.StrConcat(`"<`, uuid.NewString(), `@`, `cloudreveplus`, `>"`))
m.SetBody("text/html", body)
client.ch <- m
return nil
diff --git a/pkg/email/template.go b/pkg/email/template.go
index cb9cb3a..213e5e3 100644
--- a/pkg/email/template.go
+++ b/pkg/email/template.go
@@ -7,6 +7,20 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
+// NewOveruseNotification 新建超额提醒邮件
+func NewOveruseNotification(userName, reason string) (string, string) {
+ options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "over_used_template")
+ replace := map[string]string{
+ "{siteTitle}": options["siteName"],
+ "{userName}": userName,
+ "{notifyReason}": reason,
+ "{siteUrl}": options["siteURL"],
+ "{siteSecTitle}": options["siteTitle"],
+ }
+ return fmt.Sprintf("【%s】空间容量超额提醒", options["siteName"]),
+ util.Replace(replace, options["over_used_template"])
+}
+
// NewActivationEmail 新建激活邮件
func NewActivationEmail(userName, activateURL string) (string, string) {
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_activation_template")
diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go
index 78fc45f..cd3aa83 100644
--- a/pkg/filesystem/archive.go
+++ b/pkg/filesystem/archive.go
@@ -239,13 +239,6 @@ func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string)
reader = zipFile
}
- // 重设存储策略
- fs.Policy = &fs.User.Policy
- err = fs.DispatchHandler()
- if err != nil {
- return err
- }
-
var wg sync.WaitGroup
parallel := model.GetIntSetting("max_parallel_transfer", 4)
worker := make(chan int, parallel)
diff --git a/pkg/filesystem/archive_test.go b/pkg/filesystem/archive_test.go
deleted file mode 100644
index 07f5087..0000000
--- a/pkg/filesystem/archive_test.go
+++ /dev/null
@@ -1,256 +0,0 @@
-package filesystem
-
-import (
- "bytes"
- "context"
- "errors"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- testMock "github.com/stretchr/testify/mock"
- "io"
- "os"
- "path/filepath"
- "runtime"
- "strings"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestFileSystem_Compress(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{
- User: &model.User{Model: gorm.Model{ID: 1}},
- }
-
- // 成功
- {
- // 查找压缩父目录
- mock.ExpectQuery("SELECT(.+)folders(.+)").
- WithArgs(1, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent"))
- // 查找顶级待压缩文件
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(1, 1).
- WillReturnRows(
- sqlmock.NewRows(
- []string{"id", "name", "source_name", "policy_id"}).
- AddRow(1, "1.txt", "tests/file1.txt", 1),
- )
- asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
- // 查找父目录子文件
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"}))
- // 查找子目录
- mock.ExpectQuery("SELECT(.+)folders(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(2, "sub"))
- // 查找子目录子文件
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(2).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id"}).
- AddRow(2, "2.txt", "tests/file2.txt", 1),
- )
- // 查找上传策略
- asserts.NoError(cache.Set("policy_1", model.Policy{Type: "local"}, -1))
- w := &bytes.Buffer{}
-
- err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true)
- asserts.NoError(err)
- asserts.NotEmpty(w.Len())
- }
-
- // 上下文取消
- {
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- // 查找压缩父目录
- mock.ExpectQuery("SELECT(.+)folders(.+)").
- WithArgs(1, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "parent"))
- // 查找顶级待压缩文件
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(1, 1).
- WillReturnRows(
- sqlmock.NewRows(
- []string{"id", "name", "source_name", "policy_id"}).
- AddRow(1, "1.txt", "tests/file1.txt", 1),
- )
- asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
-
- w := &bytes.Buffer{}
- err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true)
- asserts.Error(err)
- asserts.NotEmpty(w.Len())
- }
-
- // 限制父目录
- {
- ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, &model.Folder{
- Model: gorm.Model{ID: 3},
- })
- // 查找压缩父目录
- mock.ExpectQuery("SELECT(.+)folders(.+)").
- WithArgs(1, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "parent_id"}).AddRow(1, "parent", 3))
- // 查找顶级待压缩文件
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(1, 1).
- WillReturnRows(
- sqlmock.NewRows(
- []string{"id", "name", "source_name", "policy_id"}).
- AddRow(1, "1.txt", "tests/file1.txt", 1),
- )
- asserts.NoError(cache.Set("setting_temp_path", "tests", -1))
-
- w := &bytes.Buffer{}
- err := fs.Compress(ctx, w, []uint{1}, []uint{1}, true)
- asserts.Error(err)
- asserts.Equal(ErrObjectNotExist, err)
- asserts.Empty(w.Len())
- }
-
-}
-
-type MockNopRSC string
-
-func (m MockNopRSC) Read(b []byte) (int, error) {
- return 0, errors.New("read error")
-}
-
-func (m MockNopRSC) Seek(n int64, offset int) (int64, error) {
- return 0, errors.New("read error")
-}
-
-func (m MockNopRSC) Close() error {
- return errors.New("read error")
-}
-
-type MockRSC struct {
- rs io.ReadSeeker
-}
-
-func (m MockRSC) Read(b []byte) (int, error) {
- return m.rs.Read(b)
-}
-
-func (m MockRSC) Seek(n int64, offset int) (int64, error) {
- return m.rs.Seek(n, offset)
-}
-
-func (m MockRSC) Close() error {
- return nil
-}
-
-var basepath string
-
-func init() {
- _, currentFile, _, _ := runtime.Caller(0)
- basepath = filepath.Dir(currentFile)
-}
-
-func Path(rel string) string {
- return filepath.Join(basepath, rel)
-}
-
-func TestFileSystem_Decompress(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{
- User: &model.User{Model: gorm.Model{ID: 1}},
- }
- os.RemoveAll(util.RelativePath("tests/decompress"))
-
- // 压缩文件不存在
- {
- // 查找根目录
- mock.ExpectQuery("SELECT(.+)folders(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "/"))
- // 查找压缩文件,未找到
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- err := fs.Decompress(ctx, "/1.zip", "/", "")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
-
- // 无法下载压缩文件
- {
- fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
- fs.FileTarget[0].Policy.ID = 1
- testHandler := new(FileHeaderMock)
- testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, errors.New("error"))
- fs.Handler = testHandler
- err := fs.Decompress(ctx, "/1.zip", "/", "")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.EqualError(err, "error")
- }
-
- // 无法创建临时压缩文件
- {
- cache.Set("setting_temp_path", "/tests:", 0)
- fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
- fs.FileTarget[0].Policy.ID = 1
- testHandler := new(FileHeaderMock)
- testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{}, nil)
- fs.Handler = testHandler
- err := fs.Decompress(ctx, "/1.zip", "/", "")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
-
- // 无法写入压缩文件
- {
- cache.Set("setting_temp_path", "tests", 0)
- fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
- fs.FileTarget[0].Policy.ID = 1
- testHandler := new(FileHeaderMock)
- testHandler.On("Get", testMock.Anything, "1.zip").Return(MockNopRSC("1"), nil)
- fs.Handler = testHandler
- err := fs.Decompress(ctx, "/1.zip", "/", "")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Contains(err.Error(), "read error")
- }
-
- // 无法重设上传策略
- {
- cache.Set("setting_temp_path", "tests", 0)
- fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
- fs.FileTarget[0].Policy.ID = 1
- testHandler := new(FileHeaderMock)
- testHandler.On("Get", testMock.Anything, "1.zip").Return(MockRSC{rs: strings.NewReader("read")}, nil)
- fs.Handler = testHandler
- err := fs.Decompress(ctx, "/1.zip", "/", "")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.True(util.IsEmpty(util.RelativePath("tests/decompress")))
- }
-
- // 无法上传,容量不足
- {
- cache.Set("setting_max_parallel_transfer", "1", 0)
- zipFile, _ := os.Open(Path("tests/test.zip"))
- fs.FileTarget = []model.File{{SourceName: "1.zip", Policy: model.Policy{Type: "mock"}}}
- fs.FileTarget[0].Policy.ID = 1
- fs.User.Policy.Type = "mock"
- testHandler := new(FileHeaderMock)
- testHandler.On("Get", testMock.Anything, "1.zip").Return(zipFile, nil)
- fs.Handler = testHandler
-
- fs.Decompress(ctx, "/1.zip", "/", "")
-
- zipFile.Close()
-
- asserts.NoError(mock.ExpectationsWereMet())
- testHandler.AssertExpectations(t)
- }
-}
diff --git a/pkg/filesystem/chunk/backoff/backoff_test.go b/pkg/filesystem/chunk/backoff/backoff_test.go
deleted file mode 100644
index 0fda534..0000000
--- a/pkg/filesystem/chunk/backoff/backoff_test.go
+++ /dev/null
@@ -1,61 +0,0 @@
-package backoff
-
-import (
- "errors"
- "github.com/stretchr/testify/assert"
- "net/http"
- "testing"
- "time"
-)
-
-func TestConstantBackoff_Next(t *testing.T) {
- a := assert.New(t)
-
- // General error
- {
- err := errors.New("error")
- b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
- a.True(b.Next(err))
- a.True(b.Next(err))
- a.True(b.Next(err))
- a.False(b.Next(err))
- b.Reset()
- a.True(b.Next(err))
- a.True(b.Next(err))
- a.True(b.Next(err))
- a.False(b.Next(err))
- }
-
- // Retryable error
- {
- err := &RetryableError{RetryAfter: time.Duration(1)}
- b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
- a.True(b.Next(err))
- a.True(b.Next(err))
- a.True(b.Next(err))
- a.False(b.Next(err))
- b.Reset()
- a.True(b.Next(err))
- a.True(b.Next(err))
- a.True(b.Next(err))
- a.False(b.Next(err))
- }
-
-}
-
-func TestNewRetryableErrorFromHeader(t *testing.T) {
- a := assert.New(t)
- // no retry-after header
- {
- err := NewRetryableErrorFromHeader(nil, http.Header{})
- a.Empty(err.RetryAfter)
- }
-
- // with retry-after header
- {
- header := http.Header{}
- header.Add("retry-after", "120")
- err := NewRetryableErrorFromHeader(nil, header)
- a.EqualValues(time.Duration(120)*time.Second, err.RetryAfter)
- }
-}
diff --git a/pkg/filesystem/chunk/chunk_test.go b/pkg/filesystem/chunk/chunk_test.go
deleted file mode 100644
index 4bdcd06..0000000
--- a/pkg/filesystem/chunk/chunk_test.go
+++ /dev/null
@@ -1,250 +0,0 @@
-package chunk
-
-import (
- "errors"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/stretchr/testify/assert"
- "io"
- "os"
- "strings"
- "testing"
-)
-
-func TestNewChunkGroup(t *testing.T) {
- a := assert.New(t)
-
- testCases := []struct {
- fileSize uint64
- chunkSize uint64
- expectedInnerChunkSize uint64
- expectedChunkNum uint64
- expectedInfo [][2]int //Start, Index,Length
- }{
- {10, 0, 10, 1, [][2]int{{0, 10}}},
- {0, 0, 0, 1, [][2]int{{0, 0}}},
- {0, 10, 10, 1, [][2]int{{0, 0}}},
- {50, 10, 10, 5, [][2]int{
- {0, 10},
- {10, 10},
- {20, 10},
- {30, 10},
- {40, 10},
- }},
- {50, 50, 50, 1, [][2]int{
- {0, 50},
- }},
-
- {50, 15, 15, 4, [][2]int{
- {0, 15},
- {15, 15},
- {30, 15},
- {45, 5},
- }},
- }
-
- for index, testCase := range testCases {
- file := &fsctx.FileStream{Size: testCase.fileSize}
- chunkGroup := NewChunkGroup(file, testCase.chunkSize, &backoff.ConstantBackoff{}, true)
- a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(),
- "TestCase:%d,ChunkNum()", index)
- a.EqualValues(testCase.expectedInnerChunkSize, chunkGroup.chunkSize,
- "TestCase:%d,InnerChunkSize()", index)
- a.EqualValues(testCase.expectedChunkNum, chunkGroup.Num(),
- "TestCase:%d,len(Chunks)", index)
- a.EqualValues(testCase.fileSize, chunkGroup.Total())
-
- for cIndex, info := range testCase.expectedInfo {
- a.True(chunkGroup.Next())
- a.EqualValues(info[1], chunkGroup.Length(),
- "TestCase:%d,Chunks[%d].Length()", index, cIndex)
- a.EqualValues(info[0], chunkGroup.Start(),
- "TestCase:%d,Chunks[%d].Start()", index, cIndex)
-
- a.Equal(cIndex == len(testCase.expectedInfo)-1, chunkGroup.IsLast(),
- "TestCase:%d,Chunks[%d].IsLast()", index, cIndex)
-
- a.NotEmpty(chunkGroup.RangeHeader())
- }
- a.False(chunkGroup.Next())
- }
-}
-
-func TestChunkGroup_TempAvailablet(t *testing.T) {
- a := assert.New(t)
-
- file := &fsctx.FileStream{Size: 1}
- c := NewChunkGroup(file, 0, &backoff.ConstantBackoff{}, true)
- a.False(c.TempAvailable())
-
- f, err := os.CreateTemp("", "TestChunkGroup_TempAvailablet.*")
- defer func() {
- f.Close()
- os.Remove(f.Name())
- }()
- a.NoError(err)
- c.bufferTemp = f
-
- a.False(c.TempAvailable())
- f.Write([]byte("1"))
- a.True(c.TempAvailable())
-
-}
-
-func TestChunkGroup_Process(t *testing.T) {
- a := assert.New(t)
- file := &fsctx.FileStream{Size: 10}
-
- // success
- {
- file.File = io.NopCloser(strings.NewReader("1234567890"))
- c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{}, true)
- count := 0
- a.True(c.Next())
- a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- res, err := io.ReadAll(chunk)
- a.NoError(err)
- a.EqualValues("12345", string(res))
- return nil
- }))
- a.True(c.Next())
- a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- res, err := io.ReadAll(chunk)
- a.NoError(err)
- a.EqualValues("67890", string(res))
- return nil
- }))
- a.False(c.Next())
- a.Equal(2, count)
- }
-
- // retry, read from buffer file
- {
- file.File = io.NopCloser(strings.NewReader("1234567890"))
- c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, true)
- count := 0
- a.True(c.Next())
- a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- res, err := io.ReadAll(chunk)
- a.NoError(err)
- a.EqualValues("12345", string(res))
- return nil
- }))
- a.True(c.Next())
- a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- res, err := io.ReadAll(chunk)
- a.NoError(err)
- a.EqualValues("67890", string(res))
- if count == 2 {
- return errors.New("error")
- }
- return nil
- }))
- a.False(c.Next())
- a.Equal(3, count)
- }
-
- // retry, read from seeker
- {
- f, _ := os.CreateTemp("", "TestChunkGroup_Process.*")
- f.Write([]byte("1234567890"))
- f.Seek(0, 0)
- defer func() {
- f.Close()
- os.Remove(f.Name())
- }()
- file.File = f
- file.Seeker = f
- c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false)
- count := 0
- a.True(c.Next())
- a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- res, err := io.ReadAll(chunk)
- a.NoError(err)
- a.EqualValues("12345", string(res))
- return nil
- }))
- a.True(c.Next())
- a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- res, err := io.ReadAll(chunk)
- a.NoError(err)
- a.EqualValues("67890", string(res))
- if count == 2 {
- return errors.New("error")
- }
- return nil
- }))
- a.False(c.Next())
- a.Equal(3, count)
- }
-
- // retry, seek error
- {
- f, _ := os.CreateTemp("", "TestChunkGroup_Process.*")
- f.Write([]byte("1234567890"))
- f.Seek(0, 0)
- defer func() {
- f.Close()
- os.Remove(f.Name())
- }()
- file.File = f
- file.Seeker = f
- c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false)
- count := 0
- a.True(c.Next())
- a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- res, err := io.ReadAll(chunk)
- a.NoError(err)
- a.EqualValues("12345", string(res))
- return nil
- }))
- a.True(c.Next())
- f.Close()
- a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- if count == 2 {
- return errors.New("error")
- }
- return nil
- }))
- a.False(c.Next())
- a.Equal(2, count)
- }
-
- // retry, finally error
- {
- f, _ := os.CreateTemp("", "TestChunkGroup_Process.*")
- f.Write([]byte("1234567890"))
- f.Seek(0, 0)
- defer func() {
- f.Close()
- os.Remove(f.Name())
- }()
- file.File = f
- file.Seeker = f
- c := NewChunkGroup(file, 5, &backoff.ConstantBackoff{Max: 2}, false)
- count := 0
- a.True(c.Next())
- a.NoError(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- res, err := io.ReadAll(chunk)
- a.NoError(err)
- a.EqualValues("12345", string(res))
- return nil
- }))
- a.True(c.Next())
- a.Error(c.Process(func(c *ChunkGroup, chunk io.Reader) error {
- count++
- return errors.New("error")
- }))
- a.False(c.Next())
- a.Equal(4, count)
- }
-}
diff --git a/pkg/filesystem/driver/handler.go b/pkg/filesystem/driver/handler.go
index f232781..f145281 100644
--- a/pkg/filesystem/driver/handler.go
+++ b/pkg/filesystem/driver/handler.go
@@ -3,6 +3,7 @@ package driver
import (
"context"
"fmt"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
diff --git a/pkg/filesystem/driver/local/handler_test.go b/pkg/filesystem/driver/local/handler_test.go
deleted file mode 100644
index b73b564..0000000
--- a/pkg/filesystem/driver/local/handler_test.go
+++ /dev/null
@@ -1,338 +0,0 @@
-package local
-
-import (
- "context"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/auth"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- "io"
- "os"
- "strings"
- "testing"
-)
-
-func TestHandler_Put(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{}
-
- defer func() {
- os.Remove(util.RelativePath("TestHandler_Put.txt"))
- os.Remove(util.RelativePath("inner/TestHandler_Put.txt"))
- }()
-
- testCases := []struct {
- file fsctx.FileHeader
- errContains string
- }{
- {&fsctx.FileStream{
- SavePath: "TestHandler_Put.txt",
- File: io.NopCloser(strings.NewReader("")),
- }, ""},
- {&fsctx.FileStream{
- SavePath: "TestHandler_Put.txt",
- File: io.NopCloser(strings.NewReader("")),
- }, "file with the same name existed or unavailable"},
- {&fsctx.FileStream{
- SavePath: "inner/TestHandler_Put.txt",
- File: io.NopCloser(strings.NewReader("")),
- }, ""},
- {&fsctx.FileStream{
- Mode: fsctx.Append | fsctx.Overwrite,
- SavePath: "inner/TestHandler_Put.txt",
- File: io.NopCloser(strings.NewReader("123")),
- }, ""},
- {&fsctx.FileStream{
- AppendStart: 10,
- Mode: fsctx.Append | fsctx.Overwrite,
- SavePath: "inner/TestHandler_Put.txt",
- File: io.NopCloser(strings.NewReader("123")),
- }, "size of unfinished uploaded chunks is not as expected"},
- {&fsctx.FileStream{
- Mode: fsctx.Append | fsctx.Overwrite,
- SavePath: "inner/TestHandler_Put.txt",
- File: io.NopCloser(strings.NewReader("123")),
- }, ""},
- }
-
- for _, testCase := range testCases {
- err := handler.Put(context.Background(), testCase.file)
- if testCase.errContains != "" {
- asserts.Error(err)
- asserts.Contains(err.Error(), testCase.errContains)
- } else {
- asserts.NoError(err)
- asserts.True(util.Exists(util.RelativePath(testCase.file.Info().SavePath)))
- }
- }
-}
-
-func TestDriver_TruncateFailed(t *testing.T) {
- a := assert.New(t)
- h := Driver{}
- a.Error(h.Truncate(context.Background(), "TestDriver_TruncateFailed", 0))
-}
-
-func TestHandler_Delete(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{}
- ctx := context.Background()
- filePath := util.RelativePath("TestHandler_Delete.file")
-
- file, err := os.Create(filePath)
- asserts.NoError(err)
- _ = file.Close()
- list, err := handler.Delete(ctx, []string{"TestHandler_Delete.file"})
- asserts.Equal([]string{}, list)
- asserts.NoError(err)
-
- file, err = os.Create(filePath)
- _ = file.Close()
- file, _ = os.OpenFile(filePath, os.O_RDWR, os.FileMode(0))
- asserts.NoError(err)
- list, err = handler.Delete(ctx, []string{"TestHandler_Delete.file", "test.notexist"})
- file.Close()
- asserts.Equal([]string{}, list)
- asserts.NoError(err)
-
- list, err = handler.Delete(ctx, []string{"test.notexist"})
- asserts.Equal([]string{}, list)
- asserts.NoError(err)
-
- file, err = os.Create(filePath)
- asserts.NoError(err)
- list, err = handler.Delete(ctx, []string{"TestHandler_Delete.file"})
- _ = file.Close()
- asserts.Equal([]string{}, list)
- asserts.NoError(err)
-}
-
-func TestHandler_Get(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{}
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
-
- // 成功
- file, err := os.Create(util.RelativePath("TestHandler_Get.txt"))
- asserts.NoError(err)
- _ = file.Close()
-
- rs, err := handler.Get(ctx, "TestHandler_Get.txt")
- asserts.NoError(err)
- asserts.NotNil(rs)
-
- // 文件不存在
-
- rs, err = handler.Get(ctx, "TestHandler_Get_notExist.txt")
- asserts.Error(err)
- asserts.Nil(rs)
-}
-
-func TestHandler_Thumb(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{}
- ctx := context.Background()
- file, err := os.Create(util.RelativePath("TestHandler_Thumb._thumb"))
- asserts.NoError(err)
- file.Close()
-
- f := &model.File{
- SourceName: "TestHandler_Thumb",
- MetadataSerialized: map[string]string{
- model.ThumbStatusMetadataKey: model.ThumbStatusExist,
- },
- }
-
- // 正常
- {
- thumb, err := handler.Thumb(ctx, f)
- asserts.NoError(err)
- asserts.NotNil(thumb.Content)
- }
-
- // file 不存在
- {
- f.SourceName = "not_exist"
- _, err := handler.Thumb(ctx, f)
- asserts.Error(err)
- asserts.ErrorIs(err, driver.ErrorThumbNotExist)
- }
-
- // thumb not exist
- {
- f.MetadataSerialized[model.ThumbStatusMetadataKey] = model.ThumbStatusNotExist
- _, err := handler.Thumb(ctx, f)
- asserts.Error(err)
- asserts.ErrorIs(err, driver.ErrorThumbNotExist)
- }
-}
-
-func TestHandler_Source(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{},
- }
- ctx := context.Background()
- auth.General = auth.HMACAuth{SecretKey: []byte("test")}
-
- // 成功
- {
- file := model.File{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "test.jpg",
- }
- ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
- sourceURL, err := handler.Source(ctx, "", 0, false, 0)
- asserts.NoError(err)
- asserts.NotEmpty(sourceURL)
- asserts.Contains(sourceURL, "sign=")
- }
-
- // 下载
- {
- file := model.File{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "test.jpg",
- }
- ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
- sourceURL, err := handler.Source(ctx, "", 0, true, 0)
- asserts.NoError(err)
- asserts.NotEmpty(sourceURL)
- asserts.Contains(sourceURL, "sign=")
- asserts.Contains(sourceURL, "download")
- }
-
- // 无法获取上下文
- {
- sourceURL, err := handler.Source(ctx, "", 0, false, 0)
- asserts.Error(err)
- asserts.Empty(sourceURL)
- }
-
- // 设定了CDN
- {
- handler.Policy.BaseURL = "https://cqu.edu.cn"
- file := model.File{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "test.jpg",
- }
- ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
- sourceURL, err := handler.Source(ctx, "", 0, false, 0)
- asserts.NoError(err)
- asserts.NotEmpty(sourceURL)
- asserts.Contains(sourceURL, "sign=")
- asserts.Contains(sourceURL, "https://cqu.edu.cn")
- }
-
- // 设定了CDN,解析失败
- {
- handler.Policy.BaseURL = string([]byte{0x7f})
- file := model.File{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "test.jpg",
- }
- ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
- sourceURL, err := handler.Source(ctx, "", 0, false, 0)
- asserts.Error(err)
- asserts.Empty(sourceURL)
- }
-}
-
-func TestHandler_GetDownloadURL(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{Policy: &model.Policy{}}
- ctx := context.Background()
- auth.General = auth.HMACAuth{SecretKey: []byte("test")}
-
- // 成功
- {
- file := model.File{
- Model: gorm.Model{
- ID: 1,
- },
- Name: "test.jpg",
- }
- ctx := context.WithValue(ctx, fsctx.FileModelCtx, file)
- downloadURL, err := handler.Source(ctx, "", 10, true, 0)
- asserts.NoError(err)
- asserts.Contains(downloadURL, "sign=")
- }
-
- // 无法获取上下文
- {
- downloadURL, err := handler.Source(ctx, "", 10, true, 0)
- asserts.Error(err)
- asserts.Empty(downloadURL)
- }
-}
-
-func TestHandler_Token(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{},
- }
- ctx := context.Background()
- upSession := &serializer.UploadSession{SavePath: "TestHandler_Token"}
- _, err := handler.Token(ctx, 10, upSession, &fsctx.FileStream{})
- asserts.NoError(err)
-
- file, _ := os.Create("TestHandler_Token")
- defer func() {
- file.Close()
- os.Remove("TestHandler_Token")
- }()
-
- _, err = handler.Token(ctx, 10, upSession, &fsctx.FileStream{})
- asserts.Error(err)
- asserts.Contains(err.Error(), "already exist")
-}
-
-func TestDriver_CancelToken(t *testing.T) {
- a := assert.New(t)
- handler := Driver{}
- a.NoError(handler.CancelToken(context.Background(), &serializer.UploadSession{}))
-}
-
-func TestDriver_List(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{}
- ctx := context.Background()
-
- // 创建测试目录结构
- for _, path := range []string{
- "test/TestDriver_List/parent.txt",
- "test/TestDriver_List/parent_folder2/sub2.txt",
- "test/TestDriver_List/parent_folder1/sub_folder/sub1.txt",
- "test/TestDriver_List/parent_folder1/sub_folder/sub2.txt",
- } {
- f, _ := util.CreatNestedFile(util.RelativePath(path))
- f.Close()
- }
-
- // 非递归列出
- {
- res, err := handler.List(ctx, "test/TestDriver_List", false)
- asserts.NoError(err)
- asserts.Len(res, 3)
- }
-
- // 递归列出
- {
- res, err := handler.List(ctx, "test/TestDriver_List", true)
- asserts.NoError(err)
- asserts.Len(res, 7)
- }
-}
diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go
index 56abbaa..74649ea 100644
--- a/pkg/filesystem/driver/onedrive/api.go
+++ b/pkg/filesystem/driver/onedrive/api.go
@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
"io"
"net/http"
"net/url"
@@ -15,6 +14,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
diff --git a/pkg/filesystem/driver/onedrive/api_test.go b/pkg/filesystem/driver/onedrive/api_test.go
deleted file mode 100644
index a675548..0000000
--- a/pkg/filesystem/driver/onedrive/api_test.go
+++ /dev/null
@@ -1,1155 +0,0 @@
-package onedrive
-
-import (
- "context"
- "errors"
- "fmt"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/cloudreve/Cloudreve/v3/pkg/mq"
- "io"
- "io/ioutil"
- "net/http"
- "strings"
- "testing"
- "time"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
-)
-
-func TestRequest(t *testing.T) {
- asserts := assert.New(t)
- client := Client{
- Policy: &model.Policy{},
- ClientID: "TestRequest",
- Credential: &Credential{
- ExpiresIn: time.Now().Add(time.Duration(100) * time.Hour).Unix(),
- AccessToken: "AccessToken",
- RefreshToken: "RefreshToken",
- },
- }
-
- // 请求发送失败
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: errors.New("error"),
- })
- client.Request = clientMock
- res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader(""))
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Empty(res)
- asserts.Equal("error", err.Error())
- }
-
- // 无法更新凭证
- {
- client.Credential.RefreshToken = ""
- client.Credential.AccessToken = ""
- res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader(""))
- asserts.Error(err)
- asserts.Empty(res)
- client.Credential.RefreshToken = "RefreshToken"
- client.Credential.AccessToken = "AccessToken"
- }
-
- // 无法获取响应正文
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(mockReader("")),
- },
- })
- client.Request = clientMock
- res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader(""))
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Empty(res)
- }
-
- // OneDrive返回错误
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 400,
- Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)),
- },
- })
- client.Request = clientMock
- res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader(""))
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Empty(res)
- asserts.Equal("error msg", err.Error())
- }
-
- // OneDrive返回429错误
- {
- header := http.Header{}
- header.Add("retry-after", "120")
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 429,
- Header: header,
- Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)),
- },
- })
- client.Request = clientMock
- res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader(""))
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Empty(res)
- var retryErr *backoff.RetryableError
- asserts.ErrorAs(err, &retryErr)
- asserts.EqualValues(time.Duration(120)*time.Second, retryErr.RetryAfter)
- }
-
- // OneDrive返回未知响应
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 400,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader(""))
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Empty(res)
- }
-}
-
-func TestFileInfo_GetSourcePath(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- fileInfo := FileInfo{
- Name: "%e6%96%87%e4%bb%b6%e5%90%8d.jpg",
- ParentReference: parentReference{
- Path: "/drive/root:/123/32%201",
- },
- }
- asserts.Equal("123/32 1/%e6%96%87%e4%bb%b6%e5%90%8d.jpg", fileInfo.GetSourcePath())
- }
-
- // 失败
- {
- fileInfo := FileInfo{
- Name: "123.jpg",
- ParentReference: parentReference{
- Path: "/drive/root:/123/%e6%96%87%e4%bb%b6%e5%90%8g",
- },
- }
- asserts.Equal("", fileInfo.GetSourcePath())
- }
-}
-
-func TestClient_GetRequestURL(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
-
- // 出错
- {
- client.Endpoints.EndpointURL = string([]byte{0x7f})
- asserts.Equal("", client.getRequestURL("123"))
- }
-
- // 使用DriverResource
- {
- client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
- asserts.Equal("https://graph.microsoft.com/v1.0/me/drive/123", client.getRequestURL("123"))
- }
-
- // 不使用DriverResource
- {
- client.Endpoints.EndpointURL = "https://graph.microsoft.com/v1.0"
- asserts.Equal("https://graph.microsoft.com/v1.0/123", client.getRequestURL("123", WithDriverResource(false)))
- }
-}
-
-func TestClient_GetSiteIDByURL(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
-
- // 请求失败
- {
- client.Credential.ExpiresIn = 0
- res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
- asserts.Error(err)
- asserts.Empty(res)
-
- }
-
- // 返回未知响应
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Empty(res)
- }
-
- // 返回正常
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"id":"123321"}`)),
- },
- })
- client.Request = clientMock
- res, err := client.GetSiteIDByURL(context.Background(), "https://cquedu.sharepoint.com")
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.NotEmpty(res)
- asserts.Equal("123321", res)
- }
-}
-
-func TestClient_Meta(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
-
- // 请求失败
- {
- client.Credential.ExpiresIn = 0
- res, err := client.Meta(context.Background(), "", "123")
- asserts.Error(err)
- asserts.Nil(res)
-
- }
-
- // 返回未知响应
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.Meta(context.Background(), "", "123")
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 返回正常
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)),
- },
- })
- client.Request = clientMock
- res, err := client.Meta(context.Background(), "", "123")
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.NotNil(res)
- asserts.Equal("123321", res.Name)
- }
-
- // 返回正常, 使用资源id
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)),
- },
- })
- client.Request = clientMock
- res, err := client.Meta(context.Background(), "123321", "123")
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.NotNil(res)
- asserts.Equal("123321", res.Name)
- }
-}
-
-func TestClient_CreateUploadSession(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
-
- // 请求失败
- {
- client.Credential.ExpiresIn = 0
- res, err := client.CreateUploadSession(context.Background(), "123.jpg")
- asserts.Error(err)
- asserts.Empty(res)
-
- }
-
- // 返回未知响应
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.CreateUploadSession(context.Background(), "123.jpg")
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Empty(res)
- }
-
- // 返回正常
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
- },
- })
- client.Request = clientMock
- res, err := client.CreateUploadSession(context.Background(), "123.jpg", WithConflictBehavior("fail"))
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.NotNil(res)
- asserts.Equal("123321", res)
- }
-}
-
-func TestClient_GetUploadSessionStatus(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
-
- // 请求失败
- {
- client.Credential.ExpiresIn = 0
- res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com")
- asserts.Error(err)
- asserts.Empty(res)
-
- }
-
- // 返回未知响应
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com")
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 返回正常
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
- },
- })
- client.Request = clientMock
- res, err := client.GetUploadSessionStatus(context.Background(), "http://dev.com")
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.NotNil(res)
- asserts.Equal("123321", res.UploadURL)
- }
-}
-
-func TestClient_UploadChunk(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- cg := chunk.NewChunkGroup(&fsctx.FileStream{Size: 15}, 10, &backoff.ConstantBackoff{}, false)
-
- // 非最后分片,正常
- {
- cg.Next()
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "PUT",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"http://dev.com/2"}`)),
- },
- })
- client.Request = clientMock
- res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("1234567890"), cg)
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.Equal("http://dev.com/2", res.UploadURL)
- }
-
- // 非最后分片,异常响应
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "PUT",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("1234567890"), cg)
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 最后分片,正常
- {
- cg.Next()
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "PUT",
- "http://dev.com",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("12345"), cg)
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.Nil(res)
- }
-
- // 最后分片,失败
- {
- cache.Set("setting_chunk_retries", "1", 0)
- client.Credential.ExpiresIn = 0
- go func() {
- time.Sleep(time.Duration(2) * time.Second)
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- }()
- clientMock := ClientMock{}
- client.Request = clientMock
- res, err := client.UploadChunk(context.Background(), "http://dev.com", strings.NewReader("12345"), cg)
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Nil(res)
- }
-}
-
-func TestClient_Upload(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- ctx := context.Background()
- cache.Set("setting_chunk_retries", "1", 0)
- cache.Set("setting_use_temp_chunk_buffer", "false", 0)
-
- // 小文件,简单上传,失败
- {
- client.Credential.ExpiresIn = 0
- err := client.Upload(ctx, &fsctx.FileStream{
- Size: 5,
- File: io.NopCloser(strings.NewReader("12345")),
- })
- asserts.Error(err)
- }
-
- // 无法创建分片会话
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 400,
- Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
- },
- })
- client.Request = clientMock
- err := client.Upload(context.Background(), &fsctx.FileStream{
- Size: SmallFileSize + 1,
- File: io.NopCloser(strings.NewReader("12345")),
- })
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- }
-
- // 分片上传失败
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
- },
- })
- clientMock.On(
- "Request",
- "PUT",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 400,
- Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
- },
- })
- client.Request = clientMock
- err := client.Upload(context.Background(), &fsctx.FileStream{
- Size: SmallFileSize + 1,
- File: io.NopCloser(strings.NewReader("12345")),
- })
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Contains(err.Error(), "failed to upload chunk")
- }
-
-}
-
-func TestClient_SimpleUpload(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- cache.Set("setting_chunk_retries", "1", 0)
-
- // 请求失败
- {
- client.Credential.ExpiresIn = 0
- res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3)
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 返回未知响应
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "PUT",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3)
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 返回正常
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "PUT",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"name":"123321"}`)),
- },
- })
- client.Request = clientMock
- res, err := client.SimpleUpload(context.Background(), "123.jpg", strings.NewReader("123"), 3)
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.NotNil(res)
- asserts.Equal("123321", res.Name)
- }
-}
-
-func TestClient_DeleteUploadSession(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
-
- // 请求失败
- {
- client.Credential.ExpiresIn = 0
- err := client.DeleteUploadSession(context.Background(), "123.jpg")
- asserts.Error(err)
-
- }
-
- // 返回正常
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "DELETE",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 204,
- Body: ioutil.NopCloser(strings.NewReader(``)),
- },
- })
- client.Request = clientMock
- err := client.DeleteUploadSession(context.Background(), "123.jpg")
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- }
-}
-
-func TestClient_BatchDelete(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
-
- // 小于20个,失败1个
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"responses":[{"id":"2","status":400}]}`)),
- },
- })
- client.Request = clientMock
- res, err := client.BatchDelete(context.Background(), []string{"1", "2", "3", "1", "2"})
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Equal([]string{"2"}, res)
- }
-}
-
-func TestClient_Delete(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
-
- // 请求失败
- {
- client.Credential.ExpiresIn = 0
- res, err := client.Delete(context.Background(), []string{"1", "2", "3"})
- asserts.Error(err)
- asserts.Len(res, 3)
- }
-
- // 返回未知响应
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.Delete(context.Background(), []string{"1", "2", "3"})
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Len(res, 3)
- }
-
- // 成功2两个文件
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"responses":[{"id":"2","status":400}]}`)),
- },
- })
- client.Request = clientMock
- res, err := client.Delete(context.Background(), []string{"1", "2", "3"})
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Equal([]string{"2"}, res)
- }
-}
-
-func TestClient_ListChildren(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
-
- // 根目录,请求失败,重测试
- {
- client.Credential.ExpiresIn = 0
- res, err := client.ListChildren(context.Background(), "/")
- asserts.Error(err)
- asserts.Empty(res)
- }
-
- // 非根目录,未知响应
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.ListChildren(context.Background(), "/uploads")
- asserts.Error(err)
- asserts.Empty(res)
- }
-
- // 非根目录,成功
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"value":[{}]}`)),
- },
- })
- client.Request = clientMock
- res, err := client.ListChildren(context.Background(), "/uploads")
- asserts.NoError(err)
- asserts.Len(res, 1)
- }
-}
-
-func TestClient_GetThumbURL(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.AccessToken = "AccessToken"
-
- // 请求失败
- {
- client.Credential.ExpiresIn = 0
- res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1)
- asserts.Error(err)
- asserts.Empty(res)
- }
-
- // 未知响应
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
- res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1)
- asserts.Error(err)
- asserts.Empty(res)
- }
-
- // 世纪互联 成功
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- client.Endpoints.isInChina = true
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"url":"thumb"}`)),
- },
- })
- client.Request = clientMock
- res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1)
- asserts.NoError(err)
- asserts.Equal("thumb", res)
- }
-
- // 非世纪互联 成功
- {
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- client.Endpoints.isInChina = false
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"large":{"url":"thumb"}}]}`)),
- },
- })
- client.Request = clientMock
- res, err := client.GetThumbURL(context.Background(), "123,jpg", 1, 1)
- asserts.NoError(err)
- asserts.Equal("thumb", res)
- }
-}
-
-func TestClient_MonitorUpload(t *testing.T) {
- asserts := assert.New(t)
- client, _ := NewClient(&model.Policy{})
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
-
- // 客户端完成回调
- {
- cache.Set("setting_onedrive_monitor_timeout", "600", 0)
- cache.Set("setting_onedrive_callback_check", "20", 0)
- asserts.NotPanics(func() {
- go func() {
- time.Sleep(time.Duration(1) * time.Second)
- mq.GlobalMQ.Publish("key", mq.Message{})
- }()
- client.MonitorUpload("url", "key", "path", 10, 10)
- })
- }
-
- // 上传会话到期,仍未完成上传,创建占位符
- {
- cache.Set("setting_onedrive_monitor_timeout", "600", 0)
- cache.Set("setting_onedrive_callback_check", "20", 0)
- asserts.NotPanics(func() {
- client.MonitorUpload("url", "key", "path", 10, 0)
- })
- }
-
- fmt.Println("测试:上传已完成,未发送回调")
- // 上传已完成,未发送回调
- {
- cache.Set("setting_onedrive_monitor_timeout", "0", 0)
- cache.Set("setting_onedrive_callback_check", "0", 0)
-
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- client.Credential.AccessToken = "1"
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 404,
- Body: ioutil.NopCloser(strings.NewReader(`{"error":{"code":"itemNotFound"}}`)),
- },
- })
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 404,
- Body: ioutil.NopCloser(strings.NewReader(`{"error":{"code":"itemNotFound"}}`)),
- },
- })
- client.Request = clientMock
- cache.Set("callback_key3", "ok", 0)
-
- asserts.NotPanics(func() {
- client.MonitorUpload("url", "key3", "path", 10, 10)
- })
-
- clientMock.AssertExpectations(t)
- }
-
- fmt.Println("测试:上传仍未开始")
- // 上传仍未开始
- {
- cache.Set("setting_onedrive_monitor_timeout", "0", 0)
- cache.Set("setting_onedrive_callback_check", "0", 0)
-
- client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- client.Credential.AccessToken = "1"
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"nextExpectedRanges":["0-"]}`)),
- },
- })
- clientMock.On(
- "Request",
- "DELETE",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(``)),
- },
- })
- clientMock.On(
- "Request",
- "PUT",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{}`)),
- },
- })
- client.Request = clientMock
-
- asserts.NotPanics(func() {
- client.MonitorUpload("url", "key4", "path", 10, 10)
- })
-
- clientMock.AssertExpectations(t)
- }
-
-}
diff --git a/pkg/filesystem/driver/onedrive/client.go b/pkg/filesystem/driver/onedrive/client.go
index 957af8e..89e696b 100644
--- a/pkg/filesystem/driver/onedrive/client.go
+++ b/pkg/filesystem/driver/onedrive/client.go
@@ -2,6 +2,7 @@ package onedrive
import (
"errors"
+
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
model "github.com/cloudreve/Cloudreve/v3/models"
diff --git a/pkg/filesystem/driver/onedrive/client_test.go b/pkg/filesystem/driver/onedrive/client_test.go
deleted file mode 100644
index aa3c132..0000000
--- a/pkg/filesystem/driver/onedrive/client_test.go
+++ /dev/null
@@ -1,32 +0,0 @@
-package onedrive
-
-import (
- "testing"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/stretchr/testify/assert"
-)
-
-func TestNewClient(t *testing.T) {
- asserts := assert.New(t)
- // getOAuthEndpoint失败
- {
- policy := model.Policy{
- BaseURL: string([]byte{0x7f}),
- }
- res, err := NewClient(&policy)
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 成功
- {
- policy := model.Policy{}
- res, err := NewClient(&policy)
- asserts.NoError(err)
- asserts.NotNil(res)
- asserts.NotNil(res.Credential)
- asserts.NotNil(res.Endpoints)
- asserts.NotNil(res.Endpoints.OAuthEndpoints)
- }
-}
diff --git a/pkg/filesystem/driver/onedrive/handler_test.go b/pkg/filesystem/driver/onedrive/handler_test.go
deleted file mode 100644
index 2c9c2c2..0000000
--- a/pkg/filesystem/driver/onedrive/handler_test.go
+++ /dev/null
@@ -1,420 +0,0 @@
-package onedrive
-
-import (
- "context"
- "fmt"
- "github.com/cloudreve/Cloudreve/v3/pkg/mq"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "github.com/jinzhu/gorm"
- "io"
- "io/ioutil"
- "net/http"
- "strings"
- "testing"
- "time"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
-)
-
-func TestDriver_Token(t *testing.T) {
- asserts := assert.New(t)
- h, _ := NewDriver(&model.Policy{
- AccessKey: "ak",
- SecretKey: "sk",
- BucketName: "test",
- Server: "test.com",
- })
- handler := h.(Driver)
-
- // 分片上传 失败
- {
- cache.Set("setting_siteURL", "http://test.cloudreve.org", 0)
- handler.Client, _ = NewClient(&model.Policy{})
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 400,
- Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
- },
- })
- handler.Client.Request = clientMock
- res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{})
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 分片上传 成功
- {
- cache.Set("setting_siteURL", "http://test.cloudreve.org", 0)
- cache.Set("setting_onedrive_monitor_timeout", "600", 0)
- cache.Set("setting_onedrive_callback_check", "20", 0)
- handler.Client, _ = NewClient(&model.Policy{})
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- handler.Client.Credential.AccessToken = "1"
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
- },
- })
- handler.Client.Request = clientMock
- go func() {
- time.Sleep(time.Duration(1) * time.Second)
- mq.GlobalMQ.Publish("TestDriver_Token", mq.Message{})
- }()
- res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{Key: "TestDriver_Token"}, &fsctx.FileStream{})
- asserts.NoError(err)
- asserts.Equal("123321", res.UploadURLs[0])
- }
-}
-
-func TestDriver_Source(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- AccessKey: "ak",
- SecretKey: "sk",
- BucketName: "test",
- Server: "test.com",
- },
- }
- handler.Client, _ = NewClient(&model.Policy{})
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- cache.Set("setting_onedrive_source_timeout", "1800", 0)
-
- // 失败
- {
- res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0)
- asserts.Error(err)
- asserts.Empty(res)
- }
-
- // 命中缓存 成功
- {
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- handler.Client.Credential.AccessToken = "1"
- cache.Set("onedrive_source_0_123.jpg", "res", 1)
- res, err := handler.Source(context.Background(), "123.jpg", 0, true, 0)
- cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
- asserts.NoError(err)
- asserts.Equal("res", res)
- }
-
- // 命中缓存 上下文存在文件 成功
- {
- file := model.File{}
- file.ID = 1
- file.UpdatedAt = time.Now()
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- handler.Client.Credential.AccessToken = "1"
- cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0)
- res, err := handler.Source(ctx, "123.jpg", 1, true, 0)
- cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
- asserts.NoError(err)
- asserts.Equal("res", res)
- }
-
- // 成功
- {
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)),
- },
- })
- handler.Client.Request = clientMock
- handler.Client.Credential.AccessToken = "1"
- res, err := handler.Source(context.Background(), "123.jpg", 1, true, 0)
- asserts.NoError(err)
- asserts.Equal("123321", res)
- }
-}
-
-func TestDriver_List(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- AccessKey: "ak",
- SecretKey: "sk",
- BucketName: "test",
- Server: "test.com",
- },
- }
- handler.Client, _ = NewClient(&model.Policy{})
- handler.Client.Credential.AccessToken = "AccessToken"
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
-
- // 非递归
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"value":[{}]}`)),
- },
- })
- handler.Client.Request = clientMock
- res, err := handler.List(context.Background(), "/", false)
- asserts.NoError(err)
- asserts.Len(res, 1)
- }
-
- // 递归一次
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- "me/drive/root/children?$top=999999999",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"1","folder":{}}]}`)),
- },
- })
- clientMock.On(
- "Request",
- "GET",
- "me/drive/root:/1:/children?$top=999999999",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"2"}]}`)),
- },
- })
- handler.Client.Request = clientMock
- res, err := handler.List(context.Background(), "/", true)
- asserts.NoError(err)
- asserts.Len(res, 2)
- }
-}
-
-func TestDriver_Thumb(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- AccessKey: "ak",
- SecretKey: "sk",
- BucketName: "test",
- Server: "test.com",
- },
- }
- handler.Client, _ = NewClient(&model.Policy{})
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- file := &model.File{PicInfo: "1,1", Model: gorm.Model{ID: 1}}
-
- // 失败
- {
- ctx := context.WithValue(context.Background(), fsctx.ThumbSizeCtx, [2]uint{10, 20})
- res, err := handler.Thumb(ctx, file)
- asserts.Error(err)
- asserts.Empty(res.URL)
- }
-
- // 上下文错误
- {
- _, err := handler.Thumb(context.Background(), file)
- asserts.Error(err)
- }
-}
-
-func TestDriver_Delete(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- AccessKey: "ak",
- SecretKey: "sk",
- BucketName: "test",
- Server: "test.com",
- },
- }
- handler.Client, _ = NewClient(&model.Policy{})
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
-
- // 失败
- {
- _, err := handler.Delete(context.Background(), []string{"1"})
- asserts.Error(err)
- }
-
-}
-
-func TestDriver_Put(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- AccessKey: "ak",
- SecretKey: "sk",
- BucketName: "test",
- Server: "test.com",
- },
- }
- handler.Client, _ = NewClient(&model.Policy{})
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
-
- // 失败
- {
- err := handler.Put(context.Background(), &fsctx.FileStream{})
- asserts.Error(err)
- }
-}
-
-func TestDriver_Get(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- AccessKey: "ak",
- SecretKey: "sk",
- BucketName: "test",
- Server: "test.com",
- },
- }
- handler.Client, _ = NewClient(&model.Policy{})
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
-
- // 无法获取source
- {
- res, err := handler.Get(context.Background(), "123.txt")
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 成功
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)),
- },
- })
- handler.Client.Request = clientMock
- handler.Client.Credential.AccessToken = "1"
-
- driverClientMock := ClientMock{}
- driverClientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`123`)),
- },
- })
- handler.HTTPClient = driverClientMock
- res, err := handler.Get(context.Background(), "123.txt")
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- _, err = res.Seek(0, io.SeekEnd)
- asserts.NoError(err)
- content, err := ioutil.ReadAll(res)
- asserts.NoError(err)
- asserts.Equal("123", string(content))
-}
-
-func TestDriver_replaceSourceHost(t *testing.T) {
- tests := []struct {
- name string
- origin string
- cdn string
- want string
- wantErr bool
- }{
- {"TestNoReplace", "http://1dr.ms/download.aspx?123456", "", "http://1dr.ms/download.aspx?123456", false},
- {"TestReplaceCorrect", "http://1dr.ms/download.aspx?123456", "https://test.com:8080", "https://test.com:8080/download.aspx?123456", false},
- {"TestCdnFormatError", "http://1dr.ms/download.aspx?123456", string([]byte{0x7f}), "", true},
- {"TestSrcFormatError", string([]byte{0x7f}), "https://test.com:8080", "", true},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- policy := &model.Policy{}
- policy.OptionsSerialized.OdProxy = tt.cdn
- handler := Driver{
- Policy: policy,
- }
- got, err := handler.replaceSourceHost(tt.origin)
- if (err != nil) != tt.wantErr {
- t.Errorf("replaceSourceHost() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if got != tt.want {
- t.Errorf("replaceSourceHost() got = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestDriver_CancelToken(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- AccessKey: "ak",
- SecretKey: "sk",
- BucketName: "test",
- Server: "test.com",
- },
- }
- handler.Client, _ = NewClient(&model.Policy{})
- handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
-
- // 失败
- {
- err := handler.CancelToken(context.Background(), &serializer.UploadSession{})
- asserts.Error(err)
- }
-}
diff --git a/pkg/filesystem/driver/onedrive/lock.go b/pkg/filesystem/driver/onedrive/lock.go
new file mode 100755
index 0000000..655936b
--- /dev/null
+++ b/pkg/filesystem/driver/onedrive/lock.go
@@ -0,0 +1,25 @@
+package onedrive
+
+import "sync"
+
+// CredentialLock 针对存储策略凭证的锁
+type CredentialLock interface {
+ Lock(uint)
+ Unlock(uint)
+}
+
+var GlobalMutex = mutexMap{}
+
+type mutexMap struct {
+ locks sync.Map
+}
+
+func (m *mutexMap) Lock(id uint) {
+ lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
+ lock.(*sync.Mutex).Lock()
+}
+
+func (m *mutexMap) Unlock(id uint) {
+ lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
+ lock.(*sync.Mutex).Unlock()
+}
diff --git a/pkg/filesystem/driver/onedrive/oauth_test.go b/pkg/filesystem/driver/onedrive/oauth_test.go
deleted file mode 100644
index b2525b7..0000000
--- a/pkg/filesystem/driver/onedrive/oauth_test.go
+++ /dev/null
@@ -1,386 +0,0 @@
-package onedrive
-
-import (
- "context"
- "database/sql"
- "errors"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
- "io"
- "io/ioutil"
- "net/http"
- "net/url"
- "strings"
- "testing"
- "time"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
-)
-
-var mock sqlmock.Sqlmock
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- defer db.Close()
- m.Run()
-}
-
-func TestGetOAuthEndpoint(t *testing.T) {
- asserts := assert.New(t)
-
- // URL解析失败
- {
- client := Client{
- Endpoints: &Endpoints{
- OAuthURL: string([]byte{0x7f}),
- },
- }
- res := client.getOAuthEndpoint()
- asserts.Nil(res)
- }
-
- {
- testCase := []struct {
- OAuthURL string
- token string
- auth string
- isChina bool
- }{
- {
- OAuthURL: "http://login.live.com",
- token: "https://login.live.com/oauth20_token.srf",
- auth: "https://login.live.com/oauth20_authorize.srf",
- isChina: false,
- },
- {
- OAuthURL: "http://login.chinacloudapi.cn",
- token: "https://login.chinacloudapi.cn/common/oauth2/v2.0/token",
- auth: "https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize",
- isChina: true,
- },
- {
- OAuthURL: "other",
- token: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
- auth: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
- isChina: false,
- },
- }
-
- for i, testCase := range testCase {
- client := Client{
- Endpoints: &Endpoints{
- OAuthURL: testCase.OAuthURL,
- },
- }
- res := client.getOAuthEndpoint()
- asserts.Equal(testCase.token, res.token.String(), "Test Case #%d", i)
- asserts.Equal(testCase.auth, res.authorize.String(), "Test Case #%d", i)
- asserts.Equal(testCase.isChina, client.Endpoints.isInChina, "Test Case #%d", i)
- }
- }
-}
-
-func TestClient_OAuthURL(t *testing.T) {
- asserts := assert.New(t)
-
- client := Client{
- ClientID: "client_id",
- Redirect: "http://cloudreve.org/callback",
- Endpoints: &Endpoints{},
- }
- client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint()
- res, err := url.Parse(client.OAuthURL(context.Background(), []string{"scope1", "scope2"}))
- asserts.NoError(err)
- query := res.Query()
- asserts.Equal("client_id", query.Get("client_id"))
- asserts.Equal("scope1 scope2", query.Get("scope"))
- asserts.Equal(client.Redirect, query.Get("redirect_uri"))
-
-}
-
-type ClientMock struct {
- testMock.Mock
-}
-
-func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
- args := m.Called(method, target, body, opts)
- return args.Get(0).(*request.Response)
-}
-
-type mockReader string
-
-func (r mockReader) Read(b []byte) (int, error) {
- return 0, errors.New("read error")
-}
-
-func TestClient_ObtainToken(t *testing.T) {
- asserts := assert.New(t)
-
- client := Client{
- Endpoints: &Endpoints{},
- ClientID: "ClientID",
- ClientSecret: "ClientSecret",
- Redirect: "Redirect",
- }
- client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint()
-
- // 刷新Token 成功
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- client.Endpoints.OAuthEndpoints.token.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"access_token":"i am token"}`)),
- },
- })
- client.Request = clientMock
-
- res, err := client.ObtainToken(context.Background())
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.NotNil(res)
- asserts.Equal("i am token", res.AccessToken)
- }
-
- // 重新获取 无法发送请求
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- client.Endpoints.OAuthEndpoints.token.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: errors.New("error"),
- })
- client.Request = clientMock
-
- res, err := client.ObtainToken(context.Background(), WithCode("code"))
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 刷新Token 无法获取响应正文
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- client.Endpoints.OAuthEndpoints.token.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(mockReader("")),
- },
- })
- client.Request = clientMock
-
- res, err := client.ObtainToken(context.Background())
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Nil(res)
- asserts.Equal("read error", err.Error())
- }
-
- // 刷新Token OneDrive返回错误
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- client.Endpoints.OAuthEndpoints.token.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 400,
- Body: ioutil.NopCloser(strings.NewReader(`{"error":"i am error"}`)),
- },
- })
- client.Request = clientMock
-
- res, err := client.ObtainToken(context.Background())
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Nil(res)
- asserts.Equal("", err.Error())
- }
-
- // 刷新Token OneDrive未知响应
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- client.Endpoints.OAuthEndpoints.token.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 400,
- Body: ioutil.NopCloser(strings.NewReader(`???`)),
- },
- })
- client.Request = clientMock
-
- res, err := client.ObtainToken(context.Background())
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Nil(res)
- }
-}
-
-func TestClient_UpdateCredential(t *testing.T) {
- asserts := assert.New(t)
- client := Client{
- Policy: &model.Policy{Model: gorm.Model{ID: 257}},
- Endpoints: &Endpoints{},
- ClientID: "TestClient_UpdateCredential",
- ClientSecret: "ClientSecret",
- Redirect: "Redirect",
- Credential: &Credential{},
- }
- client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint()
-
- // 无有效的RefreshToken
- {
- err := client.UpdateCredential(context.Background(), false)
- asserts.Equal(ErrInvalidRefreshToken, err)
- client.Credential = nil
- err = client.UpdateCredential(context.Background(), false)
- asserts.Equal(ErrInvalidRefreshToken, err)
- }
-
- // 成功
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- client.Endpoints.OAuthEndpoints.token.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"expires_in":3600,"refresh_token":"new_refresh_token","access_token":"i am token"}`)),
- },
- })
- client.Request = clientMock
- client.Credential = &Credential{
- RefreshToken: "old_refresh_token",
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := client.UpdateCredential(context.Background(), false)
- clientMock.AssertExpectations(t)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- cacheRes, ok := cache.Get("onedrive_TestClient_UpdateCredential")
- asserts.True(ok)
- cacheCredential := cacheRes.(Credential)
- asserts.Equal("new_refresh_token", cacheCredential.RefreshToken)
- asserts.Equal("i am token", cacheCredential.AccessToken)
- }
-
- // OneDrive返回错误
- {
- cache.Deletes([]string{"TestClient_UpdateCredential"}, "onedrive_")
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- client.Endpoints.OAuthEndpoints.token.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 400,
- Body: ioutil.NopCloser(strings.NewReader(`{"error":"error"}`)),
- },
- })
- client.Request = clientMock
- client.Credential = &Credential{
- RefreshToken: "old_refresh_token",
- }
- err := client.UpdateCredential(context.Background(), false)
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- }
-
- // 从缓存中获取
- {
- cache.Set("onedrive_TestClient_UpdateCredential", Credential{
- ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(),
- AccessToken: "AccessToken",
- RefreshToken: "RefreshToken",
- }, 0)
- client.Credential = &Credential{
- RefreshToken: "old_refresh_token",
- }
- err := client.UpdateCredential(context.Background(), false)
- asserts.NoError(err)
- asserts.Equal("AccessToken", client.Credential.AccessToken)
- asserts.Equal("RefreshToken", client.Credential.RefreshToken)
- }
-
- // 无需重新获取
- {
- client.Credential = &Credential{
- RefreshToken: "old_refresh_token",
- AccessToken: "AccessToken2",
- ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(),
- }
- err := client.UpdateCredential(context.Background(), false)
- asserts.NoError(err)
- asserts.Equal("AccessToken2", client.Credential.AccessToken)
- }
-
- // slave failed
- {
- mockController := &controllermock.SlaveControllerMock{}
- mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("", errors.New("error"))
- client.ClusterController = mockController
- err := client.UpdateCredential(context.Background(), true)
- asserts.Error(err)
- }
-
- // slave success
- {
- mockController := &controllermock.SlaveControllerMock{}
- mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("AccessToken3", nil)
- client.ClusterController = mockController
- err := client.UpdateCredential(context.Background(), true)
- asserts.NoError(err)
- asserts.Equal("AccessToken3", client.Credential.AccessToken)
- }
-}
diff --git a/pkg/filesystem/driver/oss/handler.go b/pkg/filesystem/driver/oss/handler.go
index 2ae50a3..ccccbd2 100644
--- a/pkg/filesystem/driver/oss/handler.go
+++ b/pkg/filesystem/driver/oss/handler.go
@@ -172,14 +172,24 @@ func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([
if err != nil {
continue
}
- res = append(res, response.Object{
- Name: path.Base(object.Key),
- Source: object.Key,
- RelativePath: filepath.ToSlash(rel),
- Size: uint64(object.Size),
- IsDir: false,
- LastModify: object.LastModified,
- })
+ if strings.HasSuffix(object.Key, "/") {
+ res = append(res, response.Object{
+ Name: path.Base(object.Key),
+ RelativePath: filepath.ToSlash(rel),
+ Size: 0,
+ IsDir: true,
+ LastModify: time.Now(),
+ })
+ } else {
+ res = append(res, response.Object{
+ Name: path.Base(object.Key),
+ Source: object.Key,
+ RelativePath: filepath.ToSlash(rel),
+ Size: uint64(object.Size),
+ IsDir: false,
+ LastModify: object.LastModified,
+ })
+ }
}
return res, nil
diff --git a/pkg/filesystem/driver/remote/client_test.go b/pkg/filesystem/driver/remote/client_test.go
deleted file mode 100644
index c195521..0000000
--- a/pkg/filesystem/driver/remote/client_test.go
+++ /dev/null
@@ -1,262 +0,0 @@
-package remote
-
-import (
- "context"
- "errors"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
- "io/ioutil"
- "net/http"
- "strings"
- "testing"
-)
-
-func TestNewClient(t *testing.T) {
- a := assert.New(t)
- policy := &model.Policy{}
-
- // 无法解析服务端url
- {
- policy.Server = string([]byte{0x7f})
- c, err := NewClient(policy)
- a.Error(err)
- a.Nil(c)
- }
-
- // 成功
- {
- policy.Server = ""
- c, err := NewClient(policy)
- a.NoError(err)
- a.NotNil(c)
- }
-}
-
-func TestRemoteClient_Upload(t *testing.T) {
- a := assert.New(t)
- c, _ := NewClient(&model.Policy{})
-
- // 无法创建上传会话
- {
- clientMock := requestmock.RequestMock{}
- c.(*remoteClient).httpClient = &clientMock
- clientMock.On(
- "Request",
- "PUT",
- "upload",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: errors.New("error"),
- })
- err := c.Upload(context.Background(), &fsctx.FileStream{})
- a.Error(err)
- a.Contains(err.Error(), "error")
- clientMock.AssertExpectations(t)
- }
-
- // 分片上传失败,成功删除上传会话
- {
- cache.Set("setting_chunk_retries", "1", 0)
- clientMock := requestmock.RequestMock{}
- c.(*remoteClient).httpClient = &clientMock
- clientMock.On(
- "Request",
- "PUT",
- "upload",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: errors.New("error"),
- })
- clientMock.On(
- "Request",
- "DELETE",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- err := c.Upload(context.Background(), &fsctx.FileStream{})
- a.Error(err)
- a.Contains(err.Error(), "error")
- clientMock.AssertExpectations(t)
- }
-
- // 分片上传失败,无法删除上传会话
- {
- cache.Set("setting_chunk_retries", "1", 0)
- clientMock := requestmock.RequestMock{}
- c.(*remoteClient).httpClient = &clientMock
- clientMock.On(
- "Request",
- "PUT",
- "upload",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: errors.New("error"),
- })
- clientMock.On(
- "Request",
- "DELETE",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: errors.New("error2"),
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- err := c.Upload(context.Background(), &fsctx.FileStream{})
- a.Error(err)
- a.Contains(err.Error(), "error")
- clientMock.AssertExpectations(t)
- }
-
- // 成功
- {
- cache.Set("setting_chunk_retries", "1", 0)
- clientMock := requestmock.RequestMock{}
- c.(*remoteClient).httpClient = &clientMock
- clientMock.On(
- "Request",
- "PUT",
- "upload",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- err := c.Upload(context.Background(), &fsctx.FileStream{})
- a.NoError(err)
- clientMock.AssertExpectations(t)
- }
-}
-
-func TestRemoteClient_CreateUploadSessionFailed(t *testing.T) {
- a := assert.New(t)
- c, _ := NewClient(&model.Policy{})
-
- clientMock := requestmock.RequestMock{}
- c.(*remoteClient).httpClient = &clientMock
- clientMock.On(
- "Request",
- "PUT",
- "upload",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":500,"msg":"error"}`)),
- },
- })
- err := c.Upload(context.Background(), &fsctx.FileStream{})
- a.Error(err)
- a.Contains(err.Error(), "error")
- clientMock.AssertExpectations(t)
-}
-
-func TestRemoteClient_UploadChunkFailed(t *testing.T) {
- a := assert.New(t)
- c, _ := NewClient(&model.Policy{})
-
- clientMock := requestmock.RequestMock{}
- c.(*remoteClient).httpClient = &clientMock
- clientMock.On(
- "Request",
- "POST",
- testMock.Anything,
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":500,"msg":"error"}`)),
- },
- })
- err := c.(*remoteClient).uploadChunk(context.Background(), "", 0, strings.NewReader(""), false, 0)
- a.Error(err)
- a.Contains(err.Error(), "error")
- clientMock.AssertExpectations(t)
-}
-
-func TestRemoteClient_GetUploadURL(t *testing.T) {
- a := assert.New(t)
- c, _ := NewClient(&model.Policy{})
-
- // url 解析失败
- {
- c.(*remoteClient).policy.Server = string([]byte{0x7f})
- res, sign, err := c.GetUploadURL(0, "")
- a.Error(err)
- a.Empty(res)
- a.Empty(sign)
- }
-
- // 成功
- {
- c.(*remoteClient).policy.Server = ""
- res, sign, err := c.GetUploadURL(0, "")
- a.NoError(err)
- a.NotEmpty(res)
- a.NotEmpty(sign)
- }
-}
diff --git a/pkg/filesystem/driver/remote/handler_test.go b/pkg/filesystem/driver/remote/handler_test.go
deleted file mode 100644
index 4f6f239..0000000
--- a/pkg/filesystem/driver/remote/handler_test.go
+++ /dev/null
@@ -1,460 +0,0 @@
-package remote
-
-import (
- "context"
- "errors"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/remoteclientmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "io"
- "io/ioutil"
- "net/http"
- "strings"
- "testing"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/auth"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
-)
-
-func TestNewDriver(t *testing.T) {
- a := assert.New(t)
-
- // remoteClient 初始化失败
- {
- d, err := NewDriver(&model.Policy{Server: string([]byte{0x7f})})
- a.Error(err)
- a.Nil(d)
- }
-
- // 成功
- {
- d, err := NewDriver(&model.Policy{})
- a.NoError(err)
- a.NotNil(d)
- }
-}
-
-func TestHandler_Source(t *testing.T) {
- asserts := assert.New(t)
- auth.General = auth.HMACAuth{SecretKey: []byte("test")}
-
- // 无法获取上下文
- {
- handler := Driver{
- Policy: &model.Policy{Server: "/"},
- AuthInstance: auth.HMACAuth{},
- }
- ctx := context.Background()
- res, err := handler.Source(ctx, "", 0, true, 0)
- asserts.NoError(err)
- asserts.NotEmpty(res)
- }
-
- // 成功
- {
- handler := Driver{
- Policy: &model.Policy{Server: "/"},
- AuthInstance: auth.HMACAuth{},
- }
- file := model.File{
- SourceName: "1.txt",
- }
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
- res, err := handler.Source(ctx, "", 10, true, 0)
- asserts.NoError(err)
- asserts.Contains(res, "api/v3/slave/download/0")
- }
-
- // 成功 自定义CDN
- {
- handler := Driver{
- Policy: &model.Policy{Server: "/", BaseURL: "https://cqu.edu.cn"},
- AuthInstance: auth.HMACAuth{},
- }
- file := model.File{
- SourceName: "1.txt",
- }
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
- res, err := handler.Source(ctx, "", 10, true, 0)
- asserts.NoError(err)
- asserts.Contains(res, "api/v3/slave/download/0")
- asserts.Contains(res, "https://cqu.edu.cn")
- }
-
- // 解析失败 自定义CDN
- {
- handler := Driver{
- Policy: &model.Policy{Server: "/", BaseURL: string([]byte{0x7f})},
- AuthInstance: auth.HMACAuth{},
- }
- file := model.File{
- SourceName: "1.txt",
- }
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
- res, err := handler.Source(ctx, "", 10, true, 0)
- asserts.Error(err)
- asserts.Empty(res)
- }
-
- // 成功 预览
- {
- handler := Driver{
- Policy: &model.Policy{Server: "/"},
- AuthInstance: auth.HMACAuth{},
- }
- file := model.File{
- SourceName: "1.txt",
- }
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
- res, err := handler.Source(ctx, "", 10, false, 0)
- asserts.NoError(err)
- asserts.Contains(res, "api/v3/slave/source/0")
- }
-}
-
-type ClientMock struct {
- testMock.Mock
-}
-
-func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
- args := m.Called(method, target, body, opts)
- return args.Get(0).(*request.Response)
-}
-
-func TestHandler_Delete(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- SecretKey: "test",
- Server: "http://test.com",
- },
- AuthInstance: auth.HMACAuth{},
- }
- ctx := context.Background()
- cache.Set("setting_slave_api_timeout", "60", 0)
-
- // 成功
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test.com/api/v3/slave/delete",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- handler.Client = clientMock
- failed, err := handler.Delete(ctx, []string{"/test1.txt", "test2.txt"})
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.Len(failed, 0)
-
- }
-
- // 结果解析失败
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test.com/api/v3/slave/delete",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":203}`)),
- },
- })
- handler.Client = clientMock
- failed, err := handler.Delete(ctx, []string{"/test1.txt", "test2.txt"})
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Len(failed, 2)
- }
-
- // 一个失败
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test.com/api/v3/slave/delete",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":203,"data":"{\"files\":[\"1\"]}"}`)),
- },
- })
- handler.Client = clientMock
- failed, err := handler.Delete(ctx, []string{"/test1.txt", "test2.txt"})
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Len(failed, 1)
- }
-}
-
-func TestDriver_List(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- SecretKey: "test",
- Server: "http://test.com",
- },
- AuthInstance: auth.HMACAuth{},
- }
- ctx := context.Background()
- cache.Set("setting_slave_api_timeout", "60", 0)
-
- // 成功
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test.com/api/v3/slave/list",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"data":"[{}]"}`)),
- },
- })
- handler.Client = clientMock
- res, err := handler.List(ctx, "/", true)
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- asserts.Len(res, 1)
-
- }
-
- // 响应解析失败
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test.com/api/v3/slave/list",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0,"data":"233"}`)),
- },
- })
- handler.Client = clientMock
- res, err := handler.List(ctx, "/", true)
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Len(res, 0)
- }
-
- // 从机返回错误
- {
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test.com/api/v3/slave/list",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":203}`)),
- },
- })
- handler.Client = clientMock
- res, err := handler.List(ctx, "/", true)
- clientMock.AssertExpectations(t)
- asserts.Error(err)
- asserts.Len(res, 0)
- }
-}
-
-func TestHandler_Get(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- SecretKey: "test",
- Server: "http://test.com",
- },
- AuthInstance: auth.HMACAuth{},
- }
- ctx := context.Background()
-
- // 成功
- {
- ctx = context.WithValue(ctx, fsctx.UserCtx, model.User{})
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- nil,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- handler.Client = clientMock
- resp, err := handler.Get(ctx, "/test.txt")
- clientMock.AssertExpectations(t)
- asserts.NotNil(resp)
- asserts.NoError(err)
- }
-
- // 请求失败
- {
- ctx = context.WithValue(ctx, fsctx.UserCtx, model.User{})
- clientMock := ClientMock{}
- clientMock.On(
- "Request",
- "GET",
- testMock.Anything,
- nil,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 404,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- handler.Client = clientMock
- resp, err := handler.Get(ctx, "/test.txt")
- clientMock.AssertExpectations(t)
- asserts.Nil(resp)
- asserts.Error(err)
- }
-}
-
-func TestHandler_Put(t *testing.T) {
- a := assert.New(t)
- handler, _ := NewDriver(&model.Policy{
- Type: "remote",
- SecretKey: "test",
- Server: "http://test.com",
- })
- clientMock := &remoteclientmock.RemoteClientMock{}
- handler.uploadClient = clientMock
- clientMock.On("Upload", testMock.Anything, testMock.Anything).Return(errors.New("error"))
- a.Error(handler.Put(context.Background(), &fsctx.FileStream{}))
- clientMock.AssertExpectations(t)
-}
-
-func TestHandler_Thumb(t *testing.T) {
- asserts := assert.New(t)
- handler := Driver{
- Policy: &model.Policy{
- Type: "remote",
- SecretKey: "test",
- Server: "http://test.com",
- OptionsSerialized: model.PolicyOption{
- ThumbExts: []string{"txt"},
- },
- },
- AuthInstance: auth.HMACAuth{},
- }
- file := &model.File{
- Name: "1.txt",
- SourceName: "1.txt",
- }
- ctx := context.Background()
- asserts.NoError(cache.Set("setting_preview_timeout", "60", 0))
-
- // no error
- {
- resp, err := handler.Thumb(ctx, file)
- asserts.NoError(err)
- asserts.True(resp.Redirect)
- }
-
- // ext not support
- {
- file.Name = "1.jpg"
- resp, err := handler.Thumb(ctx, file)
- asserts.ErrorIs(err, driver.ErrorThumbNotSupported)
- asserts.Nil(resp)
- }
-}
-
-func TestHandler_Token(t *testing.T) {
- a := assert.New(t)
- handler, _ := NewDriver(&model.Policy{})
-
- // 无法创建上传会话
- {
- clientMock := &remoteclientmock.RemoteClientMock{}
- handler.uploadClient = clientMock
- clientMock.On("CreateUploadSession", testMock.Anything, testMock.Anything, int64(10), false).Return(errors.New("error"))
- res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{})
- a.Error(err)
- a.Contains(err.Error(), "error")
- a.Nil(res)
- clientMock.AssertExpectations(t)
- }
-
- // 无法创建上传地址
- {
- clientMock := &remoteclientmock.RemoteClientMock{}
- handler.uploadClient = clientMock
- clientMock.On("CreateUploadSession", testMock.Anything, testMock.Anything, int64(10), false).Return(nil)
- clientMock.On("GetUploadURL", int64(10), "").Return("", "", errors.New("error"))
- res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{})
- a.Error(err)
- a.Contains(err.Error(), "error")
- a.Nil(res)
- clientMock.AssertExpectations(t)
- }
-
- // 成功
- {
- clientMock := &remoteclientmock.RemoteClientMock{}
- handler.uploadClient = clientMock
- clientMock.On("CreateUploadSession", testMock.Anything, testMock.Anything, int64(10), false).Return(nil)
- clientMock.On("GetUploadURL", int64(10), "").Return("1", "2", nil)
- res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{})
- a.NoError(err)
- a.NotNil(res)
- a.Equal("1", res.UploadURLs[0])
- a.Equal("2", res.Credential)
- clientMock.AssertExpectations(t)
- }
-}
-
-func TestDriver_CancelToken(t *testing.T) {
- a := assert.New(t)
- handler, _ := NewDriver(&model.Policy{})
-
- clientMock := &remoteclientmock.RemoteClientMock{}
- handler.uploadClient = clientMock
- clientMock.On("DeleteUploadSession", testMock.Anything, "key").Return(errors.New("error"))
- err := handler.CancelToken(context.Background(), &serializer.UploadSession{Key: "key"})
- a.Error(err)
- a.Contains(err.Error(), "error")
- clientMock.AssertExpectations(t)
-}
diff --git a/pkg/filesystem/errors.go b/pkg/filesystem/errors.go
index d267038..5c3f231 100644
--- a/pkg/filesystem/errors.go
+++ b/pkg/filesystem/errors.go
@@ -2,7 +2,6 @@ package filesystem
import (
"errors"
-
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
diff --git a/pkg/filesystem/file_test.go b/pkg/filesystem/file_test.go
deleted file mode 100644
index 66f3444..0000000
--- a/pkg/filesystem/file_test.go
+++ /dev/null
@@ -1,669 +0,0 @@
-package filesystem
-
-import (
- "context"
- "errors"
- "os"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/auth"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestFileSystem_AddFile(t *testing.T) {
- asserts := assert.New(t)
- file := fsctx.FileStream{
- Size: 5,
- Name: "1.png",
- SavePath: "/Uploads/1_sad.png",
- }
- folder := model.Folder{
- Model: gorm.Model{
- ID: 1,
- },
- }
- fs := FileSystem{
- User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- Policy: model.Policy{
- Type: "cos",
- Model: gorm.Model{
- ID: 1,
- },
- },
- },
- Policy: &model.Policy{Type: "cos"},
- }
-
- _, err := fs.AddFile(context.Background(), &folder, &file)
-
- asserts.Error(err)
-
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- f, err := fs.AddFile(context.Background(), &folder, &file)
-
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal("/Uploads/1_sad.png", f.SourceName)
-
- // 前置钩子执行失败
- {
- hookExecuted := false
- fs.Use("BeforeAddFile", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
- hookExecuted = true
- return errors.New("error")
- })
- f, err := fs.AddFile(context.Background(), &folder, &file)
- asserts.Error(err)
- asserts.Nil(f)
- asserts.True(hookExecuted)
- }
-
- // 后置钩子执行失败
- {
- hookExecuted := false
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- fs.Hooks = map[string][]Hook{}
- fs.Use("AfterValidateFailed", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
- hookExecuted = true
- return errors.New("error")
- })
- f, err := fs.AddFile(context.Background(), &folder, &file)
- asserts.Error(err)
- asserts.Nil(f)
- asserts.True(hookExecuted)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestFileSystem_GetContent(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{
- User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- Policy: model.Policy{
- Model: gorm.Model{
- ID: 1,
- },
- },
- },
- }
-
- // 文件不存在
- rs, err := fs.GetContent(ctx, 1)
- asserts.Equal(ErrObjectNotExist, err)
- asserts.Nil(rs)
- fs.CleanTargets()
-
- // 未知存储策略
- file, err := os.Create(util.RelativePath("TestFileSystem_GetContent.txt"))
- asserts.NoError(err)
- _ = file.Close()
-
- cache.Deletes([]string{"1"}, "policy_")
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "source_name", "policy_id"}).AddRow(1, "TestFileSystem_GetContent.txt", 1))
- mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "unknown"))
-
- rs, err = fs.GetContent(ctx, 1)
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
- fs.CleanTargets()
-
- // 打开文件失败
- cache.Deletes([]string{"1"}, "policy_")
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "source_name", "policy_id"}).AddRow(1, "TestFileSystem_GetContent2.txt", 1))
- mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type", "source_name"}).AddRow(1, "local", "not exist"))
-
- rs, err = fs.GetContent(ctx, 1)
- asserts.Equal(serializer.CodeIOFailed, err.(serializer.AppError).Code)
- asserts.NoError(mock.ExpectationsWereMet())
- fs.CleanTargets()
-
- // 打开成功
- cache.Deletes([]string{"1"}, "policy_")
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "source_name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetContent.txt", 1, "TestFileSystem_GetContent.txt"))
- mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local"))
-
- rs, err = fs.GetContent(ctx, 1)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestFileSystem_GetDownloadContent(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{
- User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- Policy: model.Policy{
- Model: gorm.Model{
- ID: 599,
- },
- },
- },
- }
- file, err := os.Create(util.RelativePath("TestFileSystem_GetDownloadContent.txt"))
- asserts.NoError(err)
- _ = file.Close()
-
- cache.Deletes([]string{"599"}, "policy_")
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt"))
- mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local"))
-
- // 无限速
- cache.Deletes([]string{"599"}, "policy_")
- _, err = fs.GetDownloadContent(ctx, 1)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- fs.CleanTargets()
-
- // 有限速
- cache.Deletes([]string{"599"}, "policy_")
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id", "source_name"}).AddRow(1, "TestFileSystem_GetDownloadContent.txt", 599, "TestFileSystem_GetDownloadContent.txt"))
- mock.ExpectQuery("SELECT(.+)poli(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(1, "local"))
-
- fs.User.Group.SpeedLimit = 1
- _, err = fs.GetDownloadContent(ctx, 1)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestFileSystem_GroupFileByPolicy(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- files := []model.File{
- model.File{
- PolicyID: 1,
- Name: "1_1.txt",
- },
- model.File{
- PolicyID: 2,
- Name: "2_1.txt",
- },
- model.File{
- PolicyID: 3,
- Name: "3_1.txt",
- },
- model.File{
- PolicyID: 2,
- Name: "2_2.txt",
- },
- model.File{
- PolicyID: 1,
- Name: "1_2.txt",
- },
- }
- fs := FileSystem{}
- policyGroup := fs.GroupFileByPolicy(ctx, files)
- asserts.Equal(map[uint][]*model.File{
- 1: {&files[0], &files[4]},
- 2: {&files[1], &files[3]},
- 3: {&files[2]},
- }, policyGroup)
-}
-
-func TestFileSystem_deleteGroupedFile(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{}
- files := []model.File{
- {
- PolicyID: 1,
- Name: "1_1.txt",
- SourceName: "1_1.txt",
- Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"},
- },
- {
- PolicyID: 2,
- Name: "2_1.txt",
- SourceName: "2_1.txt",
- Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"},
- },
- {
- PolicyID: 3,
- Name: "3_1.txt",
- SourceName: "3_1.txt",
- Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"},
- },
- {
- PolicyID: 2,
- Name: "2_2.txt",
- SourceName: "2_2.txt",
- Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"},
- },
- {
- PolicyID: 1,
- Name: "1_2.txt",
- SourceName: "1_2.txt",
- Policy: model.Policy{Model: gorm.Model{ID: 1}, Type: "local"},
- },
- }
-
- // 全部不存在
- {
- failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files))
- asserts.Equal(map[uint][]string{
- 1: {},
- 2: {},
- 3: {},
- }, failed)
- }
- // 部分不存在
- {
- file, err := os.Create(util.RelativePath("1_1.txt"))
- asserts.NoError(err)
- _ = file.Close()
- failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files))
- asserts.Equal(map[uint][]string{
- 1: {},
- 2: {},
- 3: {},
- }, failed)
- }
- // 部分失败,包含整组未知存储策略导致的失败
- {
- file, err := os.Create(util.RelativePath("1_1.txt"))
- asserts.NoError(err)
- _ = file.Close()
-
- files[1].Policy.Type = "unknown"
- files[3].Policy.Type = "unknown"
- failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files))
- asserts.Equal(map[uint][]string{
- 1: {},
- 2: {"2_1.txt", "2_2.txt"},
- 3: {},
- }, failed)
- }
- // 包含上传会话文件
- {
- sessionID := "session"
- cache.Set(UploadSessionCachePrefix+sessionID, serializer.UploadSession{Key: sessionID}, 0)
- files[1].Policy.Type = "local"
- files[3].Policy.Type = "local"
- files[0].UploadSessionID = &sessionID
- failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files))
- asserts.Equal(map[uint][]string{
- 1: {},
- 2: {},
- 3: {},
- }, failed)
- _, ok := cache.Get(UploadSessionCachePrefix + sessionID)
- asserts.False(ok)
- }
-
- // 包含缩略图
- {
- files[0].MetadataSerialized = map[string]string{
- model.ThumbSidecarMetadataKey: "1",
- }
- failed := fs.deleteGroupedFile(ctx, fs.GroupFileByPolicy(ctx, files))
- asserts.Equal(map[uint][]string{
- 1: {},
- 2: {},
- 3: {},
- }, failed)
- }
-}
-
-func TestFileSystem_GetSource(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- auth.General = auth.HMACAuth{SecretKey: []byte("123")}
-
- // 正常
- {
- fs := FileSystem{
- User: &model.User{Model: gorm.Model{ID: 1}},
- }
- // 清空缓存
- err := cache.Deletes([]string{"siteURL"}, "setting_")
- asserts.NoError(err)
- // 查找文件
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(2, 1).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
- AddRow(2, 35, "1.txt"),
- )
- // 查找上传策略
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}).
- AddRow(35, "local", true),
- )
-
- sourceURL, err := fs.GetSource(ctx, 2)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.NotEmpty(sourceURL)
- fs.CleanTargets()
- }
-
- // 文件不存在
- {
- fs := FileSystem{
- User: &model.User{Model: gorm.Model{ID: 1}},
- }
- // 清空缓存
- err := cache.Deletes([]string{"siteURL"}, "setting_")
- asserts.NoError(err)
- // 查找文件
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(2, 1).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "policy_id", "source_name"}),
- )
-
- sourceURL, err := fs.GetSource(ctx, 2)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(ErrObjectNotExist.Code, err.(serializer.AppError).Code)
- asserts.Empty(sourceURL)
- fs.CleanTargets()
- }
-
- // 未知上传策略
- {
- fs := FileSystem{
- User: &model.User{Model: gorm.Model{ID: 1}},
- }
- // 清空缓存
- err := cache.Deletes([]string{"siteURL"}, "setting_")
- asserts.NoError(err)
- // 查找文件
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(2, 1).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
- AddRow(2, 36, "1.txt"),
- )
- // 查找上传策略
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}).
- AddRow(36, "?", true),
- )
-
- sourceURL, err := fs.GetSource(ctx, 2)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Empty(sourceURL)
- fs.CleanTargets()
- }
-
- // 不允许获取外链
- {
- fs := FileSystem{
- User: &model.User{Model: gorm.Model{ID: 1}},
- }
- // 清空缓存
- err := cache.Deletes([]string{"siteURL"}, "setting_")
- asserts.NoError(err)
- // 查找文件
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(2, 1).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
- AddRow(2, 37, "1.txt"),
- )
- // 查找上传策略
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}).
- AddRow(37, "local", false),
- )
-
- sourceURL, err := fs.GetSource(ctx, 2)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(serializer.CodePolicyNotAllowed, err.(serializer.AppError).Code)
- asserts.Empty(sourceURL)
- fs.CleanTargets()
- }
-}
-
-func TestFileSystem_GetDownloadURL(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{
- User: &model.User{Model: gorm.Model{ID: 1}},
- }
- auth.General = auth.HMACAuth{SecretKey: []byte("123")}
-
- // 正常
- {
- err := cache.Deletes([]string{"35"}, "policy_")
- cache.Set("setting_download_timeout", "20", 0)
- cache.Set("setting_siteURL", "https://cloudreve.org", 0)
- asserts.NoError(err)
- // 查找文件
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35))
- // 查找上传策略
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}).
- AddRow(35, "local", true),
- )
- // 相关设置
- downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.NotEmpty(downloadURL)
- fs.CleanTargets()
- }
-
- // 文件不存在
- {
- err := cache.Deletes([]string{"siteURL"}, "setting_")
- err = cache.Deletes([]string{"35"}, "policy_")
- err = cache.Deletes([]string{"download_timeout"}, "setting_")
- asserts.NoError(err)
- // 查找文件
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}))
-
- downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Empty(downloadURL)
- fs.CleanTargets()
- }
-
- // 未知存储策略
- {
- err := cache.Deletes([]string{"siteURL"}, "setting_")
- err = cache.Deletes([]string{"35"}, "policy_")
- err = cache.Deletes([]string{"download_timeout"}, "setting_")
- asserts.NoError(err)
- // 查找文件
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policy_id"}).AddRow(1, "1.txt", 35))
- // 查找上传策略
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "type", "is_origin_link_enable"}).
- AddRow(35, "unknown", true),
- )
-
- downloadURL, err := fs.GetDownloadURL(ctx, 1, "download_timeout")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Empty(downloadURL)
- fs.CleanTargets()
- }
-}
-
-func TestFileSystem_GetPhysicalFileContent(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{
- User: &model.User{},
- }
-
- // 文件不存在
- {
- rs, err := fs.GetPhysicalFileContent(ctx, "not_exist.txt")
- asserts.Error(err)
- asserts.Nil(rs)
- }
-
- // 成功
- {
- testFile, err := os.Create(util.RelativePath("GetPhysicalFileContent.txt"))
- asserts.NoError(err)
- asserts.NoError(testFile.Close())
-
- rs, err := fs.GetPhysicalFileContent(ctx, "GetPhysicalFileContent.txt")
- asserts.NoError(err)
- asserts.NoError(rs.Close())
- asserts.NotNil(rs)
- }
-}
-
-func TestFileSystem_Preview(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
-
- // 文件不存在
- {
- fs := FileSystem{
- User: &model.User{},
- }
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- resp, err := fs.Preview(ctx, 1, false)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Nil(resp)
- }
-
- // 直接返回文件内容,找不到文件
- {
- fs := FileSystem{
- User: &model.User{},
- }
- fs.FileTarget = []model.File{
- {
- SourceName: "tests/no.txt",
- PolicyID: 1,
- Policy: model.Policy{
- Model: gorm.Model{ID: 1},
- Type: "local",
- },
- },
- }
- resp, err := fs.Preview(ctx, 1, false)
- asserts.Error(err)
- asserts.Nil(resp)
- }
-
- // 直接返回文件内容
- {
- fs := FileSystem{
- User: &model.User{},
- }
- fs.FileTarget = []model.File{
- {
- SourceName: "tests/file1.txt",
- PolicyID: 1,
- Policy: model.Policy{
- Model: gorm.Model{ID: 1},
- Type: "local",
- },
- },
- }
- resp, err := fs.Preview(ctx, 1, false)
- asserts.Error(err)
- asserts.Nil(resp)
- }
-
- // 需要重定向,成功
- {
- fs := FileSystem{
- User: &model.User{},
- }
- fs.FileTarget = []model.File{
- {
- SourceName: "tests/file1.txt",
- PolicyID: 1,
- Policy: model.Policy{
- Model: gorm.Model{ID: 1},
- Type: "remote",
- },
- },
- }
- asserts.NoError(cache.Set("setting_preview_timeout", "233", 0))
- resp, err := fs.Preview(ctx, 1, false)
- asserts.NoError(err)
- asserts.NotNil(resp)
- asserts.True(resp.Redirect)
- }
-
- // 文本文件,大小超出限制
- {
- fs := FileSystem{
- User: &model.User{},
- }
- fs.FileTarget = []model.File{
- {
- SourceName: "tests/file1.txt",
- PolicyID: 1,
- Policy: model.Policy{
- Model: gorm.Model{ID: 1},
- Type: "remote",
- },
- Size: 11,
- },
- }
- asserts.NoError(cache.Set("setting_maxEditSize", "10", 0))
- resp, err := fs.Preview(ctx, 1, true)
- asserts.Equal(ErrFileSizeTooBig, err)
- asserts.Nil(resp)
- }
-}
-
-func TestFileSystem_ResetFileIDIfNotExist(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.WithValue(context.Background(), fsctx.LimitParentCtx, &model.Folder{Model: gorm.Model{ID: 1}})
- fs := FileSystem{
- FileTarget: []model.File{
- {
- FolderID: 2,
- },
- },
- }
- asserts.Equal(ErrObjectNotExist, fs.resetFileIDIfNotExist(ctx, 1))
-}
-
-func TestFileSystem_Search(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := &FileSystem{
- User: &model.User{},
- }
- fs.User.ID = 1
-
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, "k1", "k2").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- res, err := fs.Search(ctx, "k1", "k2")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Len(res, 1)
-}
diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go
index 1e14fa8..d892ed0 100644
--- a/pkg/filesystem/filesystem.go
+++ b/pkg/filesystem/filesystem.go
@@ -3,6 +3,10 @@ package filesystem
import (
"errors"
"fmt"
+ "net/http"
+ "net/url"
+ "sync"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
@@ -22,9 +26,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
cossdk "github.com/tencentyun/cos-go-sdk-v5"
- "net/http"
- "net/url"
- "sync"
)
// FSPool 文件系统资源池
@@ -92,7 +93,7 @@ func (fs *FileSystem) reset() {
func NewFileSystem(user *model.User) (*FileSystem, error) {
fs := getEmptyFS()
fs.User = user
- fs.Policy = &fs.User.Policy
+ fs.Policy = user.GetPolicyID(nil)
// 分配存储策略适配器
err := fs.DispatchHandler()
@@ -122,70 +123,57 @@ func NewAnonymousFileSystem() (*FileSystem, error) {
// DispatchHandler 根据存储策略分配文件适配器
func (fs *FileSystem) DispatchHandler() error {
- if fs.Policy == nil {
- return errors.New("未设置存储策略")
+ handler, err := getNewPolicyHandler(fs.Policy)
+ fs.Handler = handler
+
+ return err
+}
+
+// getNewPolicyHandler 根据存储策略类型字段获取处理器
+func getNewPolicyHandler(policy *model.Policy) (driver.Handler, error) {
+ if policy == nil {
+ return nil, ErrUnknownPolicyType
}
- policyType := fs.Policy.Type
- currentPolicy := fs.Policy
- switch policyType {
+ switch policy.Type {
case "mock", "anonymous":
- return nil
+ return nil, nil
case "local":
- fs.Handler = local.Driver{
- Policy: currentPolicy,
- }
- return nil
+ return local.Driver{
+ Policy: policy,
+ }, nil
case "remote":
- handler, err := remote.NewDriver(currentPolicy)
- if err != nil {
- return err
- }
-
- fs.Handler = handler
+ return remote.NewDriver(policy)
case "qiniu":
- fs.Handler = qiniu.NewDriver(currentPolicy)
- return nil
+ return qiniu.NewDriver(policy), nil
case "oss":
- handler, err := oss.NewDriver(currentPolicy)
- fs.Handler = handler
- return err
+ return oss.NewDriver(policy)
case "upyun":
- fs.Handler = upyun.Driver{
- Policy: currentPolicy,
- }
- return nil
+ return upyun.Driver{
+ Policy: policy,
+ }, nil
case "onedrive":
- var odErr error
- fs.Handler, odErr = onedrive.NewDriver(currentPolicy)
- return odErr
+ return onedrive.NewDriver(policy)
case "cos":
- u, _ := url.Parse(currentPolicy.Server)
+ u, _ := url.Parse(policy.Server)
b := &cossdk.BaseURL{BucketURL: u}
- fs.Handler = cos.Driver{
- Policy: currentPolicy,
+ return cos.Driver{
+ Policy: policy,
Client: cossdk.NewClient(b, &http.Client{
Transport: &cossdk.AuthorizationTransport{
- SecretID: currentPolicy.AccessKey,
- SecretKey: currentPolicy.SecretKey,
+ SecretID: policy.AccessKey,
+ SecretKey: policy.SecretKey,
},
}),
HTTPClient: request.NewClient(),
- }
- return nil
+ }, nil
case "s3":
- handler, err := s3.NewDriver(currentPolicy)
- fs.Handler = handler
- return err
+ return s3.NewDriver(policy)
case "googledrive":
- handler, err := googledrive.NewDriver(currentPolicy)
- fs.Handler = handler
- return err
+ return googledrive.NewDriver(policy)
default:
- return ErrUnknownPolicyType
+ return nil, ErrUnknownPolicyType
}
-
- return nil
}
// NewFileSystemFromContext 从gin.Context创建文件系统
@@ -290,3 +278,18 @@ func (fs *FileSystem) CleanTargets() {
fs.FileTarget = fs.FileTarget[:0]
fs.DirTarget = fs.DirTarget[:0]
}
+
+// SetPolicyFromPath 根据给定路径尝试设定偏好存储策略
+func (fs *FileSystem) SetPolicyFromPath(filePath string) error {
+ _, parent := fs.getClosedParent(filePath)
+ // 尝试获取并重设存储策略
+ fs.Policy = fs.User.GetPolicyID(parent)
+ return fs.DispatchHandler()
+}
+
+// SetPolicyFromPreference 尝试设定偏好存储策略
+func (fs *FileSystem) SetPolicyFromPreference(preference uint) error {
+ // 尝试获取并重设存储策略
+ fs.Policy = fs.User.GetPolicyByPreference(preference)
+ return fs.DispatchHandler()
+}
diff --git a/pkg/filesystem/filesystem_test.go b/pkg/filesystem/filesystem_test.go
deleted file mode 100644
index 8b7aae3..0000000
--- a/pkg/filesystem/filesystem_test.go
+++ /dev/null
@@ -1,299 +0,0 @@
-package filesystem
-
-import (
- "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "net/http/httptest"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/assert"
-
- "testing"
-)
-
-func TestNewFileSystem(t *testing.T) {
- asserts := assert.New(t)
- user := model.User{
- Policy: model.Policy{
- Type: "local",
- },
- }
-
- // 本地 成功
- fs, err := NewFileSystem(&user)
- asserts.NoError(err)
- asserts.NotNil(fs.Handler)
- asserts.IsType(local.Driver{}, fs.Handler)
- // 远程
- user.Policy.Type = "remote"
- fs, err = NewFileSystem(&user)
- asserts.NoError(err)
- asserts.NotNil(fs.Handler)
- asserts.IsType(&remote.Driver{}, fs.Handler)
-
- user.Policy.Type = "unknown"
- fs, err = NewFileSystem(&user)
- asserts.Error(err)
-}
-
-func TestNewFileSystemFromContext(t *testing.T) {
- asserts := assert.New(t)
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Set("user", &model.User{
- Policy: model.Policy{
- Type: "local",
- },
- })
- fs, err := NewFileSystemFromContext(c)
- asserts.NotNil(fs)
- asserts.NoError(err)
-
- c, _ = gin.CreateTestContext(httptest.NewRecorder())
- fs, err = NewFileSystemFromContext(c)
- asserts.Nil(fs)
- asserts.Error(err)
-}
-
-func TestDispatchHandler(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{
- User: &model.User{},
- Policy: &model.Policy{
- Type: "local",
- },
- }
-
- // 未指定,使用用户默认
- err := fs.DispatchHandler()
- asserts.NoError(err)
- asserts.IsType(local.Driver{}, fs.Handler)
-
- // 已指定,发生错误
- fs.Policy = &model.Policy{Type: "unknown"}
- err = fs.DispatchHandler()
- asserts.Error(err)
-
- fs.Policy = &model.Policy{Type: "mock"}
- err = fs.DispatchHandler()
- asserts.NoError(err)
-
- fs.Policy = &model.Policy{Type: "local"}
- err = fs.DispatchHandler()
- asserts.NoError(err)
-
- fs.Policy = &model.Policy{Type: "remote"}
- err = fs.DispatchHandler()
- asserts.NoError(err)
-
- fs.Policy = &model.Policy{Type: "qiniu"}
- err = fs.DispatchHandler()
- asserts.NoError(err)
-
- fs.Policy = &model.Policy{Type: "oss", Server: "https://s.com", BucketName: "1234"}
- err = fs.DispatchHandler()
- asserts.NoError(err)
-
- fs.Policy = &model.Policy{Type: "upyun"}
- err = fs.DispatchHandler()
- asserts.NoError(err)
-
- fs.Policy = &model.Policy{Type: "onedrive"}
- err = fs.DispatchHandler()
- asserts.NoError(err)
-
- fs.Policy = &model.Policy{Type: "cos"}
- err = fs.DispatchHandler()
- asserts.NoError(err)
-
- fs.Policy = &model.Policy{Type: "s3"}
- err = fs.DispatchHandler()
- asserts.NoError(err)
-}
-
-func TestNewFileSystemFromCallback(t *testing.T) {
- asserts := assert.New(t)
-
- // 用户上下文不存在
- {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- fs, err := NewFileSystemFromCallback(c)
- asserts.Nil(fs)
- asserts.Error(err)
- }
-
- // 找不到回调会话
- {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Set("user", &model.User{
- Policy: model.Policy{
- Type: "local",
- },
- })
- fs, err := NewFileSystemFromCallback(c)
- asserts.Nil(fs)
- asserts.Error(err)
- }
-
- // 成功
- {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Set("user", &model.User{
- Policy: model.Policy{
- Type: "local",
- },
- })
- c.Set(UploadSessionCtx, &serializer.UploadSession{Policy: model.Policy{Type: "local"}})
- fs, err := NewFileSystemFromCallback(c)
- asserts.NotNil(fs)
- asserts.NoError(err)
- }
-
-}
-
-func TestFileSystem_SetTargetFileByIDs(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- fs := &FileSystem{}
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt"))
- err := fs.SetTargetFileByIDs([]uint{1, 2})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(fs.FileTarget, 1)
- asserts.NoError(err)
- }
-
- // 未找到
- {
- fs := &FileSystem{}
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, 2).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- err := fs.SetTargetFileByIDs([]uint{1, 2})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(fs.FileTarget, 0)
- asserts.Error(err)
- }
-}
-
-func TestFileSystem_CleanTargets(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{
- FileTarget: []model.File{{}, {}},
- DirTarget: []model.Folder{{}, {}},
- }
-
- fs.CleanTargets()
- asserts.Len(fs.FileTarget, 0)
- asserts.Len(fs.DirTarget, 0)
-}
-
-func TestNewAnonymousFileSystem(t *testing.T) {
- asserts := assert.New(t)
-
- // 正常
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policies"}).AddRow(3, "游客", "[]"))
- fs, err := NewAnonymousFileSystem()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.Equal("游客", fs.User.Group.Name)
- }
-
- // 游客用户组不存在
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "policies"}))
- fs, err := NewAnonymousFileSystem()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Nil(fs)
- }
-
- // 从机
- {
- conf.SystemConfig.Mode = "slave"
- fs, err := NewAnonymousFileSystem()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.NotNil(fs)
- asserts.NotNil(fs.Handler)
- }
-}
-
-func TestFileSystem_Recycle(t *testing.T) {
- fs := &FileSystem{
- User: &model.User{},
- Policy: &model.Policy{},
- FileTarget: []model.File{model.File{}},
- DirTarget: []model.Folder{model.Folder{}},
- Hooks: map[string][]Hook{"AfterUpload": []Hook{GenericAfterUpdate}},
- }
- fs.Recycle()
- newFS := getEmptyFS()
- if fs != newFS {
- t.Error("指针不一致")
- }
-}
-
-func TestFileSystem_SetTargetByInterface(t *testing.T) {
- asserts := assert.New(t)
- fs := FileSystem{}
-
- // 目录
- {
- asserts.NoError(fs.SetTargetByInterface(&model.Folder{}))
- asserts.Len(fs.DirTarget, 1)
- asserts.Len(fs.FileTarget, 0)
- }
-
- // 文件
- {
- asserts.NoError(fs.SetTargetByInterface(&model.File{}))
- asserts.Len(fs.DirTarget, 1)
- asserts.Len(fs.FileTarget, 1)
- }
-}
-
-func TestFileSystem_SwitchToSlaveHandler(t *testing.T) {
- a := assert.New(t)
- fs := FileSystem{
- User: &model.User{},
- }
- mockNode := &cluster.MasterNode{
- Model: &model.Node{},
- }
- fs.SwitchToSlaveHandler(mockNode)
- a.IsType(&slaveinmaster.Driver{}, fs.Handler)
-}
-
-func TestFileSystem_SwitchToShadowHandler(t *testing.T) {
- a := assert.New(t)
- fs := FileSystem{
- User: &model.User{},
- Policy: &model.Policy{},
- }
- mockNode := &cluster.MasterNode{
- Model: &model.Node{},
- }
-
- // local to remote
- {
- fs.Policy.Type = "local"
- fs.SwitchToShadowHandler(mockNode, "", "")
- a.IsType(&masterinslave.Driver{}, fs.Handler)
- }
-
- // onedrive
- {
- fs.Policy.Type = "onedrive"
- fs.SwitchToShadowHandler(mockNode, "", "")
- a.IsType(&masterinslave.Driver{}, fs.Handler)
- }
-}
diff --git a/pkg/filesystem/fsctx/stream_test.go b/pkg/filesystem/fsctx/stream_test.go
deleted file mode 100644
index 1ef6e1f..0000000
--- a/pkg/filesystem/fsctx/stream_test.go
+++ /dev/null
@@ -1,78 +0,0 @@
-package fsctx
-
-import (
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/stretchr/testify/assert"
- "io"
- "io/ioutil"
- "os"
- "strings"
- "testing"
-)
-
-func TestFileStream_Read(t *testing.T) {
- asserts := assert.New(t)
- file := FileStream{
- File: ioutil.NopCloser(strings.NewReader("123")),
- }
- var p = make([]byte, 3)
- {
- n, err := file.Read(p)
- asserts.Equal(3, n)
- asserts.NoError(err)
- }
-}
-
-func TestFileStream_Close(t *testing.T) {
- asserts := assert.New(t)
- {
- file := FileStream{
- File: ioutil.NopCloser(strings.NewReader("123")),
- }
- err := file.Close()
- asserts.NoError(err)
- }
-
- {
- file := FileStream{}
- err := file.Close()
- asserts.NoError(err)
- }
-}
-
-func TestFileStream_Seek(t *testing.T) {
- asserts := assert.New(t)
- f, _ := os.CreateTemp("", "*")
- defer func() {
- f.Close()
- os.Remove(f.Name())
- }()
- {
- file := FileStream{
- File: f,
- Seeker: f,
- }
- res, err := file.Seek(0, io.SeekStart)
- asserts.NoError(err)
- asserts.EqualValues(0, res)
- }
-
- {
- file := FileStream{}
- res, err := file.Seek(0, io.SeekStart)
- asserts.Error(err)
- asserts.EqualValues(0, res)
- }
-}
-
-func TestFileStream_Info(t *testing.T) {
- a := assert.New(t)
- file := FileStream{}
- a.NotNil(file.Info())
-
- file.SetSize(10)
- a.EqualValues(10, file.Info().Size)
-
- file.SetModel(&model.File{})
- a.NotNil(file.Info().Model)
-}
diff --git a/pkg/filesystem/hooks.go b/pkg/filesystem/hooks.go
index a2f9ed5..2a4c5d5 100644
--- a/pkg/filesystem/hooks.go
+++ b/pkg/filesystem/hooks.go
@@ -2,6 +2,12 @@ package filesystem
import (
"context"
+ "io/ioutil"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
@@ -9,11 +15,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
- "io/ioutil"
- "net/http"
- "strconv"
- "strings"
- "time"
)
// Hook 钩子函数
@@ -222,6 +223,21 @@ func GenericAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.Fi
return nil
}
+// HookGenerateThumb 生成缩略图
+// func HookGenerateThumb(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
+// // 异步尝试生成缩略图
+// fileMode := fileHeader.Info().Model.(*model.File)
+// if fs.Policy.IsThumbGenerateNeeded() {
+// fs.recycleLock.Lock()
+// go func() {
+// defer fs.recycleLock.Unlock()
+// _, _ = fs.Handler.Delete(ctx, []string{fileMode.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")})
+// fs.GenerateThumbnail(ctx, fileMode)
+// }()
+// }
+// return nil
+// }
+
// HookClearFileHeaderSize 将FileHeader大小设定为0
func HookClearFileHeaderSize(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileHeader.SetSize(0)
diff --git a/pkg/filesystem/hooks_test.go b/pkg/filesystem/hooks_test.go
deleted file mode 100644
index cc660ce..0000000
--- a/pkg/filesystem/hooks_test.go
+++ /dev/null
@@ -1,708 +0,0 @@
-package filesystem
-
-import (
- "context"
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "io/ioutil"
- "net/http"
- "strings"
- "testing"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
-)
-
-func TestGenericBeforeUpload(t *testing.T) {
- asserts := assert.New(t)
- file := &fsctx.FileStream{
- Size: 5,
- Name: "1.txt",
- }
- ctx := context.Background()
- cache.Set("pack_size_0", uint64(0), 0)
- fs := FileSystem{
- User: &model.User{
- Storage: 0,
- Group: model.Group{
- MaxStorage: 11,
- },
- },
- Policy: &model.Policy{
- MaxSize: 4,
- OptionsSerialized: model.PolicyOption{
- FileType: []string{"txt"},
- },
- },
- }
-
- asserts.Error(HookValidateFile(ctx, &fs, file))
-
- file.Size = 1
- file.Name = "1"
- asserts.Error(HookValidateFile(ctx, &fs, file))
-
- file.Name = "1.txt"
- asserts.NoError(HookValidateFile(ctx, &fs, file))
-
- file.Name = "1.t/xt"
- asserts.Error(HookValidateFile(ctx, &fs, file))
-}
-
-func TestGenericAfterUploadCanceled(t *testing.T) {
- asserts := assert.New(t)
- file := &fsctx.FileStream{
- Size: 5,
- Name: "TestGenericAfterUploadCanceled",
- SavePath: "TestGenericAfterUploadCanceled",
- }
- ctx := context.Background()
- fs := FileSystem{
- User: &model.User{},
- }
-
- // 成功
- {
- mockHandler := &FileHeaderMock{}
- fs.Handler = mockHandler
- mockHandler.On("Delete", testMock.Anything, testMock.Anything).Return([]string{}, nil)
- err := HookDeleteTempFile(ctx, &fs, file)
- asserts.NoError(err)
- mockHandler.AssertExpectations(t)
- }
-
- // 失败
- {
- mockHandler := &FileHeaderMock{}
- fs.Handler = mockHandler
- mockHandler.On("Delete", testMock.Anything, testMock.Anything).Return([]string{}, errors.New(""))
- err := HookDeleteTempFile(ctx, &fs, file)
- asserts.NoError(err)
- mockHandler.AssertExpectations(t)
- }
-
-}
-
-func TestGenericAfterUpload(t *testing.T) {
- asserts := assert.New(t)
- fs := FileSystem{
- User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- },
- Policy: &model.Policy{},
- }
-
- ctx := context.Background()
- file := &fsctx.FileStream{
- VirtualPath: "/我的文件",
- Name: "test.txt",
- }
-
- // 正常
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)files").
- WithArgs(1, "我的文件").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("我的文件", 1, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
- mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found"))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- err := GenericAfterUpload(ctx, &fs, file)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 文件已存在
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)files").
- WithArgs(1, "我的文件").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("我的文件", 1, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
- mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows(
- mock.NewRows([]string{"name"}).AddRow("test.txt"),
- )
- err = GenericAfterUpload(ctx, &fs, file)
- asserts.Equal(ErrFileExisted, err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 文件已存在, 且为上传占位符
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)files").
- WithArgs(1, "我的文件").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("我的文件", 1, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
- mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnRows(
- mock.NewRows([]string{"name", "upload_session_id"}).AddRow("test.txt", "1"),
- )
- err = GenericAfterUpload(ctx, &fs, file)
- asserts.Equal(ErrFileUploadSessionExisted, err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 插入失败
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)files").
- WithArgs(1, "我的文件").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("我的文件", 1, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
-
- mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found"))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)files(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
-
- err = GenericAfterUpload(ctx, &fs, file)
- asserts.Equal(ErrInsertFileRecord, err)
- asserts.NoError(mock.ExpectationsWereMet())
-
-}
-
-func TestFileSystem_Use(t *testing.T) {
- asserts := assert.New(t)
- fs := FileSystem{}
-
- hook := func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
- return nil
- }
-
- // 添加一个
- fs.Use("BeforeUpload", hook)
- asserts.Len(fs.Hooks["BeforeUpload"], 1)
-
- // 添加一个
- fs.Use("BeforeUpload", hook)
- asserts.Len(fs.Hooks["BeforeUpload"], 2)
-
- // 不存在
- fs.Use("BeforeUpload2333", hook)
-
- asserts.NotPanics(func() {
- for _, hookName := range []string{
- "AfterUpload",
- "AfterValidateFailed",
- "AfterUploadCanceled",
- "BeforeFileDownload",
- } {
- fs.Use(hookName, hook)
- }
- })
-
-}
-
-func TestFileSystem_Trigger(t *testing.T) {
- asserts := assert.New(t)
- fs := FileSystem{
- User: &model.User{},
- }
- ctx := context.Background()
-
- hook := func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
- fs.User.Storage++
- return nil
- }
-
- // 一个
- fs.Use("BeforeUpload", hook)
- err := fs.Trigger(ctx, "BeforeUpload", nil)
- asserts.NoError(err)
- asserts.Equal(uint64(1), fs.User.Storage)
-
- // 多个
- fs.Use("BeforeUpload", hook)
- fs.Use("BeforeUpload", hook)
- err = fs.Trigger(ctx, "BeforeUpload", nil)
- asserts.NoError(err)
- asserts.Equal(uint64(4), fs.User.Storage)
-
- // 多个,有失败
- fs.Use("BeforeUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
- return errors.New("error")
- })
- fs.Use("BeforeUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
- asserts.Fail("following hooks executed")
- return nil
- })
- err = fs.Trigger(ctx, "BeforeUpload", nil)
- asserts.Error(err)
-}
-
-func TestHookValidateCapacity(t *testing.T) {
- asserts := assert.New(t)
- cache.Set("pack_size_1", uint64(0), 0)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{ID: 1},
- Storage: 0,
- Group: model.Group{
- MaxStorage: 11,
- },
- }}
- ctx := context.Background()
- file := &fsctx.FileStream{Size: 11}
- {
- err := HookValidateCapacity(ctx, fs, file)
- asserts.NoError(err)
- }
- {
- file.Size = 12
- err := HookValidateCapacity(ctx, fs, file)
- asserts.Error(err)
- }
-}
-
-func TestHookValidateCapacityDiff(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Group: model.Group{
- MaxStorage: 11,
- },
- }}
- file := model.File{Size: 10}
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
-
- // 无需操作
- {
- a.NoError(HookValidateCapacityDiff(ctx, fs, &fsctx.FileStream{Size: 10}))
- }
-
- // 需要验证
- {
- a.Error(HookValidateCapacityDiff(ctx, fs, &fsctx.FileStream{Size: 12}))
- }
-
-}
-
-func TestHookResetPolicy(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{ID: 1},
- }}
-
- // 成功
- {
- file := model.File{PolicyID: 2}
- cache.Deletes([]string{"2"}, "policy_")
- mock.ExpectQuery("SELECT(.+)policies(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(2, "local"))
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
- err := HookResetPolicy(ctx, fs, nil)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // 上下文文件不存在
- {
- cache.Deletes([]string{"2"}, "policy_")
- ctx := context.Background()
- err := HookResetPolicy(ctx, fs, nil)
- asserts.Error(err)
- }
-}
-
-func TestHookCleanFileContent(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{ID: 1},
- }}
-
- file := &fsctx.FileStream{SavePath: "123/123"}
- handlerMock := FileHeaderMock{}
- handlerMock.On("Put", testMock.Anything, testMock.Anything).Return(errors.New("error"))
- fs.Handler = handlerMock
- err := HookCleanFileContent(context.Background(), fs, file)
- asserts.Error(err)
- handlerMock.AssertExpectations(t)
-}
-
-func TestHookClearFileSize(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{ID: 1},
- }}
-
- // 成功
- {
- ctx := context.WithValue(
- context.Background(),
- fsctx.FileModelCtx,
- model.File{Model: gorm.Model{ID: 1}, Size: 10},
- )
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").
- WithArgs("", 0, sqlmock.AnyArg(), 1, 10).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)users(.+)").
- WithArgs(10, sqlmock.AnyArg()).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := HookClearFileSize(ctx, fs, nil)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // 上下文对象不存在
- {
- ctx := context.Background()
- err := HookClearFileSize(ctx, fs, nil)
- asserts.Error(err)
- }
-
-}
-
-func TestHookUpdateSourceName(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{ID: 1},
- }}
-
- // 成功
- {
- originFile := model.File{
- Model: gorm.Model{ID: 1},
- SourceName: "new.txt",
- }
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile)
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WithArgs("", "new.txt", sqlmock.AnyArg(), 1).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := HookUpdateSourceName(ctx, fs, nil)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // 上下文错误
- {
- ctx := context.Background()
- err := HookUpdateSourceName(ctx, fs, nil)
- asserts.Error(err)
- }
-}
-
-func TestGenericAfterUpdate(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{ID: 1},
- }}
-
- // 成功 是图像文件
- {
- originFile := model.File{
- Model: gorm.Model{ID: 1},
- PicInfo: "1,1",
- }
- newFile := &fsctx.FileStream{Size: 10}
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile)
-
- handlerMock := FileHeaderMock{}
- handlerMock.On("Delete", testMock.Anything, []string{"._thumb"}).Return([]string{}, nil)
- fs.Handler = handlerMock
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").
- WithArgs("", 10, sqlmock.AnyArg(), 1, 0).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)users(.+)").
- WithArgs(10, sqlmock.AnyArg()).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- err := GenericAfterUpdate(ctx, fs, newFile)
-
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // 原始文件上下文不存在
- {
- newFile := &fsctx.FileStream{Size: 10}
- ctx := context.Background()
- err := GenericAfterUpdate(ctx, fs, newFile)
- asserts.Error(err)
- }
-
- // 无法更新数据库容量
- // 成功 是图像文件
- {
- originFile := model.File{
- Model: gorm.Model{ID: 1},
- PicInfo: "1,1",
- }
- newFile := &fsctx.FileStream{Size: 10}
- ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, originFile)
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").
- WithArgs("", 10, sqlmock.AnyArg(), 1, 0).
- WillReturnError(errors.New("error"))
- mock.ExpectRollback()
-
- err := GenericAfterUpdate(ctx, fs, newFile)
-
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
-}
-
-func TestSlaveAfterUpload(t *testing.T) {
- asserts := assert.New(t)
- conf.SystemConfig.Mode = "slave"
- fs, err := NewAnonymousFileSystem()
- conf.SystemConfig.Mode = "master"
- asserts.NoError(err)
-
- // 成功
- {
- clientMock := requestmock.RequestMock{}
- clientMock.On(
- "Request",
- "POST",
- "http://test/callbakc",
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: nil,
- Response: &http.Response{
- StatusCode: 200,
- Body: ioutil.NopCloser(strings.NewReader(`{"code":0}`)),
- },
- })
- request.GeneralClient = clientMock
- file := &fsctx.FileStream{
- Size: 10,
- VirtualPath: "/my",
- Name: "test.txt",
- SavePath: "/not_exist",
- }
- err := SlaveAfterUpload(&serializer.UploadSession{Callback: "http://test/callbakc"})(context.Background(), fs, file)
- clientMock.AssertExpectations(t)
- asserts.NoError(err)
- }
-
- // 跳过回调
- {
- file := &fsctx.FileStream{
- Size: 10,
- VirtualPath: "/my",
- Name: "test.txt",
- SavePath: "/not_exist",
- }
- err := SlaveAfterUpload(&serializer.UploadSession{})(context.Background(), fs, file)
- asserts.NoError(err)
- }
-}
-
-func TestFileSystem_CleanHooks(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{
- User: &model.User{
- Model: gorm.Model{ID: 1},
- },
- Hooks: map[string][]Hook{
- "hook1": []Hook{},
- "hook2": []Hook{},
- "hook3": []Hook{},
- },
- }
-
- // 清理一个
- {
- fs.CleanHooks("hook2")
- asserts.Len(fs.Hooks, 2)
- asserts.Contains(fs.Hooks, "hook1")
- asserts.Contains(fs.Hooks, "hook3")
- }
-
- // 清理全部
- {
- fs.CleanHooks("")
- asserts.Len(fs.Hooks, 0)
- }
-}
-
-func TestHookCancelContext(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{}
- ctx, cancel := context.WithCancel(context.Background())
-
- // empty ctx
- {
- asserts.NoError(HookCancelContext(ctx, fs, nil))
- select {
- case <-ctx.Done():
- t.Errorf("Channel should not be closed")
- default:
-
- }
- }
-
- // with cancel ctx
- {
- ctx = context.WithValue(ctx, fsctx.CancelFuncCtx, cancel)
- asserts.NoError(HookCancelContext(ctx, fs, nil))
- _, ok := <-ctx.Done()
- asserts.False(ok)
- }
-}
-
-func TestHookClearFileHeaderSize(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{}
- file := &fsctx.FileStream{Size: 10}
- a.NoError(HookClearFileHeaderSize(context.Background(), fs, file))
- a.EqualValues(0, file.Size)
-}
-
-func TestHookTruncateFileTo(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{}
- file := &fsctx.FileStream{}
- a.NoError(HookTruncateFileTo(0)(context.Background(), fs, file))
-
- fs.Handler = local.Driver{}
- a.Error(HookTruncateFileTo(0)(context.Background(), fs, file))
-}
-
-func TestHookChunkUploaded(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{}
- file := &fsctx.FileStream{
- AppendStart: 10,
- Size: 10,
- Model: &model.File{
- Model: gorm.Model{ID: 1},
- },
- }
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 20, sqlmock.AnyArg(), 1, 0).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)users(.+)").
- WithArgs(20, sqlmock.AnyArg()).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- a.NoError(HookChunkUploaded(context.Background(), fs, file))
- a.NoError(mock.ExpectationsWereMet())
-}
-
-func TestHookChunkUploadFailed(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{}
- file := &fsctx.FileStream{
- AppendStart: 10,
- Size: 10,
- Model: &model.File{
- Model: gorm.Model{ID: 1},
- },
- }
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WithArgs("", 10, sqlmock.AnyArg(), 1, 0).WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)users(.+)").
- WithArgs(10, sqlmock.AnyArg()).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- a.NoError(HookChunkUploadFailed(context.Background(), fs, file))
- a.NoError(mock.ExpectationsWereMet())
-}
-
-func TestHookPopPlaceholderToFile(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{}
- file := &fsctx.FileStream{
- Model: &model.File{
- Model: gorm.Model{ID: 1},
- },
- }
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- a.NoError(HookPopPlaceholderToFile("1,1")(context.Background(), fs, file))
- a.NoError(mock.ExpectationsWereMet())
-}
-
-func TestHookPopPlaceholderToFileBySuffix(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{
- Policy: &model.Policy{Type: "cos"},
- }
- file := &fsctx.FileStream{
- Name: "1.png",
- Model: &model.File{
- Model: gorm.Model{ID: 1},
- },
- }
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- a.NoError(HookPopPlaceholderToFile("")(context.Background(), fs, file))
- a.NoError(mock.ExpectationsWereMet())
-}
-
-func TestHookDeleteUploadSession(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{}
- file := &fsctx.FileStream{
- Model: &model.File{
- Model: gorm.Model{ID: 1},
- },
- }
-
- cache.Set(UploadSessionCachePrefix+"TestHookDeleteUploadSession", "", 0)
- a.NoError(HookDeleteUploadSession("TestHookDeleteUploadSession")(context.Background(), fs, file))
- _, ok := cache.Get(UploadSessionCachePrefix + "TestHookDeleteUploadSession")
- a.False(ok)
-}
-func TestNewWebdavAfterUploadHook(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{}
- file := &fsctx.FileStream{
- Model: &model.File{
- Model: gorm.Model{ID: 1},
- },
- }
-
- req, _ := http.NewRequest("get", "http://localhost", nil)
- req.Header.Add("X-Oc-Mtime", "1681521402")
- req.Header.Add("OC-Checksum", "checksum")
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := NewWebdavAfterUploadHook(req)(context.Background(), fs, file)
- a.NoError(err)
- a.NoError(mock.ExpectationsWereMet())
-
-}
diff --git a/pkg/filesystem/image_test.go b/pkg/filesystem/image_test.go
deleted file mode 100644
index 4180858..0000000
--- a/pkg/filesystem/image_test.go
+++ /dev/null
@@ -1,127 +0,0 @@
-package filesystem
-
-import (
- "context"
- "errors"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/thumbmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/thumb"
- testMock "github.com/stretchr/testify/mock"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestFileSystem_GetThumb(t *testing.T) {
- a := assert.New(t)
- fs := &FileSystem{User: &model.User{}}
-
- // file not found
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
- res, err := fs.GetThumb(context.Background(), 1)
- a.ErrorIs(err, ErrObjectNotExist)
- a.Nil(res)
- a.NoError(mock.ExpectationsWereMet())
- }
-
- // thumb not exist
- {
- fs.SetTargetFile(&[]model.File{{
- MetadataSerialized: map[string]string{
- model.ThumbStatusMetadataKey: model.ThumbStatusNotAvailable,
- },
- Policy: model.Policy{Type: "mock"},
- }})
- fs.FileTarget[0].Policy.ID = 1
-
- res, err := fs.GetThumb(context.Background(), 1)
- a.ErrorIs(err, ErrObjectNotExist)
- a.Nil(res)
- }
-
- // thumb not initialized, also failed to generate
- {
- fs.CleanTargets()
- fs.SetTargetFile(&[]model.File{{
- Policy: model.Policy{Type: "mock"},
- Size: 31457281,
- }})
- testHandller2 := new(FileHeaderMock)
- testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist)
- fs.Handler = testHandller2
- fs.FileTarget[0].Policy.ID = 1
- res, err := fs.GetThumb(context.Background(), 1)
- a.Contains(err.Error(), "file too large")
- a.Nil(res.Content)
- }
-
- // thumb not initialized, failed to get source
- {
- fs.CleanTargets()
- fs.SetTargetFile(&[]model.File{{
- Policy: model.Policy{Type: "mock"},
- }})
- testHandller2 := new(FileHeaderMock)
- testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist)
- testHandller2.On("Get", testMock.Anything, "").Return(MockRSC{}, errors.New("error"))
- fs.Handler = testHandller2
- fs.FileTarget[0].Policy.ID = 1
- res, err := fs.GetThumb(context.Background(), 1)
- a.Contains(err.Error(), "error")
- a.Nil(res.Content)
- }
-
- // thumb not initialized, no available generators
- {
- thumb.Generators = []thumb.Generator{}
- fs.CleanTargets()
- fs.SetTargetFile(&[]model.File{{
- Policy: model.Policy{Type: "local"},
- }})
- testHandller2 := new(FileHeaderMock)
- testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist)
- testHandller2.On("Get", testMock.Anything, "").Return(MockRSC{}, nil)
- fs.Handler = testHandller2
- fs.FileTarget[0].Policy.ID = 1
- res, err := fs.GetThumb(context.Background(), 1)
- a.ErrorIs(err, thumb.ErrNotAvailable)
- a.Nil(res)
- }
-
- // thumb not initialized, thumb generated but cannot be open
- {
- mockGenerator := &thumbmock.GeneratorMock{}
- thumb.Generators = []thumb.Generator{mockGenerator}
- fs.CleanTargets()
- fs.SetTargetFile(&[]model.File{{
- Policy: model.Policy{Type: "mock"},
- }})
- cache.Set("setting_thumb_vips_enabled", "1", 0)
- testHandller2 := new(FileHeaderMock)
- testHandller2.On("Thumb", testMock.Anything, &fs.FileTarget[0]).Return(&response.ContentResponse{}, driver.ErrorThumbNotExist)
- testHandller2.On("Get", testMock.Anything, "").Return(MockRSC{}, nil)
- mockGenerator.On("Generate", testMock.Anything, testMock.Anything, testMock.Anything, testMock.Anything, testMock.Anything).
- Return(&thumb.Result{Path: "not_exit_thumb"}, nil)
-
- fs.Handler = testHandller2
- fs.FileTarget[0].Policy.ID = 1
- res, err := fs.GetThumb(context.Background(), 1)
- a.Contains(err.Error(), "failed to open temp thumb")
- a.Nil(res.Content)
- testHandller2.AssertExpectations(t)
- mockGenerator.AssertExpectations(t)
- }
-}
-
-func TestFileSystem_ThumbWorker(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.NotPanics(func() {
- getThumbWorker().addWorker()
- getThumbWorker().releaseWorker()
- })
-}
diff --git a/pkg/filesystem/manage_test.go b/pkg/filesystem/manage_test.go
deleted file mode 100644
index 1f2cc1a..0000000
--- a/pkg/filesystem/manage_test.go
+++ /dev/null
@@ -1,848 +0,0 @@
-package filesystem
-
-import (
- "context"
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- "os"
- "testing"
-
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
- testMock "github.com/stretchr/testify/mock"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestFileSystem_ListPhysical(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{
- User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- },
- Policy: &model.Policy{Type: "mock"},
- }
- ctx := context.Background()
-
- // 未知存储策略
- {
- fs.Policy.Type = "unknown"
- res, err := fs.ListPhysical(ctx, "/")
- asserts.Equal(ErrUnknownPolicyType, err)
- asserts.Empty(res)
- fs.Policy.Type = "mock"
- }
-
- // 无法列取目录
- {
- testHandler := new(FileHeaderMock)
- testHandler.On("List", testMock.Anything, "/", testMock.Anything).Return([]response.Object{}, errors.New("error"))
- fs.Handler = testHandler
- res, err := fs.ListPhysical(ctx, "/")
- asserts.EqualError(err, "error")
- asserts.Empty(res)
- }
-
- // 成功
- {
- testHandler := new(FileHeaderMock)
- testHandler.On("List", testMock.Anything, "/", testMock.Anything).Return(
- []response.Object{{IsDir: true, Name: "1"}, {IsDir: false, Name: "2"}},
- nil,
- )
- fs.Handler = testHandler
- res, err := fs.ListPhysical(ctx, "/")
- asserts.NoError(err)
- asserts.Len(res, 1)
- asserts.Equal("1", res[0].Name)
- }
-}
-
-func TestFileSystem_List(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- }}
- ctx := context.Background()
-
- // 成功,子目录包含文件和路径,不使用路径处理钩子
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1))
- // folder
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "folder").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(5, "folder", 1))
-
- mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_folder1").AddRow(7, "sub_folder2"))
- mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_file1.txt").AddRow(7, "sub_file2.txt"))
- objects, err := fs.List(ctx, "/folder", nil)
- asserts.Len(objects, 4)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 成功,子目录包含文件和路径,不使用路径处理钩子,包含分享key
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1))
- // folder
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "folder").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(5, "folder", 1))
-
- mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_folder1").AddRow(7, "sub_folder2"))
- mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_file1.txt").AddRow(7, "sub_file2.txt"))
- ctxWithKey := context.WithValue(ctx, fsctx.ShareKeyCtx, "share")
- objects, err = fs.List(ctxWithKey, "/folder", nil)
- asserts.Len(objects, 4)
- asserts.Equal("share", objects[3].Key)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 成功,子目录包含文件和路径,使用路径处理钩子
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1))
- // folder
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "folder").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(2, "folder", 1))
-
- mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"}).AddRow(6, "sub_folder1", "/folder").AddRow(7, "sub_folder2", "/folder"))
- mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"}).AddRow(6, "sub_file1.txt", "/folder").AddRow(7, "sub_file2.txt", "/folder"))
- objects, err = fs.List(ctx, "/folder", func(s string) string {
- return "prefix" + s
- })
- asserts.Len(objects, 4)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- for _, value := range objects {
- asserts.Contains(value.Path, "prefix/")
- }
-
- // 成功,子目录包含路径,使用路径处理钩子
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1))
- // folder
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "folder").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(2, "folder", 1))
-
- mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"}))
- mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"}).AddRow(6, "sub_file1.txt", "/folder").AddRow(7, "sub_file2.txt", "/folder"))
- objects, err = fs.List(ctx, "/folder", func(s string) string {
- return "prefix" + s
- })
- asserts.Len(objects, 2)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- for _, value := range objects {
- asserts.Contains(value.Path, "prefix/")
- }
-
- // 成功,子目录下为空,使用路径处理钩子
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1))
- // folder
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "folder").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(2, "folder", 1))
-
- mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"}))
- mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"}))
- objects, err = fs.List(ctx, "/folder", func(s string) string {
- return "prefix" + s
- })
- asserts.Len(objects, 0)
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 成功,子目录路径不存在
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}).AddRow(1, "/", 1))
- // folder
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "folder").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name", "owner_id"}))
-
- objects, err = fs.List(ctx, "/folder", func(s string) string {
- return "prefix" + s
- })
- asserts.Len(objects, 0)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestFileSystem_CreateDirectory(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- }}
- ctx := context.Background()
-
- // 目录名非法
- _, err := fs.CreateDirectory(ctx, "/ad/a+?")
- asserts.Equal(ErrIllegalObjectName, err)
-
- // 存在同名文件
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // ad
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "ad").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
-
- mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "ab"))
- _, err = fs.CreateDirectory(ctx, "/ad/ab")
- asserts.Equal(ErrFileExisted, err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 存在同名目录,直接返回
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // ad
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "ad").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
-
- mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- // ab
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("ab", 2, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1))
- res, err := fs.CreateDirectory(ctx, "/ad/ab")
- asserts.NoError(err)
- asserts.EqualValues(3, res.ID)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 成功创建
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // ad
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "ad").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
-
- mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("ab", 2, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- _, err = fs.CreateDirectory(ctx, "/ad/ab")
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 成功创建, 递归创建父目录
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // ad
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "ad").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- // 创建ad
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("ad", 1, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectCommit()
- mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- // 创建ab
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("ab", 2, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- _, err = fs.CreateDirectory(ctx, "/ad/ab")
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 底层创建失败
- // 成功创建
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // ad
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "ad").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- // 创建ad
- mock.ExpectQuery("SELECT(.+)").
- WithArgs("ad", 1, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(2, 1)).WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- mock.ExpectQuery("SELECT(.+)").
- WillReturnError(errors.New("error"))
- _, err = fs.CreateDirectory(ctx, "/ad/ab")
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 直接创建根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- _, err = fs.CreateDirectory(ctx, "/")
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-
- // 直接创建根目录, 重设根目录
- fs.Root = &model.Folder{}
- _, err = fs.CreateDirectory(ctx, "/")
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestFileSystem_ListDeleteFiles(t *testing.T) {
- conf.DatabaseConfig.Type = "mysql"
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- }}
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt").AddRow(2, "2.txt"))
- err := fs.ListDeleteFiles(context.Background(), []uint{1})
- asserts.NoError(err)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 失败
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
- err := fs.ListDeleteFiles(context.Background(), []uint{1})
- asserts.Error(err)
- asserts.Equal(serializer.CodeDBError, err.(serializer.AppError).Code)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestFileSystem_ListDeleteDirs(t *testing.T) {
- conf.DatabaseConfig.Type = "mysql"
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- }}
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "parent_id"}).
- AddRow(1, 0).
- AddRow(2, 0).
- AddRow(3, 0),
- )
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(1, 2, 3).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name"}).
- AddRow(4, "1.txt").
- AddRow(5, "2.txt").
- AddRow(6, "3.txt"),
- )
- err := fs.ListDeleteDirs(context.Background(), []uint{1})
- asserts.NoError(err)
- asserts.Len(fs.FileTarget, 3)
- asserts.Len(fs.DirTarget, 3)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 成功,忽略根目录
- {
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "parent_id"}).
- AddRow(1, 0).
- AddRow(2, nil).
- AddRow(3, 0),
- )
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(1, 3).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name"}).
- AddRow(4, "1.txt").
- AddRow(5, "2.txt").
- AddRow(6, "3.txt"),
- )
- fs.CleanTargets()
- err := fs.ListDeleteDirs(context.Background(), []uint{1})
- asserts.NoError(err)
- asserts.Len(fs.FileTarget, 3)
- asserts.Len(fs.DirTarget, 2)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 检索文件发生错误
- {
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "parent_id"}).
- AddRow(1, 0).
- AddRow(2, 0).
- AddRow(3, 0),
- )
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3).
- WillReturnError(errors.New("error"))
- fs.CleanTargets()
- err := fs.ListDeleteDirs(context.Background(), []uint{1})
- asserts.Error(err)
- asserts.Len(fs.DirTarget, 3)
- asserts.NoError(mock.ExpectationsWereMet())
- }
- // 检索目录发生错误
- {
- mock.ExpectQuery("SELECT(.+)").
- WillReturnError(errors.New("error"))
- err := fs.ListDeleteDirs(context.Background(), []uint{1})
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestFileSystem_Delete(t *testing.T) {
- conf.DatabaseConfig.Type = "mysql"
- asserts := assert.New(t)
- cache.Set("pack_size_1", uint64(0), 0)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 0,
- },
- Storage: 3,
- Group: model.Group{MaxStorage: 3},
- }}
- ctx := context.Background()
-
- //全部未成功,强制
- {
- fs.CleanTargets()
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "parent_id"}).
- AddRow(1, 0).
- AddRow(2, 0).
- AddRow(3, 0),
- )
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).
- AddRow(4, "1.txt", "1.txt", 365, 1),
- )
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 365, 2))
- // 两次查询软连接
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
- // 查询上传策略
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(365, "local"))
- // 删除文件记录
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(0, 1))
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(0, 1))
- mock.ExpectCommit()
- // 删除对应分享
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)shares").
- WillReturnResult(sqlmock.NewResult(0, 3))
- mock.ExpectCommit()
- // 删除目录
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(0, 3))
- mock.ExpectCommit()
- // 删除对应分享
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)shares").
- WillReturnResult(sqlmock.NewResult(0, 3))
- mock.ExpectCommit()
-
- fs.FileTarget = []model.File{}
- fs.DirTarget = []model.Folder{}
- err := fs.Delete(ctx, []uint{1}, []uint{1}, true, false)
- asserts.NoError(err)
- }
- //全部成功
- {
- fs.CleanTargets()
- file, err := os.Create(util.RelativePath("1.txt"))
- file2, err := os.Create(util.RelativePath("2.txt"))
- file.Close()
- file2.Close()
- asserts.NoError(err)
- mock.ExpectQuery("SELECT(.+)").
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "parent_id"}).
- AddRow(1, 0).
- AddRow(2, 0).
- AddRow(3, 0),
- )
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 2, 3).
- WillReturnRows(
- sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).
- AddRow(4, "1.txt", "1.txt", 602, 1),
- )
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 602, 2))
- // 两次查询软连接
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
- // 查询上传策略
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(602, "local"))
- // 删除文件记录
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(0, 1))
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(0, 1))
- mock.ExpectCommit()
- // 删除对应分享
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)shares").
- WillReturnResult(sqlmock.NewResult(0, 3))
- mock.ExpectCommit()
- // 删除目录
- mock.ExpectBegin()
- mock.ExpectExec("DELETE(.+)").
- WillReturnResult(sqlmock.NewResult(0, 3))
- mock.ExpectCommit()
- // 删除对应分享
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)shares").
- WillReturnResult(sqlmock.NewResult(0, 3))
- mock.ExpectCommit()
-
- fs.FileTarget = []model.File{}
- fs.DirTarget = []model.Folder{}
- err = fs.Delete(ctx, []uint{1}, []uint{1}, false, false)
- asserts.NoError(err)
- }
-
-}
-
-func TestFileSystem_Copy(t *testing.T) {
- asserts := assert.New(t)
- cache.Set("pack_size_1", uint64(0), 0)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- Storage: 3,
- Group: model.Group{MaxStorage: 3},
- }}
- ctx := context.Background()
-
- // 目录不存在
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(
- sqlmock.NewRows([]string{"name"}),
- )
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(
- sqlmock.NewRows([]string{"name"}),
- )
- err := fs.Copy(ctx, []uint{}, []uint{}, "/src", "/dst")
- asserts.Equal(ErrPathNotExist, err)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 复制目录出错
- {
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "dst").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "src").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
-
- err := fs.Copy(ctx, []uint{1}, []uint{}, "/src", "/dst")
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
-}
-
-func TestFileSystem_Move(t *testing.T) {
- asserts := assert.New(t)
- cache.Set("pack_size_1", uint64(0), 0)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- Storage: 3,
- Group: model.Group{MaxStorage: 3},
- }}
- ctx := context.Background()
-
- // 目录不存在
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(
- sqlmock.NewRows([]string{"name"}),
- )
- err := fs.Move(ctx, []uint{}, []uint{}, "/src", "/dst")
- asserts.Equal(ErrPathNotExist, err)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 移动目录出错
- {
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "dst").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "src").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
- err := fs.Move(ctx, []uint{1}, []uint{}, "/src", "/dst")
- asserts.Error(err)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-}
-
-func TestFileSystem_Rename(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- },
- Policy: &model.Policy{},
- }
- ctx := context.Background()
-
- // 重命名文件 成功
- {
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(10, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old.text"))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").
- WithArgs(sqlmock.AnyArg(), "new.txt", sqlmock.AnyArg(), 10).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // 重命名文件 不存在
- {
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(10, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(ErrPathNotExist, err)
- }
-
- // 重命名文件 失败
- {
- mock.ExpectQuery("SELECT(.+)files(.+)").
- WithArgs(10, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old.text"))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)files(.+)SET(.+)").
- WithArgs(sqlmock.AnyArg(), "new.txt", sqlmock.AnyArg(), 10).
- WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- err := fs.Rename(ctx, []uint{}, []uint{10}, "new.txt")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(ErrFileExisted, err)
- }
-
- // 重命名目录 成功
- {
- mock.ExpectQuery("SELECT(.+)folders(.+)").
- WithArgs(10, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old"))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)").
- WithArgs("new", 10).
- WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- err := fs.Rename(ctx, []uint{10}, []uint{}, "new")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- }
-
- // 重命名目录 不存在
- {
- mock.ExpectQuery("SELECT(.+)folders(.+)").
- WithArgs(10, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- err := fs.Rename(ctx, []uint{10}, []uint{}, "new")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(ErrPathNotExist, err)
- }
-
- // 重命名目录 失败
- {
- mock.ExpectQuery("SELECT(.+)folders(.+)").
- WithArgs(10, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "old"))
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)folders(.+)SET(.+)").
- WithArgs("new", 10).
- WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- err := fs.Rename(ctx, []uint{10}, []uint{}, "new")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(ErrFileExisted, err)
- }
-
- // 未选中任何对象
- {
- err := fs.Rename(ctx, []uint{}, []uint{}, "new")
- asserts.Error(err)
- asserts.Equal(ErrPathNotExist, err)
- }
-
- // 新名字是目录,不合法
- {
- err := fs.Rename(ctx, []uint{10}, []uint{}, "ne/w")
- asserts.Error(err)
- asserts.Equal(ErrIllegalObjectName, err)
- }
-
- // 新名字是文件,不合法
- {
- err := fs.Rename(ctx, []uint{}, []uint{10}, "ne/w")
- asserts.Error(err)
- asserts.Equal(ErrIllegalObjectName, err)
- }
-
- // 新名字是文件,扩展名不合法
- {
- fs.Policy.OptionsSerialized.FileType = []string{"txt"}
- err := fs.Rename(ctx, []uint{}, []uint{10}, "1.jpg")
- asserts.Error(err)
- asserts.Equal(ErrIllegalObjectName, err)
- }
-
- // 新名字是目录,不应该检测扩展名
- {
- fs.Policy.OptionsSerialized.FileType = []string{"txt"}
- mock.ExpectQuery("SELECT(.+)folders(.+)").
- WithArgs(10, 1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- err := fs.Rename(ctx, []uint{10}, []uint{}, "new")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Equal(ErrPathNotExist, err)
- }
-}
-
-func TestFileSystem_SaveTo(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- }}
- ctx := context.Background()
-
- // 单文件 失败
- {
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
- fs.SetTargetFile(&[]model.File{{Name: "test.txt"}})
- err := fs.SaveTo(ctx, "/")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
- // 目录 成功
- {
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
- fs.SetTargetDir(&[]model.Folder{{Name: "folder"}})
- err := fs.SaveTo(ctx, "/")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
- // 父目录不存在
- {
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- fs.SetTargetDir(&[]model.Folder{{Name: "folder"}})
- err := fs.SaveTo(ctx, "/")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- }
-}
diff --git a/pkg/filesystem/path.go b/pkg/filesystem/path.go
index b0637aa..056c0c8 100644
--- a/pkg/filesystem/path.go
+++ b/pkg/filesystem/path.go
@@ -15,13 +15,20 @@ import (
// IsPathExist 返回给定目录是否存在
// 如果存在就返回目录
func (fs *FileSystem) IsPathExist(path string) (bool, *model.Folder) {
+ tracedEnd, currentFolder := fs.getClosedParent(path)
+ if tracedEnd {
+ return true, currentFolder
+ }
+ return false, nil
+}
+
+func (fs *FileSystem) getClosedParent(path string) (bool, *model.Folder) {
pathList := util.SplitPath(path)
if len(pathList) == 0 {
return false, nil
}
// 递归步入目录
- // TODO:测试新增
var currentFolder *model.Folder
// 如果已设定跟目录对象,则从给定目录向下遍历
@@ -42,10 +49,12 @@ func (fs *FileSystem) IsPathExist(path string) (bool, *model.Folder) {
return false, nil
}
} else {
- currentFolder, err = currentFolder.GetChild(folderName)
+ nextFolder, err := currentFolder.GetChild(folderName)
if err != nil {
- return false, nil
+ return false, currentFolder
}
+
+ currentFolder = nextFolder
}
}
diff --git a/pkg/filesystem/path_test.go b/pkg/filesystem/path_test.go
deleted file mode 100644
index e4065a4..0000000
--- a/pkg/filesystem/path_test.go
+++ /dev/null
@@ -1,172 +0,0 @@
-package filesystem
-
-import (
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestFileSystem_IsFileExist(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- }}
-
- // 存在
- {
- path := "/1.txt"
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, "1.txt").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1.txt"))
- exist, file := fs.IsFileExist(path)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.True(exist)
- asserts.Equal(uint(1), file.ID)
- }
-
- // 文件不存在
- {
- path := "/1.txt"
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectQuery("SELECT(.+)").WithArgs(1, "1.txt").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
- exist, _ := fs.IsFileExist(path)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.False(exist)
- }
-
- // 父目录不存在
- {
- path := "/1.txt"
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id"}))
- exist, _ := fs.IsFileExist(path)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.False(exist)
- }
-}
-
-func TestFileSystem_IsPathExist(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- }}
-
- // 查询根目录
- {
- path := "/"
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- exist, folder := fs.IsPathExist(path)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.True(exist)
- asserts.Equal(uint(1), folder.ID)
- }
-
- // 深层路径
- {
- path := "/1/2/3"
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "1").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
- // 2
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(2, 1, "2").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1))
- // 3
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(3, 1, "3").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(4, 1))
- exist, folder := fs.IsPathExist(path)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.True(exist)
- asserts.Equal(uint(4), folder.ID)
- }
-
- // 深层路径 重设根目录为/1
- {
- path := "/2/3"
- fs.Root = &model.Folder{Name: "1", Model: gorm.Model{ID: 2}, OwnerID: 1}
- // 2
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(2, 1, "2").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1))
- // 3
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(3, 1, "3").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(4, 1))
- exist, folder := fs.IsPathExist(path)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.True(exist)
- asserts.Equal(uint(4), folder.ID)
- fs.Root = nil
- }
-
- // 深层 不存在
- {
- path := "/1/2/3"
- // 根目录
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- // 1
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, 1, "1").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(2, 1))
- // 2
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(2, 1, "2").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(3, 1))
- // 3
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(3, 1, "3").
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}))
- exist, folder := fs.IsPathExist(path)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.False(exist)
- asserts.Nil(folder)
- }
-
-}
-
-func TestFileSystem_IsChildFileExist(t *testing.T) {
- asserts := assert.New(t)
- fs := &FileSystem{User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- }}
- folder := model.Folder{
- Model: gorm.Model{ID: 1},
- Name: "123",
- Position: "/",
- }
-
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1, "321").
- WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(2, "321"))
- exist, childFile := fs.IsChildFileExist(&folder, "321")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.True(exist)
- asserts.Equal("/123", childFile.Position)
-}
diff --git a/pkg/filesystem/relocate.go b/pkg/filesystem/relocate.go
new file mode 100755
index 0000000..673c61b
--- /dev/null
+++ b/pkg/filesystem/relocate.go
@@ -0,0 +1,102 @@
+package filesystem
+
+import (
+ "context"
+
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
+)
+
+/* ================
+ 存储策略迁移
+ ================
+*/
+
+// Relocate 将目标文件转移到当前存储策略下
+func (fs *FileSystem) Relocate(ctx context.Context, files []model.File, policy *model.Policy) error {
+ // 重设存储策略为要转移的目的策略
+ fs.Policy = policy
+ if err := fs.DispatchHandler(); err != nil {
+ return err
+ }
+
+ // 将目前文件根据存储策略分组
+ fileGroup := fs.GroupFileByPolicy(ctx, files)
+
+ // 按照存储策略分组处理每个文件
+ for _, fileList := range fileGroup {
+ // 如果存储策略一样,则跳过
+ if fileList[0].GetPolicy().ID == fs.Policy.ID {
+ util.Log().Debug("Skip relocating %d file(s), since they are already in desired policy.",
+ len(fileList))
+ continue
+ }
+
+ // 获取当前存储策略的处理器
+ currentPolicy, _ := model.GetPolicyByID(fileList[0].PolicyID)
+ currentHandler, err := getNewPolicyHandler(¤tPolicy)
+ if err != nil {
+ return err
+ }
+
+ // 记录转移完毕需要删除的文件
+ toBeDeleted := make([]model.File, 0, len(fileList))
+
+ // 循环处理每一个文件
+ // for id, r := 0, len(fileList); id < r; id++ {
+ for id, _ := range fileList {
+ // 验证文件是否符合新存储策略的规定
+ if err := HookValidateFile(ctx, fs, fileList[id]); err != nil {
+ util.Log().Debug("File %q failed to pass validators in new policy %q, skipping.",
+ fileList[id].Name, err)
+ continue
+ }
+
+ // 为文件生成新存储策略下的物理路径
+ savePath := fs.GenerateSavePath(ctx, fileList[id])
+
+ // 获取原始文件
+ src, err := currentHandler.Get(ctx, fileList[id].SourceName)
+ if err != nil {
+ util.Log().Debug("Failed to get file %q: %s, skipping.",
+ fileList[id].Name, err)
+ continue
+ }
+
+ // 转存到新的存储策略
+ if err := fs.Handler.Put(ctx, &fsctx.FileStream{
+ File: src,
+ SavePath: savePath,
+ Size: fileList[id].Size,
+ }); err != nil {
+ util.Log().Debug("Failed to migrate file %q: %s, skipping.",
+ fileList[id].Name, err)
+ continue
+ }
+
+ toBeDeleted = append(toBeDeleted, *fileList[id])
+
+ // 更新文件信息
+ fileList[id].Relocate(savePath, fs.Policy.ID)
+ }
+
+ // 排除带有软链接的文件
+ toBeDeletedClean, err := model.RemoveFilesWithSoftLinks(toBeDeleted)
+ if err != nil {
+ util.Log().Warning("Failed to check soft links: %s", err)
+ }
+
+ deleteSourceNames := make([]string, 0, len(toBeDeleted))
+ for i := 0; i < len(toBeDeletedClean); i++ {
+ deleteSourceNames = append(deleteSourceNames, toBeDeletedClean[i].SourceName)
+ }
+
+ // 删除原始策略中的文件
+ if _, err := currentHandler.Delete(ctx, deleteSourceNames); err != nil {
+ util.Log().Warning("Cannot delete files in origin policy after relocating: %s", err)
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/filesystem/upload.go b/pkg/filesystem/upload.go
index 08dde53..bb8844d 100644
--- a/pkg/filesystem/upload.go
+++ b/pkg/filesystem/upload.go
@@ -194,17 +194,15 @@ func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileS
// UploadFromStream 从文件流上传文件
func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream, resetPolicy bool) error {
+ // 给文件系统分配钩子
+ fs.Lock.Lock()
if resetPolicy {
- // 重设存储策略
- fs.Policy = &fs.User.Policy
- err := fs.DispatchHandler()
+ err := fs.SetPolicyFromPath(file.VirtualPath)
if err != nil {
return err
}
}
- // 给文件系统分配钩子
- fs.Lock.Lock()
if fs.Hooks == nil {
fs.Use("BeforeUpload", HookValidateFile)
fs.Use("BeforeUpload", HookValidateCapacity)
diff --git a/pkg/filesystem/upload_test.go b/pkg/filesystem/upload_test.go
deleted file mode 100644
index 61dad9f..0000000
--- a/pkg/filesystem/upload_test.go
+++ /dev/null
@@ -1,263 +0,0 @@
-package filesystem
-
-import (
- "context"
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
- "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
- "github.com/gin-gonic/gin"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
- "io/ioutil"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-)
-
-type FileHeaderMock struct {
- testMock.Mock
-}
-
-func (m FileHeaderMock) Put(ctx context.Context, file fsctx.FileHeader) error {
- args := m.Called(ctx, file)
- return args.Error(0)
-}
-
-func (m FileHeaderMock) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
- args := m.Called(ctx, ttl, uploadSession, file)
- return args.Get(0).(*serializer.UploadCredential), args.Error(1)
-}
-
-func (m FileHeaderMock) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
- args := m.Called(ctx, uploadSession)
- return args.Error(0)
-}
-
-func (m FileHeaderMock) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
- args := m.Called(ctx, path, recursive)
- return args.Get(0).([]response.Object), args.Error(1)
-}
-
-func (m FileHeaderMock) Get(ctx context.Context, path string) (response.RSCloser, error) {
- args := m.Called(ctx, path)
- return args.Get(0).(response.RSCloser), args.Error(1)
-}
-
-func (m FileHeaderMock) Delete(ctx context.Context, files []string) ([]string, error) {
- args := m.Called(ctx, files)
- return args.Get(0).([]string), args.Error(1)
-}
-
-func (m FileHeaderMock) Thumb(ctx context.Context, files *model.File) (*response.ContentResponse, error) {
- args := m.Called(ctx, files)
- return args.Get(0).(*response.ContentResponse), args.Error(1)
-}
-
-func (m FileHeaderMock) Source(ctx context.Context, path string, expires int64, isDownload bool, speed int) (string, error) {
- args := m.Called(ctx, path, expires, isDownload, speed)
- return args.Get(0).(string), args.Error(1)
-}
-
-func TestFileSystem_Upload(t *testing.T) {
- asserts := assert.New(t)
-
- // 正常
- testHandler := new(FileHeaderMock)
- testHandler.On("Put", testMock.Anything, testMock.Anything, testMock.Anything).Return(nil)
- fs := &FileSystem{
- Handler: testHandler,
- User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- },
- Policy: &model.Policy{
- AutoRename: false,
- DirNameRule: "{path}",
- },
- }
- ctx, cancel := context.WithCancel(context.Background())
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", nil)
- ctx = context.WithValue(ctx, fsctx.GinCtx, c)
- cancel()
- file := &fsctx.FileStream{
- Size: 5,
- VirtualPath: "/",
- Name: "1.txt",
- }
- err := fs.Upload(ctx, file)
- asserts.NoError(err)
-
- // 正常,上下文已指定源文件
- testHandler = new(FileHeaderMock)
- testHandler.On("Put", testMock.Anything, testMock.Anything).Return(nil)
- fs = &FileSystem{
- Handler: testHandler,
- User: &model.User{
- Model: gorm.Model{
- ID: 1,
- },
- },
- Policy: &model.Policy{
- AutoRename: false,
- DirNameRule: "{path}",
- },
- }
- ctx, cancel = context.WithCancel(context.Background())
- c, _ = gin.CreateTestContext(httptest.NewRecorder())
- c.Request, _ = http.NewRequest("POST", "/", nil)
- ctx = context.WithValue(ctx, fsctx.GinCtx, c)
- ctx = context.WithValue(ctx, fsctx.FileModelCtx, model.File{SourceName: "123/123.txt"})
- cancel()
- file = &fsctx.FileStream{
- Size: 5,
- VirtualPath: "/",
- Name: "1.txt",
- File: ioutil.NopCloser(strings.NewReader("")),
- }
- err = fs.Upload(ctx, file)
- asserts.NoError(err)
-
- // BeforeUpload 返回错误
- fs.Use("BeforeUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
- return errors.New("error")
- })
- err = fs.Upload(ctx, file)
- asserts.Error(err)
- fs.Hooks["BeforeUpload"] = nil
- testHandler.AssertExpectations(t)
-
- // 上传文件失败
- testHandler2 := new(FileHeaderMock)
- testHandler2.On("Put", testMock.Anything, testMock.Anything).Return(errors.New("error"))
- fs.Handler = testHandler2
- err = fs.Upload(ctx, file)
- asserts.Error(err)
- testHandler2.AssertExpectations(t)
-
- // AfterUpload失败
- testHandler3 := new(FileHeaderMock)
- testHandler3.On("Put", testMock.Anything, testMock.Anything).Return(nil)
- fs.Handler = testHandler3
- fs.Use("AfterUpload", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
- return errors.New("error")
- })
- fs.Use("AfterValidateFailed", func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
- return errors.New("error")
- })
- err = fs.Upload(ctx, file)
- asserts.Error(err)
- testHandler2.AssertExpectations(t)
-
-}
-
-func TestFileSystem_GetUploadToken(t *testing.T) {
- asserts := assert.New(t)
- fs := FileSystem{
- User: &model.User{Model: gorm.Model{ID: 1}},
- Policy: &model.Policy{},
- }
- ctx := context.Background()
-
- // 成功
- {
- cache.SetSettings(map[string]string{
- "upload_session_timeout": "10",
- }, "setting_")
- testHandler := new(FileHeaderMock)
- testHandler.On("Token", testMock.Anything, int64(10), testMock.Anything, testMock.Anything).Return(&serializer.UploadCredential{Credential: "test"}, nil)
- fs.Handler = testHandler
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found"))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- res, err := fs.CreateUploadSession(ctx, &fsctx.FileStream{
- Size: 0,
- Name: "file",
- VirtualPath: "/",
- })
- asserts.NoError(mock.ExpectationsWereMet())
- testHandler.AssertExpectations(t)
- asserts.NoError(err)
- asserts.Equal("test", res.Credential)
- }
-
- // 无法获取上传凭证
- {
- cache.SetSettings(map[string]string{
- "upload_credential_timeout": "10",
- "upload_session_timeout": "10",
- }, "setting_")
- testHandler := new(FileHeaderMock)
- testHandler.On("Token", testMock.Anything, int64(10), testMock.Anything, testMock.Anything).Return(&serializer.UploadCredential{}, errors.New("error"))
- fs.Handler = testHandler
- mock.ExpectQuery("SELECT(.+)").
- WithArgs(1).
- WillReturnRows(sqlmock.NewRows([]string{"id", "owner_id"}).AddRow(1, 1))
- mock.ExpectQuery("SELECT(.+)files(.+)").WillReturnError(errors.New("not found"))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectExec("UPDATE(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- _, err := fs.CreateUploadSession(ctx, &fsctx.FileStream{
- Size: 0,
- Name: "file",
- VirtualPath: "/",
- })
- asserts.NoError(mock.ExpectationsWereMet())
- testHandler.AssertExpectations(t)
- asserts.Error(err)
- }
-}
-
-func TestFileSystem_UploadFromStream(t *testing.T) {
- asserts := assert.New(t)
- fs := FileSystem{
- User: &model.User{
- Model: gorm.Model{ID: 1},
- Policy: model.Policy{Type: "mock"},
- },
- Policy: &model.Policy{Type: "mock"},
- }
- ctx := context.Background()
-
- err := fs.UploadFromStream(ctx, &fsctx.FileStream{
- File: ioutil.NopCloser(strings.NewReader("123")),
- }, true)
- asserts.Error(err)
-}
-
-func TestFileSystem_UploadFromPath(t *testing.T) {
- asserts := assert.New(t)
- fs := FileSystem{
- User: &model.User{
- Model: gorm.Model{ID: 1},
- Policy: model.Policy{Type: "mock"},
- },
- Policy: &model.Policy{Type: "mock"},
- }
- ctx := context.Background()
-
- // 文件不存在
- {
- err := fs.UploadFromPath(ctx, "test/not_exist", "/", fsctx.Overwrite)
- asserts.Error(err)
- }
-
- // 文存在,上传失败
- {
- err := fs.UploadFromPath(ctx, "tests/test.zip", "/", fsctx.Overwrite)
- asserts.Error(err)
- }
-}
diff --git a/pkg/filesystem/validator_test.go b/pkg/filesystem/validator_test.go
deleted file mode 100644
index 8f685f2..0000000
--- a/pkg/filesystem/validator_test.go
+++ /dev/null
@@ -1,112 +0,0 @@
-package filesystem
-
-import (
- "context"
- "database/sql"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-var mock sqlmock.Sqlmock
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- defer db.Close()
- m.Run()
-}
-
-func TestFileSystem_ValidateLegalName(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{}
- asserts.True(fs.ValidateLegalName(ctx, "1.txt"))
- asserts.True(fs.ValidateLegalName(ctx, "1-1.txt"))
- asserts.True(fs.ValidateLegalName(ctx, "1?1.txt"))
- asserts.False(fs.ValidateLegalName(ctx, "1:1.txt"))
- asserts.False(fs.ValidateLegalName(ctx, "../11.txt"))
- asserts.False(fs.ValidateLegalName(ctx, "/11.txt"))
- asserts.False(fs.ValidateLegalName(ctx, "\\11.txt"))
- asserts.False(fs.ValidateLegalName(ctx, ""))
- asserts.False(fs.ValidateLegalName(ctx, "1.tx t "))
- asserts.True(fs.ValidateLegalName(ctx, "1.tx t"))
-}
-
-func TestFileSystem_ValidateCapacity(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- cache.Set("pack_size_0", uint64(0), 0)
- fs := FileSystem{
- User: &model.User{
- Storage: 10,
- Group: model.Group{
- MaxStorage: 11,
- },
- },
- }
-
- asserts.True(fs.ValidateCapacity(ctx, 1))
- asserts.Equal(uint64(11), fs.User.Storage)
-
- fs.User.Storage = 5
- asserts.False(fs.ValidateCapacity(ctx, 10))
- asserts.Equal(uint64(5), fs.User.Storage)
-}
-
-func TestFileSystem_ValidateFileSize(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{
- User: &model.User{},
- Policy: &model.Policy{
- MaxSize: 10,
- },
- }
-
- asserts.True(fs.ValidateFileSize(ctx, 5))
- asserts.True(fs.ValidateFileSize(ctx, 10))
- asserts.False(fs.ValidateFileSize(ctx, 11))
-
- // 无限制
- fs.Policy.MaxSize = 0
- asserts.True(fs.ValidateFileSize(ctx, 11))
-}
-
-func TestFileSystem_ValidateExtension(t *testing.T) {
- asserts := assert.New(t)
- ctx := context.Background()
- fs := FileSystem{
- User: &model.User{},
- Policy: &model.Policy{
- OptionsSerialized: model.PolicyOption{
- FileType: nil,
- },
- },
- }
-
- asserts.True(fs.ValidateExtension(ctx, "1"))
- asserts.True(fs.ValidateExtension(ctx, "1.txt"))
-
- fs.Policy.OptionsSerialized.FileType = []string{}
- asserts.True(fs.ValidateExtension(ctx, "1"))
- asserts.True(fs.ValidateExtension(ctx, "1.txt"))
-
- fs.Policy.OptionsSerialized.FileType = []string{"txt", "jpg"}
- asserts.False(fs.ValidateExtension(ctx, "1"))
- asserts.False(fs.ValidateExtension(ctx, "1.jpg.png"))
- asserts.True(fs.ValidateExtension(ctx, "1.txt"))
- asserts.True(fs.ValidateExtension(ctx, "1.png.jpg"))
- asserts.True(fs.ValidateExtension(ctx, "1.png.jpG"))
- asserts.False(fs.ValidateExtension(ctx, "1.png"))
-}
diff --git a/pkg/hashid/hash_test.go b/pkg/hashid/hash_test.go
deleted file mode 100644
index 5471d9e..0000000
--- a/pkg/hashid/hash_test.go
+++ /dev/null
@@ -1,69 +0,0 @@
-package hashid
-
-import (
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestHashEncode(t *testing.T) {
- asserts := assert.New(t)
-
- {
- res, err := HashEncode([]int{1, 2, 3})
- asserts.NoError(err)
- asserts.NotEmpty(res)
- }
-
- {
- res, err := HashEncode([]int{})
- asserts.Error(err)
- asserts.Empty(res)
- }
-
-}
-
-func TestHashID(t *testing.T) {
- asserts := assert.New(t)
-
- {
- res := HashID(1, ShareID)
- asserts.NotEmpty(res)
- }
-}
-
-func TestHashDecode(t *testing.T) {
- asserts := assert.New(t)
-
- // 正常
- {
- res, _ := HashEncode([]int{1, 2, 3})
- decodeRes, err := HashDecode(res)
- asserts.NoError(err)
- asserts.Equal([]int{1, 2, 3}, decodeRes)
- }
-
- // 出错
- {
- decodeRes, err := HashDecode("233")
- asserts.Error(err)
- asserts.Len(decodeRes, 0)
- }
-}
-
-func TestDecodeHashID(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- uid, err := DecodeHashID(HashID(1, ShareID), ShareID)
- asserts.NoError(err)
- asserts.EqualValues(1, uid)
- }
-
- // 类型不匹配
- {
- uid, err := DecodeHashID(HashID(1, ShareID), UserID)
- asserts.Error(err)
- asserts.EqualValues(0, uid)
- }
-}
diff --git a/pkg/mocks/remoteclientmock/mock.go b/pkg/mocks/remoteclientmock/mock.go
index 303b673..f036541 100644
--- a/pkg/mocks/remoteclientmock/mock.go
+++ b/pkg/mocks/remoteclientmock/mock.go
@@ -2,6 +2,7 @@ package remoteclientmock
import (
"context"
+
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/mock"
diff --git a/pkg/mq/mq_test.go b/pkg/mq/mq_test.go
deleted file mode 100644
index 9acdd3f..0000000
--- a/pkg/mq/mq_test.go
+++ /dev/null
@@ -1,149 +0,0 @@
-package mq
-
-import (
- "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
- "github.com/stretchr/testify/assert"
- "sync"
- "testing"
- "time"
-)
-
-func TestPublishAndSubscribe(t *testing.T) {
- t.Parallel()
- asserts := assert.New(t)
- mq := NewMQ()
-
- // No subscriber
- {
- asserts.NotPanics(func() {
- mq.Publish("No subscriber", Message{})
- })
- }
-
- // One channel subscriber
- {
- topic := "One channel subscriber"
- msg := Message{TriggeredBy: "Tester"}
- notifier := mq.Subscribe(topic, 0)
- mq.Publish(topic, msg)
- wg := sync.WaitGroup{}
- wg.Add(1)
- go func() {
- wg.Done()
- msgRecv := <-notifier
- asserts.Equal(msg, msgRecv)
- }()
- wg.Wait()
- }
-
- // two channel subscriber
- {
- topic := "two channel subscriber"
- msg := Message{TriggeredBy: "Tester"}
- notifier := mq.Subscribe(topic, 0)
- notifier2 := mq.Subscribe(topic, 0)
- mq.Publish(topic, msg)
- wg := sync.WaitGroup{}
- wg.Add(2)
- go func() {
- wg.Done()
- msgRecv := <-notifier
- asserts.Equal(msg, msgRecv)
- }()
- go func() {
- wg.Done()
- msgRecv := <-notifier2
- asserts.Equal(msg, msgRecv)
- }()
- wg.Wait()
- }
-
- // two channel subscriber, one timeout
- {
- topic := "two channel subscriber, one timeout"
- msg := Message{TriggeredBy: "Tester"}
- mq.Subscribe(topic, 0)
- notifier2 := mq.Subscribe(topic, 0)
- mq.Publish(topic, msg)
- wg := sync.WaitGroup{}
- wg.Add(1)
- go func() {
- wg.Done()
- msgRecv := <-notifier2
- asserts.Equal(msg, msgRecv)
- }()
- wg.Wait()
- }
-
- // two channel subscriber, one unsubscribe
- {
- topic := "two channel subscriber, one unsubscribe"
- msg := Message{TriggeredBy: "Tester"}
- mq.Subscribe(topic, 0)
- notifier2 := mq.Subscribe(topic, 0)
- notifier := mq.Subscribe(topic, 0)
- mq.Unsubscribe(topic, notifier)
- mq.Publish(topic, msg)
- wg := sync.WaitGroup{}
- wg.Add(1)
- go func() {
- wg.Done()
- msgRecv := <-notifier2
- asserts.Equal(msg, msgRecv)
- }()
- wg.Wait()
-
- select {
- case <-notifier:
- t.Error()
- default:
- }
- }
-}
-
-func TestAria2Interface(t *testing.T) {
- t.Parallel()
- asserts := assert.New(t)
- mq := NewMQ()
- var (
- OnDownloadStart int
- OnDownloadPause int
- OnDownloadStop int
- OnDownloadComplete int
- OnDownloadError int
- )
- l := sync.Mutex{}
-
- mq.SubscribeCallback("TestAria2Interface", func(message Message) {
- asserts.Equal("TestAria2Interface", message.TriggeredBy)
- l.Lock()
- defer l.Unlock()
- switch message.Event {
- case "1":
- OnDownloadStart++
- case "2":
- OnDownloadPause++
- case "5":
- OnDownloadStop++
- case "4":
- OnDownloadComplete++
- case "3":
- OnDownloadError++
- }
- })
-
- mq.OnDownloadStart([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
- mq.OnDownloadPause([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
- mq.OnDownloadStop([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
- mq.OnDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
- mq.OnDownloadError([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
- mq.OnBtDownloadComplete([]rpc.Event{{"TestAria2Interface"}, {"TestAria2Interface"}})
-
- time.Sleep(time.Duration(500) * time.Millisecond)
-
- asserts.Equal(2, OnDownloadStart)
- asserts.Equal(2, OnDownloadPause)
- asserts.Equal(2, OnDownloadStop)
- asserts.Equal(4, OnDownloadComplete)
- asserts.Equal(2, OnDownloadError)
-}
diff --git a/pkg/payment/alipay.go b/pkg/payment/alipay.go
new file mode 100755
index 0000000..a08f45e
--- /dev/null
+++ b/pkg/payment/alipay.go
@@ -0,0 +1,43 @@
+package payment
+
+import (
+ "fmt"
+ "net/url"
+
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ alipay "github.com/smartwalle/alipay/v3"
+)
+
+// Alipay 支付宝当面付支付处理
+type Alipay struct {
+ Client *alipay.Client
+}
+
+// Create 创建订单
+func (pay *Alipay) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
+ gateway, _ := url.Parse("/api/v3/callback/alipay")
+ var p = alipay.TradePreCreate{
+ Trade: alipay.Trade{
+ NotifyURL: model.GetSiteURL().ResolveReference(gateway).String(),
+ Subject: order.Name,
+ OutTradeNo: order.OrderNo,
+ TotalAmount: fmt.Sprintf("%.2f", float64(order.Price*order.Num)/100),
+ },
+ }
+
+ if _, err := order.Create(); err != nil {
+ return nil, ErrInsertOrder.WithError(err)
+ }
+
+ res, err := pay.Client.TradePreCreate(p)
+ if err != nil {
+ return nil, ErrIssueOrder.WithError(err)
+ }
+
+ return &OrderCreateRes{
+ Payment: true,
+ QRCode: res.QRCode,
+ ID: order.OrderNo,
+ }, nil
+}
diff --git a/pkg/payment/custom.go b/pkg/payment/custom.go
new file mode 100755
index 0000000..23c3748
--- /dev/null
+++ b/pkg/payment/custom.go
@@ -0,0 +1,93 @@
+package payment
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/auth"
+ "github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/request"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/gofrs/uuid"
+ "github.com/qiniu/go-sdk/v7/sms/bytes"
+ "net/http"
+ "net/url"
+)
+
+// Custom payment client
+type Custom struct {
+ client request.Client
+ endpoint string
+ authClient auth.Auth
+}
+
+const (
+ paymentTTL = 3600 * 24 // 24h
+ CallbackSessionPrefix = "custom_payment_callback_"
+)
+
+func newCustomClient(endpoint, secret string) *Custom {
+ authClient := auth.HMACAuth{
+ SecretKey: []byte(secret),
+ }
+ return &Custom{
+ endpoint: endpoint,
+ authClient: auth.General,
+ client: request.NewClient(
+ request.WithCredential(authClient, paymentTTL),
+ request.WithMasterMeta(),
+ ),
+ }
+}
+
+// Request body from Cloudreve to create a new payment
+type NewCustomOrderRequest struct {
+ Name string `json:"name"` // Order name
+ OrderNo string `json:"order_no"` // Order number
+ NotifyURL string `json:"notify_url"` // Payment callback url
+ Amount int64 `json:"amount"` // Order total amount
+}
+
+// Create a new payment
+func (pay *Custom) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
+ callbackID := uuid.Must(uuid.NewV4())
+ gateway, _ := url.Parse(fmt.Sprintf("/api/v3/callback/custom/%s/%s", order.OrderNo, callbackID))
+ callback, err := auth.SignURI(pay.authClient, model.GetSiteURL().ResolveReference(gateway).String(), paymentTTL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to sign callback url: %w", err)
+ }
+
+ cache.Set(CallbackSessionPrefix+callbackID.String(), order.OrderNo, paymentTTL)
+
+ body := &NewCustomOrderRequest{
+ Name: order.Name,
+ OrderNo: order.OrderNo,
+ NotifyURL: callback.String(),
+ Amount: int64(order.Price * order.Num),
+ }
+ bodyJson, err := json.Marshal(body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to encode body: %w", err)
+ }
+
+ res, err := pay.client.Request("POST", pay.endpoint, bytes.NewReader(bodyJson)).
+ CheckHTTPResponse(http.StatusOK).DecodeResponse()
+ if err != nil {
+ return nil, fmt.Errorf("failed to request payment gateway: %w", err)
+ }
+
+ if res.Code != 0 {
+ return nil, errors.New(res.Error)
+ }
+
+ if _, err := order.Create(); err != nil {
+ return nil, ErrInsertOrder.WithError(err)
+ }
+
+ return &OrderCreateRes{
+ Payment: true,
+ QRCode: res.Data.(string),
+ ID: order.OrderNo,
+ }, nil
+}
diff --git a/pkg/payment/order.go b/pkg/payment/order.go
new file mode 100755
index 0000000..c161411
--- /dev/null
+++ b/pkg/payment/order.go
@@ -0,0 +1,171 @@
+package payment
+
+import (
+ "fmt"
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/iGoogle-ink/gopay/wechat/v3"
+ "github.com/qingwg/payjs"
+ "github.com/smartwalle/alipay/v3"
+ "math/rand"
+ "net/url"
+ "time"
+)
+
+var (
+ // ErrUnknownPaymentMethod 未知支付方式
+ ErrUnknownPaymentMethod = serializer.NewError(serializer.CodeInternalSetting, "Unknown payment method", nil)
+ // ErrUnsupportedPaymentMethod 未知支付方式
+ ErrUnsupportedPaymentMethod = serializer.NewError(serializer.CodeInternalSetting, "This order cannot be paid with this method", nil)
+ // ErrInsertOrder 无法插入订单记录
+ ErrInsertOrder = serializer.NewError(serializer.CodeDBError, "Failed to insert order record", nil)
+ // ErrScoreNotEnough 积分不足
+ ErrScoreNotEnough = serializer.NewError(serializer.CodeInsufficientCredit, "", nil)
+ // ErrCreateStoragePack 无法创建容量包
+ ErrCreateStoragePack = serializer.NewError(serializer.CodeDBError, "Failed to create storage pack record", nil)
+ // ErrGroupConflict 用户组冲突
+ ErrGroupConflict = serializer.NewError(serializer.CodeGroupConflict, "", nil)
+ // ErrGroupInvalid 用户组冲突
+ ErrGroupInvalid = serializer.NewError(serializer.CodeGroupInvalid, "", nil)
+ // ErrAdminFulfillGroup 管理员无法购买用户组
+ ErrAdminFulfillGroup = serializer.NewError(serializer.CodeFulfillAdminGroup, "", nil)
+ // ErrUpgradeGroup 用户组冲突
+ ErrUpgradeGroup = serializer.NewError(serializer.CodeDBError, "Failed to update user's group", nil)
+ // ErrUInitPayment 无法初始化支付实例
+ ErrUInitPayment = serializer.NewError(serializer.CodeInternalSetting, "Failed to initialize payment client", nil)
+ // ErrIssueOrder 订单接口请求失败
+ ErrIssueOrder = serializer.NewError(serializer.CodeInternalSetting, "Failed to create order", nil)
+ // ErrOrderNotFound 订单不存在
+ ErrOrderNotFound = serializer.NewError(serializer.CodeNotFound, "", nil)
+)
+
+// Pay 支付处理接口
+type Pay interface {
+ Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error)
+}
+
+// OrderCreateRes 订单创建结果
+type OrderCreateRes struct {
+ Payment bool `json:"payment"` // 是否需要支付
+ ID string `json:"id,omitempty"` // 订单号
+ QRCode string `json:"qr_code,omitempty"` // 支付二维码指向的地址
+}
+
+// NewPaymentInstance 获取新的支付实例
+func NewPaymentInstance(method string) (Pay, error) {
+ switch method {
+ case "score":
+ return &ScorePayment{}, nil
+ case "alipay":
+ options := model.GetSettingByNames("alipay_enabled", "appid", "appkey", "shopid")
+ if options["alipay_enabled"] != "1" {
+ return nil, ErrUnknownPaymentMethod
+ }
+
+ // 初始化支付宝客户端
+ var client, err = alipay.New(options["appid"], options["appkey"], true)
+ if err != nil {
+ return nil, ErrUInitPayment.WithError(err)
+ }
+
+ // 加载支付宝公钥
+ err = client.LoadAliPayPublicKey(options["shopid"])
+ if err != nil {
+ return nil, ErrUInitPayment.WithError(err)
+ }
+
+ return &Alipay{Client: client}, nil
+ case "payjs":
+ options := model.GetSettingByNames("payjs_enabled", "payjs_secret", "payjs_id")
+ if options["payjs_enabled"] != "1" {
+ return nil, ErrUnknownPaymentMethod
+ }
+
+ callback, _ := url.Parse("/api/v3/callback/payjs")
+ payjsConfig := &payjs.Config{
+ Key: options["payjs_secret"],
+ MchID: options["payjs_id"],
+ NotifyUrl: model.GetSiteURL().ResolveReference(callback).String(),
+ }
+
+ return &PayJSClient{Client: payjs.New(payjsConfig)}, nil
+ case "wechat":
+ options := model.GetSettingByNames("wechat_enabled", "wechat_appid", "wechat_mchid", "wechat_serial_no", "wechat_api_key", "wechat_pk_content")
+ if options["wechat_enabled"] != "1" {
+ return nil, ErrUnknownPaymentMethod
+ }
+ client, err := wechat.NewClientV3(options["wechat_appid"], options["wechat_mchid"], options["wechat_serial_no"], options["wechat_api_key"], options["wechat_pk_content"])
+ if err != nil {
+ return nil, ErrUInitPayment.WithError(err)
+ }
+
+ return &Wechat{Client: client, ApiV3Key: options["wechat_api_key"]}, nil
+ case "custom":
+ options := model.GetSettingByNames("custom_payment_enabled", "custom_payment_endpoint", "custom_payment_secret")
+ if !model.IsTrueVal(options["custom_payment_enabled"]) {
+ return nil, ErrUnknownPaymentMethod
+ }
+
+ return newCustomClient(options["custom_payment_endpoint"], options["custom_payment_secret"]), nil
+ default:
+ return nil, ErrUnknownPaymentMethod
+ }
+}
+
+// NewOrder 创建新订单
+func NewOrder(pack *serializer.PackProduct, group *serializer.GroupProducts, num int, method string, user *model.User) (*OrderCreateRes, error) {
+ // 获取支付实例
+ pay, err := NewPaymentInstance(method)
+ if err != nil {
+ return nil, err
+ }
+
+ var (
+ orderType int
+ productID int64
+ title string
+ price int
+ )
+ if pack != nil {
+ orderType = model.PackOrderType
+ productID = pack.ID
+ title = pack.Name
+ price = pack.Price
+ } else if group != nil {
+ if err := checkGroupUpgrade(user, group); err != nil {
+ return nil, err
+ }
+
+ orderType = model.GroupOrderType
+ productID = group.ID
+ title = group.Name
+ price = group.Price
+ } else {
+ orderType = model.ScoreOrderType
+ productID = 0
+ title = fmt.Sprintf("%d 积分", num)
+ price = model.GetIntSetting("score_price", 1)
+ }
+
+ // 创建订单记录
+ order := &model.Order{
+ UserID: user.ID,
+ OrderNo: orderID(),
+ Type: orderType,
+ Method: method,
+ ProductID: productID,
+ Num: num,
+ Name: fmt.Sprintf("%s - %s", model.GetSettingByName("siteName"), title),
+ Price: price,
+ Status: model.OrderUnpaid,
+ }
+
+ return pay.Create(order, pack, group, user)
+}
+
+func orderID() string {
+ return fmt.Sprintf("%s%d",
+ time.Now().Format("20060102150405"),
+ 100000+rand.Intn(900000),
+ )
+}
diff --git a/pkg/payment/payjs.go b/pkg/payment/payjs.go
new file mode 100755
index 0000000..a82ce7b
--- /dev/null
+++ b/pkg/payment/payjs.go
@@ -0,0 +1,31 @@
+package payment
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/qingwg/payjs"
+)
+
+// PayJSClient PayJS支付处理
+type PayJSClient struct {
+ Client *payjs.PayJS
+}
+
+// Create 创建订单
+func (pay *PayJSClient) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
+ if _, err := order.Create(); err != nil {
+ return nil, ErrInsertOrder.WithError(err)
+ }
+
+ PayNative := pay.Client.GetNative()
+ res, err := PayNative.Create(int64(order.Price*order.Num), order.Name, order.OrderNo, "", "")
+ if err != nil {
+ return nil, ErrIssueOrder.WithError(err)
+ }
+
+ return &OrderCreateRes{
+ Payment: true,
+ QRCode: res.CodeUrl,
+ ID: order.OrderNo,
+ }, nil
+}
diff --git a/pkg/payment/purchase.go b/pkg/payment/purchase.go
new file mode 100755
index 0000000..1c558a6
--- /dev/null
+++ b/pkg/payment/purchase.go
@@ -0,0 +1,137 @@
+package payment
+
+import (
+ "encoding/json"
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "strconv"
+ "time"
+)
+
+// GivePack 创建容量包
+func GivePack(user *model.User, packInfo *serializer.PackProduct, num int) error {
+ timeNow := time.Now()
+ expires := timeNow.Add(time.Duration(packInfo.Time*int64(num)) * time.Second)
+ pack := model.StoragePack{
+ Name: packInfo.Name,
+ UserID: user.ID,
+ ActiveTime: &timeNow,
+ ExpiredTime: &expires,
+ Size: packInfo.Size,
+ }
+ if _, err := pack.Create(); err != nil {
+ return ErrCreateStoragePack.WithError(err)
+ }
+ cache.Deletes([]string{strconv.FormatUint(uint64(user.ID), 10)}, "pack_size_")
+ return nil
+}
+
+func checkGroupUpgrade(user *model.User, groupInfo *serializer.GroupProducts) error {
+ if user.Group.ID == 1 {
+ return ErrAdminFulfillGroup
+ }
+
+ // 检查用户是否已有未过期用户
+ if user.PreviousGroupID != 0 && user.GroupID != groupInfo.GroupID {
+ return ErrGroupConflict
+ }
+
+ // 用户组不能相同
+ if user.GroupID == groupInfo.GroupID && user.PreviousGroupID == 0 {
+ return ErrGroupInvalid
+ }
+
+ return nil
+}
+
+// GiveGroup 升级用户组
+func GiveGroup(user *model.User, groupInfo *serializer.GroupProducts, num int) error {
+ if err := checkGroupUpgrade(user, groupInfo); err != nil {
+ return err
+ }
+
+ timeNow := time.Now()
+ expires := timeNow.Add(time.Duration(groupInfo.Time*int64(num)) * time.Second)
+ if user.PreviousGroupID != 0 {
+ expires = user.GroupExpires.Add(time.Duration(groupInfo.Time*int64(num)) * time.Second)
+ }
+
+ if err := user.UpgradeGroup(groupInfo.GroupID, &expires); err != nil {
+ return ErrUpgradeGroup.WithError(err)
+ }
+
+ return nil
+}
+
+// GiveScore 积分充值
+func GiveScore(user *model.User, num int) error {
+ user.AddScore(num)
+ return nil
+}
+
+// GiveProduct “发货”
+func GiveProduct(user *model.User, pack *serializer.PackProduct, group *serializer.GroupProducts, num int) error {
+ if pack != nil {
+ return GivePack(user, pack, num)
+ } else if group != nil {
+ return GiveGroup(user, group, num)
+ } else {
+ return GiveScore(user, num)
+ }
+}
+
+// OrderPaid 订单已支付处理
+func OrderPaid(orderNo string) error {
+ order, err := model.GetOrderByNo(orderNo)
+ if err != nil || order.Status == model.OrderPaid {
+ return ErrOrderNotFound.WithError(err)
+ }
+
+ // 更新订单状态为 已支付
+ order.UpdateStatus(model.OrderPaid)
+
+ user, err := model.GetActiveUserByID(order.UserID)
+ if err != nil {
+ return serializer.NewError(serializer.CodeUserNotFound, "", err)
+ }
+
+ // 查询商品
+ options := model.GetSettingByNames("pack_data", "group_sell_data")
+
+ var (
+ packs []serializer.PackProduct
+ groups []serializer.GroupProducts
+ )
+ if err := json.Unmarshal([]byte(options["pack_data"]), &packs); err != nil {
+ return err
+ }
+ if err := json.Unmarshal([]byte(options["group_sell_data"]), &groups); err != nil {
+ return err
+ }
+
+ // 查找要购买的商品
+ var (
+ pack *serializer.PackProduct
+ group *serializer.GroupProducts
+ )
+ if order.Type == model.GroupOrderType {
+ for _, v := range groups {
+ if v.ID == order.ProductID {
+ group = &v
+ break
+ }
+ }
+ } else if order.Type == model.PackOrderType {
+ for _, v := range packs {
+ if v.ID == order.ProductID {
+ pack = &v
+ break
+ }
+ }
+ }
+
+ // "发货"
+ return GiveProduct(&user, pack, group, order.Num)
+
+}
diff --git a/pkg/payment/score.go b/pkg/payment/score.go
new file mode 100755
index 0000000..351c848
--- /dev/null
+++ b/pkg/payment/score.go
@@ -0,0 +1,45 @@
+package payment
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+)
+
+// ScorePayment 积分支付处理
+type ScorePayment struct {
+}
+
+// Create 创建新订单
+func (pay *ScorePayment) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
+ if pack != nil {
+ order.Price = pack.Score
+ } else {
+ order.Price = group.Score
+ }
+
+ // 检查此订单是否可用积分支付
+ if order.Price == 0 {
+ return nil, ErrUnsupportedPaymentMethod
+ }
+
+ // 扣除用户积分
+ if !user.PayScore(order.Price * order.Num) {
+ return nil, ErrScoreNotEnough
+ }
+
+ // 商品“发货”
+ if err := GiveProduct(user, pack, group, order.Num); err != nil {
+ user.AddScore(order.Price * order.Num)
+ return nil, err
+ }
+
+ // 创建订单记录
+ order.Status = model.OrderPaid
+ if _, err := order.Create(); err != nil {
+ return nil, ErrInsertOrder.WithError(err)
+ }
+
+ return &OrderCreateRes{
+ Payment: false,
+ }, nil
+}
diff --git a/pkg/payment/wechat.go b/pkg/payment/wechat.go
new file mode 100755
index 0000000..c9025c5
--- /dev/null
+++ b/pkg/payment/wechat.go
@@ -0,0 +1,88 @@
+package payment
+
+import (
+ "errors"
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
+ "github.com/iGoogle-ink/gopay"
+ "github.com/iGoogle-ink/gopay/wechat/v3"
+ "net/url"
+ "time"
+)
+
+// Wechat 微信扫码支付接口
+type Wechat struct {
+ Client *wechat.ClientV3
+ ApiV3Key string
+}
+
+// Create 创建订单
+func (pay *Wechat) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
+ gateway, _ := url.Parse("/api/v3/callback/wechat")
+ bm := make(gopay.BodyMap)
+ bm.
+ Set("description", order.Name).
+ Set("out_trade_no", order.OrderNo).
+ Set("notify_url", model.GetSiteURL().ResolveReference(gateway).String()).
+ SetBodyMap("amount", func(bm gopay.BodyMap) {
+ bm.Set("total", int64(order.Price*order.Num)).
+ Set("currency", "CNY")
+ })
+
+ wxRsp, err := pay.Client.V3TransactionNative(bm)
+ if err != nil {
+ return nil, ErrIssueOrder.WithError(err)
+ }
+
+ if wxRsp.Code == wechat.Success {
+ if _, err := order.Create(); err != nil {
+ return nil, ErrInsertOrder.WithError(err)
+ }
+
+ return &OrderCreateRes{
+ Payment: true,
+ QRCode: wxRsp.Response.CodeUrl,
+ ID: order.OrderNo,
+ }, nil
+ }
+
+ return nil, ErrIssueOrder.WithError(errors.New(wxRsp.Error))
+}
+
+// GetPlatformCert 获取微信平台证书
+func (pay *Wechat) GetPlatformCert() string {
+ if cert, ok := cache.Get("wechat_platform_cert"); ok {
+ return cert.(string)
+ }
+
+ res, err := pay.Client.GetPlatformCerts()
+ if err == nil {
+ // 使用反馈证书中启用时间较晚的
+ var (
+ currentLatest *time.Time
+ currentCert string
+ )
+ for _, cert := range res.Certs {
+ effectiveTime, err := time.Parse("2006-01-02T15:04:05-0700", cert.EffectiveTime)
+ if err != nil {
+ if currentLatest == nil {
+ currentLatest = &effectiveTime
+ currentCert = cert.PublicKey
+ continue
+ }
+ if currentLatest.Before(effectiveTime) {
+ currentLatest = &effectiveTime
+ currentCert = cert.PublicKey
+ }
+ }
+ }
+
+ cache.Set("wechat_platform_cert", currentCert, 3600*10)
+ return currentCert
+ }
+
+ util.Log().Debug("Failed to get Wechat Pay platform certificate: %s", err)
+ return ""
+}
diff --git a/pkg/qq/connect.go b/pkg/qq/connect.go
new file mode 100755
index 0000000..b8cb72f
--- /dev/null
+++ b/pkg/qq/connect.go
@@ -0,0 +1,211 @@
+package qq
+
+import (
+ "crypto/md5"
+ "encoding/json"
+ "errors"
+ "fmt"
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/request"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/gofrs/uuid"
+ "net/url"
+ "strings"
+)
+
+// LoginPage 登陆页面描述
+type LoginPage struct {
+ URL string
+ SecretKey string
+}
+
+// UserCredentials 登陆成功后的凭证
+type UserCredentials struct {
+ OpenID string
+ AccessToken string
+}
+
+// UserInfo 用户信息
+type UserInfo struct {
+ Nick string
+ Avatar string
+}
+
+var (
+ // ErrNotEnabled 未开启登录功能
+ ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "QQ Login is not enabled", nil)
+ // ErrObtainAccessToken 无法获取AccessToken
+ ErrObtainAccessToken = serializer.NewError(serializer.CodeNotSet, "Cannot obtain AccessToken", nil)
+ // ErrObtainOpenID 无法获取OpenID
+ ErrObtainOpenID = serializer.NewError(serializer.CodeNotSet, "Cannot obtain OpenID", nil)
+ //ErrDecodeResponse 无法解析服务端响应
+ ErrDecodeResponse = serializer.NewError(serializer.CodeNotSet, "Cannot parse serverside response", nil)
+)
+
+// NewLoginRequest 新建登录会话
+func NewLoginRequest() (*LoginPage, error) {
+ // 获取相关设定
+ options := model.GetSettingByNames("qq_login", "qq_login_id")
+ if options["qq_login"] == "0" {
+ return nil, ErrNotEnabled
+ }
+
+ // 生成唯一ID
+ u2, err := uuid.NewV4()
+ if err != nil {
+ return nil, err
+ }
+ secret := fmt.Sprintf("%x", md5.Sum(u2.Bytes()))
+
+ // 生成登录地址
+ loginURL, _ := url.Parse("https://graph.qq.com/oauth2.0/authorize?response_type=code")
+ queries := loginURL.Query()
+ queries.Add("client_id", options["qq_login_id"])
+ queries.Add("redirect_uri", getCallbackURL())
+ queries.Add("state", secret)
+ loginURL.RawQuery = queries.Encode()
+
+ return &LoginPage{
+ URL: loginURL.String(),
+ SecretKey: secret,
+ }, nil
+}
+
+func getCallbackURL() string {
+ //return "https://drive.aoaoao.me/Callback/QQ"
+ // 生成回调地址
+ gateway, _ := url.Parse("/login/qq")
+ callback := model.GetSiteURL().ResolveReference(gateway).String()
+
+ return callback
+}
+
+func getAccessTokenURL(code string) string {
+ // 获取相关设定
+ options := model.GetSettingByNames("qq_login_id", "qq_login_key")
+
+ api, _ := url.Parse("https://graph.qq.com/oauth2.0/token?grant_type=authorization_code")
+ queries := api.Query()
+ queries.Add("client_id", options["qq_login_id"])
+ queries.Add("redirect_uri", getCallbackURL())
+ queries.Add("client_secret", options["qq_login_key"])
+ queries.Add("code", code)
+ api.RawQuery = queries.Encode()
+
+ return api.String()
+}
+
+func getUserInfoURL(openid, ak string) string {
+ // 获取相关设定
+ options := model.GetSettingByNames("qq_login_id", "qq_login_key")
+
+ api, _ := url.Parse("https://graph.qq.com/user/get_user_info")
+ queries := api.Query()
+ queries.Add("oauth_consumer_key", options["qq_login_id"])
+ queries.Add("openid", openid)
+ queries.Add("access_token", ak)
+ api.RawQuery = queries.Encode()
+
+ return api.String()
+}
+
+func getResponse(body string) (map[string]interface{}, error) {
+ var res map[string]interface{}
+
+ if !strings.Contains(body, "callback") {
+ return res, nil
+ }
+
+ body = strings.TrimPrefix(body, "callback(")
+ body = strings.TrimSuffix(body, ");\n")
+
+ err := json.Unmarshal([]byte(body), &res)
+
+ return res, err
+}
+
+// Callback 处理回调,返回openid和access key
+func Callback(code string) (*UserCredentials, error) {
+ // 获取相关设定
+ options := model.GetSettingByNames("qq_login")
+ if options["qq_login"] == "0" {
+ return nil, ErrNotEnabled
+ }
+
+ api := getAccessTokenURL(code)
+
+ // 获取AccessToken
+ client := request.NewClient()
+ res := client.Request("GET", api, nil)
+ resp, err := res.GetResponse()
+ if err != nil {
+ return nil, ErrObtainAccessToken.WithError(err)
+ }
+
+ // 如果服务端返回错误
+ errResp, err := getResponse(resp)
+ if msg, ok := errResp["error_description"]; err == nil && ok {
+ return nil, ErrObtainAccessToken.WithError(errors.New(msg.(string)))
+ }
+
+ // 获取AccessToken
+ vals, err := url.ParseQuery(resp)
+ if err != nil {
+ return nil, ErrDecodeResponse.WithError(err)
+ }
+ accessToken := vals.Get("access_token")
+
+ // 用 AccessToken 换取OpenID
+ res = client.Request("GET", "https://graph.qq.com/oauth2.0/me?access_token="+accessToken, nil)
+ resp, err = res.GetResponse()
+ if err != nil {
+ return nil, ErrObtainOpenID.WithError(err)
+ }
+
+ // 解析服务端响应
+ errResp, err = getResponse(resp)
+ if msg, ok := errResp["error_description"]; err == nil && ok {
+ return nil, ErrObtainOpenID.WithError(errors.New(msg.(string)))
+ }
+
+ if openid, ok := errResp["openid"]; ok {
+ return &UserCredentials{
+ OpenID: openid.(string),
+ AccessToken: accessToken,
+ }, nil
+ }
+
+ return nil, ErrDecodeResponse
+}
+
+// GetUserInfo 使用凭证获取用户信息
+func GetUserInfo(credential *UserCredentials) (*UserInfo, error) {
+ api := getUserInfoURL(credential.OpenID, credential.AccessToken)
+
+ // 获取用户信息
+ client := request.NewClient()
+ res := client.Request("GET", api, nil)
+ resp, err := res.GetResponse()
+ if err != nil {
+ return nil, ErrObtainAccessToken.WithError(err)
+ }
+
+ var resSerialized map[string]interface{}
+ if err := json.Unmarshal([]byte(resp), &resSerialized); err != nil {
+ return nil, ErrDecodeResponse.WithError(err)
+ }
+
+ // 如果服务端返回错误
+ if msg, ok := resSerialized["msg"]; ok && msg.(string) != "" {
+ return nil, ErrObtainAccessToken.WithError(errors.New(msg.(string)))
+ }
+
+ if avatar, ok := resSerialized["figureurl_qq_2"]; ok {
+ return &UserInfo{
+ Nick: resSerialized["nickname"].(string),
+ Avatar: avatar.(string),
+ }, nil
+ }
+
+ return nil, ErrDecodeResponse
+}
diff --git a/pkg/recaptcha/recaptcha.go b/pkg/recaptcha/recaptcha.go
index 75354bd..e360898 100644
--- a/pkg/recaptcha/recaptcha.go
+++ b/pkg/recaptcha/recaptcha.go
@@ -67,7 +67,8 @@ type ReCAPTCHA struct {
}
// NewReCAPTCHA new ReCAPTCHA instance if version is set to V2 uses recatpcha v2 API, get your secret from https://www.google.com/recaptcha/admin
-// if version is set to V2 uses recatpcha v2 API, get your secret from https://g.co/recaptcha/v3
+//
+// if version is set to V2 uses recatpcha v2 API, get your secret from https://g.co/recaptcha/v3
func NewReCAPTCHA(ReCAPTCHASecret string, version VERSION, timeout time.Duration) (ReCAPTCHA, error) {
if ReCAPTCHASecret == "" {
return ReCAPTCHA{}, fmt.Errorf("recaptcha secret cannot be blank")
diff --git a/pkg/request/request_test.go b/pkg/request/request_test.go
deleted file mode 100644
index e54831e..0000000
--- a/pkg/request/request_test.go
+++ /dev/null
@@ -1,278 +0,0 @@
-package request
-
-import (
- "context"
- "errors"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "io"
- "io/ioutil"
- "net/http"
- "strings"
- "testing"
- "time"
-
- "github.com/cloudreve/Cloudreve/v3/pkg/auth"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
-)
-
-type ClientMock struct {
- testMock.Mock
-}
-
-func (m ClientMock) Request(method, target string, body io.Reader, opts ...Option) *Response {
- args := m.Called(method, target, body, opts)
- return args.Get(0).(*Response)
-}
-
-func TestWithTimeout(t *testing.T) {
- asserts := assert.New(t)
- options := newDefaultOption()
- WithTimeout(time.Duration(5) * time.Second).apply(options)
- asserts.Equal(time.Duration(5)*time.Second, options.timeout)
-}
-
-func TestWithHeader(t *testing.T) {
- asserts := assert.New(t)
- options := newDefaultOption()
- WithHeader(map[string][]string{"Origin": []string{"123"}}).apply(options)
- asserts.Equal(http.Header{"Origin": []string{"123"}}, options.header)
-}
-
-func TestWithContentLength(t *testing.T) {
- asserts := assert.New(t)
- options := newDefaultOption()
- WithContentLength(10).apply(options)
- asserts.EqualValues(10, options.contentLength)
-}
-
-func TestWithContext(t *testing.T) {
- asserts := assert.New(t)
- options := newDefaultOption()
- WithContext(context.Background()).apply(options)
- asserts.NotNil(options.ctx)
-}
-
-func TestHTTPClient_Request(t *testing.T) {
- asserts := assert.New(t)
- client := NewClient(WithSlaveMeta("test"))
-
- // 正常
- {
- resp := client.Request(
- "POST",
- "/test",
- strings.NewReader(""),
- WithContentLength(0),
- WithEndpoint("http://cloudreveisnotexist.com"),
- WithTimeout(time.Duration(1)*time.Microsecond),
- WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10),
- WithoutHeader([]string{"origin", "origin"}),
- )
- asserts.Error(resp.Err)
- asserts.Nil(resp.Response)
- }
-
- // 正常 带有ctx
- {
- resp := client.Request(
- "GET",
- "http://cloudreveisnotexist.com",
- strings.NewReader(""),
- WithTimeout(time.Duration(1)*time.Microsecond),
- WithCredential(auth.HMACAuth{SecretKey: []byte("123")}, 10),
- WithContext(context.Background()),
- WithoutHeader([]string{"s s", "s s"}),
- WithMasterMeta(),
- )
- asserts.Error(resp.Err)
- asserts.Nil(resp.Response)
- }
-
-}
-
-func TestResponse_GetResponse(t *testing.T) {
- asserts := assert.New(t)
-
- // 直接返回错误
- {
- resp := Response{
- Err: errors.New("error"),
- }
- content, err := resp.GetResponse()
- asserts.Empty(content)
- asserts.Error(err)
- }
-
- // 正常
- {
- resp := Response{
- Response: &http.Response{Body: ioutil.NopCloser(strings.NewReader("123"))},
- }
- content, err := resp.GetResponse()
- asserts.Equal("123", content)
- asserts.NoError(err)
- }
-}
-
-func TestResponse_CheckHTTPResponse(t *testing.T) {
- asserts := assert.New(t)
-
- // 直接返回错误
- {
- resp := Response{
- Err: errors.New("error"),
- }
- res := resp.CheckHTTPResponse(200)
- asserts.Error(res.Err)
- }
-
- // 404错误
- {
- resp := Response{
- Response: &http.Response{StatusCode: 404},
- }
- res := resp.CheckHTTPResponse(200)
- asserts.Error(res.Err)
- }
-
- // 通过
- {
- resp := Response{
- Response: &http.Response{StatusCode: 200},
- }
- res := resp.CheckHTTPResponse(200)
- asserts.NoError(res.Err)
- }
-}
-
-func TestResponse_GetRSCloser(t *testing.T) {
- asserts := assert.New(t)
-
- // 直接返回错误
- {
- resp := Response{
- Err: errors.New("error"),
- }
- res, err := resp.GetRSCloser()
- asserts.Error(err)
- asserts.Nil(res)
- }
-
- // 正常
- {
- resp := Response{
- Response: &http.Response{ContentLength: 3, Body: ioutil.NopCloser(strings.NewReader("123"))},
- }
- res, err := resp.GetRSCloser()
- asserts.NoError(err)
- content, err := ioutil.ReadAll(res)
- asserts.NoError(err)
- asserts.Equal("123", string(content))
- offset, err := res.Seek(0, 0)
- asserts.NoError(err)
- asserts.Equal(int64(0), offset)
- offset, err = res.Seek(0, 2)
- asserts.NoError(err)
- asserts.Equal(int64(3), offset)
- _, err = res.Seek(1, 2)
- asserts.Error(err)
- asserts.NoError(res.Close())
- }
-
-}
-
-func TestResponse_DecodeResponse(t *testing.T) {
- asserts := assert.New(t)
-
- // 直接返回错误
- {
- resp := Response{Err: errors.New("error")}
- response, err := resp.DecodeResponse()
- asserts.Error(err)
- asserts.Nil(response)
- }
-
- // 无法解析响应
- {
- resp := Response{
- Response: &http.Response{
- Body: ioutil.NopCloser(strings.NewReader("test")),
- },
- }
- response, err := resp.DecodeResponse()
- asserts.Error(err)
- asserts.Nil(response)
- }
-
- // 成功
- {
- resp := Response{
- Response: &http.Response{
- Body: ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
- },
- }
- response, err := resp.DecodeResponse()
- asserts.NoError(err)
- asserts.NotNil(response)
- asserts.Equal(0, response.Code)
- }
-}
-
-func TestNopRSCloser_SetFirstFakeChunk(t *testing.T) {
- asserts := assert.New(t)
- rsc := NopRSCloser{
- status: &rscStatus{},
- }
- rsc.SetFirstFakeChunk()
- asserts.True(rsc.status.IgnoreFirst)
-
- rsc.SetContentLength(20)
- asserts.EqualValues(20, rsc.status.Size)
-}
-
-func TestBlackHole(t *testing.T) {
- a := assert.New(t)
- cache.Set("setting_reset_after_upload_failed", "true", 0)
- a.NotPanics(func() {
- BlackHole(strings.NewReader("TestBlackHole"))
- })
-}
-
-func TestHTTPClient_TPSLimit(t *testing.T) {
- a := assert.New(t)
- client := NewClient()
-
- finished := make(chan struct{})
- go func() {
- client.Request(
- "POST",
- "/test",
- strings.NewReader(""),
- WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1),
- )
- close(finished)
- }()
- select {
- case <-finished:
- case <-time.After(10 * time.Second):
- a.Fail("Request should be finished instantly.")
- }
-
- finished = make(chan struct{})
- go func() {
- client.Request(
- "POST",
- "/test",
- strings.NewReader(""),
- WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1),
- )
- close(finished)
- }()
- select {
- case <-finished:
- case <-time.After(2 * time.Second):
- a.Fail("Request should be finished in 1 second.")
- }
-
-}
diff --git a/pkg/request/tpslimiter_test.go b/pkg/request/tpslimiter_test.go
deleted file mode 100644
index daec236..0000000
--- a/pkg/request/tpslimiter_test.go
+++ /dev/null
@@ -1,46 +0,0 @@
-package request
-
-import (
- "context"
- "github.com/stretchr/testify/assert"
- "testing"
- "time"
-)
-
-func TestLimit(t *testing.T) {
- a := assert.New(t)
- l := NewTPSLimiter()
- finished := make(chan struct{})
- go func() {
- l.Limit(context.Background(), "token", 1, 1)
- close(finished)
- }()
- select {
- case <-finished:
- case <-time.After(10 * time.Second):
- a.Fail("Limit should be finished instantly.")
- }
-
- finished = make(chan struct{})
- go func() {
- l.Limit(context.Background(), "token", 1, 1)
- close(finished)
- }()
- select {
- case <-finished:
- case <-time.After(2 * time.Second):
- a.Fail("Limit should be finished in 1 second.")
- }
-
- finished = make(chan struct{})
- go func() {
- l.Limit(context.Background(), "token", 10, 1)
- close(finished)
- }()
- select {
- case <-finished:
- case <-time.After(1 * time.Second):
- a.Fail("Limit should be finished instantly.")
- }
-
-}
diff --git a/pkg/serializer/aria2_test.go b/pkg/serializer/aria2_test.go
deleted file mode 100644
index 1f3ca61..0000000
--- a/pkg/serializer/aria2_test.go
+++ /dev/null
@@ -1,95 +0,0 @@
-package serializer
-
-import (
- "testing"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestBuildFinishedListResponse(t *testing.T) {
- asserts := assert.New(t)
- tasks := []model.Download{
- {
- StatusInfo: rpc.StatusInfo{
- Files: []rpc.FileInfo{
- {
- Path: "/file/name.txt",
- },
- },
- },
- Task: &model.Task{
- Model: gorm.Model{},
- Error: "error",
- },
- },
- {
- StatusInfo: rpc.StatusInfo{
- Files: []rpc.FileInfo{
- {
- Path: "/file/name1.txt",
- },
- {
- Path: "/file/name2.txt",
- },
- },
- },
- },
- }
- tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt"
- res := BuildFinishedListResponse(tasks).Data.([]FinishedListResponse)
- asserts.Len(res, 2)
- asserts.Equal("name.txt", res[1].Name)
- asserts.Equal("name.txt", res[0].Name)
- asserts.Equal("name.txt", res[0].Files[0].Path)
- asserts.Equal("name1.txt", res[1].Files[0].Path)
- asserts.Equal("name2.txt", res[1].Files[1].Path)
- asserts.EqualValues(0, res[0].TaskStatus)
- asserts.Equal("error", res[0].TaskError)
-}
-
-func TestBuildDownloadingResponse(t *testing.T) {
- asserts := assert.New(t)
- cache.Set("setting_aria2_interval", "10", 0)
- tasks := []model.Download{
- {
- StatusInfo: rpc.StatusInfo{
- Files: []rpc.FileInfo{
- {
- Path: "/file/name.txt",
- },
- },
- },
- Task: &model.Task{
- Model: gorm.Model{},
- Error: "error",
- },
- },
- {
- StatusInfo: rpc.StatusInfo{
- Files: []rpc.FileInfo{
- {
- Path: "/file/name1.txt",
- },
- {
- Path: "/file/name2.txt",
- },
- },
- },
- },
- }
- tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt"
- tasks[1].ID = 1
-
- res := BuildDownloadingResponse(tasks, map[uint]int{1: 5}).Data.([]DownloadListResponse)
- asserts.Len(res, 2)
- asserts.Equal("name1.txt", res[1].Name)
- asserts.Equal(5, res[1].UpdateInterval)
- asserts.Equal("name.txt", res[0].Name)
- asserts.Equal("name.txt", res[0].Info.Files[0].Path)
- asserts.Equal("name1.txt", res[1].Info.Files[0].Path)
- asserts.Equal("name2.txt", res[1].Info.Files[1].Path)
-}
diff --git a/pkg/serializer/auth_test.go b/pkg/serializer/auth_test.go
deleted file mode 100644
index 96b6b9b..0000000
--- a/pkg/serializer/auth_test.go
+++ /dev/null
@@ -1,13 +0,0 @@
-package serializer
-
-import (
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestNewRequestSignString(t *testing.T) {
- asserts := assert.New(t)
-
- sign := NewRequestSignString("1", "2", "3")
- asserts.NotEmpty(sign)
-}
diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go
index 326c0d8..7ddbc59 100644
--- a/pkg/serializer/error.go
+++ b/pkg/serializer/error.go
@@ -78,6 +78,8 @@ const (
CodeInvalidChunkIndex = 40012
// CodeInvalidContentLength 无效的正文长度
CodeInvalidContentLength = 40013
+ // CodePhoneRequired 未绑定手机
+ CodePhoneRequired = 40010
// CodeBatchSourceSize 超出批量获取外链限制
CodeBatchSourceSize = 40014
// CodeBatchAria2Size 超出最大 Aria2 任务数量限制
@@ -112,6 +114,8 @@ const (
CodeInvalidTempLink = 40029
// CodeTempLinkExpired 临时链接过期
CodeTempLinkExpired = 40030
+ // CodeEmailProviderBaned 邮箱后缀被禁用
+ CodeEmailProviderBaned = 40031
// CodeEmailExisted 邮箱已被使用
CodeEmailExisted = 40032
// CodeEmailSent 邮箱已重新发送
@@ -192,6 +196,8 @@ const (
CodeDisabledSharePreview = 40070
// 签名无效
CodeInvalidSign = 40071
+ // 管理员无法购买用户组
+ CodeFulfillAdminGroup = 40072
// CodeDBError 数据库操作失败
CodeDBError = 50001
// CodeEncryptError 加密失败
diff --git a/pkg/serializer/error_test.go b/pkg/serializer/error_test.go
deleted file mode 100644
index d02fd5d..0000000
--- a/pkg/serializer/error_test.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package serializer
-
-import (
- "errors"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestNewError(t *testing.T) {
- a := assert.New(t)
- err := NewError(400, "Bad Request", errors.New("error"))
- a.Error(err)
- a.EqualValues(400, err.Code)
-
- err.WithError(errors.New("error2"))
- a.Equal("error2", err.RawError.Error())
- a.Equal("Bad Request", err.Error())
-
- resp := &Response{
- Code: 400,
- Msg: "Bad Request",
- Error: "error",
- }
- err = NewErrorFromResponse(resp)
- a.Error(err)
-}
-
-func TestDBErr(t *testing.T) {
- a := assert.New(t)
- resp := DBErr("", nil)
- a.NotEmpty(resp.Msg)
-
- resp = ParamErr("", nil)
- a.NotEmpty(resp.Msg)
-}
-
-func TestErr(t *testing.T) {
- a := assert.New(t)
- err := NewError(400, "Bad Request", errors.New("error"))
- resp := Err(400, "", err)
- a.Equal("Bad Request", resp.Msg)
-}
diff --git a/pkg/serializer/explorer_test.go b/pkg/serializer/explorer_test.go
deleted file mode 100644
index 00c9efc..0000000
--- a/pkg/serializer/explorer_test.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package serializer
-
-import (
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestBuildObjectList(t *testing.T) {
- a := assert.New(t)
- res := BuildObjectList(1, []Object{{}, {}}, &model.Policy{})
- a.NotEmpty(res.Parent)
- a.NotNil(res.Policy)
- a.Len(res.Objects, 2)
-}
diff --git a/pkg/serializer/response_test.go b/pkg/serializer/response_test.go
deleted file mode 100644
index 70c8899..0000000
--- a/pkg/serializer/response_test.go
+++ /dev/null
@@ -1,33 +0,0 @@
-package serializer
-
-import (
- "encoding/json"
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestNewResponseWithGobData(t *testing.T) {
- a := assert.New(t)
- type args struct {
- data interface{}
- }
-
- res := NewResponseWithGobData(args{})
- a.Equal(CodeInternalSetting, res.Code)
-
- res = NewResponseWithGobData("TestNewResponseWithGobData")
- a.Equal(0, res.Code)
- a.NotEmpty(res.Data)
-}
-
-func TestResponse_GobDecode(t *testing.T) {
- a := assert.New(t)
- res := NewResponseWithGobData("TestResponse_GobDecode")
- jsonContent, err := json.Marshal(res)
- a.NoError(err)
- resDecoded := &Response{}
- a.NoError(json.Unmarshal(jsonContent, resDecoded))
- var target string
- resDecoded.GobDecode(&target)
- a.Equal("TestResponse_GobDecode", target)
-}
diff --git a/pkg/serializer/setting.go b/pkg/serializer/setting.go
index 7e4ce00..df6fc2c 100644
--- a/pkg/serializer/setting.go
+++ b/pkg/serializer/setting.go
@@ -1,8 +1,9 @@
package serializer
import (
- model "github.com/cloudreve/Cloudreve/v3/models"
"time"
+
+ model "github.com/cloudreve/Cloudreve/v3/models"
)
// SiteConfig 站点全局设置序列
@@ -12,18 +13,25 @@ type SiteConfig struct {
RegCaptcha bool `json:"regCaptcha"`
ForgetCaptcha bool `json:"forgetCaptcha"`
EmailActive bool `json:"emailActive"`
+ QQLogin bool `json:"QQLogin"`
Themes string `json:"themes"`
DefaultTheme string `json:"defaultTheme"`
+ ScoreEnabled bool `json:"score_enabled"`
+ ShareScoreRate string `json:"share_score_rate"`
HomepageViewMethod string `json:"home_view_method"`
ShareViewMethod string `json:"share_view_method"`
Authn bool `json:"authn"`
User User `json:"user"`
ReCaptchaKey string `json:"captcha_ReCaptchaKey"`
+ SiteNotice string `json:"site_notice"`
CaptchaType string `json:"captcha_type"`
TCaptchaCaptchaAppId string `json:"tcaptcha_captcha_app_id"`
RegisterEnabled bool `json:"registerEnabled"`
+ ReportEnabled bool `json:"report_enabled"`
AppPromotion bool `json:"app_promotion"`
WopiExts []string `json:"wopi_exts"`
+ AppFeedbackLink string `json:"app_feedback"`
+ AppForumLink string `json:"app_forum"`
}
type task struct {
@@ -75,18 +83,31 @@ func BuildSiteConfig(settings map[string]string, user *model.User, wopiExts []st
RegCaptcha: model.IsTrueVal(checkSettingValue(settings, "reg_captcha")),
ForgetCaptcha: model.IsTrueVal(checkSettingValue(settings, "forget_captcha")),
EmailActive: model.IsTrueVal(checkSettingValue(settings, "email_active")),
+ QQLogin: model.IsTrueVal(checkSettingValue(settings, "qq_login")),
Themes: checkSettingValue(settings, "themes"),
DefaultTheme: checkSettingValue(settings, "defaultTheme"),
+ ScoreEnabled: model.IsTrueVal(checkSettingValue(settings, "score_enabled")),
+ ShareScoreRate: checkSettingValue(settings, "share_score_rate"),
HomepageViewMethod: checkSettingValue(settings, "home_view_method"),
ShareViewMethod: checkSettingValue(settings, "share_view_method"),
Authn: model.IsTrueVal(checkSettingValue(settings, "authn_enabled")),
User: userRes,
+ SiteNotice: checkSettingValue(settings, "siteNotice"),
ReCaptchaKey: checkSettingValue(settings, "captcha_ReCaptchaKey"),
CaptchaType: checkSettingValue(settings, "captcha_type"),
TCaptchaCaptchaAppId: checkSettingValue(settings, "captcha_TCaptcha_CaptchaAppId"),
RegisterEnabled: model.IsTrueVal(checkSettingValue(settings, "register_enabled")),
+ ReportEnabled: model.IsTrueVal(checkSettingValue(settings, "report_enabled")),
AppPromotion: model.IsTrueVal(checkSettingValue(settings, "show_app_promotion")),
+ AppFeedbackLink: checkSettingValue(settings, "app_feedback_link"),
+ AppForumLink: checkSettingValue(settings, "app_forum_link"),
WopiExts: wopiExts,
}}
return res
}
+
+// VolResponse VOL query response
+type VolResponse struct {
+ Signature string `json:"signature"`
+ Content string `json:"content"`
+}
diff --git a/pkg/serializer/setting_test.go b/pkg/serializer/setting_test.go
deleted file mode 100644
index 680edb6..0000000
--- a/pkg/serializer/setting_test.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package serializer
-
-import (
- "testing"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestCheckSettingValue(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.Equal("", checkSettingValue(map[string]string{}, "key"))
- asserts.Equal("123", checkSettingValue(map[string]string{"key": "123"}, "key"))
-}
-
-func TestBuildSiteConfig(t *testing.T) {
- asserts := assert.New(t)
-
- res := BuildSiteConfig(map[string]string{"not exist": ""}, &model.User{}, nil)
- asserts.Equal("", res.Data.(SiteConfig).SiteName)
-
- res = BuildSiteConfig(map[string]string{"siteName": "123"}, &model.User{}, nil)
- asserts.Equal("123", res.Data.(SiteConfig).SiteName)
-
- // 非空用户
- res = BuildSiteConfig(map[string]string{"qq_login": "1"}, &model.User{
- Model: gorm.Model{
- ID: 5,
- },
- }, nil)
- asserts.Len(res.Data.(SiteConfig).User.ID, 4)
-}
-
-func TestBuildTaskList(t *testing.T) {
- asserts := assert.New(t)
- tasks := []model.Task{{}}
-
- res := BuildTaskList(tasks, 1)
- asserts.NotNil(res)
-}
diff --git a/pkg/serializer/share.go b/pkg/serializer/share.go
index 94da4c6..ade604d 100644
--- a/pkg/serializer/share.go
+++ b/pkg/serializer/share.go
@@ -12,6 +12,7 @@ type Share struct {
Key string `json:"key"`
Locked bool `json:"locked"`
IsDir bool `json:"is_dir"`
+ Score int `json:"score"`
CreateDate time.Time `json:"create_date,omitempty"`
Downloads int `json:"downloads"`
Views int `json:"views"`
@@ -36,6 +37,7 @@ type shareSource struct {
type myShareItem struct {
Key string `json:"key"`
IsDir bool `json:"is_dir"`
+ Score int `json:"score"`
Password string `json:"password"`
CreateDate time.Time `json:"create_date,omitempty"`
Downloads int `json:"downloads"`
@@ -54,6 +56,7 @@ func BuildShareList(shares []model.Share, total int) Response {
item := myShareItem{
Key: hashid.HashID(shares[i].ID, hashid.ShareID),
IsDir: shares[i].IsDir,
+ Score: shares[i].Score,
Password: shares[i].Password,
CreateDate: shares[i].CreatedAt,
Downloads: shares[i].Downloads,
@@ -99,6 +102,7 @@ func BuildShareResponse(share *model.Share, unlocked bool) Share {
Nick: creator.Nick,
GroupName: creator.Group.Name,
},
+ Score: share.Score,
CreateDate: share.CreatedAt,
}
diff --git a/pkg/serializer/share_test.go b/pkg/serializer/share_test.go
deleted file mode 100644
index 72feb0c..0000000
--- a/pkg/serializer/share_test.go
+++ /dev/null
@@ -1,85 +0,0 @@
-package serializer
-
-import (
- "testing"
- "time"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestBuildShareList(t *testing.T) {
- asserts := assert.New(t)
- timeNow := time.Now()
-
- shares := []model.Share{
- {
- Expires: &timeNow,
- File: model.File{
- Model: gorm.Model{ID: 1},
- },
- },
- {
- Folder: model.Folder{
- Model: gorm.Model{ID: 1},
- },
- },
- }
-
- res := BuildShareList(shares, 2)
- asserts.Equal(0, res.Code)
-}
-
-func TestBuildShareResponse(t *testing.T) {
- asserts := assert.New(t)
-
- // 未解锁
- {
- share := &model.Share{
- User: model.User{Model: gorm.Model{ID: 1}},
- Downloads: 1,
- }
- res := BuildShareResponse(share, false)
- asserts.EqualValues(0, res.Downloads)
- asserts.True(res.Locked)
- asserts.NotNil(res.Creator)
- }
-
- // 已解锁,非目录
- {
- expires := time.Now().Add(time.Duration(10) * time.Second)
- share := &model.Share{
- User: model.User{Model: gorm.Model{ID: 1}},
- Downloads: 1,
- Expires: &expires,
- File: model.File{
- Model: gorm.Model{ID: 1},
- },
- }
- res := BuildShareResponse(share, true)
- asserts.EqualValues(1, res.Downloads)
- asserts.False(res.Locked)
- asserts.NotEmpty(res.Expire)
- asserts.NotNil(res.Creator)
- }
-
- // 已解锁,是目录
- {
- expires := time.Now().Add(time.Duration(10) * time.Second)
- share := &model.Share{
- User: model.User{Model: gorm.Model{ID: 1}},
- Downloads: 1,
- Expires: &expires,
- Folder: model.Folder{
- Model: gorm.Model{ID: 1},
- },
- IsDir: true,
- }
- res := BuildShareResponse(share, true)
- asserts.EqualValues(1, res.Downloads)
- asserts.False(res.Locked)
- asserts.NotEmpty(res.Expire)
- asserts.NotNil(res.Creator)
- }
-}
diff --git a/pkg/serializer/slave_test.go b/pkg/serializer/slave_test.go
deleted file mode 100644
index 46b5d2d..0000000
--- a/pkg/serializer/slave_test.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package serializer
-
-import (
- "testing"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/stretchr/testify/assert"
-)
-
-func TestSlaveTransferReq_Hash(t *testing.T) {
- a := assert.New(t)
- s1 := &SlaveTransferReq{
- Src: "1",
- Policy: &model.Policy{},
- }
- s2 := &SlaveTransferReq{
- Src: "2",
- Policy: &model.Policy{},
- }
- a.NotEqual(s1.Hash("1"), s2.Hash("1"))
-}
diff --git a/pkg/serializer/user.go b/pkg/serializer/user.go
index 142f424..e5f67d8 100644
--- a/pkg/serializer/user.go
+++ b/pkg/serializer/user.go
@@ -2,11 +2,11 @@ package serializer
import (
"fmt"
+ "time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/duo-labs/webauthn/webauthn"
- "time"
)
// CheckLogin 检查登录
@@ -17,6 +17,14 @@ func CheckLogin() Response {
}
}
+// PhoneRequired 需要绑定手机
+func PhoneRequired() Response {
+ return Response{
+ Code: CodePhoneRequired,
+ Msg: "此功能需要绑定手机后使用",
+ }
+}
+
// User 用户序列化器
type User struct {
ID string `json:"id"`
@@ -26,6 +34,7 @@ type User struct {
Avatar string `json:"avatar"`
CreatedAt time.Time `json:"created_at"`
PreferredTheme string `json:"preferred_theme"`
+ Score int `json:"score"`
Anonymous bool `json:"anonymous"`
Group group `json:"group"`
Tags []tag `json:"tags"`
@@ -37,10 +46,13 @@ type group struct {
AllowShare bool `json:"allowShare"`
AllowRemoteDownload bool `json:"allowRemoteDownload"`
AllowArchiveDownload bool `json:"allowArchiveDownload"`
+ ShareFreeEnabled bool `json:"shareFree"`
ShareDownload bool `json:"shareDownload"`
CompressEnabled bool `json:"compress"`
WebDAVEnabled bool `json:"webdav"`
+ RelocateEnabled bool `json:"relocate"`
SourceBatchSize int `json:"sourceBatch"`
+ SelectNode bool `json:"selectNode"`
AdvanceDelete bool `json:"advanceDelete"`
AllowWebDAVProxy bool `json:"allowWebDAVProxy"`
}
@@ -91,6 +103,7 @@ func BuildUser(user model.User) User {
Avatar: user.Avatar,
CreatedAt: user.CreatedAt,
PreferredTheme: user.OptionsSerialized.PreferredTheme,
+ Score: user.Score,
Anonymous: user.IsAnonymous(),
Group: group{
ID: user.GroupID,
@@ -98,11 +111,14 @@ func BuildUser(user model.User) User {
AllowShare: user.Group.ShareEnabled,
AllowRemoteDownload: user.Group.OptionsSerialized.Aria2,
AllowArchiveDownload: user.Group.OptionsSerialized.ArchiveDownload,
+ ShareFreeEnabled: user.Group.OptionsSerialized.ShareFree,
ShareDownload: user.Group.OptionsSerialized.ShareDownload,
CompressEnabled: user.Group.OptionsSerialized.ArchiveTask,
WebDAVEnabled: user.Group.WebDAVEnabled,
AllowWebDAVProxy: user.Group.OptionsSerialized.WebDAVProxy,
+ RelocateEnabled: user.Group.OptionsSerialized.Relocate,
SourceBatchSize: user.Group.OptionsSerialized.SourceBatchSize,
+ SelectNode: user.Group.OptionsSerialized.SelectNode,
AdvanceDelete: user.Group.OptionsSerialized.AdvanceDelete,
},
Tags: buildTagRes(tags),
@@ -118,7 +134,7 @@ func BuildUserResponse(user model.User) Response {
// BuildUserStorageResponse 序列化用户存储概况响应
func BuildUserStorageResponse(user model.User) Response {
- total := user.Group.MaxStorage
+ total := user.Group.MaxStorage + user.GetAvailablePackSize()
storageResp := storage{
Used: user.Storage,
Free: total - user.Storage,
diff --git a/pkg/serializer/user_test.go b/pkg/serializer/user_test.go
deleted file mode 100644
index 2942186..0000000
--- a/pkg/serializer/user_test.go
+++ /dev/null
@@ -1,116 +0,0 @@
-package serializer
-
-import (
- "database/sql"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/duo-labs/webauthn/webauthn"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-var mock sqlmock.Sqlmock
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- defer db.Close()
- m.Run()
-}
-
-func TestBuildUser(t *testing.T) {
- asserts := assert.New(t)
- user := model.User{}
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- res := BuildUser(user)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotNil(res)
-
-}
-
-func TestBuildUserResponse(t *testing.T) {
- asserts := assert.New(t)
- user := model.User{}
- res := BuildUserResponse(user)
- asserts.NotNil(res)
-}
-
-func TestBuildUserStorageResponse(t *testing.T) {
- asserts := assert.New(t)
- cache.Set("pack_size_0", uint64(0), 0)
-
- {
- user := model.User{
- Storage: 0,
- Group: model.Group{MaxStorage: 10},
- }
- res := BuildUserStorageResponse(user)
- asserts.Equal(uint64(0), res.Data.(storage).Used)
- asserts.Equal(uint64(10), res.Data.(storage).Total)
- asserts.Equal(uint64(10), res.Data.(storage).Free)
- }
- {
- user := model.User{
- Storage: 6,
- Group: model.Group{MaxStorage: 10},
- }
- res := BuildUserStorageResponse(user)
- asserts.Equal(uint64(6), res.Data.(storage).Used)
- asserts.Equal(uint64(10), res.Data.(storage).Total)
- asserts.Equal(uint64(4), res.Data.(storage).Free)
- }
- {
- user := model.User{
- Storage: 20,
- Group: model.Group{MaxStorage: 10},
- }
- res := BuildUserStorageResponse(user)
- asserts.Equal(uint64(20), res.Data.(storage).Used)
- asserts.Equal(uint64(10), res.Data.(storage).Total)
- asserts.Equal(uint64(0), res.Data.(storage).Free)
- }
- {
- user := model.User{
- Storage: 6,
- Group: model.Group{MaxStorage: 10},
- }
- res := BuildUserStorageResponse(user)
- asserts.Equal(uint64(6), res.Data.(storage).Used)
- asserts.Equal(uint64(10), res.Data.(storage).Total)
- asserts.Equal(uint64(4), res.Data.(storage).Free)
- }
-}
-
-func TestBuildTagRes(t *testing.T) {
- asserts := assert.New(t)
- tags := []model.Tag{
- {
- Type: 0,
- Expression: "exp",
- },
- {
- Type: 1,
- Expression: "exp",
- },
- }
- res := buildTagRes(tags)
- asserts.Len(res, 2)
- asserts.Equal("", res[0].Expression)
- asserts.Equal("exp", res[1].Expression)
-}
-
-func TestBuildWebAuthnList(t *testing.T) {
- asserts := assert.New(t)
- credentials := []webauthn.Credential{{}}
- res := BuildWebAuthnList(credentials)
- asserts.Len(res, 1)
-}
diff --git a/pkg/serializer/vas.go b/pkg/serializer/vas.go
new file mode 100755
index 0000000..2da604a
--- /dev/null
+++ b/pkg/serializer/vas.go
@@ -0,0 +1,158 @@
+package serializer
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/hashid"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
+ "time"
+)
+
+type quota struct {
+ Base uint64 `json:"base"`
+ Pack uint64 `json:"pack"`
+ Used uint64 `json:"used"`
+ Total uint64 `json:"total"`
+ Packs []storagePacks `json:"packs"`
+}
+
+type storagePacks struct {
+ Name string `json:"name"`
+ Size uint64 `json:"size"`
+ ActivateDate time.Time `json:"activate_date"`
+ Expiration int `json:"expiration"`
+ ExpirationDate time.Time `json:"expiration_date"`
+}
+
+// MountedFolders 已挂载的目录
+type MountedFolders struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ PolicyName string `json:"policy_name"`
+}
+
+type policyOptions struct {
+ Name string `json:"name"`
+ ID string `json:"id"`
+}
+
+type nodeOptions struct {
+ Name string `json:"name"`
+ ID uint `json:"id"`
+}
+
+// BuildPolicySettingRes 构建存储策略选项选择
+func BuildPolicySettingRes(policies []model.Policy) Response {
+ options := make([]policyOptions, 0, len(policies))
+ for _, policy := range policies {
+ options = append(options, policyOptions{
+ Name: policy.Name,
+ ID: hashid.HashID(policy.ID, hashid.PolicyID),
+ })
+ }
+
+ return Response{
+ Data: options,
+ }
+}
+
+// BuildMountedFolderRes 构建已挂载目录响应,list为当前用户可用存储策略ID
+func BuildMountedFolderRes(folders []model.Folder, list []uint) []MountedFolders {
+ res := make([]MountedFolders, 0, len(folders))
+ for _, folder := range folders {
+ single := MountedFolders{
+ ID: hashid.HashID(folder.ID, hashid.FolderID),
+ Name: folder.Name,
+ PolicyName: "[Invalid Policy]",
+ }
+ if policy, err := model.GetPolicyByID(folder.PolicyID); err == nil && util.ContainsUint(list, policy.ID) {
+ single.PolicyName = policy.Name
+ }
+
+ res = append(res, single)
+ }
+
+ return res
+}
+
+// BuildUserQuotaResponse 序列化用户存储配额概况响应
+func BuildUserQuotaResponse(user *model.User, packs []model.StoragePack) Response {
+ packSize := user.GetAvailablePackSize()
+ res := quota{
+ Base: user.Group.MaxStorage,
+ Pack: packSize,
+ Used: user.Storage,
+ Total: packSize + user.Group.MaxStorage,
+ Packs: make([]storagePacks, 0, len(packs)),
+ }
+ for _, pack := range packs {
+ res.Packs = append(res.Packs, storagePacks{
+ Name: pack.Name,
+ Size: pack.Size,
+ ActivateDate: *pack.ActiveTime,
+ Expiration: int(pack.ExpiredTime.Sub(*pack.ActiveTime).Seconds()),
+ ExpirationDate: *pack.ExpiredTime,
+ })
+ }
+
+ return Response{
+ Data: res,
+ }
+}
+
+// PackProduct 容量包商品
+type PackProduct struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Size uint64 `json:"size"`
+ Time int64 `json:"time"`
+ Price int `json:"price"`
+ Score int `json:"score"`
+}
+
+// GroupProducts 用户组商品
+type GroupProducts struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ GroupID uint `json:"group_id"`
+ Time int64 `json:"time"`
+ Price int `json:"price"`
+ Score int `json:"score"`
+ Des []string `json:"des"`
+ Highlight bool `json:"highlight"`
+}
+
+// BuildProductResponse 构建增值服务商品响应
+func BuildProductResponse(groups []GroupProducts, packs []PackProduct,
+ wechat, alipay, payjs, custom bool, customName string, scorePrice int) Response {
+ // 隐藏响应中的用户组ID
+ for i := 0; i < len(groups); i++ {
+ groups[i].GroupID = 0
+ }
+ return Response{
+ Data: map[string]interface{}{
+ "packs": packs,
+ "groups": groups,
+ "alipay": alipay,
+ "wechat": wechat,
+ "payjs": payjs,
+ "custom": custom,
+ "custom_name": customName,
+ "score_price": scorePrice,
+ },
+ }
+}
+
+// BuildNodeOptionRes 构建可用节点列表响应
+func BuildNodeOptionRes(nodes []*model.Node) Response {
+ options := make([]nodeOptions, 0, len(nodes))
+ for _, node := range nodes {
+ options = append(options, nodeOptions{
+ Name: node.Name,
+ ID: node.ID,
+ })
+ }
+
+ return Response{
+ Data: options,
+ }
+}
diff --git a/pkg/task/compress.go b/pkg/task/compress.go
index 5e20a36..7f5025f 100644
--- a/pkg/task/compress.go
+++ b/pkg/task/compress.go
@@ -122,7 +122,7 @@ func (job *CompressTask) Do() {
job.zipPath = zipFilePath
zipFile.Close()
- util.Log().Debug("Compressed file saved to %q, start uploading it...", zipFilePath)
+ util.Log().Debug("Compressed file saved to %q, start uploading it...", zipFile)
job.TaskModel.SetProgress(TransferringProgress)
// 上传文件
@@ -155,7 +155,7 @@ func NewCompressTask(user *model.User, dst string, dirs, files []uint) (Job, err
return newTask, nil
}
-// NewCompressTaskFromModel 从数据库记录中恢复压缩任务
+// NewRelocateTaskFromModel 从数据库记录中恢复迁移任务
func NewCompressTaskFromModel(task *model.Task) (Job, error) {
user, err := model.GetActiveUserByID(task.UserID)
if err != nil {
diff --git a/pkg/task/compress_test.go b/pkg/task/compress_test.go
deleted file mode 100644
index 34b282d..0000000
--- a/pkg/task/compress_test.go
+++ /dev/null
@@ -1,197 +0,0 @@
-package task
-
-import (
- "errors"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestCompressTask_Props(t *testing.T) {
- asserts := assert.New(t)
- task := &CompressTask{
- User: &model.User{},
- }
- asserts.NotEmpty(task.Props())
- asserts.Equal(CompressTaskType, task.Type())
- asserts.EqualValues(0, task.Creator())
- asserts.Nil(task.Model())
-}
-
-func TestCompressTask_SetStatus(t *testing.T) {
- asserts := assert.New(t)
- task := &CompressTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task.SetStatus(3)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestCompressTask_SetError(t *testing.T) {
- asserts := assert.New(t)
- task := &CompressTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- zipPath: "test/TestCompressTask_SetError",
- }
- zipFile, _ := util.CreatNestedFile("test/TestCompressTask_SetError")
- zipFile.Close()
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- task.SetErrorMsg("error")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.False(util.Exists("test/TestCompressTask_SetError"))
- asserts.Equal("error", task.GetError().Msg)
-}
-
-func TestCompressTask_Do(t *testing.T) {
- asserts := assert.New(t)
- task := &CompressTask{
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
-
- // 无法创建文件系统
- {
- task.User = &model.User{
- Policy: model.Policy{
- Type: "unknown",
- },
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
- 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotEmpty(task.GetError().Msg)
- }
-
- // 压缩出错
- {
- task.User = &model.User{
- Policy: model.Policy{
- Type: "mock",
- },
- }
- task.TaskProps.Dirs = []uint{1}
- // 更新进度
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
- 1))
- mock.ExpectCommit()
- // 查找目录
- mock.ExpectQuery("SELECT(.+)").WillReturnError(errors.New("error"))
- // 更新错误
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
- 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotEmpty(task.GetError().Msg)
- }
-
- // 上传出错
- {
- task.User = &model.User{
- Policy: model.Policy{
- Type: "mock",
- MaxSize: 1,
- },
- }
- task.TaskProps.Dirs = []uint{1}
- cache.Set("setting_temp_path", "test", 0)
- // 更新进度
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
- 1))
- mock.ExpectCommit()
- // 查找目录
- mock.ExpectQuery("SELECT(.+)folders").
- WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- // 查找文件
- mock.ExpectQuery("SELECT(.+)files").
- WillReturnRows(sqlmock.NewRows([]string{"id"}))
- // 查找子文件
- mock.ExpectQuery("SELECT(.+)files").
- WillReturnRows(sqlmock.NewRows([]string{"id"}))
- // 查找子目录
- mock.ExpectQuery("SELECT(.+)folders").
- WillReturnRows(sqlmock.NewRows([]string{"id"}))
- // 更新错误
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
- 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotEmpty(task.GetError().Msg)
- asserts.True(util.IsEmpty(util.RelativePath("test/compress")))
- }
-}
-
-func TestNewCompressTask(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- job, err := NewCompressTask(&model.User{}, "/", []uint{12}, []uint{})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotNil(job)
- asserts.NoError(err)
- }
-
- // 失败
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- job, err := NewCompressTask(&model.User{}, "/", []uint{12}, []uint{})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(job)
- asserts.Error(err)
- }
-}
-
-func TestNewCompressTaskFromModel(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewCompressTaskFromModel(&model.Task{Props: "{}"})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.NotNil(job)
- }
-
- // JSON解析失败
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewCompressTaskFromModel(&model.Task{Props: ""})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Nil(job)
- }
-}
diff --git a/pkg/task/decompress_test.go b/pkg/task/decompress_test.go
deleted file mode 100644
index 75b7cfe..0000000
--- a/pkg/task/decompress_test.go
+++ /dev/null
@@ -1,140 +0,0 @@
-package task
-
-import (
- "errors"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestDecompressTask_Props(t *testing.T) {
- asserts := assert.New(t)
- task := &DecompressTask{
- User: &model.User{},
- }
- asserts.NotEmpty(task.Props())
- asserts.Equal(DecompressTaskType, task.Type())
- asserts.EqualValues(0, task.Creator())
- asserts.Nil(task.Model())
-}
-
-func TestDecompressTask_SetStatus(t *testing.T) {
- asserts := assert.New(t)
- task := &DecompressTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task.SetStatus(3)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestDecompressTask_SetError(t *testing.T) {
- asserts := assert.New(t)
- task := &DecompressTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- zipPath: "test/TestCompressTask_SetError",
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- task.SetErrorMsg("error", nil)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal("error", task.GetError().Msg)
-}
-
-func TestDecompressTask_Do(t *testing.T) {
- asserts := assert.New(t)
- task := &DecompressTask{
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
-
- // 无法创建文件系统
- {
- task.User = &model.User{
- Policy: model.Policy{
- Type: "unknown",
- },
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
- 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotEmpty(task.GetError().Msg)
- }
-
- // 压缩文件不存在
- {
- task.User = &model.User{
- Policy: model.Policy{
- Type: "mock",
- },
- }
- task.TaskProps.Src = "test"
- task.Do()
- asserts.NotEmpty(task.GetError().Msg)
- }
-}
-
-func TestNewDecompressTask(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- job, err := NewDecompressTask(&model.User{}, "/", "/", "utf-8")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotNil(job)
- asserts.NoError(err)
- }
-
- // 失败
- {
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- job, err := NewDecompressTask(&model.User{}, "/", "/", "utf-8")
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(job)
- asserts.Error(err)
- }
-}
-
-func TestNewDecompressTaskFromModel(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewDecompressTaskFromModel(&model.Task{Props: "{}"})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.NotNil(job)
- }
-
- // JSON解析失败
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewDecompressTaskFromModel(&model.Task{Props: ""})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Nil(job)
- }
-}
diff --git a/pkg/task/import.go b/pkg/task/import.go
index 607b4d1..2b5134b 100644
--- a/pkg/task/import.go
+++ b/pkg/task/import.go
@@ -86,7 +86,6 @@ func (job *ImportTask) Do() {
}
// 创建文件系统
- job.User.Policy = policy
fs, err := filesystem.NewFileSystem(job.User)
if err != nil {
job.SetErrorMsg(err.Error(), nil)
diff --git a/pkg/task/import_test.go b/pkg/task/import_test.go
deleted file mode 100644
index a17123d..0000000
--- a/pkg/task/import_test.go
+++ /dev/null
@@ -1,246 +0,0 @@
-package task
-
-import (
- "errors"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestImportTask_Props(t *testing.T) {
- asserts := assert.New(t)
- task := &ImportTask{
- User: &model.User{},
- }
- asserts.NotEmpty(task.Props())
- asserts.Equal(ImportTaskType, task.Type())
- asserts.EqualValues(0, task.Creator())
- asserts.Nil(task.Model())
-}
-
-func TestImportTask_SetStatus(t *testing.T) {
- asserts := assert.New(t)
- task := &ImportTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task.SetStatus(3)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestImportTask_SetError(t *testing.T) {
- asserts := assert.New(t)
- task := &ImportTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- task.SetErrorMsg("error", nil)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal("error", task.GetError().Msg)
-}
-
-func TestImportTask_Do(t *testing.T) {
- asserts := assert.New(t)
- task := &ImportTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- TaskProps: ImportProps{
- PolicyID: 63,
- Src: "",
- Recursive: false,
- Dst: "",
- },
- }
-
- // 存储策略不存在
- {
- cache.Deletes([]string{"63"}, "policy_")
- mock.ExpectQuery("SELECT(.+)policies(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id"}))
- // 设定失败状态
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotEmpty(task.Err.Error)
- task.Err = nil
- }
-
- // 无法分配 Filesystem
- {
- cache.Deletes([]string{"63"}, "policy_")
- mock.ExpectQuery("SELECT(.+)policies(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "unknown"))
- // 设定失败状态
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotEmpty(task.Err.Msg)
- task.Err = nil
- }
-
- // 成功列取,但是文件为空
- {
- cache.Deletes([]string{"63"}, "policy_")
- task.TaskProps.Src = "TestImportTask_Do/empty"
- mock.ExpectQuery("SELECT(.+)policies(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "local"))
- // 设定listing状态
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- // 设定inserting状态
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(task.Err)
- task.Err = nil
- }
-
- // 创建测试文件
- f, _ := util.CreatNestedFile(util.RelativePath("tests/TestImportTask_Do/test.txt"))
- f.Close()
-
- // 成功列取,包含一个文件一个目录,父目录创建失败
- {
- cache.Deletes([]string{"63"}, "policy_")
- task.TaskProps.Src = "tests"
- task.TaskProps.Dst = "/"
- task.TaskProps.Recursive = true
- mock.ExpectQuery("SELECT(.+)policies(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "local"))
- // 设定listing状态
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- // 设定inserting状态
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- // 查找父目录,但是不存在
- mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- // 仍然不存在
- mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- // 创建文件时查找父目录,仍然不存在
- mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"}))
-
- task.Do()
-
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(task.Err)
- task.Err = nil
- }
-
- // 成功列取,包含一个文件一个目录, 全部操作成功
- {
- cache.Deletes([]string{"63"}, "policy_")
- task.TaskProps.Src = "tests"
- task.TaskProps.Dst = "/"
- task.TaskProps.Recursive = true
- mock.ExpectQuery("SELECT(.+)policies(.+)").
- WillReturnRows(sqlmock.NewRows([]string{"id", "type"}).AddRow(63, "local"))
- // 设定listing状态
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- // 设定inserting状态
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- // 查找父目录,存在
- mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- // 查找同名文件,不存在
- mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- // 创建目录
- mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id"}))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)folders(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectCommit()
- // 插入文件记录
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)files(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectExec("UPDATE(.+)users(.+)storage(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
- mock.ExpectCommit()
-
- task.Do()
-
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(task.Err)
- task.Err = nil
- }
-}
-
-func TestNewImportTask(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- job, err := NewImportTask(1, 1, "/", "/", false)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotNil(job)
- asserts.NoError(err)
- }
-
- // 失败
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- job, err := NewImportTask(1, 1, "/", "/", false)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(job)
- asserts.Error(err)
- }
-}
-
-func TestNewImportTaskFromModel(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewImportTaskFromModel(&model.Task{Props: "{}"})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.NotNil(job)
- }
-
- // JSON解析失败
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewImportTaskFromModel(&model.Task{Props: "?"})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Nil(job)
- }
-}
diff --git a/pkg/task/job.go b/pkg/task/job.go
index d480492..ad77c6b 100644
--- a/pkg/task/job.go
+++ b/pkg/task/job.go
@@ -15,6 +15,8 @@ const (
TransferTaskType
// ImportTaskType 导入任务
ImportTaskType
+ // RelocateTaskType 存储策略迁移任务
+ RelocateTaskType
// RecycleTaskType 回收任务
RecycleTaskType
)
@@ -115,6 +117,8 @@ func GetJobFromModel(task *model.Task) (Job, error) {
return NewTransferTaskFromModel(task)
case ImportTaskType:
return NewImportTaskFromModel(task)
+ case RelocateTaskType:
+ return NewRelocateTaskFromModel(task)
case RecycleTaskType:
return NewRecycleTaskFromModel(task)
default:
diff --git a/pkg/task/job_test.go b/pkg/task/job_test.go
deleted file mode 100644
index 737f5b7..0000000
--- a/pkg/task/job_test.go
+++ /dev/null
@@ -1,118 +0,0 @@
-package task
-
-import (
- "errors"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
-)
-
-func TestRecord(t *testing.T) {
- asserts := assert.New(t)
- job := &TransferTask{
- User: &model.User{Policy: model.Policy{Type: "unknown"}},
- }
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- _, err := Record(job)
- asserts.NoError(err)
-}
-
-type taskPoolMock struct {
- testMock.Mock
-}
-
-func (t taskPoolMock) Add(num int) {
- t.Called(num)
-}
-
-func (t taskPoolMock) Submit(job Job) {
- t.Called(job)
-}
-
-func TestResume(t *testing.T) {
- asserts := assert.New(t)
- mockPool := taskPoolMock{}
-
- // 没有任务
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type"}))
- Resume(mockPool)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 有任务, 类型未知
- {
- mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(233))
- Resume(mockPool)
- asserts.NoError(mock.ExpectationsWereMet())
- }
-
- // 有任务
- {
- mockPool.On("Submit", testMock.Anything)
- mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type", "props"}).AddRow(CompressTaskType, "{}"))
- mock.ExpectQuery("SELECT(.+)users").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectQuery("SELECT(.+)policies").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- Resume(mockPool)
- asserts.NoError(mock.ExpectationsWereMet())
- mockPool.AssertExpectations(t)
- }
-}
-
-func TestGetJobFromModel(t *testing.T) {
- asserts := assert.New(t)
-
- // CompressTaskType
- {
- task := &model.Task{
- Status: 0,
- Type: CompressTaskType,
- }
- mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
- job, err := GetJobFromModel(task)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(job)
- asserts.Error(err)
- }
- // DecompressTaskType
- {
- task := &model.Task{
- Status: 0,
- Type: DecompressTaskType,
- }
- mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
- job, err := GetJobFromModel(task)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(job)
- asserts.Error(err)
- }
- // TransferTaskType
- {
- task := &model.Task{
- Status: 0,
- Type: TransferTaskType,
- }
- mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
- job, err := GetJobFromModel(task)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(job)
- asserts.Error(err)
- }
- // RecycleTaskType
- {
- task := &model.Task{
- Status: 0,
- Type: RecycleTaskType,
- }
- mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error"))
- job, err := GetJobFromModel(task)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(job)
- asserts.Error(err)
- }
-}
diff --git a/pkg/task/pool_test.go b/pkg/task/pool_test.go
deleted file mode 100644
index fbe4134..0000000
--- a/pkg/task/pool_test.go
+++ /dev/null
@@ -1,52 +0,0 @@
-package task
-
-import (
- "database/sql"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-var mock sqlmock.Sqlmock
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- defer db.Close()
- m.Run()
-}
-
-func TestInit(t *testing.T) {
- asserts := assert.New(t)
- cache.Set("setting_max_worker_num", "10", 0)
- mock.ExpectQuery("SELECT(.+)").WithArgs(Queued, Processing).WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(-1))
- Init()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Len(TaskPoll.(*AsyncPool).idleWorker, 10)
-}
-
-func TestPool_Submit(t *testing.T) {
- asserts := assert.New(t)
- pool := &AsyncPool{
- idleWorker: make(chan int, 1),
- }
- pool.Add(1)
- job := &MockJob{
- DoFunc: func() {
-
- },
- }
- asserts.NotPanics(func() {
- pool.Submit(job)
- })
-}
diff --git a/pkg/task/recycle_test.go b/pkg/task/recycle_test.go
deleted file mode 100644
index 0092a30..0000000
--- a/pkg/task/recycle_test.go
+++ /dev/null
@@ -1,117 +0,0 @@
-package task
-
-import (
- "errors"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestRecycleTask_Props(t *testing.T) {
- asserts := assert.New(t)
- task := &RecycleTask{
- User: &model.User{},
- }
- asserts.NotEmpty(task.Props())
- asserts.Equal(RecycleTaskType, task.Type())
- asserts.EqualValues(0, task.Creator())
- asserts.Nil(task.Model())
-}
-
-func TestRecycleTask_SetStatus(t *testing.T) {
- asserts := assert.New(t)
- task := &RecycleTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task.SetStatus(3)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestRecycleTask_SetError(t *testing.T) {
- asserts := assert.New(t)
- task := &RecycleTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- task.SetErrorMsg("error", nil)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal("error", task.GetError().Msg)
-}
-
-func TestNewRecycleTask(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- job, err := NewRecycleTask(&model.Download{
- Model: gorm.Model{ID: 1},
- GID: "test_g_id",
- Parent: "/",
- UserID: 1,
- NodeID: 1,
- })
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotNil(job)
- asserts.NoError(err)
- }
-
- // 失败
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- job, err := NewRecycleTask(&model.Download{
- Model: gorm.Model{ID: 1},
- GID: "test_g_id",
- Parent: "test/not_exist",
- UserID: 1,
- NodeID: 1,
- })
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(job)
- asserts.Error(err)
- }
-}
-
-func TestNewRecycleTaskFromModel(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewRecycleTaskFromModel(&model.Task{Props: "{}"})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.NotNil(job)
- }
-
- // JSON解析失败
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewRecycleTaskFromModel(&model.Task{Props: "?"})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Nil(job)
- }
-}
diff --git a/pkg/task/relocate.go b/pkg/task/relocate.go
new file mode 100755
index 0000000..65666ed
--- /dev/null
+++ b/pkg/task/relocate.go
@@ -0,0 +1,176 @@
+package task
+
+import (
+ "context"
+ "encoding/json"
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
+)
+
+// RelocateTask 存储策略迁移任务
+type RelocateTask struct {
+ User *model.User
+ TaskModel *model.Task
+ TaskProps RelocateProps
+ Err *JobError
+}
+
+// RelocateProps 存储策略迁移任务属性
+type RelocateProps struct {
+ Dirs []uint `json:"dirs"`
+ Files []uint `json:"files"`
+ DstPolicyID uint `json:"dst_policy_id"`
+}
+
+// Props 获取任务属性
+func (job *RelocateTask) Props() string {
+ res, _ := json.Marshal(job.TaskProps)
+ return string(res)
+}
+
+// Type 获取任务状态
+func (job *RelocateTask) Type() int {
+ return RelocateTaskType
+}
+
+// Creator 获取创建者ID
+func (job *RelocateTask) Creator() uint {
+ return job.User.ID
+}
+
+// Model 获取任务的数据库模型
+func (job *RelocateTask) Model() *model.Task {
+ return job.TaskModel
+}
+
+// SetStatus 设定状态
+func (job *RelocateTask) SetStatus(status int) {
+ job.TaskModel.SetStatus(status)
+}
+
+// SetError 设定任务失败信息
+func (job *RelocateTask) SetError(err *JobError) {
+ job.Err = err
+ res, _ := json.Marshal(job.Err)
+ job.TaskModel.SetError(string(res))
+}
+
+// SetErrorMsg 设定任务失败信息
+func (job *RelocateTask) SetErrorMsg(msg string) {
+ job.SetError(&JobError{Msg: msg})
+}
+
+// GetError 返回任务失败信息
+func (job *RelocateTask) GetError() *JobError {
+ return job.Err
+}
+
+// Do 开始执行任务
+func (job *RelocateTask) Do() {
+ // 创建文件系统
+ fs, err := filesystem.NewFileSystem(job.User)
+ if err != nil {
+ job.SetErrorMsg(err.Error())
+ return
+ }
+
+ job.TaskModel.SetProgress(ListingProgress)
+ util.Log().Debug("Start migration task.")
+
+ // ----------------------------
+ // 索引出所有待迁移的文件
+ // ----------------------------
+ targetFiles := make([]model.File, 0, len(job.TaskProps.Files))
+
+ // 索引用户选择的单独的文件
+ outerFiles, err := model.GetFilesByIDs(job.TaskProps.Files, job.User.ID)
+ if err != nil {
+ job.SetError(&JobError{
+ Msg: "Failed to index files.",
+ Error: err.Error(),
+ })
+ return
+ }
+ targetFiles = append(targetFiles, outerFiles...)
+
+ // 索引用户选择目录下的所有递归子文件
+ subFolders, err := model.GetRecursiveChildFolder(job.TaskProps.Dirs, job.User.ID, true)
+ if err != nil {
+ job.SetError(&JobError{
+ Msg: "Failed to index child folders.",
+ Error: err.Error(),
+ })
+ return
+ }
+
+ subFiles, err := model.GetChildFilesOfFolders(&subFolders)
+ if err != nil {
+ job.SetError(&JobError{
+ Msg: "Failed to index child files.",
+ Error: err.Error(),
+ })
+ return
+ }
+ targetFiles = append(targetFiles, subFiles...)
+
+ // 查找目标存储策略
+ policy, err := model.GetPolicyByID(job.TaskProps.DstPolicyID)
+ if err != nil {
+ job.SetError(&JobError{
+ Msg: "Invalid policy.",
+ Error: err.Error(),
+ })
+ return
+ }
+
+ // 开始转移文件
+ job.TaskModel.SetProgress(TransferringProgress)
+ ctx := context.Background()
+ err = fs.Relocate(ctx, targetFiles, &policy)
+ if err != nil {
+ job.SetErrorMsg(err.Error())
+ return
+ }
+
+ return
+}
+
+// NewRelocateTask 新建转移任务
+func NewRelocateTask(user *model.User, dstPolicyID uint, dirs, files []uint) (Job, error) {
+ newTask := &RelocateTask{
+ User: user,
+ TaskProps: RelocateProps{
+ Dirs: dirs,
+ Files: files,
+ DstPolicyID: dstPolicyID,
+ },
+ }
+
+ record, err := Record(newTask)
+ if err != nil {
+ return nil, err
+ }
+ newTask.TaskModel = record
+
+ return newTask, nil
+}
+
+// NewCompressTaskFromModel 从数据库记录中恢复压缩任务
+func NewRelocateTaskFromModel(task *model.Task) (Job, error) {
+ user, err := model.GetActiveUserByID(task.UserID)
+ if err != nil {
+ return nil, err
+ }
+ newTask := &RelocateTask{
+ User: &user,
+ TaskModel: task,
+ }
+
+ err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps)
+ if err != nil {
+ return nil, err
+ }
+
+ return newTask, nil
+}
diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go
index 54bba47..135c809 100644
--- a/pkg/task/tranfer.go
+++ b/pkg/task/tranfer.go
@@ -93,6 +93,7 @@ func (job *TransferTask) Do() {
job.SetErrorMsg(err.Error(), nil)
return
}
+ defer fs.Recycle()
successCount := 0
errorList := make([]string, 0, len(job.TaskProps.Src))
@@ -115,6 +116,7 @@ func (job *TransferTask) Do() {
}
// 切换为从机节点处理上传
+ fs.SetPolicyFromPath(path.Dir(dst))
fs.SwitchToSlaveHandler(node)
err = fs.UploadFromStream(context.Background(), &fsctx.FileStream{
File: nil,
diff --git a/pkg/task/transfer_test.go b/pkg/task/transfer_test.go
deleted file mode 100644
index 612a453..0000000
--- a/pkg/task/transfer_test.go
+++ /dev/null
@@ -1,170 +0,0 @@
-package task
-
-import (
- "errors"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestTransferTask_Props(t *testing.T) {
- asserts := assert.New(t)
- task := &TransferTask{
- User: &model.User{},
- }
- asserts.NotEmpty(task.Props())
- asserts.Equal(TransferTaskType, task.Type())
- asserts.EqualValues(0, task.Creator())
- asserts.Nil(task.Model())
-}
-
-func TestTransferTask_SetStatus(t *testing.T) {
- asserts := assert.New(t)
- task := &TransferTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- task.SetStatus(3)
- asserts.NoError(mock.ExpectationsWereMet())
-}
-
-func TestTransferTask_SetError(t *testing.T) {
- asserts := assert.New(t)
- task := &TransferTask{
- User: &model.User{},
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
-
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
-
- task.SetErrorMsg("error", nil)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Equal("error", task.GetError().Msg)
-}
-
-func TestTransferTask_Do(t *testing.T) {
- asserts := assert.New(t)
- task := &TransferTask{
- TaskModel: &model.Task{
- Model: gorm.Model{ID: 1},
- },
- }
-
- // 无法创建文件系统
- {
- task.TaskProps.Parent = "test/not_exist"
- task.User = &model.User{
- Policy: model.Policy{
- Type: "unknown",
- },
- }
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
- 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotEmpty(task.GetError().Msg)
- }
-
- // 上传出错
- {
- task.User = &model.User{
- Policy: model.Policy{
- Type: "mock",
- },
- }
- task.TaskProps.Src = []string{"test/not_exist"}
- task.TaskProps.Parent = "test/not_exist"
- // 更新错误
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
- 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotEmpty(task.GetError().Msg)
- }
-
- // 替换目录前缀
- {
- task.User = &model.User{
- Policy: model.Policy{
- Type: "mock",
- },
- }
- task.TaskProps.Src = []string{"test/not_exist"}
- task.TaskProps.Parent = "test/not_exist"
- task.TaskProps.TrimPath = true
- // 更新错误
- mock.ExpectBegin()
- mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1,
- 1))
- mock.ExpectCommit()
- task.Do()
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotEmpty(task.GetError().Msg)
- }
-}
-
-func TestNewTransferTask(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
- mock.ExpectCommit()
- job, err := NewTransferTask(1, []string{}, "/", "/", false, 0, nil)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NotNil(job)
- asserts.NoError(err)
- }
-
- // 失败
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- mock.ExpectBegin()
- mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error"))
- mock.ExpectRollback()
- job, err := NewTransferTask(1, []string{}, "/", "/", false, 0, nil)
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Nil(job)
- asserts.Error(err)
- }
-}
-
-func TestNewTransferTaskFromModel(t *testing.T) {
- asserts := assert.New(t)
-
- // 成功
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewTransferTaskFromModel(&model.Task{Props: "{}"})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.NoError(err)
- asserts.NotNil(job)
- }
-
- // JSON解析失败
- {
- mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
- job, err := NewTransferTaskFromModel(&model.Task{Props: "?"})
- asserts.NoError(mock.ExpectationsWereMet())
- asserts.Error(err)
- asserts.Nil(job)
- }
-}
diff --git a/pkg/task/worker_test.go b/pkg/task/worker_test.go
deleted file mode 100644
index 64c6551..0000000
--- a/pkg/task/worker_test.go
+++ /dev/null
@@ -1,81 +0,0 @@
-package task
-
-import (
- "testing"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/stretchr/testify/assert"
-)
-
-type MockJob struct {
- Err *JobError
- Status int
- DoFunc func()
-}
-
-func (job *MockJob) Type() int {
- panic("implement me")
-}
-
-func (job *MockJob) Creator() uint {
- panic("implement me")
-}
-
-func (job *MockJob) Props() string {
- panic("implement me")
-}
-
-func (job *MockJob) Model() *model.Task {
- panic("implement me")
-}
-
-func (job *MockJob) SetStatus(status int) {
- job.Status = status
-}
-
-func (job *MockJob) Do() {
- job.DoFunc()
-}
-
-func (job *MockJob) SetError(*JobError) {
-}
-
-func (job *MockJob) GetError() *JobError {
- return job.Err
-}
-
-func TestGeneralWorker_Do(t *testing.T) {
- asserts := assert.New(t)
- worker := &GeneralWorker{}
- job := &MockJob{}
-
- // 正常
- {
- job.DoFunc = func() {
- }
- worker.Do(job)
- asserts.Equal(Complete, job.Status)
- }
-
- // 有错误
- {
- job.DoFunc = func() {
- }
- job.Status = Queued
- job.Err = &JobError{Msg: "error"}
- worker.Do(job)
- asserts.Equal(Error, job.Status)
- }
-
- // 有致命错误
- {
- job.DoFunc = func() {
- panic("mock fatal error")
- }
- job.Status = Queued
- job.Err = nil
- worker.Do(job)
- asserts.Equal(Error, job.Status)
- }
-
-}
diff --git a/pkg/thumb/builtin.go b/pkg/thumb/builtin.go
index 206d046..fadac6e 100644
--- a/pkg/thumb/builtin.go
+++ b/pkg/thumb/builtin.go
@@ -14,6 +14,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gofrs/uuid"
+
//"github.com/nfnt/resize"
"golang.org/x/image/draw"
)
diff --git a/pkg/util/common_test.go b/pkg/util/common_test.go
deleted file mode 100644
index ae4c47f..0000000
--- a/pkg/util/common_test.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package util
-
-import (
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestRandStringRunes(t *testing.T) {
- asserts := assert.New(t)
-
- // 0 长度字符
- randStr := RandStringRunes(0)
- asserts.Len(randStr, 0)
-
- // 16 长度字符
- randStr = RandStringRunes(16)
- asserts.Len(randStr, 16)
-
- // 32 长度字符
- randStr = RandStringRunes(32)
- asserts.Len(randStr, 32)
-
- //相同长度字符
- sameLenStr1 := RandStringRunes(32)
- sameLenStr2 := RandStringRunes(32)
- asserts.NotEqual(sameLenStr1, sameLenStr2)
-}
-
-func TestContainsUint(t *testing.T) {
- asserts := assert.New(t)
- asserts.True(ContainsUint([]uint{0, 2, 3, 65, 4}, 65))
- asserts.True(ContainsUint([]uint{65}, 65))
- asserts.False(ContainsUint([]uint{65}, 6))
-}
-
-func TestContainsString(t *testing.T) {
- asserts := assert.New(t)
- asserts.True(ContainsString([]string{"", "1"}, ""))
- asserts.True(ContainsString([]string{"", "1"}, "1"))
- asserts.False(ContainsString([]string{"", "1"}, " "))
-}
-
-func TestReplace(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.Equal("origin", Replace(map[string]string{
- "123": "321",
- }, "origin"))
-
- asserts.Equal("321origin321", Replace(map[string]string{
- "123": "321",
- }, "123origin123"))
- asserts.Equal("321new321", Replace(map[string]string{
- "123": "321",
- "origin": "new",
- }, "123origin123"))
-}
-
-func TestBuildRegexp(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.Equal("^/dir/", BuildRegexp([]string{"/dir"}, "^", "/", "|"))
- asserts.Equal("^/dir/|^/dir/di\\*r/", BuildRegexp([]string{"/dir", "/dir/di*r"}, "^", "/", "|"))
-}
-
-func TestBuildConcat(t *testing.T) {
- asserts := assert.New(t)
- asserts.Equal("CONCAT(1,2)", BuildConcat("1", "2", "mysql"))
- asserts.Equal("1||2", BuildConcat("1", "2", "sqlite"))
-}
-
-func TestSliceDifference(t *testing.T) {
- asserts := assert.New(t)
-
- {
- s1 := []string{"1", "2", "3", "4"}
- s2 := []string{"2", "4"}
- asserts.Equal([]string{"1", "3"}, SliceDifference(s1, s2))
- }
-
- {
- s2 := []string{"1", "2", "3", "4"}
- s1 := []string{"2", "4"}
- asserts.Equal([]string{}, SliceDifference(s1, s2))
- }
-
- {
- s1 := []string{"1", "2", "3", "4"}
- s2 := []string{"1", "2", "3", "4"}
- asserts.Equal([]string{}, SliceDifference(s1, s2))
- }
-
- {
- s1 := []string{"1", "2", "3", "4"}
- s2 := []string{}
- asserts.Equal([]string{"1", "2", "3", "4"}, SliceDifference(s1, s2))
- }
-}
diff --git a/pkg/util/io_test.go b/pkg/util/io_test.go
deleted file mode 100644
index 755d203..0000000
--- a/pkg/util/io_test.go
+++ /dev/null
@@ -1,39 +0,0 @@
-package util
-
-import (
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestExists(t *testing.T) {
- asserts := assert.New(t)
- asserts.True(Exists("io_test.go"))
- asserts.False(Exists("io_test.js"))
-}
-
-func TestCreatNestedFile(t *testing.T) {
- asserts := assert.New(t)
-
- // 父目录不存在
- {
- file, err := CreatNestedFile("test/nest.txt")
- asserts.NoError(err)
- asserts.NoError(file.Close())
- asserts.FileExists("test/nest.txt")
- }
-
- // 父目录存在
- {
- file, err := CreatNestedFile("test/direct.txt")
- asserts.NoError(err)
- asserts.NoError(file.Close())
- asserts.FileExists("test/direct.txt")
- }
-}
-
-func TestIsEmpty(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.False(IsEmpty(""))
- asserts.False(IsEmpty("not_exist"))
-}
diff --git a/pkg/util/logger_test.go b/pkg/util/logger_test.go
deleted file mode 100644
index 5f1352d..0000000
--- a/pkg/util/logger_test.go
+++ /dev/null
@@ -1,103 +0,0 @@
-// +build !race
-
-package util
-
-import (
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestBuildLogger(t *testing.T) {
- asserts := assert.New(t)
- asserts.NotPanics(func() {
- BuildLogger("error")
- })
- asserts.NotPanics(func() {
- BuildLogger("warning")
- })
- asserts.NotPanics(func() {
- BuildLogger("info")
- })
- asserts.NotPanics(func() {
- BuildLogger("?")
- })
- asserts.NotPanics(func() {
- BuildLogger("debug")
- })
-}
-
-func TestLog(t *testing.T) {
- asserts := assert.New(t)
- asserts.NotNil(Log())
- GloablLogger = nil
- asserts.NotNil(Log())
-}
-
-func TestLogger_Debug(t *testing.T) {
- asserts := assert.New(t)
- l := Logger{
- level: LevelDebug,
- }
- asserts.NotPanics(func() {
- l.Debug("123")
- })
- l.level = LevelError
- asserts.NotPanics(func() {
- l.Debug("123")
- })
-}
-
-func TestLogger_Info(t *testing.T) {
- asserts := assert.New(t)
- l := Logger{
- level: LevelDebug,
- }
- asserts.NotPanics(func() {
- l.Info("123")
- })
- l.level = LevelError
- asserts.NotPanics(func() {
- l.Info("123")
- })
-}
-func TestLogger_Warning(t *testing.T) {
- asserts := assert.New(t)
- l := Logger{
- level: LevelDebug,
- }
- asserts.NotPanics(func() {
- l.Warning("123")
- })
- l.level = LevelError
- asserts.NotPanics(func() {
- l.Warning("123")
- })
-}
-
-func TestLogger_Error(t *testing.T) {
- asserts := assert.New(t)
- l := Logger{
- level: LevelDebug,
- }
- asserts.NotPanics(func() {
- l.Error("123")
- })
- l.level = -1
- asserts.NotPanics(func() {
- l.Error("123")
- })
-}
-
-func TestLogger_Panic(t *testing.T) {
- asserts := assert.New(t)
- l := Logger{
- level: LevelDebug,
- }
- asserts.Panics(func() {
- l.Panic("123")
- })
- l.level = -1
- asserts.NotPanics(func() {
- l.Error("123")
- })
-}
diff --git a/pkg/util/path.go b/pkg/util/path.go
index 2dd8aef..ff51d57 100644
--- a/pkg/util/path.go
+++ b/pkg/util/path.go
@@ -56,4 +56,3 @@ func RelativePath(name string) string {
e, _ := os.Executable()
return filepath.Join(filepath.Dir(e), name)
}
-
diff --git a/pkg/util/path_test.go b/pkg/util/path_test.go
deleted file mode 100644
index a417a9c..0000000
--- a/pkg/util/path_test.go
+++ /dev/null
@@ -1,36 +0,0 @@
-package util
-
-import (
- "github.com/stretchr/testify/assert"
- "testing"
-)
-
-func TestDotPathToStandardPath(t *testing.T) {
- asserts := assert.New(t)
-
- asserts.Equal("/", DotPathToStandardPath(""))
- asserts.Equal("/目录", DotPathToStandardPath("目录"))
- asserts.Equal("/目录/目录2", DotPathToStandardPath("目录,目录2"))
-}
-
-func TestFillSlash(t *testing.T) {
- asserts := assert.New(t)
- asserts.Equal("/", FillSlash("/"))
- asserts.Equal("/", FillSlash(""))
- asserts.Equal("/123/", FillSlash("/123"))
-}
-
-func TestRemoveSlash(t *testing.T) {
- asserts := assert.New(t)
- asserts.Equal("/", RemoveSlash("/"))
- asserts.Equal("/123/1236", RemoveSlash("/123/1236"))
- asserts.Equal("/123/1236", RemoveSlash("/123/1236/"))
-}
-
-func TestSplitPath(t *testing.T) {
- asserts := assert.New(t)
- asserts.Equal([]string{}, SplitPath(""))
- asserts.Equal([]string{}, SplitPath("1"))
- asserts.Equal([]string{"/"}, SplitPath("/"))
- asserts.Equal([]string{"/", "123", "321"}, SplitPath("/123/321"))
-}
diff --git a/pkg/util/ztool.go b/pkg/util/ztool.go
new file mode 100644
index 0000000..9733b6f
--- /dev/null
+++ b/pkg/util/ztool.go
@@ -0,0 +1,35 @@
+package util
+
+import "strings"
+
+type (
+ // 可取长度类型
+ LenAble interface{ string | []any | chan any }
+)
+
+// 计算切片元素总长度
+/*
+ 传入字符串切片, 返回其所有元素长度之和
+ e.g. LenArray({`ele3`,`ele2`,`ele1`}) => 12
+*/
+func LenArray[T LenAble](a []T) int {
+ var o int
+ for i, r := 0, len(a); i < r; i++ {
+ o += len(a[i])
+ }
+ return o
+}
+
+// 字符串快速拼接
+/*
+ 传入多个字符串参数, 返回拼接后的结果
+ e.g: StrConcat("str1", "str2", "str3") => "str1str2str3"
+*/
+func StrConcat(a ...string) string {
+ var b strings.Builder
+ b.Grow(LenArray(a))
+ for i, r := 0, len(a); i < r; i++ {
+ b.WriteString(a[i])
+ }
+ return b.String()
+}
diff --git a/pkg/vol/vol.go b/pkg/vol/vol.go
new file mode 100755
index 0000000..58dae92
--- /dev/null
+++ b/pkg/vol/vol.go
@@ -0,0 +1,39 @@
+package vol
+
+import (
+ "fmt"
+ "github.com/cloudreve/Cloudreve/v3/pkg/request"
+ "net/http"
+)
+
+var ClientSecret = ""
+
+const CRMSite = "https://pro.cloudreve.org/crm/api/vol/"
+
+type Client interface {
+ // Sync VOL from CRM, return content (base64 encoded) and signature.
+ Sync() (string, string, error)
+}
+
+type VolClient struct {
+ secret string
+ client request.Client
+}
+
+func New(secret string) Client {
+ return &VolClient{secret: secret, client: request.NewClient()}
+}
+
+func (c *VolClient) Sync() (string, string, error) {
+ res, err := c.client.Request("GET", CRMSite+c.secret, nil).CheckHTTPResponse(http.StatusOK).DecodeResponse()
+ if err != nil {
+ return "", "", fmt.Errorf("failed to get VOL from CRM: %w", err)
+ }
+
+ if res.Code != 0 {
+ return "", "", fmt.Errorf("CRM return error: %s", res.Msg)
+ }
+
+ vol := res.Data.(map[string]interface{})
+ return vol["content"].(string), vol["signature"].(string), nil
+}
diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go
index 5ab9906..78cc9f4 100644
--- a/pkg/webdav/webdav.go
+++ b/pkg/webdav/webdav.go
@@ -414,6 +414,9 @@ func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request, fs *filesyst
ctx = context.WithValue(ctx, fsctx.FileModelCtx, *originFile)
fileData.Mode |= fsctx.Overwrite
} else {
+ // 尝试获取并重设存储策略
+ fs.SetPolicyFromPath(filePath)
+
// 给文件系统分配钩子
fs.Use("BeforeUpload", filesystem.HookValidateFile)
fs.Use("BeforeUpload", filesystem.HookValidateCapacity)
diff --git a/pkg/wopi/discovery_test.go b/pkg/wopi/discovery_test.go
deleted file mode 100644
index 8092384..0000000
--- a/pkg/wopi/discovery_test.go
+++ /dev/null
@@ -1,129 +0,0 @@
-package wopi
-
-import (
- "errors"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
- "io"
- "net/http"
- "net/url"
- "strings"
- "testing"
-)
-
-func TestClient_AvailableExts(t *testing.T) {
- a := assert.New(t)
- endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery")
- client := &client{
- cache: cache.NewMemoStore(),
- config: config{
- discoveryEndpoint: endpoint,
- },
- }
-
- // Discovery failed
- {
- expectedErr := errors.New("error")
- mockHttp := &requestmock.RequestMock{}
- client.http = mockHttp
- mockHttp.On(
- "Request",
- "GET",
- endpoint.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: expectedErr,
- })
- res := client.AvailableExts()
- a.Empty(res)
- mockHttp.AssertExpectations(t)
- }
-
- // pass
- {
- client.discovery = &WopiDiscovery{}
- client.actions = map[string]map[string]Action{
- ".doc": {
- string(ActionPreviewFallback): Action{},
- },
- ".ppt": {},
- ".xls": {
- "not_supported": Action{},
- },
- }
- res := client.AvailableExts()
- a.Len(res, 1)
- a.Equal("doc", res[0])
- }
-}
-
-func TestClient_RefreshDiscovery(t *testing.T) {
- a := assert.New(t)
- endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery")
- client := &client{
- cache: cache.NewMemoStore(),
- config: config{
- discoveryEndpoint: endpoint,
- },
- }
-
- // cache hit
- {
- client.cache.Set(DiscoverResponseCacheKey, WopiDiscovery{Text: "test"}, 0)
- a.NoError(client.checkDiscovery())
- a.Equal("test", client.discovery.Text)
- client.discovery = &WopiDiscovery{}
- client.cache.Delete([]string{DiscoverResponseCacheKey}, "")
- }
-
- // malformed xml
- {
- mockHttp := &requestmock.RequestMock{}
- client.http = mockHttp
- mockHttp.On(
- "Request",
- "GET",
- endpoint.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: io.NopCloser(strings.NewReader(`{"code":203}`)),
- },
- })
- res := client.refreshDiscovery()
- a.ErrorContains(res, "failed to parse")
- mockHttp.AssertExpectations(t)
- }
-
- // all pass
- {
- testResponse := `
-`
- mockHttp := &requestmock.RequestMock{}
- client.http = mockHttp
- mockHttp.On(
- "Request",
- "GET",
- endpoint.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Response: &http.Response{
- StatusCode: 200,
- Body: io.NopCloser(strings.NewReader(testResponse)),
- },
- })
- res := client.refreshDiscovery()
- a.NoError(res, res)
- a.NotEmpty(client.actions[".docx"])
- a.NotEmpty(client.actions[".docx"][string(ActionPreview)])
- a.NotEmpty(client.actions[".docx"][string(ActionEdit)])
- mockHttp.AssertExpectations(t)
- }
-}
diff --git a/pkg/wopi/wopi.go b/pkg/wopi/wopi.go
index 2938de0..7a7b296 100644
--- a/pkg/wopi/wopi.go
+++ b/pkg/wopi/wopi.go
@@ -3,17 +3,18 @@ package wopi
import (
"errors"
"fmt"
+ "net/url"
+ "path"
+ "strings"
+ "sync"
+ "time"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gofrs/uuid"
- "net/url"
- "path"
- "strings"
- "sync"
- "time"
)
type Client interface {
diff --git a/pkg/wopi/wopi_test.go b/pkg/wopi/wopi_test.go
deleted file mode 100644
index 78c4bcc..0000000
--- a/pkg/wopi/wopi_test.go
+++ /dev/null
@@ -1,184 +0,0 @@
-package wopi
-
-import (
- "database/sql"
- "errors"
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/cloudreve/Cloudreve/v3/pkg/cache"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/cachemock"
- "github.com/cloudreve/Cloudreve/v3/pkg/mocks/requestmock"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
- testMock "github.com/stretchr/testify/mock"
- "net/url"
- "testing"
-)
-
-var mock sqlmock.Sqlmock
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
- model.DB, _ = gorm.Open("mysql", db)
- defer db.Close()
- m.Run()
-}
-
-func TestNewSession(t *testing.T) {
- a := assert.New(t)
- endpoint, _ := url.Parse("http://localhost:8001/hosting/discovery")
- client := &client{
- cache: cache.NewMemoStore(),
- config: config{
- discoveryEndpoint: endpoint,
- },
- }
-
- // Discovery failed
- {
- expectedErr := errors.New("error")
- mockHttp := &requestmock.RequestMock{}
- client.http = mockHttp
- mockHttp.On(
- "Request",
- "GET",
- endpoint.String(),
- testMock.Anything,
- testMock.Anything,
- ).Return(&request.Response{
- Err: expectedErr,
- })
- res, err := client.NewSession(0, &model.File{}, ActionPreview)
- a.Nil(res)
- a.ErrorIs(err, expectedErr)
- mockHttp.AssertExpectations(t)
- }
-
- // not supported ext
- {
- client.discovery = &WopiDiscovery{}
- client.actions = make(map[string]map[string]Action)
- res, err := client.NewSession(0, &model.File{}, ActionPreview)
- a.Nil(res)
- a.ErrorIs(err, ErrActionNotSupported)
- }
-
- // preferred action not supported
- {
- client.discovery = &WopiDiscovery{}
- client.actions = map[string]map[string]Action{
- ".doc": {},
- }
- res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionPreview)
- a.Nil(res)
- a.ErrorIs(err, ErrActionNotSupported)
- }
-
- // src url cannot be parsed
- {
- client.discovery = &WopiDiscovery{}
- client.actions = map[string]map[string]Action{
- ".doc": {
- string(ActionPreviewFallback): Action{
- Urlsrc: string([]byte{0x7f}),
- },
- },
- }
- res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit)
- a.Nil(res)
- a.ErrorContains(err, "invalid control character in URL")
- }
-
- // all pass - default placeholder
- {
- client.discovery = &WopiDiscovery{}
- client.actions = map[string]map[string]Action{
- ".doc": {
- string(ActionPreviewFallback): Action{
- Urlsrc: "https://doc.com/doc",
- },
- },
- }
- res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit)
- a.NotNil(res)
- a.NoError(err)
- resUrl := res.ActionURL.String()
- a.Contains(resUrl, wopiSrcParamDefault)
- }
-
- // all pass - with placeholders
- {
- client.discovery = &WopiDiscovery{}
- client.actions = map[string]map[string]Action{
- ".doc": {
- string(ActionPreviewFallback): Action{
- Urlsrc: "https://doc.com/doc?origin=preserved&",
- },
- },
- }
- res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit)
- a.NotNil(res)
- a.NoError(err)
- resUrl := res.ActionURL.String()
- a.Contains(resUrl, "origin=preserved")
- a.Contains(resUrl, "dc=lng")
- a.Contains(resUrl, "src=")
- a.NotContains(resUrl, "notsuported")
- }
-
- // cache operation failed
- {
- mockCache := &cachemock.CacheClientMock{}
- expectedErr := errors.New("error")
- client.cache = mockCache
- client.discovery = &WopiDiscovery{}
- client.actions = map[string]map[string]Action{
- ".doc": {
- string(ActionPreviewFallback): Action{
- Urlsrc: "https://doc.com/doc",
- },
- },
- }
- mockCache.On("Set", testMock.Anything, testMock.Anything, testMock.Anything).Return(expectedErr)
- res, err := client.NewSession(0, &model.File{Name: "1.doc"}, ActionEdit)
- a.Nil(res)
- a.ErrorIs(err, expectedErr)
- }
-}
-
-func TestInit(t *testing.T) {
- a := assert.New(t)
-
- // not enabled
- {
- a.Nil(Default)
- Default = &client{}
- Init()
- a.Nil(Default)
- }
-
- // throw error
- {
- a.Nil(Default)
- cache.Set("setting_wopi_enabled", "1", 0)
- cache.Set("setting_wopi_endpoint", string([]byte{0x7f}), 0)
- Init()
- a.Nil(Default)
- }
-
- // all pass
- {
- a.Nil(Default)
- cache.Set("setting_wopi_enabled", "1", 0)
- cache.Set("setting_wopi_endpoint", "", 0)
- Init()
- a.NotNil(Default)
- }
-}
diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go
index 26d917e..e9f4dac 100644
--- a/routers/controllers/admin.go
+++ b/routers/controllers/admin.go
@@ -1,14 +1,15 @@
package controllers
import (
- "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
- "github.com/cloudreve/Cloudreve/v3/pkg/mq"
- "io"
+ // "io"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
+ "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/email"
- "github.com/cloudreve/Cloudreve/v3/pkg/request"
+ "github.com/cloudreve/Cloudreve/v3/pkg/mq"
+
+ // "github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/cloudreve/Cloudreve/v3/service/admin"
@@ -27,17 +28,17 @@ func AdminSummary(c *gin.Context) {
}
// AdminNews 获取社区新闻
-func AdminNews(c *gin.Context) {
- tag := "announcements"
- if c.Query("tag") != "" {
- tag = c.Query("tag")
- }
- r := request.NewClient()
- res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3A"+tag+"&sort=-startTime&page%5Blimit%5D=10", nil)
- if res.Err == nil {
- io.Copy(c.Writer, res.Response.Body)
- }
-}
+// func AdminNews(c *gin.Context) {
+// tag := "announcements"
+// if c.Query("tag") != "" {
+// tag = c.Query("tag")
+// }
+// r := request.NewClient()
+// res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3A"+tag+"&sort=-startTime&page%5Blimit%5D=10", nil)
+// if res.Err == nil {
+// io.Copy(c.Writer, res.Response.Body)
+// }
+// }
// AdminChangeSetting 获取站点设定项
func AdminChangeSetting(c *gin.Context) {
@@ -109,6 +110,39 @@ func AdminTestThumbGenerator(c *gin.Context) {
}
}
+// AdminListRedeems 列出激活码
+func AdminListRedeems(c *gin.Context) {
+ var service admin.AdminListService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Redeems()
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// AdminGenerateRedeems 生成激活码
+func AdminGenerateRedeems(c *gin.Context) {
+ var service admin.GenerateRedeemsService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Generate()
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// AdminDeleteRedeem 删除激活码
+func AdminDeleteRedeem(c *gin.Context) {
+ var service admin.SingleIDService
+ if err := c.ShouldBindUri(&service); err == nil {
+ res := service.DeleteRedeem()
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
// AdminTestAria2 测试aria2连接
func AdminTestAria2(c *gin.Context) {
var service admin.Aria2TestService
@@ -389,6 +423,28 @@ func AdminDeleteShare(c *gin.Context) {
}
}
+// AdminListOrder 列出订单
+func AdminListOrder(c *gin.Context) {
+ var service admin.AdminListService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Orders()
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// AdminDeleteOrder 批量删除订单
+func AdminDeleteOrder(c *gin.Context) {
+ var service admin.OrderBatchService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Delete(c)
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
// AdminListDownload 列出离线下载任务
func AdminListDownload(c *gin.Context) {
var service admin.AdminListService
@@ -455,6 +511,28 @@ func AdminListFolders(c *gin.Context) {
}
}
+// AdminListReport 列出未处理举报
+func AdminListReport(c *gin.Context) {
+ var service admin.AdminListService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Reports()
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// AdminDeleteTask 批量删除举报
+func AdminDeleteReport(c *gin.Context) {
+ var service admin.ReportBatchService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Delete()
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
// AdminListNodes 列出从机节点
func AdminListNodes(c *gin.Context) {
var service admin.AdminListService
@@ -509,3 +587,10 @@ func AdminGetNode(c *gin.Context) {
c.JSON(200, ErrorResponse(err))
}
}
+
+// AdminSyncVol 同步VOL授权
+func AdminSyncVol(c *gin.Context) {
+ var service admin.VolService
+ res := service.Sync()
+ c.JSON(200, res)
+}
diff --git a/routers/controllers/file.go b/routers/controllers/file.go
index 0e7c206..e32061b 100644
--- a/routers/controllers/file.go
+++ b/routers/controllers/file.go
@@ -3,10 +3,10 @@ package controllers
import (
"context"
"fmt"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"net/http"
model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/service/explorer"
@@ -51,6 +51,17 @@ func Compress(c *gin.Context) {
}
}
+// Relocate 创建文件转移任务
+func Relocate(c *gin.Context) {
+ var service explorer.ItemRelocateService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.CreateRelocateTask(c)
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
// Decompress 创建文件解压缩任务
func Decompress(c *gin.Context) {
var service explorer.ItemDecompressService
@@ -135,6 +146,7 @@ func AnonymousPermLink(c *gin.Context) {
}
+// GetSource 获取文件的外链地址
func GetSource(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
@@ -304,52 +316,6 @@ func FileUpload(c *gin.Context) {
} else {
c.JSON(200, ErrorResponse(err))
}
-
- //fileData := fsctx.FileStream{
- // MIMEType: c.Request.Header.Get("Content-Type"),
- // File: c.Request.Body,
- // Size: fileSize,
- // Name: fileName,
- // VirtualPath: filePath,
- // Mode: fsctx.Create,
- //}
- //
- //// 创建文件系统
- //fs, err := filesystem.NewFileSystemFromContext(c)
- //if err != nil {
- // c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err))
- // return
- //}
- //
- //// 非可用策略时拒绝上传
- //if !fs.Policy.IsTransitUpload(fileSize) {
- // request.BlackHole(c.Request.Body)
- // c.JSON(200, serializer.Err(serializer.CodePolicyNotAllowed, "当前存储策略无法使用", nil))
- // return
- //}
- //
- //// 给文件系统分配钩子
- //fs.Use("BeforeUpload", filesystem.HookValidateFile)
- //fs.Use("BeforeUpload", filesystem.HookValidateCapacity)
- //fs.Use("AfterUploadCanceled", filesystem.HookDeleteTempFile)
- //fs.Use("AfterUploadCanceled", filesystem.HookGiveBackCapacity)
- //fs.Use("AfterUpload", filesystem.GenericAfterUpload)
- //fs.Use("AfterValidateFailed", filesystem.HookDeleteTempFile)
- //fs.Use("AfterValidateFailed", filesystem.HookGiveBackCapacity)
- //fs.Use("AfterUploadFailed", filesystem.HookGiveBackCapacity)
- //
- //// 执行上传
- //ctx = context.WithValue(ctx, fsctx.ValidateCapacityOnceCtx, &sync.Once{})
- //uploadCtx := context.WithValue(ctx, fsctx.GinCtx, c)
- //err = fs.Upload(uploadCtx, &fileData)
- //if err != nil {
- // c.JSON(200, serializer.Err(serializer.CodeUploadFailed, err.Error(), err))
- // return
- //}
- //
- //c.JSON(200, serializer.Response{
- // Code: 0,
- //})
}
// DeleteUploadSession 删除上传会话
diff --git a/routers/controllers/share.go b/routers/controllers/share.go
index 8795c1e..b9bcb0a 100644
--- a/routers/controllers/share.go
+++ b/routers/controllers/share.go
@@ -173,6 +173,17 @@ func GetShareDocPreview(c *gin.Context) {
}
}
+// SaveShare 转存他人分享
+func SaveShare(c *gin.Context) {
+ var service share.Service
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.SaveToMyFile(c)
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
// ListSharedFolder 列出分享的目录下的对象
func ListSharedFolder(c *gin.Context) {
var service share.Service
@@ -235,3 +246,14 @@ func GetUserShare(c *gin.Context) {
c.JSON(200, ErrorResponse(err))
}
}
+
+// ReportShare 举报分享
+func ReportShare(c *gin.Context) {
+ var service share.ShareReportService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Report(c)
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
diff --git a/routers/controllers/site.go b/routers/controllers/site.go
index c4a3508..141deed 100644
--- a/routers/controllers/site.go
+++ b/routers/controllers/site.go
@@ -14,13 +14,17 @@ import (
func SiteConfig(c *gin.Context) {
siteConfig := model.GetSettingByNames(
"siteName",
+ "siteNotice",
"login_captcha",
+ "qq_login",
"reg_captcha",
"email_active",
"forget_captcha",
- "email_active",
+ // "email_active",
"themes",
"defaultTheme",
+ "score_enabled",
+ "share_score_rate",
"home_view_method",
"share_view_method",
"authn_enabled",
@@ -28,7 +32,10 @@ func SiteConfig(c *gin.Context) {
"captcha_type",
"captcha_TCaptcha_CaptchaAppId",
"register_enabled",
+ "report_enabled",
"show_app_promotion",
+ "app_forum_link",
+ "app_feedback_link",
)
var wopiExts []string
@@ -49,8 +56,8 @@ func SiteConfig(c *gin.Context) {
// Ping 状态检查页面
func Ping(c *gin.Context) {
version := conf.BackendVersion
- if conf.IsPro == "true" {
- version += "-pro"
+ if conf.IsPlus == "true" {
+ version += "-plus"
}
c.JSON(200, serializer.Response{
@@ -139,3 +146,22 @@ func Manifest(c *gin.Context) {
"background_color": options["pwa_background_color"],
})
}
+
+// GetVolSecret 获取 VOL 密钥
+func GetVolSecret(c *gin.Context) {
+ vol := model.GetSettingByNames("vol_content", "vol_signature")
+ if vol["vol_signature"] == "" {
+ c.JSON(200, serializer.Response{
+ Code: serializer.CodeNotFound,
+ })
+
+ return
+ }
+
+ c.JSON(200, serializer.Response{
+ Data: serializer.VolResponse{
+ Signature: vol["vol_signature"],
+ Content: vol["vol_content"],
+ },
+ })
+}
diff --git a/routers/controllers/user.go b/routers/controllers/user.go
index 5d6301e..648fd1b 100644
--- a/routers/controllers/user.go
+++ b/routers/controllers/user.go
@@ -6,6 +6,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/authn"
+ "github.com/cloudreve/Cloudreve/v3/pkg/qq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/thumb"
@@ -213,6 +214,24 @@ func UserActivate(c *gin.Context) {
}
}
+// UserQQLogin 初始化QQ登录
+func UserQQLogin(c *gin.Context) {
+ // 新建绑定
+ res, err := qq.NewLoginRequest()
+ if err != nil {
+ c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法使用QQ登录", err))
+ return
+ }
+
+ // 设定QQ登录会话Secret
+ util.SetSession(c, map[string]interface{}{"qq_login_secret": res.SecretKey})
+
+ c.JSON(200, serializer.Response{
+ Data: res.URL,
+ })
+
+}
+
// UserSignOut 用户退出登录
func UserSignOut(c *gin.Context) {
util.DeleteSession(c, "user_id")
@@ -233,6 +252,28 @@ func UserStorage(c *gin.Context) {
c.JSON(200, res)
}
+// UserAvailablePolicies 用户存储策略设置
+func UserAvailablePolicies(c *gin.Context) {
+ var service user.SettingService
+ if err := c.ShouldBindUri(&service); err == nil {
+ res := service.Policy(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// UserAvailableNodes 用户可选节点
+func UserAvailableNodes(c *gin.Context) {
+ var service user.SettingService
+ if err := c.ShouldBindUri(&service); err == nil {
+ res := service.Nodes(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
// UserTasks 获取任务队列
func UserTasks(c *gin.Context) {
var service user.SettingListService
@@ -339,6 +380,12 @@ func UpdateOption(c *gin.Context) {
switch service.Option {
case "nick":
subService = &user.ChangerNick{}
+ case "vip":
+ subService = &user.VIPUnsubscribe{}
+ case "qq":
+ subService = &user.QQBind{}
+ case "policy":
+ subService = &user.PolicyChange{}
case "homepage":
subService = &user.HomePage{}
case "password":
diff --git a/routers/controllers/vas.go b/routers/controllers/vas.go
new file mode 100755
index 0000000..9177941
--- /dev/null
+++ b/routers/controllers/vas.go
@@ -0,0 +1,214 @@
+package controllers
+
+import (
+ "github.com/cloudreve/Cloudreve/v3/pkg/cache"
+ "github.com/cloudreve/Cloudreve/v3/pkg/payment"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
+ "github.com/cloudreve/Cloudreve/v3/service/vas"
+ "github.com/gin-gonic/gin"
+ "github.com/iGoogle-ink/gopay"
+ "github.com/iGoogle-ink/gopay/wechat/v3"
+ "github.com/qingwg/payjs/notify"
+ "github.com/smartwalle/alipay/v3"
+ "net/http"
+)
+
+// GetQuota 获取容量配额信息
+func GetQuota(c *gin.Context) {
+ var service vas.GeneralVASService
+ if err := c.ShouldBindUri(&service); err == nil {
+ res := service.Quota(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// GetProduct 获取商品信息
+func GetProduct(c *gin.Context) {
+ var service vas.GeneralVASService
+ if err := c.ShouldBindUri(&service); err == nil {
+ res := service.Products(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// NewOrder 新建支付订单
+func NewOrder(c *gin.Context) {
+ var service vas.CreateOrderService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Create(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// OrderStatus 查询订单状态
+func OrderStatus(c *gin.Context) {
+ var service vas.OrderService
+ if err := c.ShouldBindUri(&service); err == nil {
+ res := service.Status(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// GetRedeemInfo 获取兑换码信息
+func GetRedeemInfo(c *gin.Context) {
+ var service vas.RedeemService
+ if err := c.ShouldBindUri(&service); err == nil {
+ res := service.Query(c)
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// DoRedeem 获取兑换码信息
+func DoRedeem(c *gin.Context) {
+ var service vas.RedeemService
+ if err := c.ShouldBindUri(&service); err == nil {
+ res := service.Redeem(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// AlipayCallback 支付宝回调
+func AlipayCallback(c *gin.Context) {
+ pay, err := payment.NewPaymentInstance("alipay")
+ if err != nil {
+ util.Log().Debug("[Alipay callback] Failed to create alipay client, %s", err)
+ c.Status(400)
+ return
+ }
+
+ res, err := pay.(*payment.Alipay).Client.GetTradeNotification(c.Request)
+ if err != nil {
+ util.Log().Debug("[Alipay callback] Failed to validate callback request, %s", err)
+ c.Status(403)
+ return
+ }
+
+ if res != nil && res.TradeStatus == "TRADE_SUCCESS" {
+ // 支付成功
+ if err := payment.OrderPaid(res.OutTradeNo); err != nil {
+ util.Log().Debug("[Alipay callback] Failed to process payment, %s", err)
+ }
+ }
+
+ // 确认收到通知消息
+ alipay.AckNotification(c.Writer)
+}
+
+// WechatCallback 微信扫码支付回调
+func WechatCallback(c *gin.Context) {
+ pay, err := payment.NewPaymentInstance("wechat")
+ if err != nil {
+ util.Log().Debug("[Wechat pay callback] Failed to create alipay client, %s", err)
+ c.JSON(500, &wechat.V3NotifyRsp{Code: gopay.FAIL, Message: "Failed to create alipay client"})
+ return
+ }
+
+ notifyReq, err := wechat.V3ParseNotify(c.Request)
+ if err != nil {
+ util.Log().Debug("[Wechat pay callback] Failed to parse callback content, %s", err)
+ c.JSON(500, &wechat.V3NotifyRsp{Code: gopay.FAIL, Message: "Failed to parse callback content"})
+ return
+ }
+
+ err = notifyReq.VerifySign(pay.(*payment.Wechat).GetPlatformCert())
+ if err != nil {
+ util.Log().Debug("[Wechat pay callback] Failed to verify callback signature, %s", err)
+ c.JSON(403, &wechat.V3NotifyRsp{Code: gopay.FAIL, Message: "Failed to verify callback signature"})
+ return
+ }
+
+ // 解密回调正文
+ result, err := notifyReq.DecryptCipherText(pay.(*payment.Wechat).ApiV3Key)
+ if result != nil && result.TradeState == "SUCCESS" {
+ // 支付成功
+ if err := payment.OrderPaid(result.OutTradeNo); err != nil {
+ util.Log().Debug("[Wechat pay callback] Failed to process payment, %s", err)
+ }
+ }
+
+ // 确认收到通知消息
+ c.JSON(http.StatusOK, &wechat.V3NotifyRsp{Code: gopay.SUCCESS, Message: "Success"})
+}
+
+// PayJSCallback PayJS回调
+func PayJSCallback(c *gin.Context) {
+ pay, err := payment.NewPaymentInstance("payjs")
+ if err != nil {
+ util.Log().Debug("[PayJS callback] Failed to initialize payment client, %s", err)
+ c.Status(400)
+ return
+ }
+
+ payNotify := pay.(*payment.PayJSClient).Client.GetNotify(c.Request, c.Writer)
+
+ //设置接收消息的处理方法
+ payNotify.SetMessageHandler(func(msg notify.Message) {
+ if err := payment.OrderPaid(msg.OutTradeNo); err != nil {
+ util.Log().Debug("[PayJS callback] Failed to process payment, %s", err)
+ }
+ })
+
+ //处理消息接收以及回复
+ err = payNotify.Serve()
+ if err != nil {
+ util.Log().Debug("[PayJS callback] Failed to process payment, %s", err)
+ return
+ }
+
+ //发送回复的消息
+ payNotify.SendResponseMsg()
+
+}
+
+// QQCallback QQ互联回调
+func QQCallback(c *gin.Context) {
+ var service vas.QQCallbackService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Callback(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// CustomCallback PayJS回调
+func CustomCallback(c *gin.Context) {
+ orderNo := c.Param("orderno")
+ sessionID := c.Param("id")
+ sessionRaw, exist := cache.Get(payment.CallbackSessionPrefix + sessionID)
+ if !exist {
+ util.Log().Debug("[Custom callback] Failed to process payment, session not found")
+ c.JSON(200, serializer.Err(serializer.CodeNotFound, "session not found", nil))
+ return
+ }
+
+ expectedID := sessionRaw.(string)
+ if expectedID != orderNo {
+ util.Log().Debug("[Custom callback] Failed to process payment, session mismatch")
+ c.JSON(200, serializer.Err(serializer.CodeInternalSetting, "session mismatch", nil))
+ return
+ }
+
+ cache.Deletes([]string{sessionID}, payment.CallbackSessionPrefix)
+
+ if err := payment.OrderPaid(orderNo); err != nil {
+ c.JSON(200, serializer.Err(serializer.CodeInternalSetting, "failed to fulfill payment", err))
+ util.Log().Debug("[Custom callback] Failed to process payment, %s", err)
+ return
+ }
+
+ c.JSON(200, serializer.Response{})
+}
diff --git a/routers/controllers/webdav.go b/routers/controllers/webdav.go
index 0453ada..e3baf91 100644
--- a/routers/controllers/webdav.go
+++ b/routers/controllers/webdav.go
@@ -2,6 +2,9 @@ package controllers
import (
"context"
+ "net/http"
+ "sync"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
@@ -9,8 +12,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/webdav"
"github.com/cloudreve/Cloudreve/v3/service/setting"
"github.com/gin-gonic/gin"
- "net/http"
- "sync"
)
var handler *webdav.Handler
@@ -92,6 +93,28 @@ func UpdateWebDAVAccounts(c *gin.Context) {
}
}
+// DeleteWebDAVMounts 删除WebDAV挂载
+func DeleteWebDAVMounts(c *gin.Context) {
+ var service setting.WebDAVListService
+ if err := c.ShouldBindUri(&service); err == nil {
+ res := service.Unmount(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
+// UpdateWebDAVAccountsReadonly 更改WebDAV账户只读性
+func UpdateWebDAVAccountsReadonly(c *gin.Context) {
+ var service setting.WebDAVAccountUpdateReadonlyService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Update(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
+
// CreateWebDAVAccounts 创建WebDAV账户
func CreateWebDAVAccounts(c *gin.Context) {
var service setting.WebDAVAccountCreateService
@@ -102,3 +125,14 @@ func CreateWebDAVAccounts(c *gin.Context) {
c.JSON(200, ErrorResponse(err))
}
}
+
+// CreateWebDAVMounts 创建WebDAV目录挂载
+func CreateWebDAVMounts(c *gin.Context) {
+ var service setting.WebDAVMountCreateService
+ if err := c.ShouldBindJSON(&service); err == nil {
+ res := service.Create(c, CurrentUser(c))
+ c.JSON(200, res)
+ } else {
+ c.JSON(200, ErrorResponse(err))
+ }
+}
diff --git a/routers/main_test.go b/routers/main_test.go
deleted file mode 100644
index 83664bd..0000000
--- a/routers/main_test.go
+++ /dev/null
@@ -1,47 +0,0 @@
-package routers
-
-import (
- "database/sql"
- "testing"
-
- "github.com/DATA-DOG/go-sqlmock"
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/gin-gonic/gin"
- "github.com/jinzhu/gorm"
-)
-
-var mock sqlmock.Sqlmock
-var memDB *gorm.DB
-var mockDB *gorm.DB
-
-// TestMain 初始化数据库Mock
-func TestMain(m *testing.M) {
- // 设置gin为测试模式
- gin.SetMode(gin.TestMode)
-
- // 初始化sqlmock
- var db *sql.DB
- var err error
- db, mock, err = sqlmock.New()
- if err != nil {
- panic("An error was not expected when opening a stub database connection")
- }
-
- // 初始话内存数据库
- model.Init()
- memDB = model.DB
-
- mockDB, _ = gorm.Open("mysql", db)
- model.DB = memDB
- defer db.Close()
-
- m.Run()
-}
-
-func switchToMemDB() {
- model.DB = memDB
-}
-
-func switchToMockDB() {
- model.DB = mockDB
-}
diff --git a/routers/router.go b/routers/router.go
index aa8e903..f2a187b 100644
--- a/routers/router.go
+++ b/routers/router.go
@@ -1,6 +1,8 @@
package routers
import (
+ // "github.com/abslant/gzip"
+ "github.com/cloudreve/Cloudreve/v3/bootstrap"
"github.com/cloudreve/Cloudreve/v3/middleware"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
@@ -117,11 +119,14 @@ func InitCORS(router *gin.Engine) {
// InitMasterRouter 初始化主机模式路由
func InitMasterRouter() *gin.Engine {
r := gin.Default()
+ bootstrap.InitCustomRoute(r.Group("/api/v3"))
+ // bootstrap.InitCustomRoute(r.Group("/custom"))
/*
静态资源
*/
r.Use(gzip.Gzip(gzip.DefaultCompression, gzip.WithExcludedPaths([]string{"/api/"})))
+ // r.Use(gzip.GzipHandler())
r.Use(middleware.FrontendFileHandler())
r.GET("manifest.json", controllers.Manifest)
@@ -165,6 +170,8 @@ func InitMasterRouter() *gin.Engine {
site.GET("captcha", controllers.Captcha)
// 站点全局配置
site.GET("config", middleware.CSRFInit(), controllers.SiteConfig)
+ // VOL 密钥
+ site.GET("vol", controllers.GetVolSecret)
}
// 用户相关路由
@@ -190,6 +197,8 @@ func InitMasterRouter() *gin.Engine {
middleware.HashID(hashid.UserID),
controllers.UserActivate,
)
+ // 初始化QQ登录
+ user.POST("qq", controllers.UserQQLogin)
// WebAuthn登陆初始化
user.GET("authn/:username",
middleware.IsFunctionEnabled("authn_enabled"),
@@ -267,6 +276,32 @@ func InitMasterRouter() *gin.Engine {
// 回调接口
callback := v3.Group("callback")
{
+ // QQ互联回调
+ callback.POST(
+ "qq",
+ controllers.QQCallback,
+ )
+ // PAYJS回调
+ callback.POST(
+ "payjs",
+ controllers.PayJSCallback,
+ )
+ // 支付宝回调
+ callback.POST(
+ "alipay",
+ controllers.AlipayCallback,
+ )
+ // 微信扫码支付回调
+ callback.POST(
+ "wechat",
+ controllers.WechatCallback,
+ )
+ // Custom payment callback
+ callback.GET(
+ "custom/:orderno/:id",
+ middleware.SignRequired(auth.General),
+ controllers.CustomCallback,
+ )
// 远程策略上传回调
callback.POST(
"remote/:sessionID/:key",
@@ -392,8 +427,14 @@ func InitMasterRouter() *gin.Engine {
middleware.ShareCanPreview(),
controllers.ShareThumb,
)
+ // 举报分享
+ share.POST("report/:id",
+ middleware.IsFunctionEnabled("report_enabled"),
+ middleware.CheckShareUnlocked(),
+ controllers.ReportShare,
+ )
// 搜索公共分享
- v3.Group("share").GET("search", controllers.SearchShare)
+ v3.Group("share").GET("search", middleware.AuthRequired(), controllers.SearchShare)
}
wopi := v3.Group(
@@ -422,7 +463,7 @@ func InitMasterRouter() *gin.Engine {
// 获取站点概况
admin.GET("summary", controllers.AdminSummary)
// 获取社区新闻
- admin.GET("news", controllers.AdminNews)
+ // admin.GET("news", controllers.AdminNews)
// 更改设置
admin.PATCH("setting", controllers.AdminChangeSetting)
// 获取设置
@@ -447,6 +488,22 @@ func InitMasterRouter() *gin.Engine {
aria2.POST("test", controllers.AdminTestAria2)
}
+ vol := admin.Group("vol")
+ {
+ vol.GET("sync", controllers.AdminSyncVol)
+ }
+
+ // 兑换码相关
+ redeem := admin.Group("redeem")
+ {
+ // 列出激活码
+ redeem.POST("list", controllers.AdminListRedeems)
+ // 生成激活码
+ redeem.POST("", controllers.AdminGenerateRedeems)
+ // 删除激活码
+ redeem.DELETE(":id", controllers.AdminDeleteRedeem)
+ }
+
// 存储策略管理
policy := admin.Group("policy")
{
@@ -525,6 +582,14 @@ func InitMasterRouter() *gin.Engine {
share.POST("delete", controllers.AdminDeleteShare)
}
+ order := admin.Group("order")
+ {
+ // 列出订单
+ order.POST("list", controllers.AdminListOrder)
+ // 删除
+ order.POST("delete", controllers.AdminDeleteOrder)
+ }
+
download := admin.Group("download")
{
// 列出任务
@@ -543,6 +608,14 @@ func InitMasterRouter() *gin.Engine {
task.POST("import", controllers.AdminCreateImportTask)
}
+ report := admin.Group("report")
+ {
+ // 列出未处理举报
+ report.POST("list", controllers.AdminListReport)
+ // 删除
+ report.POST("delete", controllers.AdminDeleteReport)
+ }
+
node := admin.Group("node")
{
// 列出从机节点
@@ -585,6 +658,10 @@ func InitMasterRouter() *gin.Engine {
// 用户设置
setting := user.Group("setting")
{
+ // 获取用户可选存储策略
+ setting.GET("policies", controllers.UserAvailablePolicies)
+ // 获取用户可选节点
+ setting.GET("nodes", controllers.UserAvailableNodes)
// 任务队列
setting.GET("tasks", controllers.UserTasks)
// 获取当前用户设定
@@ -601,7 +678,7 @@ func InitMasterRouter() *gin.Engine {
}
// 文件
- file := auth.Group("file", middleware.HashID(hashid.FileID))
+ file := auth.Group("file", middleware.PhoneRequired(), middleware.HashID(hashid.FileID))
{
// 上传
upload := file.Group("upload")
@@ -637,12 +714,14 @@ func InitMasterRouter() *gin.Engine {
file.POST("compress", controllers.Compress)
// 创建文件解压缩任务
file.POST("decompress", controllers.Decompress)
- // 创建文件解压缩任务
+ // 创建文件转移任务
+ file.POST("relocate", controllers.Relocate)
+ // 搜索文件
file.GET("search/:type/:keywords", controllers.SearchFile)
}
// 离线下载任务
- aria2 := auth.Group("aria2")
+ aria2 := auth.Group("aria2", middleware.PhoneRequired())
{
// 创建URL下载任务
aria2.POST("url", controllers.AddAria2URL)
@@ -659,7 +738,7 @@ func InitMasterRouter() *gin.Engine {
}
// 目录
- directory := auth.Group("directory")
+ directory := auth.Group("directory", middleware.PhoneRequired())
{
// 创建目录
directory.PUT("", controllers.CreateDirectory)
@@ -668,7 +747,7 @@ func InitMasterRouter() *gin.Engine {
}
// 对象,文件和目录的抽象
- object := auth.Group("object")
+ object := auth.Group("object", middleware.PhoneRequired())
{
// 删除对象
object.DELETE("", controllers.Delete)
@@ -683,12 +762,19 @@ func InitMasterRouter() *gin.Engine {
}
// 分享
- share := auth.Group("share")
+ share := auth.Group("share", middleware.PhoneRequired())
{
// 创建新分享
share.POST("", controllers.CreateShare)
// 列出我的分享
share.GET("", controllers.ListShare)
+ // 转存他人分享
+ share.POST("save/:id",
+ middleware.ShareAvailable(),
+ middleware.CheckShareUnlocked(),
+ middleware.BeforeShareDownload(),
+ controllers.SaveShare,
+ )
// 更新分享属性
share.PATCH(":id",
middleware.ShareAvailable(),
@@ -712,8 +798,25 @@ func InitMasterRouter() *gin.Engine {
tag.DELETE(":id", middleware.HashID(hashid.TagID), controllers.DeleteTag)
}
+ // 增值服务相关
+ vas := auth.Group("vas", middleware.PhoneRequired())
+ {
+ // 获取容量包及配额信息
+ vas.GET("pack", controllers.GetQuota)
+ // 获取商品信息,同时返回支付信息
+ vas.GET("product", controllers.GetProduct)
+ // 新建支付订单
+ vas.POST("order", controllers.NewOrder)
+ // 查询订单状态
+ vas.GET("order/:id", controllers.OrderStatus)
+ // 获取兑换码信息
+ vas.GET("redeem/:code", controllers.GetRedeemInfo)
+ // 执行兑换
+ vas.POST("redeem/:code", controllers.DoRedeem)
+ }
+
// WebDAV管理相关
- webdav := auth.Group("webdav")
+ webdav := auth.Group("webdav", middleware.PhoneRequired())
{
// 获取账号信息
webdav.GET("accounts", controllers.GetWebDAVAccounts)
@@ -721,6 +824,13 @@ func InitMasterRouter() *gin.Engine {
webdav.POST("accounts", controllers.CreateWebDAVAccounts)
// 删除账号
webdav.DELETE("accounts/:id", controllers.DeleteWebDAVAccounts)
+ // 删除目录挂载
+ webdav.DELETE("mount/:id",
+ middleware.HashID(hashid.FolderID),
+ controllers.DeleteWebDAVMounts,
+ )
+ // 创建目录挂载
+ webdav.POST("mount", controllers.CreateWebDAVMounts)
// 更新账号可读性和是否使用代理服务
webdav.PATCH("accounts", controllers.UpdateWebDAVAccounts)
}
diff --git a/routers/router_test.go b/routers/router_test.go
deleted file mode 100644
index 2476de6..0000000
--- a/routers/router_test.go
+++ /dev/null
@@ -1,251 +0,0 @@
-package routers
-
-import (
- "github.com/cloudreve/Cloudreve/v3/pkg/conf"
- "net/http"
- "net/http/httptest"
- "testing"
-
- model "github.com/cloudreve/Cloudreve/v3/models"
- "github.com/jinzhu/gorm"
- "github.com/stretchr/testify/assert"
-)
-
-func TestPing(t *testing.T) {
- asserts := assert.New(t)
- router := InitMasterRouter()
-
- w := httptest.NewRecorder()
- req, _ := http.NewRequest("GET", "/api/v3/site/ping", nil)
- router.ServeHTTP(w, req)
-
- assert.Equal(t, 200, w.Code)
- asserts.Contains(w.Body.String(), conf.BackendVersion)
-}
-
-func TestCaptcha(t *testing.T) {
- asserts := assert.New(t)
- router := InitMasterRouter()
- w := httptest.NewRecorder()
-
- req, _ := http.NewRequest(
- "GET",
- "/api/v3/site/captcha",
- nil,
- )
-
- router.ServeHTTP(w, req)
-
- asserts.Equal(200, w.Code)
- asserts.Contains(w.Body.String(), "base64")
-}
-
-//func TestUserSession(t *testing.T) {
-// mutex.Lock()
-// defer mutex.Unlock()
-// switchToMockDB()
-// asserts := assert.New(t)
-// router := InitMasterRouter()
-// w := httptest.NewRecorder()
-//
-// // 创建测试用验证码
-// var configD = base64Captcha.ConfigDigit{
-// Height: 80,
-// Width: 240,
-// MaxSkew: 0.7,
-// DotCount: 80,
-// CaptchaLen: 1,
-// }
-// idKeyD, _ := base64Captcha.GenerateCaptcha("", configD)
-// middleware.ContextMock = map[string]interface{}{
-// "captchaID": idKeyD,
-// }
-//
-// testCases := []struct {
-// settingRows *sqlmock.Rows
-// userRows *sqlmock.Rows
-// policyRows *sqlmock.Rows
-// reqBody string
-// expected interface{}
-// }{
-// // 登录信息正确,不需要验证码
-// {
-// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}).
-// AddRow("login_captcha", "0", "login"),
-// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}).
-// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"),
-// expected: serializer.BuildUserResponse(model.User{
-// Email: "admin@cloudreve.org",
-// Nick: "admin",
-// Policy: model.Policy{
-// Type: "local",
-// OptionsSerialized: model.PolicyOption{FileType: []string{}},
-// },
-// }),
-// },
-// // 登录信息正确,需要验证码,验证码错误
-// {
-// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}).
-// AddRow("login_captcha", "1", "login"),
-// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}).
-// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"),
-// expected: serializer.ParamErr("验证码错误", nil),
-// },
-// // 邮箱正确密码错误
-// {
-// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}).
-// AddRow("login_captcha", "0", "login"),
-// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}).
-// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"),
-// expected: serializer.Err(401, "用户邮箱或密码错误", nil),
-// },
-// //邮箱格式不正确
-// {
-// reqBody: `{"userName":"admin@cloudreve","captchaCode":"captchaCode","Password":"admin123"}`,
-// expected: serializer.Err(40001, "邮箱格式不正确", errors.New("Key: 'UserLoginService.UserName' Error:Field validation for 'UserName' failed on the 'email' tag")),
-// },
-// // 用户被Ban
-// {
-// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}).
-// AddRow("login_captcha", "0", "login"),
-// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options", "status"}).
-// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}", model.Baned),
-// expected: serializer.Err(403, "该账号已被封禁", nil),
-// },
-// // 用户未激活
-// {
-// settingRows: sqlmock.NewRows([]string{"name", "value", "type"}).
-// AddRow("login_captcha", "0", "login"),
-// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options", "status"}).
-// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}", model.NotActivicated),
-// expected: serializer.Err(403, "该账号未激活", nil),
-// },
-// }
-//
-// for k, testCase := range testCases {
-// if testCase.settingRows != nil {
-// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.settingRows)
-// }
-// if testCase.userRows != nil {
-// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.userRows)
-// }
-// if testCase.policyRows != nil {
-// mock.ExpectQuery("^SELECT \\* FROM `(.+)` WHERE `(.+)`\\.`deleted_at` IS NULL AND \\(\\(`policies`.`id` = 1\\)\\)(.+)$").WillReturnRows(testCase.policyRows)
-// }
-// req, _ := http.NewRequest(
-// "POST",
-// "/api/v3/user/session",
-// bytes.NewReader([]byte(testCase.reqBody)),
-// )
-// router.ServeHTTP(w, req)
-//
-// asserts.Equal(200, w.Code)
-// expectedJSON, _ := json.Marshal(testCase.expected)
-// asserts.JSONEq(string(expectedJSON), w.Body.String(), "测试用例:%d", k)
-//
-// w.Body.Reset()
-// asserts.NoError(mock.ExpectationsWereMet())
-// model.ClearCache()
-// }
-//
-//}
-//
-//func TestSessionAuthCheck(t *testing.T) {
-// mutex.Lock()
-// defer mutex.Unlock()
-// switchToMockDB()
-// asserts := assert.New(t)
-// router := InitMasterRouter()
-// w := httptest.NewRecorder()
-//
-// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(sqlmock.NewRows([]string{"email", "nick", "password", "options"}).
-// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"))
-// expectedUser, _ := model.GetUserByID(1)
-//
-// testCases := []struct {
-// userRows *sqlmock.Rows
-// sessionMock map[string]interface{}
-// contextMock map[string]interface{}
-// expected interface{}
-// }{
-// // 未登录
-// {
-// expected: serializer.CheckLogin(),
-// },
-// // 登录正常
-// {
-// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}).
-// AddRow("admin@cloudreve.org", "admin", "CKLmDKa1C9SD64vU:76adadd4fd4bad86959155f6f7bc8993c94e7adf", "{}"),
-// sessionMock: map[string]interface{}{"user_id": 1},
-// expected: serializer.BuildUserResponse(expectedUser),
-// },
-// // UID不存在
-// {
-// userRows: sqlmock.NewRows([]string{"email", "nick", "password", "options"}),
-// sessionMock: map[string]interface{}{"user_id": -1},
-// expected: serializer.CheckLogin(),
-// },
-// }
-//
-// for _, testCase := range testCases {
-// req, _ := http.NewRequest(
-// "GET",
-// "/api/v3/user/me",
-// nil,
-// )
-// if testCase.userRows != nil {
-// mock.ExpectQuery("^SELECT (.+)").WillReturnRows(testCase.userRows)
-// }
-// middleware.ContextMock = testCase.contextMock
-// middleware.SessionMock = testCase.sessionMock
-// router.ServeHTTP(w, req)
-// expectedJSON, _ := json.Marshal(testCase.expected)
-//
-// asserts.Equal(200, w.Code)
-// asserts.JSONEq(string(expectedJSON), w.Body.String())
-// asserts.NoError(mock.ExpectationsWereMet())
-//
-// w.Body.Reset()
-// }
-//
-//}
-
-func TestSiteConfigRoute(t *testing.T) {
- switchToMemDB()
- asserts := assert.New(t)
- router := InitMasterRouter()
- w := httptest.NewRecorder()
-
- req, _ := http.NewRequest(
- "GET",
- "/api/v3/site/config",
- nil,
- )
- router.ServeHTTP(w, req)
- asserts.Equal(200, w.Code)
- asserts.Contains(w.Body.String(), "Cloudreve")
-
- w.Body.Reset()
-
- // 消除无效值
- model.DB.Model(&model.Setting{
- Model: gorm.Model{
- ID: 2,
- },
- }).UpdateColumn("name", "siteName_b")
-
- req, _ = http.NewRequest(
- "GET",
- "/api/v3/site/config",
- nil,
- )
- router.ServeHTTP(w, req)
- asserts.Equal(200, w.Code)
- asserts.Contains(w.Body.String(), "\"title\"")
-
- model.DB.Model(&model.Setting{
- Model: gorm.Model{
- ID: 2,
- },
- }).UpdateColumn("name", "siteName")
-}
diff --git a/service/admin/aria2.go b/service/admin/aria2.go
index 6a2b77d..0bdd641 100644
--- a/service/admin/aria2.go
+++ b/service/admin/aria2.go
@@ -3,10 +3,10 @@ package admin
import (
"bytes"
"encoding/json"
- model "github.com/cloudreve/Cloudreve/v3/models"
"net/url"
"time"
+ model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
diff --git a/service/admin/order.go b/service/admin/order.go
new file mode 100755
index 0000000..4556c0f
--- /dev/null
+++ b/service/admin/order.go
@@ -0,0 +1,75 @@
+package admin
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/gin-gonic/gin"
+ "strings"
+)
+
+// OrderBatchService 订单批量操作服务
+type OrderBatchService struct {
+ ID []uint `json:"id" binding:"min=1"`
+}
+
+// Delete 删除订单
+func (service *OrderBatchService) Delete(c *gin.Context) serializer.Response {
+ if err := model.DB.Where("id in (?)", service.ID).Delete(&model.Order{}).Error; err != nil {
+ return serializer.DBErr("Failed to delete order records.", err)
+ }
+ return serializer.Response{}
+}
+
+// Orders 列出订单
+func (service *AdminListService) Orders() serializer.Response {
+ var res []model.Order
+ total := 0
+
+ tx := model.DB.Model(&model.Order{})
+ if service.OrderBy != "" {
+ tx = tx.Order(service.OrderBy)
+ }
+
+ for k, v := range service.Conditions {
+ tx = tx.Where(k+" = ?", v)
+ }
+
+ if len(service.Searches) > 0 {
+ search := ""
+ for k, v := range service.Searches {
+ search += k + " like '%" + v + "%' OR "
+ }
+ search = strings.TrimSuffix(search, " OR ")
+ tx = tx.Where(search)
+ }
+
+ // 计算总数用于分页
+ tx.Count(&total)
+
+ // 查询记录
+ tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res)
+
+ // 查询对应用户,同时计算HashID
+ users := make(map[uint]model.User)
+ for _, file := range res {
+ users[file.UserID] = model.User{}
+ }
+
+ userIDs := make([]uint, 0, len(users))
+ for k := range users {
+ userIDs = append(userIDs, k)
+ }
+
+ var userList []model.User
+ model.DB.Where("id in (?)", userIDs).Find(&userList)
+
+ for _, v := range userList {
+ users[v.ID] = v
+ }
+
+ return serializer.Response{Data: map[string]interface{}{
+ "total": total,
+ "items": res,
+ "users": users,
+ }}
+}
diff --git a/service/admin/policy.go b/service/admin/policy.go
index 478203a..d0b01e0 100644
--- a/service/admin/policy.go
+++ b/service/admin/policy.go
@@ -5,7 +5,6 @@ import (
"context"
"encoding/json"
"fmt"
- "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive"
"net/http"
"net/url"
"os"
@@ -14,6 +13,8 @@ import (
"strings"
"time"
+ "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
@@ -80,7 +81,10 @@ func (service *PolicyService) Delete() serializer.Response {
// 检查用户组使用
var groups []model.Group
model.DB.Model(&model.Group{}).Where(
- "policies like ?",
+ "policies like ? OR policies like ? OR policies like ? OR policies like ?",
+ fmt.Sprintf("[%d,%%", service.ID),
+ fmt.Sprintf("%%,%d]", service.ID),
+ fmt.Sprintf("%%,%d,%%", service.ID),
fmt.Sprintf("%%[%d]%%", service.ID),
).Find(&groups)
@@ -185,7 +189,6 @@ func (service *PolicyService) AddCORS() serializer.Response {
},
}),
}
-
if err := handler.CORS(); err != nil {
return serializer.Err(serializer.CodeAddCORS, "", err)
}
@@ -227,8 +230,8 @@ func (service *SlavePingService) Test() serializer.Response {
}
version := conf.BackendVersion
- if conf.IsPro == "true" {
- version += "-pro"
+ if conf.IsPlus == "true" {
+ version += "-plus"
}
if res.Data.(string) != version {
return serializer.Err(serializer.CodeVersionMismatch, "Master: "+res.Data.(string)+", Slave: "+version, nil)
diff --git a/service/admin/report.go b/service/admin/report.go
new file mode 100755
index 0000000..b30d7af
--- /dev/null
+++ b/service/admin/report.go
@@ -0,0 +1,72 @@
+package admin
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/hashid"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+)
+
+// ReportBatchService 任务批量操作服务
+type ReportBatchService struct {
+ ID []uint `json:"id" binding:"min=1"`
+}
+
+// Reports 批量删除举报
+func (service *ReportBatchService) Delete() serializer.Response {
+ if err := model.DB.Where("id in (?)", service.ID).Delete(&model.Report{}).Error; err != nil {
+ return serializer.DBErr("Failed to change report status", err)
+ }
+ return serializer.Response{}
+}
+
+// Reports 列出待处理举报
+func (service *AdminListService) Reports() serializer.Response {
+ var res []model.Report
+ total := 0
+
+ tx := model.DB.Model(&model.Report{})
+ if service.OrderBy != "" {
+ tx = tx.Order(service.OrderBy)
+ }
+
+ for k, v := range service.Conditions {
+ tx = tx.Where(k+" = ?", v)
+ }
+
+ // 计算总数用于分页
+ tx.Count(&total)
+
+ // 查询记录
+ tx.Set("gorm:auto_preload", true).Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res)
+
+ // 计算分享的 HashID
+ hashIDs := make(map[uint]string, len(res))
+ for _, report := range res {
+ hashIDs[report.Share.ID] = hashid.HashID(report.Share.ID, hashid.ShareID)
+ }
+
+ // 查询对应用户
+ users := make(map[uint]model.User)
+ for _, report := range res {
+ users[report.Share.UserID] = model.User{}
+ }
+
+ userIDs := make([]uint, 0, len(users))
+ for k := range users {
+ userIDs = append(userIDs, k)
+ }
+
+ var userList []model.User
+ model.DB.Where("id in (?)", userIDs).Find(&userList)
+
+ for _, v := range userList {
+ users[v.ID] = v
+ }
+
+ return serializer.Response{Data: map[string]interface{}{
+ "total": total,
+ "items": res,
+ "users": users,
+ "ids": hashIDs,
+ }}
+}
diff --git a/service/admin/site.go b/service/admin/site.go
index 69aa2d8..f4e83ad 100644
--- a/service/admin/site.go
+++ b/service/admin/site.go
@@ -10,6 +10,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/thumb"
+ "github.com/cloudreve/Cloudreve/v3/pkg/vol"
"github.com/gin-gonic/gin"
)
@@ -89,7 +90,7 @@ func (service *NoParamService) Summary() serializer.Response {
"backend": conf.BackendVersion,
"db": conf.RequiredDBVersion,
"commit": conf.LastCommit,
- "is_pro": conf.IsPro,
+ "is_plus": conf.IsPlus,
}
if res, ok := cache.Get("admin_summary"); ok {
@@ -163,3 +164,36 @@ func (s *ThumbGeneratorTestService) Test(c *gin.Context) serializer.Response {
Data: version,
}
}
+
+// VOL 授权管理服务
+type VolService struct {
+}
+
+// Sync 同步 VOL 授权
+func (s *VolService) Sync() serializer.Response {
+ volClient := vol.New(vol.ClientSecret)
+ content, signature, err := volClient.Sync()
+ if err != nil {
+ return serializer.Err(serializer.CodeInternalSetting, err.Error(), err)
+ }
+
+ subService := &BatchSettingChangeService{
+ Options: []SettingChangeService{
+ {
+ Key: "vol_content",
+ Value: content,
+ },
+ {
+ Key: "vol_signature",
+ Value: signature,
+ },
+ },
+ }
+
+ res := subService.Change()
+ if res.Code != 0 {
+ return res
+ }
+
+ return serializer.Response{Data: content}
+}
diff --git a/service/admin/user.go b/service/admin/user.go
index eb76ac9..984a799 100644
--- a/service/admin/user.go
+++ b/service/admin/user.go
@@ -72,6 +72,12 @@ func (service *UserBatchService) Delete() serializer.Response {
model.DB.Where("user_id = ?", uid).Delete(&model.Download{})
model.DB.Where("user_id = ?", uid).Delete(&model.Task{})
+ // 删除订单记录
+ model.DB.Where("user_id = ?", uid).Delete(&model.Order{})
+
+ // 删除容量包
+ model.DB.Where("user_id = ?", uid).Delete(&model.StoragePack{})
+
// 删除标签
model.DB.Where("user_id = ?", uid).Delete(&model.Tag{})
@@ -109,6 +115,7 @@ func (service *AddUserService) Add() serializer.Response {
user.Email = service.User.Email
user.GroupID = service.User.GroupID
user.Status = service.User.Status
+ user.Score = service.User.Score
user.TwoFactor = service.User.TwoFactor
// 检查愚蠢操作
diff --git a/service/admin/vas.go b/service/admin/vas.go
new file mode 100755
index 0000000..20e3a12
--- /dev/null
+++ b/service/admin/vas.go
@@ -0,0 +1,98 @@
+package admin
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/gofrs/uuid"
+)
+
+// GenerateRedeemsService 兑换码生成服务
+type GenerateRedeemsService struct {
+ Num int `json:"num" binding:"required,min=1,max=100"`
+ ID int64 `json:"id"`
+ Time int `json:"time" binding:"required,min=1"`
+ Type int `json:"type" binding:"min=0,max=2"`
+}
+
+// SingleIDService 单ID服务
+type SingleIDService struct {
+ ID uint `uri:"id" binding:"required"`
+}
+
+// DeleteRedeem 删除兑换码
+func (service *SingleIDService) DeleteRedeem() serializer.Response {
+ if err := model.DB.Where("id = ?", service.ID).Delete(&model.Redeem{}).Error; err != nil {
+ return serializer.DBErr("Failed to delete gift code record.", err)
+ }
+
+ return serializer.Response{}
+}
+
+// Generate 生成兑换码
+func (service *GenerateRedeemsService) Generate() serializer.Response {
+ res := make([]string, service.Num)
+ redeem := model.Redeem{}
+
+ // 开始事务
+ tx := model.DB.Begin()
+ if err := tx.Error; err != nil {
+ return serializer.DBErr("Cannot start transaction", err)
+ }
+
+ // 创建每个兑换码
+ for i := 0; i < service.Num; i++ {
+ redeem.Model.ID = 0
+ redeem.Num = service.Time
+ redeem.Type = service.Type
+ redeem.ProductID = service.ID
+ redeem.Used = false
+
+ // 生成唯一兑换码
+ u2, err := uuid.NewV4()
+ if err != nil {
+ tx.Rollback()
+ return serializer.Err(serializer.CodeInternalSetting, "Failed to generate UUID", err)
+ }
+
+ redeem.Code = u2.String()
+ if err := tx.Create(&redeem).Error; err != nil {
+ tx.Rollback()
+ return serializer.DBErr("Failed to insert gift code record", err)
+ }
+
+ res[i] = redeem.Code
+ }
+
+ if err := tx.Commit().Error; err != nil {
+ return serializer.DBErr("Failed to insert gift code record", err)
+ }
+
+ return serializer.Response{Data: res}
+
+}
+
+// Redeems 列出激活码
+func (service *AdminListService) Redeems() serializer.Response {
+ var res []model.Redeem
+ total := 0
+
+ tx := model.DB.Model(&model.Redeem{})
+ if service.OrderBy != "" {
+ tx = tx.Order(service.OrderBy)
+ }
+
+ for k, v := range service.Conditions {
+ tx = tx.Where("? = ?", k, v)
+ }
+
+ // 计算总数用于分页
+ tx.Count(&total)
+
+ // 查询记录
+ tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res)
+
+ return serializer.Response{Data: map[string]interface{}{
+ "total": total,
+ "items": res,
+ }}
+}
diff --git a/service/aria2/add.go b/service/aria2/add.go
index 816c57b..0a47dd0 100644
--- a/service/aria2/add.go
+++ b/service/aria2/add.go
@@ -60,8 +60,9 @@ func (service *BatchAddURLService) Add(c *gin.Context, taskType int) serializer.
// AddURLService 添加URL离线下载服务
type AddURLService struct {
- URL string `json:"url" binding:"required"`
- Dst string `json:"dst" binding:"required,min=1"`
+ URL string `json:"url" binding:"required"`
+ Dst string `json:"dst" binding:"required,min=1"`
+ PreferredNode uint `json:"preferred_node"`
}
// Add 主机创建新的链接离线下载任务
@@ -92,6 +93,10 @@ func (service *AddURLService) Add(c *gin.Context, fs *filesystem.FileSystem, tas
return serializer.Err(serializer.CodeBatchAria2Size, "", nil)
}
+ if service.PreferredNode > 0 && !fs.User.Group.OptionsSerialized.SelectNode {
+ return serializer.Err(serializer.CodeGroupNotAllowed, "not allowed to select nodes", nil)
+ }
+
// 创建任务
task := &model.Download{
Status: common.Ready,
@@ -105,7 +110,8 @@ func (service *AddURLService) Add(c *gin.Context, fs *filesystem.FileSystem, tas
lb := aria2.GetLoadBalancer()
// 获取 Aria2 实例
- err, node := cluster.Default.BalanceNodeByFeature("aria2", lb)
+ err, node := cluster.Default.BalanceNodeByFeature("aria2", lb, fs.User.Group.OptionsSerialized.AvailableNodes,
+ service.PreferredNode)
if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "Failed to get Aria2 instance", err)
}
diff --git a/service/callback/upload.go b/service/callback/upload.go
index 0dd7924..c91979b 100644
--- a/service/callback/upload.go
+++ b/service/callback/upload.go
@@ -239,7 +239,7 @@ func (service *S3Callback) PreProcess(c *gin.Context) serializer.Response {
return ProcessCallback(service, c)
}
-// PreProcess 对从机客户端回调进行预处理验证
+// PreProcess 对OneDrive客户端回调进行预处理验证
func (service *UploadCallbackService) PreProcess(c *gin.Context) serializer.Response {
// 创建文件系统
fs, err := filesystem.NewFileSystemFromCallback(c)
diff --git a/service/explorer/directory.go b/service/explorer/directory.go
index cd03999..0116fdb 100644
--- a/service/explorer/directory.go
+++ b/service/explorer/directory.go
@@ -37,6 +37,11 @@ func (service *DirectoryService) ListDirectory(c *gin.Context) serializer.Respon
parentID = fs.DirTarget[0].ID
}
+ // 获取目录的存储策略
+ if err := fs.SetPolicyFromPath(service.Path); err != nil {
+ return serializer.Err(serializer.CodePolicyNotExist, "", err)
+ }
+
return serializer.Response{
Code: 0,
Data: serializer.BuildObjectList(parentID, objects, fs.Policy),
diff --git a/service/explorer/file.go b/service/explorer/file.go
index 1c9d870..44aa202 100644
--- a/service/explorer/file.go
+++ b/service/explorer/file.go
@@ -5,7 +5,6 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
- "github.com/cloudreve/Cloudreve/v3/pkg/util"
"io/ioutil"
"net/http"
"net/url"
@@ -18,6 +17,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/gin-gonic/gin"
)
@@ -56,6 +56,9 @@ func (service *SingleFileService) Create(c *gin.Context) serializer.Response {
}
defer fs.Recycle()
+ baseDir := path.Dir(service.Path)
+ fs.SetPolicyFromPath(baseDir)
+
// 上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -68,7 +71,7 @@ func (service *SingleFileService) Create(c *gin.Context) serializer.Response {
err = fs.Upload(ctx, &fsctx.FileStream{
File: ioutil.NopCloser(strings.NewReader("")),
Size: 0,
- VirtualPath: path.Dir(service.Path),
+ VirtualPath: baseDir,
Name: path.Base(service.Path),
})
if err != nil {
@@ -197,7 +200,7 @@ func (service *FileIDService) CreateDocPreviewSession(ctx context.Context, c *gi
// 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil {
- return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
+ return serializer.Err(serializer.CodeCreateFSError, "", err)
}
defer fs.Recycle()
diff --git a/service/explorer/objects.go b/service/explorer/objects.go
index 1c3c45a..fb4db07 100644
--- a/service/explorer/objects.go
+++ b/service/explorer/objects.go
@@ -55,6 +55,12 @@ type ItemCompressService struct {
Name string `json:"name" binding:"required,min=1,max=255"`
}
+// ItemRelocateService 文件转移任务服务
+type ItemRelocateService struct {
+ Src ItemIDService `json:"src"`
+ DstPolicyID string `json:"dst_policy_id" binding:"required"`
+}
+
// ItemDecompressService 文件解压缩任务服务
type ItemDecompressService struct {
Src string `json:"src"`
@@ -234,6 +240,57 @@ func (service *ItemCompressService) CreateCompressTask(c *gin.Context) serialize
}
+// CreateRelocateTask 创建文件转移任务
+func (service *ItemRelocateService) CreateRelocateTask(c *gin.Context) serializer.Response {
+ userCtx, _ := c.Get("user")
+ user := userCtx.(*model.User)
+
+ // 取得存储策略的ID
+ rawID, err := hashid.DecodeHashID(service.DstPolicyID, hashid.PolicyID)
+ if err != nil {
+ return serializer.Err(serializer.CodePolicyNotExist, "", err)
+ }
+
+ // 检查用户组权限
+ if !user.Group.OptionsSerialized.Relocate {
+ return serializer.Err(serializer.CodeGroupNotAllowed, "", nil)
+ }
+
+ // 用户是否可以使用目的存储策略
+ if !util.ContainsUint(user.Group.PolicyList, rawID) {
+ return serializer.ParamErr("Storage policy is not available", nil)
+ }
+
+ // 查找存储策略
+ if _, err := model.GetPolicyByID(rawID); err != nil {
+ return serializer.ParamErr("Storage policy is not available", nil)
+ }
+
+ // 查找是否有正在进行中的转存任务
+ var tasks []model.Task
+ model.DB.Where("status in (?) and user_id = ? and type = ?",
+ []int{task.Queued, task.Processing}, user.ID,
+ task.RelocateTaskType).Find(&tasks)
+ if len(tasks) > 0 {
+ return serializer.Response{
+ Code: serializer.CodeConflict,
+ Msg: "There's ongoing relocate task, please wait for the previous task to finish",
+ }
+ }
+
+ IDRaw := service.Src.Raw()
+
+ // 创建任务
+ job, err := task.NewRelocateTask(user, rawID, IDRaw.Dirs,
+ IDRaw.Items)
+ if err != nil {
+ return serializer.Err(serializer.CodeCreateTaskError, "", err)
+ }
+ task.TaskPoll.Submit(job)
+
+ return serializer.Response{}
+}
+
// Archive 创建归档
func (service *ItemIDService) Archive(ctx context.Context, c *gin.Context) serializer.Response {
// 创建文件系统
@@ -270,7 +327,7 @@ func (service *ItemIDService) Delete(ctx context.Context, c *gin.Context) serial
// 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil {
- return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
+ return serializer.Err(serializer.CodeCreateFSError, "", err)
}
defer fs.Recycle()
@@ -351,7 +408,7 @@ func (service *ItemRenameService) Rename(ctx context.Context, c *gin.Context) se
// 创建文件系统
fs, err := filesystem.NewFileSystemFromContext(c)
if err != nil {
- return serializer.Err(serializer.CodeCreateFSError, "", err)
+ return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
}
defer fs.Recycle()
@@ -415,6 +472,10 @@ func (service *ItemPropertyService) GetProperty(ctx context.Context, c *gin.Cont
return serializer.DBErr("Failed to query folder records", err)
}
+ policy := user.GetPolicyID(&folder[0])
+ if folder[0].PolicyID > 0 {
+ props.Policy = policy.Name
+ }
props.CreatedAt = folder[0].CreatedAt
props.UpdatedAt = folder[0].UpdatedAt
@@ -423,6 +484,7 @@ func (service *ItemPropertyService) GetProperty(ctx context.Context, c *gin.Cont
res := cacheRes.(serializer.ObjectProps)
res.CreatedAt = props.CreatedAt
res.UpdatedAt = props.UpdatedAt
+ res.Policy = props.Policy
return serializer.Response{Data: res}
}
diff --git a/service/explorer/upload.go b/service/explorer/upload.go
index 0c26c26..470691a 100644
--- a/service/explorer/upload.go
+++ b/service/explorer/upload.go
@@ -3,6 +3,11 @@ package explorer
import (
"context"
"fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+ "time"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
@@ -13,10 +18,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
- "io/ioutil"
- "strconv"
- "strings"
- "time"
)
// CreateUploadSessionService 获取上传凭证服务
@@ -43,8 +44,13 @@ func (service *CreateUploadSessionService) Create(ctx context.Context, c *gin.Co
return serializer.Err(serializer.CodePolicyNotExist, "", err)
}
+ // 分配并检查存储策略
+ if err := fs.SetPolicyFromPreference(rawID); err != nil {
+ return serializer.Err(serializer.CodePolicyNotAllowed, "", err)
+ }
+
if fs.Policy.ID != rawID {
- return serializer.Err(serializer.CodePolicyNotAllowed, "存储策略发生变化,请刷新文件列表并重新添加此任务", nil)
+ return serializer.Err(serializer.CodePolicyChanged, "", nil)
}
file := &fsctx.FileStream{
diff --git a/service/node/fabric.go b/service/node/fabric.go
index deb2184..d35a0a9 100644
--- a/service/node/fabric.go
+++ b/service/node/fabric.go
@@ -2,6 +2,7 @@ package node
import (
"encoding/gob"
+
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
diff --git a/service/setting/webdav.go b/service/setting/webdav.go
index 1f81751..78bdc4a 100644
--- a/service/setting/webdav.go
+++ b/service/setting/webdav.go
@@ -2,6 +2,8 @@ package setting
import (
model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
+ "github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
@@ -29,12 +31,71 @@ type WebDAVAccountUpdateService struct {
UseProxy *bool `json:"use_proxy" binding:"required_without=Readonly"`
}
+// WebDAVAccountUpdateReadonlyService WebDAV 修改只读性服务
+type WebDAVAccountUpdateReadonlyService struct {
+ ID uint `json:"id" binding:"required,min=1"`
+ Readonly bool `json:"readonly"`
+}
+
// WebDAVMountCreateService WebDAV 挂载创建服务
type WebDAVMountCreateService struct {
Path string `json:"path" binding:"required,min=1,max=65535"`
Policy string `json:"policy" binding:"required,min=1"`
}
+// Create 创建目录挂载
+func (service *WebDAVMountCreateService) Create(c *gin.Context, user *model.User) serializer.Response {
+ // 创建文件系统
+ fs, err := filesystem.NewFileSystem(user)
+ if err != nil {
+ return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err)
+ }
+ defer fs.Recycle()
+
+ // 检索要挂载的目录
+ exist, folder := fs.IsPathExist(service.Path)
+ if !exist {
+ return serializer.Err(serializer.CodeParentNotExist, "", err)
+ }
+
+ // 检索要挂载的存储策略
+ policyID, err := hashid.DecodeHashID(service.Policy, hashid.PolicyID)
+ if err != nil {
+ return serializer.Err(serializer.CodePolicyNotExist, "", err)
+ }
+
+ // 检查存储策略是否可用
+ if policy, err := model.GetPolicyByID(policyID); err != nil || !util.ContainsUint(user.Group.PolicyList, policy.ID) {
+ return serializer.Err(serializer.CodePolicyNotAllowed, "", err)
+ }
+
+ // 挂载
+ if err := folder.Mount(policyID); err != nil {
+ return serializer.Err(serializer.CodeDBError, "Failed to update folder record", err)
+ }
+
+ return serializer.Response{
+ Data: map[string]interface{}{
+ "id": hashid.HashID(folder.ID, hashid.FolderID),
+ },
+ }
+}
+
+// Unmount 取消目录挂载
+func (service *WebDAVListService) Unmount(c *gin.Context, user *model.User) serializer.Response {
+ folderID, _ := c.Get("object_id")
+ folder, err := model.GetFoldersByIDs([]uint{folderID.(uint)}, user.ID)
+ if err != nil || len(folder) == 0 {
+ return serializer.Err(serializer.CodeParentNotExist, "", err)
+ }
+
+ if err := folder[0].Mount(0); err != nil {
+ return serializer.DBErr("Failed to update folder record", err)
+ }
+
+ return serializer.Response{}
+}
+
// Create 创建WebDAV账户
func (service *WebDAVAccountCreateService) Create(c *gin.Context, user *model.User) serializer.Response {
account := model.Webdav{
@@ -45,7 +106,7 @@ func (service *WebDAVAccountCreateService) Create(c *gin.Context, user *model.Us
}
if _, err := account.Create(); err != nil {
- return serializer.Err(serializer.CodeDBError, "创建失败", err)
+ return serializer.DBErr("Failed to create account record", err)
}
return serializer.Response{
@@ -76,11 +137,23 @@ func (service *WebDAVAccountUpdateService) Update(c *gin.Context, user *model.Us
return serializer.Response{Data: updates}
}
+// Update 修改WebDAV账户的只读性
+func (service *WebDAVAccountUpdateReadonlyService) Update(c *gin.Context, user *model.User) serializer.Response {
+ model.UpdateWebDAVAccountReadonlyByID(service.ID, user.ID, service.Readonly)
+ return serializer.Response{Data: map[string]bool{
+ "readonly": service.Readonly,
+ }}
+}
+
// Accounts 列出WebDAV账号
func (service *WebDAVListService) Accounts(c *gin.Context, user *model.User) serializer.Response {
accounts := model.ListWebDAVAccounts(user.ID)
+ // 查找挂载了存储策略的目录
+ folders := model.GetMountedFolders(user.ID)
+
return serializer.Response{Data: map[string]interface{}{
"accounts": accounts,
+ "folders": serializer.BuildMountedFolderRes(folders, user.Group.PolicyList),
}}
}
diff --git a/service/share/manage.go b/service/share/manage.go
index 9daccdb..af13695 100644
--- a/service/share/manage.go
+++ b/service/share/manage.go
@@ -17,6 +17,7 @@ type ShareCreateService struct {
Password string `json:"password" binding:"max=255"`
RemainDownloads int `json:"downloads"`
Expire int `json:"expire"`
+ Score int `json:"score" binding:"gte=0"`
Preview bool `json:"preview"`
}
@@ -117,6 +118,7 @@ func (service *ShareCreateService) Create(c *gin.Context) serializer.Response {
IsDir: service.IsDir,
UserID: user.ID,
SourceID: sourceID,
+ Score: service.Score,
RemainDownloads: -1,
PreviewEnabled: service.Preview,
SourceName: sourceName,
diff --git a/service/share/visit.go b/service/share/visit.go
index caad060..aaadc9f 100644
--- a/service/share/visit.go
+++ b/service/share/visit.go
@@ -48,6 +48,28 @@ type ShareListService struct {
Keywords string `form:"keywords"`
}
+// ShareReportService 举报分享
+type ShareReportService struct {
+ Reason int `json:"reason" binding:"gte=0,lte=4"`
+ Des string `json:"des"`
+}
+
+// Get 获取给定用户的分享
+func (service *ShareReportService) Report(c *gin.Context) serializer.Response {
+ // 取得分享ID
+ shareID, _ := c.Get("share")
+
+ report := &model.Report{
+ ShareID: shareID.(*model.Share).ID,
+ Reason: service.Reason,
+ Description: service.Des,
+ }
+ if err := report.Create(); err != nil {
+ return serializer.DBErr("Failed to create report record", err)
+ }
+ return serializer.Response{}
+}
+
// Get 获取给定用户的分享
func (service *ShareUserGetService) Get(c *gin.Context) serializer.Response {
// 取得用户
@@ -118,6 +140,8 @@ func (service *ShareListService) List(c *gin.Context, user *model.User) serializ
func (service *ShareGetService) Get(c *gin.Context) serializer.Response {
shareCtx, _ := c.Get("share")
share := shareCtx.(*model.Share)
+ userCtx, _ := c.Get("user")
+ user := userCtx.(*model.User)
// 是否已解锁
unlocked := true
@@ -137,6 +161,11 @@ func (service *ShareGetService) Get(c *gin.Context) serializer.Response {
share.Viewed()
}
+ // 如果已经下载过或者是自己的分享,不需要付积分
+ if share.UserID == user.ID || share.WasDownloadedBy(user, c) {
+ share.Score = 0
+ }
+
return serializer.Response{
Code: 0,
Data: serializer.BuildShareResponse(share, unlocked),
@@ -224,6 +253,39 @@ func (service *Service) CreateDocPreviewSession(c *gin.Context) serializer.Respo
return subService.CreateDocPreviewSession(ctx, c, false)
}
+// SaveToMyFile 将此分享转存到自己的网盘
+func (service *Service) SaveToMyFile(c *gin.Context) serializer.Response {
+ shareCtx, _ := c.Get("share")
+ share := shareCtx.(*model.Share)
+ userCtx, _ := c.Get("user")
+ user := userCtx.(*model.User)
+
+ // 不能转存自己的文件
+ if share.UserID == user.ID {
+ return serializer.Err(serializer.CodeSaveOwnShare, "", nil)
+ }
+
+ // 创建文件系统
+ fs, err := filesystem.NewFileSystem(user)
+ if err != nil {
+ return serializer.Err(serializer.CodeCreateFSError, "", err)
+ }
+ defer fs.Recycle()
+
+ // 重设文件系统处理目标为源文件
+ err = fs.SetTargetByInterface(share.Source())
+ if err != nil {
+ return serializer.Err(serializer.CodeFileNotFound, "", err)
+ }
+
+ err = fs.SaveTo(context.Background(), service.Path)
+ if err != nil {
+ return serializer.Err(serializer.CodeNotSet, err.Error(), err)
+ }
+
+ return serializer.Response{}
+}
+
// List 列出分享的目录下的对象
func (service *Service) List(c *gin.Context) serializer.Response {
shareCtx, _ := c.Get("share")
@@ -378,11 +440,11 @@ func (service *SearchService) Search(c *gin.Context) serializer.Response {
share := shareCtx.(*model.Share)
if !share.IsDir {
- return serializer.ParamErr("此分享无法列目录", nil)
+ return serializer.ParamErr("This is not a shared folder", nil)
}
if service.Path != "" && !path.IsAbs(service.Path) {
- return serializer.ParamErr("路径无效", nil)
+ return serializer.ParamErr("Invalid path", nil)
}
// 创建文件系统
@@ -402,7 +464,7 @@ func (service *SearchService) Search(c *gin.Context) serializer.Response {
if service.Path != "" {
ok, parent := fs.IsPathExist(service.Path)
if !ok {
- return serializer.Err(serializer.CodeParentNotExist, "Cannot find parent folder", nil)
+ return serializer.Err(serializer.CodeParentNotExist, "", nil)
}
fs.Root = parent
diff --git a/service/user/login.go b/service/user/login.go
index 22649dd..a22ce23 100644
--- a/service/user/login.go
+++ b/service/user/login.go
@@ -2,17 +2,19 @@ package user
import (
"fmt"
- model "github.com/cloudreve/Cloudreve/v3/models"
+ "net/url"
+
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
+ "github.com/gofrs/uuid"
+
+ model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
- "github.com/gofrs/uuid"
"github.com/pquerna/otp/totp"
- "net/url"
)
// UserLoginService 管理用户登录的服务
diff --git a/service/user/register.go b/service/user/register.go
index 35e8253..6dccb7b 100644
--- a/service/user/register.go
+++ b/service/user/register.go
@@ -9,6 +9,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
)
@@ -22,7 +23,22 @@ type UserRegisterService struct {
// Register 新用户注册
func (service *UserRegisterService) Register(c *gin.Context) serializer.Response {
// 相关设定
- options := model.GetSettingByNames("email_active")
+ options := model.GetSettingByNames("email_active", "reg_captcha", "mail_domain_filter", "mail_domain_filter_list")
+
+ // 检查是否在邮件域黑名单里
+ if options["mail_domain_filter"] != "0" {
+ filterList := strings.Split(options["mail_domain_filter_list"], ",")
+ emailSplit := strings.Split(service.UserName, "@")
+ emailDomain := emailSplit[len(emailSplit)-1]
+ inList := util.ContainsString(filterList, emailDomain)
+ domainErr := serializer.Err(serializer.CodeEmailProviderBaned, "Email provider banned", nil)
+ if options["mail_domain_filter"] == "1" && !inList {
+ return domainErr
+ }
+ if options["mail_domain_filter"] == "2" && inList {
+ return domainErr
+ }
+ }
// 相关设定
isEmailRequired := model.IsTrueVal(options["email_active"])
diff --git a/service/user/setting.go b/service/user/setting.go
index 8d7f619..1135697 100644
--- a/service/user/setting.go
+++ b/service/user/setting.go
@@ -8,12 +8,17 @@ import (
"os"
"path/filepath"
"strings"
+ "time"
model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/cluster"
+ "github.com/cloudreve/Cloudreve/v3/pkg/hashid"
+ "github.com/cloudreve/Cloudreve/v3/pkg/qq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
+ "github.com/samber/lo"
)
// SettingService 通用设置服务
@@ -45,6 +50,14 @@ type ChangerNick struct {
Nick string `json:"nick" binding:"required,min=1,max=255"`
}
+// VIPUnsubscribe 用户组解约服务
+type VIPUnsubscribe struct {
+}
+
+// QQBind QQ互联服务
+type QQBind struct {
+}
+
// PolicyChange 更改存储策略
type PolicyChange struct {
ID string `json:"id" binding:"required"`
@@ -163,6 +176,77 @@ func (service *HomePage) Update(c *gin.Context, user *model.User) serializer.Res
return serializer.Response{}
}
+// Update 更改用户偏好的存储策略
+func (service *PolicyChange) Update(c *gin.Context, user *model.User) serializer.Response {
+ // 取得存储策略的ID
+ rawID, err := hashid.DecodeHashID(service.ID, hashid.PolicyID)
+ if err != nil {
+ return serializer.Err(serializer.CodePolicyNotExist, "", err)
+ }
+
+ // 用户是否可以切换到此存储策略
+ if !util.ContainsUint(user.Group.PolicyList, rawID) {
+ return serializer.Err(serializer.CodePolicyNotAllowed, "", nil)
+ }
+
+ // 查找存储策略
+ if _, err := model.GetPolicyByID(rawID); err != nil {
+ return serializer.Err(serializer.CodePolicyNotAllowed, "", nil)
+ }
+
+ // 切换存储策略
+ user.OptionsSerialized.PreferredPolicy = rawID
+ if err := user.UpdateOptions(); err != nil {
+ return serializer.DBErr("Failed to update user preferences", err)
+ }
+
+ return serializer.Response{}
+}
+
+// Update 绑定或解绑QQ
+func (service *QQBind) Update(c *gin.Context, user *model.User) serializer.Response {
+ // 解除绑定
+ if user.OpenID != "" {
+ // 只通过QQ登录的用户无法解除绑定
+ if strings.HasSuffix(user.Email, "@login.qq.com") {
+ return serializer.Err(serializer.CodeNoPermissionErr, "This user cannot be unlinked", nil)
+ }
+
+ if err := user.Update(map[string]interface{}{"open_id": ""}); err != nil {
+ return serializer.DBErr("Failed to update user open id", err)
+ }
+ return serializer.Response{
+ Data: "",
+ }
+ }
+
+ // 新建绑定
+ res, err := qq.NewLoginRequest()
+ if err != nil {
+ return serializer.Err(serializer.CodeNotSet, "Failed to start QQ login request", err)
+ }
+
+ // 设定QQ登录会话Secret
+ util.SetSession(c, map[string]interface{}{"qq_login_secret": res.SecretKey})
+
+ return serializer.Response{
+ Data: res.URL,
+ }
+}
+
+// Update 用户组解约
+func (service *VIPUnsubscribe) Update(c *gin.Context, user *model.User) serializer.Response {
+ if user.GroupExpires != nil {
+ timeNow := time.Now()
+ if time.Now().Before(*user.GroupExpires) {
+ if err := user.Update(map[string]interface{}{"group_expires": &timeNow}); err != nil {
+ return serializer.DBErr("Failed to update user", err)
+ }
+ }
+ }
+ return serializer.Response{}
+}
+
// Update 更改昵称
func (service *ChangerNick) Update(c *gin.Context, user *model.User) serializer.Response {
if err := user.Update(map[string]interface{}{"nick": service.Nick}); err != nil {
@@ -241,16 +325,72 @@ func (service *SettingListService) ListTasks(c *gin.Context, user *model.User) s
return serializer.BuildTaskList(tasks, total)
}
+// Policy 获取用户存储策略设置
+func (service *SettingService) Policy(c *gin.Context, user *model.User) serializer.Response {
+ // 取得用户可用存储策略
+ available := make([]model.Policy, 0, len(user.Group.PolicyList))
+ for _, id := range user.Group.PolicyList {
+ if policy, err := model.GetPolicyByID(id); err == nil {
+ available = append(available, policy)
+ }
+ }
+
+ return serializer.BuildPolicySettingRes(available)
+}
+
+// Nodes 获取用户可选节点
+func (service *SettingService) Nodes(c *gin.Context, user *model.User) serializer.Response {
+ if !user.Group.OptionsSerialized.SelectNode {
+ return serializer.Err(serializer.CodeGroupNotAllowed, "", nil)
+ }
+
+ availableNodesID := user.Group.OptionsSerialized.AvailableNodes
+
+ // All nodes available
+ if len(availableNodesID) == 0 {
+ nodes, err := model.GetNodesByStatus(model.NodeActive)
+ if err != nil {
+ return serializer.DBErr("Failed to list nodes", err)
+ }
+
+ availableNodesID = lo.Map[model.Node, uint](nodes, func(node model.Node, index int) uint {
+ return node.ID
+ })
+ }
+
+ // 取得用户可用存储策略
+ available := lo.FilterMap[uint, *model.Node](availableNodesID,
+ func(id uint, index int) (*model.Node, bool) {
+ if node := cluster.Default.GetNodeByID(id); node != nil {
+ return node.DBModel(), node.IsActive() && node.IsFeatureEnabled("aria2")
+ }
+
+ return nil, false
+ })
+
+ return serializer.BuildNodeOptionRes(available)
+}
+
// Settings 获取用户设定
func (service *SettingService) Settings(c *gin.Context, user *model.User) serializer.Response {
+ // 用户组有效期
+ var groupExpires *time.Time
+ if user.GroupExpires != nil {
+ if expires := user.GroupExpires.Unix() - time.Now().Unix(); expires > 0 {
+ groupExpires = user.GroupExpires
+ }
+ }
+
return serializer.Response{
Data: map[string]interface{}{
- "uid": user.ID,
- "homepage": !user.OptionsSerialized.ProfileOff,
- "two_factor": user.TwoFactor != "",
- "prefer_theme": user.OptionsSerialized.PreferredTheme,
- "themes": model.GetSettingByName("themes"),
- "authn": serializer.BuildWebAuthnList(user.WebAuthnCredentials()),
+ "uid": user.ID,
+ "qq": user.OpenID != "",
+ "homepage": !user.OptionsSerialized.ProfileOff,
+ "two_factor": user.TwoFactor != "",
+ "prefer_theme": user.OptionsSerialized.PreferredTheme,
+ "themes": model.GetSettingByName("themes"),
+ "group_expires": groupExpires,
+ "authn": serializer.BuildWebAuthnList(user.WebAuthnCredentials()),
},
}
}
diff --git a/service/vas/purchase.go b/service/vas/purchase.go
new file mode 100755
index 0000000..3755b13
--- /dev/null
+++ b/service/vas/purchase.go
@@ -0,0 +1,235 @@
+package vas
+
+import (
+ "encoding/json"
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/payment"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/gin-gonic/gin"
+)
+
+// CreateOrderService 创建订单服务
+type CreateOrderService struct {
+ Action string `json:"action" binding:"required,eq=group|eq=pack|eq=score"`
+ Method string `json:"method" binding:"required,eq=alipay|eq=score|eq=payjs|eq=wechat|eq=custom"`
+ ID int64 `json:"id" binding:"required"`
+ Num int `json:"num" binding:"required,min=1"`
+}
+
+// RedeemService 兑换服务
+type RedeemService struct {
+ Code string `uri:"code" binding:"required,max=64"`
+}
+
+// OrderService 订单查询
+type OrderService struct {
+ ID string `uri:"id" binding:"required"`
+}
+
+// Status 查询订单状态
+func (service *OrderService) Status(c *gin.Context, user *model.User) serializer.Response {
+ order, _ := model.GetOrderByNo(service.ID)
+ if order == nil || order.UserID != user.ID {
+ return serializer.Err(serializer.CodeNotFound, "", nil)
+ }
+
+ return serializer.Response{Data: order.Status}
+}
+
+// Redeem 开始兑换
+func (service *RedeemService) Redeem(c *gin.Context, user *model.User) serializer.Response {
+ redeem, err := model.GetAvailableRedeem(service.Code)
+ if err != nil {
+ return serializer.Err(serializer.CodeInvalidGiftCode, "", err)
+ }
+
+ // 取得当前商品信息
+ packs, groups, err := decodeProductInfo()
+ if err != nil {
+ return serializer.Err(serializer.CodeInternalSetting, "Failed to parse product settings", err)
+ }
+
+ // 查找要购买的商品
+ var (
+ pack *serializer.PackProduct
+ group *serializer.GroupProducts
+ )
+ if redeem.Type == model.GroupOrderType {
+ for _, v := range groups {
+ if v.ID == redeem.ProductID {
+ group = &v
+ break
+ }
+ }
+
+ if group == nil {
+ return serializer.Err(serializer.CodeNotFound, "", err)
+ }
+
+ } else if redeem.Type == model.PackOrderType {
+ for _, v := range packs {
+ if v.ID == redeem.ProductID {
+ pack = &v
+ break
+ }
+ }
+
+ if pack == nil {
+ return serializer.Err(serializer.CodeNotFound, "", err)
+ }
+
+ }
+
+ err = payment.GiveProduct(user, pack, group, redeem.Num)
+ if err != nil {
+ return serializer.Err(serializer.CodeNotSet, "Redeem failed", err)
+ }
+
+ redeem.Use()
+
+ return serializer.Response{}
+
+}
+
+// Query 检查兑换码信息
+func (service *RedeemService) Query(c *gin.Context) serializer.Response {
+ redeem, err := model.GetAvailableRedeem(service.Code)
+ if err != nil {
+ return serializer.Err(serializer.CodeInvalidGiftCode, "", err)
+ }
+
+ var (
+ name = "积分"
+ productTime int64
+ )
+ if redeem.Type != model.ScoreOrderType {
+ packs, groups, err := decodeProductInfo()
+ if err != nil {
+ return serializer.Err(serializer.CodeInternalSetting, "Failed to parse product settings", err)
+ }
+ if redeem.Type == model.GroupOrderType {
+ for _, v := range groups {
+ if v.ID == redeem.ProductID {
+ name = v.Name
+ productTime = v.Time
+ break
+ }
+ }
+ } else {
+ for _, v := range packs {
+ if v.ID == redeem.ProductID {
+ name = v.Name
+ productTime = v.Time
+ break
+ }
+ }
+ }
+
+ if name == "积分" {
+ return serializer.Err(serializer.CodeNotFound, "", err)
+ }
+
+ }
+
+ return serializer.Response{
+ Data: struct {
+ Name string `json:"name"`
+ Type int `json:"type"`
+ Num int `json:"num"`
+ Time int64 `json:"time"`
+ }{
+ name, redeem.Type, redeem.Num, productTime,
+ },
+ }
+}
+
+// Create 创建新订单
+func (service *CreateOrderService) Create(c *gin.Context, user *model.User) serializer.Response {
+ // 取得当前商品信息
+ packs, groups, err := decodeProductInfo()
+ if err != nil {
+ return serializer.Err(serializer.CodeInternalSetting, "Failed to parse product list", err)
+ }
+
+ // 查找要购买的商品
+ var (
+ pack *serializer.PackProduct
+ group *serializer.GroupProducts
+ )
+ if service.Action == "group" {
+ for _, v := range groups {
+ if v.ID == service.ID {
+ group = &v
+ break
+ }
+ }
+ } else if service.Action == "pack" {
+ for _, v := range packs {
+ if v.ID == service.ID {
+ pack = &v
+ break
+ }
+ }
+ }
+
+ // 购买积分
+ if pack == nil && group == nil {
+ if service.Method == "score" {
+ return serializer.ParamErr("Payment method not supported", nil)
+ }
+ }
+
+ // 创建订单
+ res, err := payment.NewOrder(pack, group, service.Num, service.Method, user)
+ if err != nil {
+ return serializer.Err(serializer.CodeNotSet, err.Error(), err)
+ }
+
+ return serializer.Response{Data: res}
+
+}
+
+// Products 获取商品信息
+func (service *GeneralVASService) Products(c *gin.Context, user *model.User) serializer.Response {
+ options := model.GetSettingByNames(
+ "wechat_enabled",
+ "alipay_enabled",
+ "payjs_enabled",
+ "payjs_enabled",
+ "custom_payment_enabled",
+ "custom_payment_name",
+ )
+ scorePrice := model.GetIntSetting("score_price", 0)
+ packs, groups, err := decodeProductInfo()
+ if err != nil {
+ return serializer.Err(serializer.CodeInternalSetting, "Failed to parse product list", err)
+ }
+
+ return serializer.BuildProductResponse(
+ groups,
+ packs,
+ model.IsTrueVal(options["wechat_enabled"]),
+ model.IsTrueVal(options["alipay_enabled"]),
+ model.IsTrueVal(options["payjs_enabled"]),
+ model.IsTrueVal(options["custom_payment_enabled"]),
+ options["custom_payment_name"],
+ scorePrice,
+ )
+}
+
+func decodeProductInfo() ([]serializer.PackProduct, []serializer.GroupProducts, error) {
+ options := model.GetSettingByNames("pack_data", "group_sell_data", "alipay_enabled", "payjs_enabled")
+
+ var (
+ packs []serializer.PackProduct
+ groups []serializer.GroupProducts
+ )
+ if err := json.Unmarshal([]byte(options["pack_data"]), &packs); err != nil {
+ return nil, nil, err
+ }
+ if err := json.Unmarshal([]byte(options["group_sell_data"]), &groups); err != nil {
+ return nil, nil, err
+ }
+
+ return packs, groups, nil
+}
diff --git a/service/vas/qq.go b/service/vas/qq.go
new file mode 100755
index 0000000..8d8795e
--- /dev/null
+++ b/service/vas/qq.go
@@ -0,0 +1,113 @@
+package vas
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/qq"
+ "github.com/cloudreve/Cloudreve/v3/pkg/request"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/cloudreve/Cloudreve/v3/pkg/thumb"
+ "github.com/cloudreve/Cloudreve/v3/pkg/util"
+ "github.com/gin-gonic/gin"
+)
+
+// QQCallbackService QQ互联回调处理服务
+type QQCallbackService struct {
+ Code string `json:"code" binding:"required"`
+ State string `json:"state" binding:"required"`
+}
+
+// Callback 处理QQ互联回调
+func (service *QQCallbackService) Callback(c *gin.Context, user *model.User) serializer.Response {
+
+ state := util.GetSession(c, "qq_login_secret")
+ if stateStr, ok := state.(string); !ok || stateStr != service.State {
+ return serializer.Err(serializer.CodeSignExpired, "", nil)
+ }
+ util.DeleteSession(c, "qq_login_secret")
+
+ // 获取OpenID
+ credential, err := qq.Callback(service.Code)
+ if err != nil {
+ return serializer.Err(serializer.CodeInternalSetting, "Failed to get session status", err)
+ }
+
+ // 如果已登录,则绑定已有用户
+ if user != nil {
+
+ if user.OpenID != "" {
+ return serializer.Err(serializer.CodeQQBindConflict, "", nil)
+ }
+
+ // OpenID 是否重复
+ if _, err := model.GetActiveUserByOpenID(credential.OpenID); err == nil {
+ return serializer.Err(serializer.CodeQQBindOtherAccount, "", nil)
+ }
+
+ if err := user.Update(map[string]interface{}{"open_id": credential.OpenID}); err != nil {
+ return serializer.DBErr("Failed to update user open id", err)
+ }
+ return serializer.Response{
+ Data: "/setting",
+ }
+
+ }
+
+ // 未登录,尝试查找用户
+ if expectedUser, err := model.GetActiveUserByOpenID(credential.OpenID); err == nil {
+ // 用户绑定了此QQ,设定为登录状态
+ util.SetSession(c, map[string]interface{}{
+ "user_id": expectedUser.ID,
+ })
+ res := serializer.BuildUserResponse(expectedUser)
+ res.Code = 203
+ return res
+
+ }
+
+ // 无匹配用户,创建新用户
+ if !model.IsTrueVal(model.GetSettingByName("qq_direct_login")) {
+ return serializer.Err(serializer.CodeQQNotLinked, "", nil)
+ }
+
+ // 获取用户信息
+ userInfo, err := qq.GetUserInfo(credential)
+ if err != nil {
+ return serializer.Err(serializer.CodeInternalSetting, "Failed to fetch user info", err)
+ }
+
+ // 生成邮箱地址
+ fakeEmail := util.RandStringRunes(16) + "@login.qq.com"
+
+ // 创建用户
+ defaultGroup := model.GetIntSetting("default_group", 2)
+
+ newUser := model.NewUser()
+ newUser.Email = fakeEmail
+ newUser.Nick = userInfo.Nick
+ newUser.SetPassword("")
+ newUser.Status = model.Active
+ newUser.GroupID = uint(defaultGroup)
+ newUser.OpenID = credential.OpenID
+ newUser.Avatar = "file"
+
+ // 创建用户
+ if err := model.DB.Create(&newUser).Error; err != nil {
+ return serializer.Err(serializer.CodeEmailExisted, "", err)
+ }
+
+ // 下载头像
+ r := request.NewClient()
+ rawAvatar := r.Request("GET", userInfo.Avatar, nil)
+ if avatar, err := thumb.NewThumbFromFile(rawAvatar.Response.Body, "avatar.jpg"); err == nil {
+ avatar.CreateAvatar(newUser.ID)
+ }
+
+ // 登录
+ util.SetSession(c, map[string]interface{}{"user_id": newUser.ID})
+
+ newUser, _ = model.GetActiveUserByID(newUser.ID)
+
+ res := serializer.BuildUserResponse(newUser)
+ res.Code = 203
+ return res
+}
diff --git a/service/vas/quota.go b/service/vas/quota.go
new file mode 100755
index 0000000..23c1019
--- /dev/null
+++ b/service/vas/quota.go
@@ -0,0 +1,17 @@
+package vas
+
+import (
+ model "github.com/cloudreve/Cloudreve/v3/models"
+ "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
+ "github.com/gin-gonic/gin"
+)
+
+// GeneralVASService 通用增值服务
+type GeneralVASService struct {
+}
+
+// Quota 获取容量配额信息
+func (service *GeneralVASService) Quota(c *gin.Context, user *model.User) serializer.Response {
+ packs := user.GetAvailableStoragePacks()
+ return serializer.BuildUserQuotaResponse(user, packs)
+}