Test: signRequired middleware
This commit is contained in:
parent
297b507ca7
commit
9f26c0c8ab
6 changed files with 82 additions and 12 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"database/sql"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/auth"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
|
@ -76,3 +77,16 @@ func TestAuthRequired(t *testing.T) {
|
|||
AuthRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
}
|
||||
|
||||
func TestSignRequired(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
SignRequiredFunc := SignRequired()
|
||||
|
||||
// 鉴权失败
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
}
|
||||
|
|
|
@ -54,7 +54,6 @@ func (folder *Folder) GetChildFiles() ([]File, error) {
|
|||
|
||||
// GetFilesByIDs 根据文件ID批量获取文件,
|
||||
// UID为0表示忽略用户,只根据文件ID检索
|
||||
// TODO 测试
|
||||
func GetFilesByIDs(ids []uint, uid uint) ([]File, error) {
|
||||
var files []File
|
||||
var result *gorm.DB
|
||||
|
|
|
@ -106,6 +106,17 @@ func TestGetFilesByIDs(t *testing.T) {
|
|||
asserts.NoError(err)
|
||||
asserts.Len(folders, 1)
|
||||
}
|
||||
|
||||
// 忽略UID查找
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)").
|
||||
WithArgs(1, 2, 3).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "1"))
|
||||
folders, err := GetFilesByIDs([]uint{1, 2, 3}, 0)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.NoError(err)
|
||||
asserts.Len(folders, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChildFilesOfFolders(t *testing.T) {
|
||||
|
|
|
@ -22,17 +22,18 @@ type Auth interface {
|
|||
Check(body string, sign string) error
|
||||
}
|
||||
|
||||
// SignURI 对URI进行签名
|
||||
// SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证
|
||||
// TODO 测试
|
||||
func SignURI(uri string, expires int64) (*url.URL, error) {
|
||||
// 生成签名
|
||||
sign := General.Sign(uri, expires)
|
||||
|
||||
// 将签名加到URI中
|
||||
base, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign := General.Sign(base.Path, expires)
|
||||
|
||||
// 将签名加到URI中
|
||||
queries := base.Query()
|
||||
queries.Set("sign", sign)
|
||||
base.RawQuery = queries.Encode()
|
||||
|
@ -47,9 +48,8 @@ func CheckURI(url *url.URL) error {
|
|||
sign := queries.Get("sign")
|
||||
queries.Del("sign")
|
||||
url.RawQuery = queries.Encode()
|
||||
requestURI := url.RequestURI()
|
||||
|
||||
return General.Check(requestURI, sign)
|
||||
return General.Check(url.Path, sign)
|
||||
}
|
||||
|
||||
// Init 初始化通用鉴权器
|
||||
|
|
48
pkg/auth/auth_test.go
Normal file
48
pkg/auth/auth_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSignURI(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 成功
|
||||
{
|
||||
sign, err := SignURI("/api/v3/something?id=1", 0)
|
||||
asserts.NoError(err)
|
||||
queries := sign.Query()
|
||||
asserts.Equal("1", queries.Get("id"))
|
||||
asserts.NotEmpty(queries.Get("sign"))
|
||||
}
|
||||
|
||||
// URI解码失败
|
||||
{
|
||||
sign, err := SignURI("://dg.;'f]gh./'", 0)
|
||||
asserts.Error(err)
|
||||
asserts.Nil(sign)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckURI(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
|
||||
// 成功
|
||||
{
|
||||
sign, err := SignURI("/api/ok?if=sdf&fd=go", time.Now().Unix()+10)
|
||||
asserts.NoError(err)
|
||||
asserts.NoError(CheckURI(sign))
|
||||
}
|
||||
|
||||
// 过期
|
||||
{
|
||||
sign, err := SignURI("/api/ok?if=sdf&fd=go", time.Now().Unix()-1)
|
||||
asserts.NoError(err)
|
||||
asserts.Error(CheckURI(sign))
|
||||
}
|
||||
}
|
6
pkg/cache/driver.go
vendored
6
pkg/cache/driver.go
vendored
|
@ -6,15 +6,13 @@ import (
|
|||
)
|
||||
|
||||
// Store 缓存存储器
|
||||
var Store Driver
|
||||
var Store Driver = NewMemoStore()
|
||||
|
||||
// Init 初始化缓存
|
||||
func Init() {
|
||||
//Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0")
|
||||
//return
|
||||
if conf.RedisConfig.Server == "" || gin.Mode() == gin.TestMode {
|
||||
Store = NewMemoStore()
|
||||
} else {
|
||||
if conf.RedisConfig.Server != "" && gin.Mode() == gin.TestMode {
|
||||
Store = NewRedisStore(
|
||||
10,
|
||||
"tcp",
|
||||
|
|
Loading…
Reference in a new issue