diff --git a/pkg/filesystem/hook_test.go b/pkg/filesystem/hook_test.go new file mode 100644 index 0000000..3d22e43 --- /dev/null +++ b/pkg/filesystem/hook_test.go @@ -0,0 +1,39 @@ +package filesystem + +import ( + "context" + model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/filesystem/local" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGenericBeforeUpload(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + file := local.FileData{ + Size: 5, + Name: "1.txt", + } + fs := FileSystem{ + User: &model.User{ + Storage: 0, + Group: model.Group{ + MaxStorage: 11, + }, + Policy: model.Policy{ + MaxSize: 4, + OptionsSerialized: model.PolicyOption{ + FileType: []string{"txt"}, + }, + }, + }, + } + + asserts.Error(GenericBeforeUpload(ctx, &fs, file)) + file.Size = 1 + file.Name = "1" + asserts.Error(GenericBeforeUpload(ctx, &fs, file)) + file.Name = "1.txt" + asserts.NoError(GenericBeforeUpload(ctx, &fs, file)) +} diff --git a/pkg/filesystem/local/handler.go b/pkg/filesystem/local/handler.go index 9a0b9d6..8d1dbcb 100644 --- a/pkg/filesystem/local/handler.go +++ b/pkg/filesystem/local/handler.go @@ -2,7 +2,6 @@ package local import ( "context" - "fmt" "github.com/HFO4/cloudreve/pkg/util" "io" "os" @@ -19,7 +18,6 @@ func (handler Handler) Put(ctx context.Context, file io.ReadCloser, dst string) // 如果目标目录不存在,创建 basePath := filepath.Dir(dst) if !util.Exists(basePath) { - fmt.Println("创建", basePath) err := os.MkdirAll(basePath, 0666) if err != nil { return err diff --git a/pkg/filesystem/local/handller_test.go b/pkg/filesystem/local/handller_test.go new file mode 100644 index 0000000..ba1a9d1 --- /dev/null +++ b/pkg/filesystem/local/handller_test.go @@ -0,0 +1,44 @@ +package local + +import ( + "context" + "github.com/HFO4/cloudreve/pkg/util" + "github.com/stretchr/testify/assert" + "io" + "io/ioutil" + "strings" + "testing" +) + +func TestHandler_Put(t *testing.T) { + asserts := assert.New(t) + handler := Handler{} + ctx := context.Background() + + testCases := []struct { + file io.ReadCloser + dst string + err bool + }{ + { + file: ioutil.NopCloser(strings.NewReader("test input file")), + dst: "test/test/txt", + err: false, + }, + { + file: ioutil.NopCloser(strings.NewReader("test input file")), + dst: "notexist:/S.TXT", + err: true, + }, + } + + for _, testCase := range testCases { + err := handler.Put(ctx, testCase.file, testCase.dst) + if testCase.err { + asserts.Error(err) + } else { + asserts.NoError(err) + asserts.True(util.Exists(testCase.dst)) + } + } +} diff --git a/pkg/filesystem/validator.go b/pkg/filesystem/validator.go index 8d58652..06f4c05 100644 --- a/pkg/filesystem/validator.go +++ b/pkg/filesystem/validator.go @@ -27,7 +27,6 @@ func (fs *FileSystem) ValidateExtension(ctx context.Context, fileName string) bo } ext := filepath.Ext(fileName) - // 无扩展名时 if len(ext) == 0 { return false diff --git a/pkg/filesystem/validator_test.go b/pkg/filesystem/validator_test.go new file mode 100644 index 0000000..ed28181 --- /dev/null +++ b/pkg/filesystem/validator_test.go @@ -0,0 +1,90 @@ +package filesystem + +import ( + "context" + "database/sql" + "github.com/DATA-DOG/go-sqlmock" + model "github.com/HFO4/cloudreve/models" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "testing" +) + +var mock sqlmock.Sqlmock + +// TestMain 初始化数据库Mock +func TestMain(m *testing.M) { + var db *sql.DB + var err error + db, mock, err = sqlmock.New() + if err != nil { + panic("An error was not expected when opening a stub database connection") + } + model.DB, _ = gorm.Open("mysql", db) + defer db.Close() + m.Run() +} + +func TestFileSystem_ValidateCapacity(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := FileSystem{ + User: &model.User{ + Storage: 10, + Group: model.Group{ + MaxStorage: 11, + }, + }, + } + + asserts.True(fs.ValidateCapacity(ctx, 1)) + asserts.Equal(uint64(11), fs.User.Storage) + + fs.User.Storage = 5 + asserts.False(fs.ValidateCapacity(ctx, 10)) + asserts.Equal(uint64(5), fs.User.Storage) +} + +func TestFileSystem_ValidateFileSize(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := FileSystem{ + User: &model.User{ + Policy: model.Policy{ + MaxSize: 10, + }, + }, + } + + asserts.True(fs.ValidateFileSize(ctx, 5)) + asserts.True(fs.ValidateFileSize(ctx, 10)) + asserts.False(fs.ValidateFileSize(ctx, 11)) +} + +func TestFileSystem_ValidateExtension(t *testing.T) { + asserts := assert.New(t) + ctx := context.Background() + fs := FileSystem{ + User: &model.User{ + Policy: model.Policy{ + OptionsSerialized: model.PolicyOption{ + FileType: nil, + }, + }, + }, + } + + asserts.True(fs.ValidateExtension(ctx, "1")) + asserts.True(fs.ValidateExtension(ctx, "1.txt")) + + fs.User.Policy.OptionsSerialized.FileType = []string{} + asserts.True(fs.ValidateExtension(ctx, "1")) + asserts.True(fs.ValidateExtension(ctx, "1.txt")) + + fs.User.Policy.OptionsSerialized.FileType = []string{"txt", "jpg"} + asserts.False(fs.ValidateExtension(ctx, "1")) + asserts.False(fs.ValidateExtension(ctx, "1.jpg.png")) + asserts.True(fs.ValidateExtension(ctx, "1.txt")) + asserts.True(fs.ValidateExtension(ctx, "1.png.jpg")) + asserts.False(fs.ValidateExtension(ctx, "1.png")) +}