diff --git a/middleware/auth.go b/middleware/auth.go index e5da46e..ce1c102 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -52,6 +52,7 @@ func AuthRequired() gin.HandlerFunc { } // WebDAVAuth 验证WebDAV登录及权限 +// TODO 测试 func WebDAVAuth() gin.HandlerFunc { return func(c *gin.Context) { // OPTIONS 请求不需要鉴权,否则Windows10下无法保存文档 diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 5f89f93..91d43f5 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -90,3 +90,107 @@ func TestSignRequired(t *testing.T) { SignRequiredFunc(c) asserts.NotNil(c) } + +func TestWebDAVAuth(t *testing.T) { + asserts := assert.New(t) + rec := httptest.NewRecorder() + AuthFunc := WebDAVAuth() + + // options请求跳过验证 + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("OPTIONS", "/test", nil) + AuthFunc(c) + } + + // 请求HTTP Basic Auth + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("POST", "/test", nil) + AuthFunc(c) + asserts.NotEmpty(c.Writer.Header()["WWW-Authenticate"]) + } + + // 用户名不存在 + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("POST", "/test", nil) + c.Request.Header = map[string][]string{ + "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, + } + mock.ExpectQuery("SELECT(.+)users(.+)"). + WillReturnRows( + sqlmock.NewRows([]string{"id", "password", "email"}), + ) + AuthFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal(c.Writer.Status(), http.StatusUnauthorized) + } + + // 密码错误 + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("POST", "/test", nil) + c.Request.Header = map[string][]string{ + "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, + } + mock.ExpectQuery("SELECT(.+)users(.+)"). + WillReturnRows( + sqlmock.NewRows([]string{"id", "password", "email", "options"}).AddRow(1, "123", "who@cloudreve.org", "{}"), + ) + AuthFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal(c.Writer.Status(), http.StatusUnauthorized) + } + + //未启用 WebDAV + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("POST", "/test", nil) + c.Request.Header = map[string][]string{ + "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, + } + mock.ExpectQuery("SELECT(.+)users(.+)"). + WillReturnRows( + sqlmock.NewRows( + []string{"id", "password", "email", "group_id", "options"}). + AddRow(1, + "rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3", + "who@cloudreve.org", + 1, + "{}", + ), + ) + mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, false)) + AuthFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal(c.Writer.Status(), http.StatusForbidden) + } + + //正常 + { + c, _ := gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("POST", "/test", nil) + c.Request.Header = map[string][]string{ + "Authorization": {"Basic d2hvQGNsb3VkcmV2ZS5vcmc6YWRtaW4="}, + } + mock.ExpectQuery("SELECT(.+)users(.+)"). + WillReturnRows( + sqlmock.NewRows( + []string{"id", "password", "email", "group_id", "options"}). + AddRow(1, + "rfBd67ti3SMtYvSg:ce6dc7bca4f17f2660e18e7608686673eae0fdf3", + "who@cloudreve.org", + 1, + "{}", + ), + ) + mock.ExpectQuery("SELECT(.+)groups(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "web_dav_enabled"}).AddRow(1, true)) + AuthFunc(c) + asserts.NoError(mock.ExpectationsWereMet()) + asserts.Equal(c.Writer.Status(), 200) + _, ok := c.Get("user") + asserts.True(ok) + } + +} diff --git a/pkg/cache/driver.go b/pkg/cache/driver.go index 1e853f8..0226531 100644 --- a/pkg/cache/driver.go +++ b/pkg/cache/driver.go @@ -25,14 +25,18 @@ func Init() { // Driver 键值缓存存储容器 type Driver interface { - // 设置值 + // 设置值,ttl为过期时间,单位为秒 Set(key string, value interface{}, ttl int) error - // 取值 + + // 取值,并返回是否成功 Get(key string) (interface{}, bool) + // 批量取值,返回成功取值的map即不存在的值 Gets(keys []string, prefix string) (map[string]interface{}, []string) - // 批量设置值 + + // 批量设置值,所有的key都会加上prefix前缀 Sets(values map[string]interface{}, prefix string) error + // 删除值 Delete(keys []string, prefix string) error } diff --git a/pkg/webdav/webdav.go b/pkg/webdav/webdav.go index 08e11b1..94638fb 100644 --- a/pkg/webdav/webdav.go +++ b/pkg/webdav/webdav.go @@ -604,6 +604,7 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request, fs *fil mw := multistatusWriter{w: w} walkFn := func(reqPath string, info FileInfo, err error) error { + if err != nil { return err } @@ -626,7 +627,7 @@ func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request, fs *fil if err != nil { return err } - href := path.Join(h.Prefix, strconv.FormatUint(uint64(fs.User.ID), 10), reqPath) + href := path.Join(h.Prefix, reqPath) if href != "/" && info.IsDir() { href += "/" }