From 03dcd9a9e0558c7b003717297fd11a8dffc79f2e Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Fri, 13 Dec 2019 20:54:28 +0800 Subject: [PATCH] Modify: add general ReaderCloserSeeker interface for handler GET method to return --- pkg/filesystem/file.go | 11 ++++++----- pkg/filesystem/filesystem.go | 2 +- pkg/filesystem/local/handler.go | 2 +- pkg/filesystem/response/common.go | 6 ++++++ pkg/filesystem/upload_test.go | 4 ++-- service/explorer/file.go | 19 +++---------------- 6 files changed, 19 insertions(+), 25 deletions(-) diff --git a/pkg/filesystem/file.go b/pkg/filesystem/file.go index 4cabbc7..bd53422 100644 --- a/pkg/filesystem/file.go +++ b/pkg/filesystem/file.go @@ -4,6 +4,7 @@ import ( "context" model "github.com/HFO4/cloudreve/models" "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" + "github.com/HFO4/cloudreve/pkg/filesystem/response" "github.com/HFO4/cloudreve/pkg/serializer" "github.com/HFO4/cloudreve/pkg/util" "github.com/juju/ratelimit" @@ -18,7 +19,7 @@ import ( // 限速后的ReaderSeeker type lrs struct { - io.ReadSeeker + response.RSCloser r io.Reader } @@ -27,7 +28,7 @@ func (r lrs) Read(p []byte) (int, error) { } // withSpeedLimit 给原有的ReadSeeker加上限速 -func (fs *FileSystem) withSpeedLimit(rs io.ReadSeeker) io.ReadSeeker { +func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser { // 如果用户组有速度限制,就返回限制流速的ReaderSeeker if fs.User.Group.SpeedLimit != 0 { speed := fs.User.Group.SpeedLimit @@ -63,7 +64,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model } // GetPhysicalFileContent 根据文件物理路径获取文件流 -func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (io.ReadSeeker, error) { +func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) { // 重设上传策略 fs.Policy = &model.Policy{Type: "local"} _ = fs.dispatchHandler() @@ -78,7 +79,7 @@ func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) ( } // GetDownloadContent 获取用于下载的文件流 -func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.ReadSeeker, error) { +func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (response.RSCloser, error) { // 获取原始文件流 rs, err := fs.GetContent(ctx, path) if err != nil { @@ -91,7 +92,7 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.R } // GetContent 获取文件内容,path为虚拟路径 -func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeker, error) { +func (fs *FileSystem) GetContent(ctx context.Context, path string) (response.RSCloser, error) { // 触发`下载前`钩子 err := fs.Trigger(ctx, fs.BeforeFileDownload) if err != nil { diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 8428557..08a8490 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -27,7 +27,7 @@ type Handler interface { // 删除一个或多个文件 Delete(ctx context.Context, files []string) ([]string, error) // 获取文件 - Get(ctx context.Context, path string) (io.ReadSeeker, error) + Get(ctx context.Context, path string) (response.RSCloser, error) // 获取缩略图 Thumb(ctx context.Context, path string) (*response.ContentResponse, error) // 获取外链地址,url diff --git a/pkg/filesystem/local/handler.go b/pkg/filesystem/local/handler.go index 1019c42..db5cdd5 100644 --- a/pkg/filesystem/local/handler.go +++ b/pkg/filesystem/local/handler.go @@ -25,7 +25,7 @@ type Handler struct { } // Get 获取文件内容 -func (handler Handler) Get(ctx context.Context, path string) (io.ReadSeeker, error) { +func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, error) { // 打开文件 file, err := os.Open(path) if err != nil { diff --git a/pkg/filesystem/response/common.go b/pkg/filesystem/response/common.go index 5ef7d54..fe95a6c 100644 --- a/pkg/filesystem/response/common.go +++ b/pkg/filesystem/response/common.go @@ -10,3 +10,9 @@ type ContentResponse struct { Content io.ReadSeeker URL string } + +// 存储策略适配器返回的文件流,有些策略需要带有Closer +type RSCloser interface { + io.ReadSeeker + io.Closer +} diff --git a/pkg/filesystem/upload_test.go b/pkg/filesystem/upload_test.go index 50fa48b..4a462ca 100644 --- a/pkg/filesystem/upload_test.go +++ b/pkg/filesystem/upload_test.go @@ -22,9 +22,9 @@ type FileHeaderMock struct { testMock.Mock } -func (m FileHeaderMock) Get(ctx context.Context, path string) (io.ReadSeeker, error) { +func (m FileHeaderMock) Get(ctx context.Context, path string) (response.RSCloser, error) { args := m.Called(ctx, path) - return args.Get(0).(io.ReadSeeker), args.Error(1) + return args.Get(0).(response.RSCloser), args.Error(1) } func (m FileHeaderMock) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { diff --git a/service/explorer/file.go b/service/explorer/file.go index 970a18d..a51a6c0 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -8,7 +8,6 @@ import ( "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/serializer" "github.com/gin-gonic/gin" - "io" "net/http" "time" ) @@ -45,6 +44,7 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con // 获取文件流 rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string)) + defer rs.Close() if err != nil { return serializer.Err(serializer.CodeNotSet, err.Error(), err) } @@ -58,11 +58,6 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con c.Header("Content-Type", "application/zip") http.ServeContent(c.Writer, c.Request, "", time.Now(), rs) - // 检查是否需要关闭文件 - if fc, ok := rs.(io.Closer); ok { - err = fc.Close() - } - return serializer.Response{ Code: 0, } @@ -84,6 +79,7 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con // 获取文件流 rs, err := fs.GetDownloadContent(ctx, "") + defer rs.Close() if err != nil { return serializer.Err(serializer.CodeNotSet, err.Error(), err) } @@ -91,11 +87,6 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con // 发送文件 http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs) - // 检查是否需要关闭文件 - if fc, ok := rs.(io.Closer); ok { - defer fc.Close() - } - return serializer.Response{ Code: 0, } @@ -139,6 +130,7 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se // 开始处理下载 ctx = context.WithValue(ctx, fsctx.GinCtx, c) rs, err := fs.GetDownloadContent(ctx, "") + defer rs.Close() if err != nil { return serializer.Err(serializer.CodeNotSet, err.Error(), err) } @@ -154,11 +146,6 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se // 发送文件 http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs) - // 检查是否需要关闭文件 - if fc, ok := rs.(io.Closer); ok { - defer fc.Close() - } - return serializer.Response{ Code: 0, }