Modify: add general ReaderCloserSeeker interface for handler GET method to return
This commit is contained in:
parent
f262caf1f5
commit
03dcd9a9e0
6 changed files with 19 additions and 25 deletions
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
model "github.com/HFO4/cloudreve/models"
|
model "github.com/HFO4/cloudreve/models"
|
||||||
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
||||||
|
"github.com/HFO4/cloudreve/pkg/filesystem/response"
|
||||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||||
"github.com/HFO4/cloudreve/pkg/util"
|
"github.com/HFO4/cloudreve/pkg/util"
|
||||||
"github.com/juju/ratelimit"
|
"github.com/juju/ratelimit"
|
||||||
|
@ -18,7 +19,7 @@ import (
|
||||||
|
|
||||||
// 限速后的ReaderSeeker
|
// 限速后的ReaderSeeker
|
||||||
type lrs struct {
|
type lrs struct {
|
||||||
io.ReadSeeker
|
response.RSCloser
|
||||||
r io.Reader
|
r io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +28,7 @@ func (r lrs) Read(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// withSpeedLimit 给原有的ReadSeeker加上限速
|
// withSpeedLimit 给原有的ReadSeeker加上限速
|
||||||
func (fs *FileSystem) withSpeedLimit(rs io.ReadSeeker) io.ReadSeeker {
|
func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser {
|
||||||
// 如果用户组有速度限制,就返回限制流速的ReaderSeeker
|
// 如果用户组有速度限制,就返回限制流速的ReaderSeeker
|
||||||
if fs.User.Group.SpeedLimit != 0 {
|
if fs.User.Group.SpeedLimit != 0 {
|
||||||
speed := fs.User.Group.SpeedLimit
|
speed := fs.User.Group.SpeedLimit
|
||||||
|
@ -63,7 +64,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPhysicalFileContent 根据文件物理路径获取文件流
|
// GetPhysicalFileContent 根据文件物理路径获取文件流
|
||||||
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (io.ReadSeeker, error) {
|
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) {
|
||||||
// 重设上传策略
|
// 重设上传策略
|
||||||
fs.Policy = &model.Policy{Type: "local"}
|
fs.Policy = &model.Policy{Type: "local"}
|
||||||
_ = fs.dispatchHandler()
|
_ = fs.dispatchHandler()
|
||||||
|
@ -78,7 +79,7 @@ func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDownloadContent 获取用于下载的文件流
|
// GetDownloadContent 获取用于下载的文件流
|
||||||
func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.ReadSeeker, error) {
|
func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (response.RSCloser, error) {
|
||||||
// 获取原始文件流
|
// 获取原始文件流
|
||||||
rs, err := fs.GetContent(ctx, path)
|
rs, err := fs.GetContent(ctx, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -91,7 +92,7 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.R
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetContent 获取文件内容,path为虚拟路径
|
// GetContent 获取文件内容,path为虚拟路径
|
||||||
func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeker, error) {
|
func (fs *FileSystem) GetContent(ctx context.Context, path string) (response.RSCloser, error) {
|
||||||
// 触发`下载前`钩子
|
// 触发`下载前`钩子
|
||||||
err := fs.Trigger(ctx, fs.BeforeFileDownload)
|
err := fs.Trigger(ctx, fs.BeforeFileDownload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -27,7 +27,7 @@ type Handler interface {
|
||||||
// 删除一个或多个文件
|
// 删除一个或多个文件
|
||||||
Delete(ctx context.Context, files []string) ([]string, error)
|
Delete(ctx context.Context, files []string) ([]string, error)
|
||||||
// 获取文件
|
// 获取文件
|
||||||
Get(ctx context.Context, path string) (io.ReadSeeker, error)
|
Get(ctx context.Context, path string) (response.RSCloser, error)
|
||||||
// 获取缩略图
|
// 获取缩略图
|
||||||
Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
|
Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
|
||||||
// 获取外链地址,url
|
// 获取外链地址,url
|
||||||
|
|
|
@ -25,7 +25,7 @@ type Handler struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get 获取文件内容
|
// Get 获取文件内容
|
||||||
func (handler Handler) Get(ctx context.Context, path string) (io.ReadSeeker, error) {
|
func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||||
// 打开文件
|
// 打开文件
|
||||||
file, err := os.Open(path)
|
file, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -10,3 +10,9 @@ type ContentResponse struct {
|
||||||
Content io.ReadSeeker
|
Content io.ReadSeeker
|
||||||
URL string
|
URL string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 存储策略适配器返回的文件流,有些策略需要带有Closer
|
||||||
|
type RSCloser interface {
|
||||||
|
io.ReadSeeker
|
||||||
|
io.Closer
|
||||||
|
}
|
||||||
|
|
|
@ -22,9 +22,9 @@ type FileHeaderMock struct {
|
||||||
testMock.Mock
|
testMock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m FileHeaderMock) Get(ctx context.Context, path string) (io.ReadSeeker, error) {
|
func (m FileHeaderMock) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||||
args := m.Called(ctx, path)
|
args := m.Called(ctx, path)
|
||||||
return args.Get(0).(io.ReadSeeker), args.Error(1)
|
return args.Get(0).(response.RSCloser), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m FileHeaderMock) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
|
func (m FileHeaderMock) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
|
||||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -45,6 +44,7 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
|
||||||
|
|
||||||
// 获取文件流
|
// 获取文件流
|
||||||
rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string))
|
rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string))
|
||||||
|
defer rs.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
||||||
}
|
}
|
||||||
|
@ -58,11 +58,6 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
|
||||||
c.Header("Content-Type", "application/zip")
|
c.Header("Content-Type", "application/zip")
|
||||||
http.ServeContent(c.Writer, c.Request, "", time.Now(), rs)
|
http.ServeContent(c.Writer, c.Request, "", time.Now(), rs)
|
||||||
|
|
||||||
// 检查是否需要关闭文件
|
|
||||||
if fc, ok := rs.(io.Closer); ok {
|
|
||||||
err = fc.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return serializer.Response{
|
return serializer.Response{
|
||||||
Code: 0,
|
Code: 0,
|
||||||
}
|
}
|
||||||
|
@ -84,6 +79,7 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
|
||||||
|
|
||||||
// 获取文件流
|
// 获取文件流
|
||||||
rs, err := fs.GetDownloadContent(ctx, "")
|
rs, err := fs.GetDownloadContent(ctx, "")
|
||||||
|
defer rs.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
||||||
}
|
}
|
||||||
|
@ -91,11 +87,6 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
|
||||||
// 发送文件
|
// 发送文件
|
||||||
http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs)
|
http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs)
|
||||||
|
|
||||||
// 检查是否需要关闭文件
|
|
||||||
if fc, ok := rs.(io.Closer); ok {
|
|
||||||
defer fc.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return serializer.Response{
|
return serializer.Response{
|
||||||
Code: 0,
|
Code: 0,
|
||||||
}
|
}
|
||||||
|
@ -139,6 +130,7 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
|
||||||
// 开始处理下载
|
// 开始处理下载
|
||||||
ctx = context.WithValue(ctx, fsctx.GinCtx, c)
|
ctx = context.WithValue(ctx, fsctx.GinCtx, c)
|
||||||
rs, err := fs.GetDownloadContent(ctx, "")
|
rs, err := fs.GetDownloadContent(ctx, "")
|
||||||
|
defer rs.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
|
||||||
}
|
}
|
||||||
|
@ -154,11 +146,6 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
|
||||||
// 发送文件
|
// 发送文件
|
||||||
http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs)
|
http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs)
|
||||||
|
|
||||||
// 检查是否需要关闭文件
|
|
||||||
if fc, ok := rs.(io.Closer); ok {
|
|
||||||
defer fc.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return serializer.Response{
|
return serializer.Response{
|
||||||
Code: 0,
|
Code: 0,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue