Modify: add general ReaderCloserSeeker interface for handler GET method to return
This commit is contained in:
parent
f262caf1f5
commit
03dcd9a9e0
6 changed files with 19 additions and 25 deletions
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
||||
"github.com/HFO4/cloudreve/pkg/filesystem/response"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/juju/ratelimit"
|
||||
|
@ -18,7 +19,7 @@ import (
|
|||
|
||||
// 限速后的ReaderSeeker
|
||||
type lrs struct {
|
||||
io.ReadSeeker
|
||||
response.RSCloser
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
|
@ -27,7 +28,7 @@ func (r lrs) Read(p []byte) (int, error) {
|
|||
}
|
||||
|
||||
// withSpeedLimit 给原有的ReadSeeker加上限速
|
||||
func (fs *FileSystem) withSpeedLimit(rs io.ReadSeeker) io.ReadSeeker {
|
||||
func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser {
|
||||
// 如果用户组有速度限制,就返回限制流速的ReaderSeeker
|
||||
if fs.User.Group.SpeedLimit != 0 {
|
||||
speed := fs.User.Group.SpeedLimit
|
||||
|
@ -63,7 +64,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model
|
|||
}
|
||||
|
||||
// GetPhysicalFileContent 根据文件物理路径获取文件流
|
||||
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (io.ReadSeeker, error) {
|
||||
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 重设上传策略
|
||||
fs.Policy = &model.Policy{Type: "local"}
|
||||
_ = fs.dispatchHandler()
|
||||
|
@ -78,7 +79,7 @@ func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (
|
|||
}
|
||||
|
||||
// GetDownloadContent 获取用于下载的文件流
|
||||
func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.ReadSeeker, error) {
|
||||
func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 获取原始文件流
|
||||
rs, err := fs.GetContent(ctx, path)
|
||||
if err != nil {
|
||||
|
@ -91,7 +92,7 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.R
|
|||
}
|
||||
|
||||
// GetContent 获取文件内容,path为虚拟路径
|
||||
func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeker, error) {
|
||||
func (fs *FileSystem) GetContent(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 触发`下载前`钩子
|
||||
err := fs.Trigger(ctx, fs.BeforeFileDownload)
|
||||
if err != nil {
|
||||
|
|
|
@ -27,7 +27,7 @@ type Handler interface {
|
|||
// 删除一个或多个文件
|
||||
Delete(ctx context.Context, files []string) ([]string, error)
|
||||
// 获取文件
|
||||
Get(ctx context.Context, path string) (io.ReadSeeker, error)
|
||||
Get(ctx context.Context, path string) (response.RSCloser, error)
|
||||
// 获取缩略图
|
||||
Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
|
||||
// 获取外链地址,url
|
||||
|
|
|
@ -25,7 +25,7 @@ type Handler struct {
|
|||
}
|
||||
|
||||
// Get 获取文件内容
|
||||
func (handler Handler) Get(ctx context.Context, path string) (io.ReadSeeker, error) {
|
||||
func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 打开文件
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
|
|
|
@ -10,3 +10,9 @@ type ContentResponse struct {
|
|||
Content io.ReadSeeker
|
||||
URL string
|
||||
}
|
||||
|
||||
// 存储策略适配器返回的文件流,有些策略需要带有Closer
|
||||
type RSCloser interface {
|
||||
io.ReadSeeker
|
||||
io.Closer
|
||||
}
|
||||
|
|
|
@ -22,9 +22,9 @@ type FileHeaderMock struct {
|
|||
testMock.Mock
|
||||
}
|
||||
|
||||
func (m FileHeaderMock) Get(ctx context.Context, path string) (io.ReadSeeker, error) {
|
||||
func (m FileHeaderMock) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
args := m.Called(ctx, path)
|
||||
return args.Get(0).(io.ReadSeeker), args.Error(1)
|
||||
return args.Get(0).(response.RSCloser), args.Error(1)
|
||||
}
|
||||
|
||||
func (m FileHeaderMock) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
@ -45,6 +44,7 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
|
|||
|
||||
// 获取文件流
|
||||
rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string))
|
||||
defer rs.Close()
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
||||
}
|
||||
|
@ -58,11 +58,6 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
|
|||
c.Header("Content-Type", "application/zip")
|
||||
http.ServeContent(c.Writer, c.Request, "", time.Now(), rs)
|
||||
|
||||
// 检查是否需要关闭文件
|
||||
if fc, ok := rs.(io.Closer); ok {
|
||||
err = fc.Close()
|
||||
}
|
||||
|
||||
return serializer.Response{
|
||||
Code: 0,
|
||||
}
|
||||
|
@ -84,6 +79,7 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
|
|||
|
||||
// 获取文件流
|
||||
rs, err := fs.GetDownloadContent(ctx, "")
|
||||
defer rs.Close()
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
||||
}
|
||||
|
@ -91,11 +87,6 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
|
|||
// 发送文件
|
||||
http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs)
|
||||
|
||||
// 检查是否需要关闭文件
|
||||
if fc, ok := rs.(io.Closer); ok {
|
||||
defer fc.Close()
|
||||
}
|
||||
|
||||
return serializer.Response{
|
||||
Code: 0,
|
||||
}
|
||||
|
@ -139,6 +130,7 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
|
|||
// 开始处理下载
|
||||
ctx = context.WithValue(ctx, fsctx.GinCtx, c)
|
||||
rs, err := fs.GetDownloadContent(ctx, "")
|
||||
defer rs.Close()
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
||||
}
|
||||
|
@ -154,11 +146,6 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
|
|||
// 发送文件
|
||||
http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs)
|
||||
|
||||
// 检查是否需要关闭文件
|
||||
if fc, ok := rs.(io.Closer); ok {
|
||||
defer fc.Close()
|
||||
}
|
||||
|
||||
return serializer.Response{
|
||||
Code: 0,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue