Modify: add general ReaderCloserSeeker interface for handler GET method to return

This commit is contained in:
HFO4 2019-12-13 20:54:28 +08:00
parent f262caf1f5
commit 03dcd9a9e0
6 changed files with 19 additions and 25 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
model "github.com/HFO4/cloudreve/models"
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/filesystem/response"
"github.com/HFO4/cloudreve/pkg/serializer"
"github.com/HFO4/cloudreve/pkg/util"
"github.com/juju/ratelimit"
@ -18,7 +19,7 @@ import (
// 限速后的ReaderSeeker
type lrs struct {
io.ReadSeeker
response.RSCloser
r io.Reader
}
@ -27,7 +28,7 @@ func (r lrs) Read(p []byte) (int, error) {
}
// withSpeedLimit 给原有的ReadSeeker加上限速
func (fs *FileSystem) withSpeedLimit(rs io.ReadSeeker) io.ReadSeeker {
func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser {
// 如果用户组有速度限制就返回限制流速的ReaderSeeker
if fs.User.Group.SpeedLimit != 0 {
speed := fs.User.Group.SpeedLimit
@ -63,7 +64,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model
}
// GetPhysicalFileContent 根据文件物理路径获取文件流
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (io.ReadSeeker, error) {
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) {
// 重设上传策略
fs.Policy = &model.Policy{Type: "local"}
_ = fs.dispatchHandler()
@ -78,7 +79,7 @@ func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (
}
// GetDownloadContent 获取用于下载的文件流
func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.ReadSeeker, error) {
func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (response.RSCloser, error) {
// 获取原始文件流
rs, err := fs.GetContent(ctx, path)
if err != nil {
@ -91,7 +92,7 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.R
}
// GetContent 获取文件内容path为虚拟路径
func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeker, error) {
func (fs *FileSystem) GetContent(ctx context.Context, path string) (response.RSCloser, error) {
// 触发`下载前`钩子
err := fs.Trigger(ctx, fs.BeforeFileDownload)
if err != nil {

View file

@ -27,7 +27,7 @@ type Handler interface {
// 删除一个或多个文件
Delete(ctx context.Context, files []string) ([]string, error)
// 获取文件
Get(ctx context.Context, path string) (io.ReadSeeker, error)
Get(ctx context.Context, path string) (response.RSCloser, error)
// 获取缩略图
Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
// 获取外链地址url

View file

@ -25,7 +25,7 @@ type Handler struct {
}
// Get 获取文件内容
func (handler Handler) Get(ctx context.Context, path string) (io.ReadSeeker, error) {
func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 打开文件
file, err := os.Open(path)
if err != nil {

View file

@ -10,3 +10,9 @@ type ContentResponse struct {
Content io.ReadSeeker
URL string
}
// 存储策略适配器返回的文件流有些策略需要带有Closer
type RSCloser interface {
io.ReadSeeker
io.Closer
}

View file

@ -22,9 +22,9 @@ type FileHeaderMock struct {
testMock.Mock
}
func (m FileHeaderMock) Get(ctx context.Context, path string) (io.ReadSeeker, error) {
func (m FileHeaderMock) Get(ctx context.Context, path string) (response.RSCloser, error) {
args := m.Called(ctx, path)
return args.Get(0).(io.ReadSeeker), args.Error(1)
return args.Get(0).(response.RSCloser), args.Error(1)
}
func (m FileHeaderMock) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {

View file

@ -8,7 +8,6 @@ import (
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
"github.com/HFO4/cloudreve/pkg/serializer"
"github.com/gin-gonic/gin"
"io"
"net/http"
"time"
)
@ -45,6 +44,7 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
// 获取文件流
rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string))
defer rs.Close()
if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
}
@ -58,11 +58,6 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
c.Header("Content-Type", "application/zip")
http.ServeContent(c.Writer, c.Request, "", time.Now(), rs)
// 检查是否需要关闭文件
if fc, ok := rs.(io.Closer); ok {
err = fc.Close()
}
return serializer.Response{
Code: 0,
}
@ -84,6 +79,7 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
// 获取文件流
rs, err := fs.GetDownloadContent(ctx, "")
defer rs.Close()
if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
}
@ -91,11 +87,6 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
// 发送文件
http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs)
// 检查是否需要关闭文件
if fc, ok := rs.(io.Closer); ok {
defer fc.Close()
}
return serializer.Response{
Code: 0,
}
@ -139,6 +130,7 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
// 开始处理下载
ctx = context.WithValue(ctx, fsctx.GinCtx, c)
rs, err := fs.GetDownloadContent(ctx, "")
defer rs.Close()
if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
}
@ -154,11 +146,6 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
// 发送文件
http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs)
// 检查是否需要关闭文件
if fc, ok := rs.(io.Closer); ok {
defer fc.Close()
}
return serializer.Response{
Code: 0,
}