Test: auth middleware for WebDAV
This commit is contained in:
parent
cf90ab5a9a
commit
fd7b6e33c8
4 changed files with 114 additions and 4 deletions
|
@ -52,6 +52,7 @@ func AuthRequired() gin.HandlerFunc {
|
|||
}
|
||||
|
||||
// WebDAVAuth 验证WebDAV登录及权限
|
||||
// TODO 测试
|
||||
func WebDAVAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// OPTIONS 请求不需要鉴权,否则Windows10下无法保存文档
|
||||
|
|
|
@ -90,3 +90,107 @@ func TestSignRequired(t *testing.T) {
|
|||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
}
|
||||
|
||||
func TestWebDAVAuth(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
rec := httptest.NewRecorder()
|
||||
AuthFunc := WebDAVAuth()
|
||||
|
||||
// options请求跳过验证
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("OPTIONS", "/test", nil)
|
||||
AuthFunc(c)
|
||||
}
|
||||
|
||||
// 请求HTTP Basic Auth
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
AuthFunc(c)
|
||||
asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"])
|
||||
}
|
||||
|
||||
// 用户名不存在
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
c.Request.Header = map[string][]string{
|
||||
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
|
||||
}
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"id", "password", "email"}),
|
||||
)
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// 密码错误
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
c.Request.Header = map[string][]string{
|
||||
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
|
||||
}
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"),
|
||||
)
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Equal(c.Writer.Status(), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
//未启用 WebDAV
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
c.Request.Header = map[string][]string{
|
||||
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
|
||||
}
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows(
|
||||
[]string{"id", "password", "email", "group_id", "options"}).
|
||||
AddRow(1,
|
||||
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
|
||||
"who@cloudreve.org",
|
||||
1,
|
||||
"{}",
|
||||
),
|
||||
)
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false))
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Equal(c.Writer.Status(), http.StatusForbidden)
|
||||
}
|
||||
|
||||
//正常
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("POST", "/test", nil)
|
||||
c.Request.Header = map[string][]string{
|
||||
"Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="},
|
||||
}
|
||||
mock.ExpectQuery("SELECT(.+)users(.+)").
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows(
|
||||
[]string{"id", "password", "email", "group_id", "options"}).
|
||||
AddRow(1,
|
||||
"rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3",
|
||||
"who@cloudreve.org",
|
||||
1,
|
||||
"{}",
|
||||
),
|
||||
)
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true))
|
||||
AuthFunc(c)
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
asserts.Equal(c.Writer.Status(), 200)
|
||||
_, ok := c.Get("user")
|
||||
asserts.True(ok)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
10
pkg/cache/driver.go
vendored
10
pkg/cache/driver.go
vendored
|
@ -25,14 +25,18 @@ func Init() {
|
|||
|
||||
// Driver 键值缓存存储容器
|
||||
type Driver interface {
|
||||
// 设置值
|
||||
// 设置值,ttl为过期时间,单位为秒
|
||||
Set(key string, value interface{}, ttl int) error
|
||||
// 取值
|
||||
|
||||
// 取值,并返回是否成功
|
||||
Get(key string) (interface{}, bool)
|
||||
|
||||
// 批量取值,返回成功取值的map即不存在的值
|
||||
Gets(keys []string, prefix string) (map[string]interface{}, []string)
|
||||
// 批量设置值
|
||||
|
||||
// 批量设置值,所有的key都会加上prefix前缀
|
||||
Sets(values map[string]interface{}, prefix string) error
|
||||
|
||||
// 删除值
|
||||
Delete(keys []string, prefix string) error
|
||||
}
|
||||
|
|
|
@ -604,6 +604,7 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request, fs *fil
|
|||
mw := multistatusWriter{w: w}
|
||||
|
||||
walkFn := func(reqPath string, info FileInfo, err error) error {
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -626,7 +627,7 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request, fs *fil
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
href := path.Join(h.Prefix, strconv.FormatUint(uint64(fs.User.ID), 10), reqPath)
|
||||
href := path.Join(h.Prefix, reqPath)
|
||||
if href != "/" && info.IsDir() {
|
||||
href += "/"
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue