diff --git a/models/file.go b/models/file.go index d29960f..b23842d 100644 --- a/models/file.go +++ b/models/file.go @@ -35,7 +35,7 @@ func GetFileByPathAndName(path string, name string, uid uint) (File, error) { return file, result.Error } -// GetChildFile 查找目录下子文件 TODO:test +// GetChildFile 查找目录下子文件 func (folder *Folder) GetChildFile() ([]File, error) { var files []File result := DB.Where("folder_id = ?", folder.ID).Find(&files) diff --git a/models/file_test.go b/models/file_test.go index 2213989..2c1be0b 100644 --- a/models/file_test.go +++ b/models/file_test.go @@ -3,6 +3,7 @@ package model import ( "errors" "github.com/DATA-DOG/go-sqlmock" + "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "testing" ) @@ -41,3 +42,27 @@ func TestFile_Create(t *testing.T) { asserts.NoError(mock.ExpectationsWereMet()) } + +func TestFolder_GetChildFile(t *testing.T) { + asserts := assert.New(t) + folder := &Folder{ + Model: gorm.Model{ + ID: 1, + }, + } + + // 找不到 + mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnError(errors.New("error")) + files, err := folder.GetChildFile() + asserts.Error(err) + asserts.Len(files, 0) + asserts.NoError(mock.ExpectationsWereMet()) + + // 找到了 + mock.ExpectQuery("SELECT(.+)folder_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2)) + files, err = folder.GetChildFile() + asserts.NoError(err) + asserts.Len(files, 2) + asserts.NoError(mock.ExpectationsWereMet()) + +} diff --git a/models/folder.go b/models/folder.go index d539edb..28c56a7 100644 --- a/models/folder.go +++ b/models/folder.go @@ -16,7 +16,7 @@ type Folder struct { PositionAbsolute string `gorm:"size:65536"` } -// Create 创建目录 TODO:test +// Create 创建目录 func (folder *Folder) Create() (uint, error) { if err := DB.Create(folder).Error; err != nil { util.Log().Warning("无法插入目录记录, %s", err) @@ -32,7 +32,7 @@ func GetFolderByPath(path string, uid uint) (Folder, error) { return folder, result.Error } -// GetChildFolder 查找子目录 TODO:test +// GetChildFolder 查找子目录 func (folder *Folder) GetChildFolder() ([]Folder, error) { var folders []Folder result := DB.Where("parent_id = ?", folder.ID).Find(&folders) diff --git a/models/folder_test.go b/models/folder_test.go index efd08b2..45b1e59 100644 --- a/models/folder_test.go +++ b/models/folder_test.go @@ -1,7 +1,9 @@ package model import ( + "errors" "github.com/DATA-DOG/go-sqlmock" + "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "testing" ) @@ -22,3 +24,51 @@ func TestGetFolderByPath(t *testing.T) { asserts.Error(err) asserts.NoError(mock.ExpectationsWereMet()) } + +func TestFolder_Create(t *testing.T) { + asserts := assert.New(t) + folder := &Folder{ + Name: "new folder", + } + + // 插入成功 + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(5, 1)) + mock.ExpectCommit() + fid, err := folder.Create() + asserts.NoError(err) + asserts.Equal(uint(5), fid) + asserts.NoError(mock.ExpectationsWereMet()) + + // 插入失败 + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + fid, err = folder.Create() + asserts.Error(err) + asserts.Equal(uint(0), fid) + asserts.NoError(mock.ExpectationsWereMet()) +} + +func TestFolder_GetChildFolder(t *testing.T) { + asserts := assert.New(t) + folder := &Folder{ + Model: gorm.Model{ + ID: 1, + }, + } + + // 找不到 + mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnError(errors.New("error")) + files, err := folder.GetChildFolder() + asserts.Error(err) + asserts.Len(files, 0) + asserts.NoError(mock.ExpectationsWereMet()) + + // 找到了 + mock.ExpectQuery("SELECT(.+)parent_id(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"name", "id"}).AddRow("1.txt", 1).AddRow("2.txt", 2)) + files, err = folder.GetChildFolder() + asserts.NoError(err) + asserts.Len(files, 2) + asserts.NoError(mock.ExpectationsWereMet()) +} diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index b9b1604..598f323 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -2,6 +2,7 @@ package filesystem import ( "context" + "errors" "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/filesystem/local" "github.com/gin-gonic/gin" @@ -78,7 +79,10 @@ func NewFileSystem(user *model.User) (*FileSystem, error) { // NewFileSystemFromContext 从gin.Context创建文件系统 // TODO:test func NewFileSystemFromContext(c *gin.Context) (*FileSystem, error) { - user, _ := c.Get("user") + user, exist := c.Get("user") + if !exist { + return nil, errors.New("无法找到用户") + } fs, err := NewFileSystem(user.(*model.User)) return fs, err } diff --git a/pkg/filesystem/filesystem_test.go b/pkg/filesystem/filesystem_test.go index 3a9a4bb..fdd123d 100644 --- a/pkg/filesystem/filesystem_test.go +++ b/pkg/filesystem/filesystem_test.go @@ -2,7 +2,9 @@ package filesystem import ( model "github.com/HFO4/cloudreve/models" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" + "net/http/httptest" "testing" ) @@ -23,3 +25,21 @@ func TestNewFileSystem(t *testing.T) { fs, err = NewFileSystem(&user) asserts.Error(err) } + +func TestNewFileSystemFromContext(t *testing.T) { + asserts := assert.New(t) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Set("user", &model.User{ + Policy: model.Policy{ + Type: "local", + }, + }) + fs, err := NewFileSystemFromContext(c) + asserts.NotNil(fs) + asserts.NoError(err) + + c, _ = gin.CreateTestContext(httptest.NewRecorder()) + fs, err = NewFileSystemFromContext(c) + asserts.Nil(fs) + asserts.Error(err) +} diff --git a/pkg/filesystem/path.go b/pkg/filesystem/path.go index cd73f73..0e9cdfb 100644 --- a/pkg/filesystem/path.go +++ b/pkg/filesystem/path.go @@ -19,9 +19,13 @@ type Object struct { Pic string `json:"pic"` Size uint64 `json:"size"` Type string `json:"type"` + Date string `json:"date"` } -// List 列出路径下的内容 +// List 列出路径下的内容, +// pathProcessor为最终对象路径的处理钩子。 +// 有些情况下(如在分享页面列对象)时, +// 路径需要截取掉被分享目录路径之前的部分。 func (fs *FileSystem) List(ctx context.Context, path string, pathProcessor func(string) string) ([]Object, error) { // 获取父目录 isExist, folder := fs.IsPathExist(path) @@ -58,6 +62,7 @@ func (fs *FileSystem) List(ctx context.Context, path string, pathProcessor func( Pic: "", Size: 0, Type: "dir", + Date: folder.CreatedAt.Format("2006-01-02 15:04:05"), }) } @@ -77,13 +82,14 @@ func (fs *FileSystem) List(ctx context.Context, path string, pathProcessor func( Pic: file.PicInfo, Size: file.Size, Type: "file", + Date: file.CreatedAt.Format("2006-01-02 15:04:05"), }) } return objects, nil } -// CreateDirectory 在`base`路径下创建名为`dir`的目录 TODO: test +// CreateDirectory 根据给定的完整创建目录,不会递归创建 func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) error { // 获取要创建目录的父路径和目录名 fullPath = path.Clean(fullPath) diff --git a/pkg/filesystem/path_test.go b/pkg/filesystem/path_test.go index 8e83f41..58aa521 100644 --- a/pkg/filesystem/path_test.go +++ b/pkg/filesystem/path_test.go @@ -1,6 +1,7 @@ package filesystem import ( + "context" "github.com/DATA-DOG/go-sqlmock" model "github.com/HFO4/cloudreve/models" "github.com/jinzhu/gorm" @@ -58,3 +59,107 @@ func TestFileSystem_IsPathExist(t *testing.T) { asserts.False(testResult) asserts.NoError(mock.ExpectationsWereMet()) } + +func TestFileSystem_List(t *testing.T) { + asserts := assert.New(t) + fs := &FileSystem{User: &model.User{ + Model: gorm.Model{ + ID: 1, + }, + }} + ctx := context.Background() + + // 成功,子目录包含文件和路径,不使用路径处理钩子 + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "folder")) + mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_folder1").AddRow(7, "sub_folder2")) + mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(6, "sub_file1.txt").AddRow(7, "sub_file2.txt")) + objects, err := fs.List(ctx, "/folder", nil) + asserts.Len(objects, 4) + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) + + // 成功,子目录包含文件和路径,使用路径处理钩子 + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "folder")) + mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"}).AddRow(6, "sub_folder1", "/folder").AddRow(7, "sub_folder2", "/folder")) + mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"}).AddRow(6, "sub_file1.txt", "/folder").AddRow(7, "sub_file2.txt", "/folder")) + objects, err = fs.List(ctx, "/folder", func(s string) string { + return "prefix" + s + }) + asserts.Len(objects, 4) + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) + for _, value := range objects { + asserts.Contains(value.Path, "prefix/") + } + + // 成功,子目录包含路径,使用路径处理钩子 + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "folder")) + mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"})) + mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"}).AddRow(6, "sub_file1.txt", "/folder").AddRow(7, "sub_file2.txt", "/folder")) + objects, err = fs.List(ctx, "/folder", func(s string) string { + return "prefix" + s + }) + asserts.Len(objects, 2) + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) + for _, value := range objects { + asserts.Contains(value.Path, "prefix/") + } + + // 成功,子目录下为空,使用路径处理钩子 + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(5, "folder")) + mock.ExpectQuery("SELECT(.+)folder(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "position"})) + mock.ExpectQuery("SELECT(.+)file(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "dir"})) + objects, err = fs.List(ctx, "/folder", func(s string) string { + return "prefix" + s + }) + asserts.Len(objects, 0) + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) + + // 成功,子目录路径不存在 + mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) + + objects, err = fs.List(ctx, "/folder", func(s string) string { + return "prefix" + s + }) + asserts.Len(objects, 0) + asserts.NoError(mock.ExpectationsWereMet()) +} + +func TestFileSystem_CreateDirectory(t *testing.T) { + asserts := assert.New(t) + fs := &FileSystem{User: &model.User{ + Model: gorm.Model{ + ID: 1, + }, + }} + ctx := context.Background() + + // 目录名非法 + err := fs.CreateDirectory(ctx, "/ad/a+?") + asserts.Equal(ErrIllegalObjectName, err) + + // 父目录不存在 + mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) + err = fs.CreateDirectory(ctx, "/ad/ab") + asserts.Equal(ErrPathNotExist, err) + asserts.NoError(mock.ExpectationsWereMet()) + + // 存在同名文件 + mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "ab")) + mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "ab")) + err = fs.CreateDirectory(ctx, "/ad/ab") + asserts.Equal(ErrFileExisted, err) + asserts.NoError(mock.ExpectationsWereMet()) + + // 成功创建 + mock.ExpectQuery("SELECT(.+)folders").WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "ab")) + mock.ExpectQuery("SELECT(.+)files").WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + err = fs.CreateDirectory(ctx, "/ad/ab") + asserts.NoError(err) + asserts.NoError(mock.ExpectationsWereMet()) +} diff --git a/routers/controllers/directory.go b/routers/controllers/directory.go index efd8702..9486840 100644 --- a/routers/controllers/directory.go +++ b/routers/controllers/directory.go @@ -19,7 +19,7 @@ func CreateDirectory(c *gin.Context) { // ListDirectory 列出目录下内容 func ListDirectory(c *gin.Context) { var service explorer.DirectoryService - if err := c.ShouldBindJSON(&service); err == nil { + if err := c.ShouldBindQuery(&service); err == nil { res := service.ListDirectory(c) c.JSON(200, res) } else { diff --git a/routers/router_test.go b/routers/router_test.go index 749c5a5..f48dc8b 100644 --- a/routers/router_test.go +++ b/routers/router_test.go @@ -258,3 +258,41 @@ func TestSiteConfigRoute(t *testing.T) { }, }).UpdateColumn("name", "siteName") } + +func TestListDirectoryRoute(t *testing.T) { + switchToMemDB() + asserts := assert.New(t) + router := InitRouter() + w := httptest.NewRecorder() + + // 成功 + req, _ := http.NewRequest( + "GET", + "/api/v3/directory?path=/", + nil, + ) + middleware.SessionMock = map[string]interface{}{"user_id": 1} + router.ServeHTTP(w, req) + asserts.Equal(200, w.Code) + resJSON := &serializer.Response{} + err := json.Unmarshal(w.Body.Bytes(), resJSON) + asserts.NoError(err) + asserts.Equal(0, resJSON.Code) + + w.Body.Reset() + + // 缺少参数 + req, _ = http.NewRequest( + "GET", + "/api/v3/directory", + nil, + ) + middleware.SessionMock = map[string]interface{}{"user_id": 1} + router.ServeHTTP(w, req) + asserts.Equal(200, w.Code) + resJSON = &serializer.Response{} + err = json.Unmarshal(w.Body.Bytes(), resJSON) + asserts.NoError(err) + asserts.NotEqual(0, resJSON.Code) + +}