diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 79fedd3..5f89f93 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "github.com/DATA-DOG/go-sqlmock" "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/auth" "github.com/HFO4/cloudreve/pkg/util" "github.com/gin-gonic/gin" "github.com/jinzhu/gorm" @@ -76,3 +77,16 @@ func TestAuthRequired(t *testing.T) { AuthRequiredFunc(c) asserts.NotNil(c) } + +func TestSignRequired(t *testing.T) { + asserts := assert.New(t) + auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("GET", "/test", nil) + SignRequiredFunc := SignRequired() + + // 鉴权失败 + SignRequiredFunc(c) + asserts.NotNil(c) +} diff --git a/models/file.go b/models/file.go index 573ff16..fe19671 100644 --- a/models/file.go +++ b/models/file.go @@ -54,7 +54,6 @@ func (folder *Folder) GetChildFiles() ([]File, error) { // GetFilesByIDs 根据文件ID批量获取文件, // UID为0表示忽略用户,只根据文件ID检索 -// TODO 测试 func GetFilesByIDs(ids []uint, uid uint) ([]File, error) { var files []File var result *gorm.DB diff --git a/models/file_test.go b/models/file_test.go index 7bfd2c4..99d5625 100644 --- a/models/file_test.go +++ b/models/file_test.go @@ -106,6 +106,17 @@ func TestGetFilesByIDs(t *testing.T) { asserts.NoError(err) asserts.Len(folders, 1) } + + // 忽略UID查找 + { + mock.ExpectQuery("SELECT(.+)"). + WithArgs(1, 2, 3). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1")) + folders, err := GetFilesByIDs([]uint{1, 2, 3}, 0) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.NoError(err) + asserts.Len(folders, 1) + } } func TestGetChildFilesOfFolders(t *testing.T) { diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index a000f70..11cbc27 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -22,17 +22,18 @@ type Auth interface { Check(body string, sign string) error } -// SignURI 对URI进行签名 +// SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证 // TODO 测试 func SignURI(uri string, expires int64) (*url.URL, error) { - // 生成签名 - sign := General.Sign(uri, expires) - - // 将签名加到URI中 base, err := url.Parse(uri) if err != nil { return nil, err } + + // 生成签名 + sign := General.Sign(base.Path, expires) + + // 将签名加到URI中 queries := base.Query() queries.Set("sign", sign) base.RawQuery = queries.Encode() @@ -47,9 +48,8 @@ func CheckURI(url *url.URL) error { sign := queries.Get("sign") queries.Del("sign") url.RawQuery = queries.Encode() - requestURI := url.RequestURI() - return General.Check(requestURI, sign) + return General.Check(url.Path, sign) } // Init 初始化通用鉴权器 diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go new file mode 100644 index 0000000..9be8eb6 --- /dev/null +++ b/pkg/auth/auth_test.go @@ -0,0 +1,48 @@ +package auth + +import ( + "github.com/HFO4/cloudreve/pkg/util" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestSignURI(t *testing.T) { + asserts := assert.New(t) + General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} + + // 成功 + { + sign, err := SignURI("/api/v3/something?id=1", 0) + asserts.NoError(err) + queries := sign.Query() + asserts.Equal("1", queries.Get("id")) + asserts.NotEmpty(queries.Get("sign")) + } + + // URI解码失败 + { + sign, err := SignURI("://dg.;'f]gh./'", 0) + asserts.Error(err) + asserts.Nil(sign) + } +} + +func TestCheckURI(t *testing.T) { + asserts := assert.New(t) + General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} + + // 成功 + { + sign, err := SignURI("/api/ok?if=sdf&fd=go", time.Now().Unix()+10) + asserts.NoError(err) + asserts.NoError(CheckURI(sign)) + } + + // 过期 + { + sign, err := SignURI("/api/ok?if=sdf&fd=go", time.Now().Unix()-1) + asserts.NoError(err) + asserts.Error(CheckURI(sign)) + } +} diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go index 7bd7868..8a59e92 100644 --- a/pkg/cache/driver.go +++ b/pkg/cache/driver.go @@ -6,15 +6,13 @@ import ( ) // Store 缓存存储器 -var Store Driver +var Store Driver = NewMemoStore() // Init 初始化缓存 func Init() { //Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0") //return - if conf.RedisConfig.Server == "" || gin.Mode() == gin.TestMode { - Store = NewMemoStore() - } else { + if conf.RedisConfig.Server != "" && gin.Mode() == gin.TestMode { Store = NewRedisStore( 10, "tcp",