From aeca1611866318b114012805b226b68813b67890 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Thu, 2 Jan 2020 15:36:13 +0800 Subject: [PATCH] Feat: cancel archive action / request with context --- pkg/filesystem/archive.go | 41 ++++++++++++++++++++++++++++---- pkg/filesystem/errors.go | 1 + pkg/filesystem/remote/handler.go | 1 + pkg/request/request.go | 20 +++++++++++++++- 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/pkg/filesystem/archive.go b/pkg/filesystem/archive.go index 990b046..5604295 100644 --- a/pkg/filesystem/archive.go +++ b/pkg/filesystem/archive.go @@ -3,10 +3,14 @@ package filesystem import ( "archive/zip" "context" + "errors" "fmt" model "github.com/HFO4/cloudreve/models" + "github.com/HFO4/cloudreve/pkg/filesystem/fsctx" "github.com/HFO4/cloudreve/pkg/util" + "github.com/gin-gonic/gin" "io" + "os" "path/filepath" "time" ) @@ -29,6 +33,11 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) ( return "", ErrDBListObjects } + ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context) + if !ok { + return "", errors.New("无法获取请求上下文") + } + // 将顶级待处理对象的路径设为根路径 for i := 0; i < len(folders); i++ { folders[i].Position = "" @@ -53,19 +62,43 @@ func (fs *FileSystem) Compress(ctx context.Context, folderIDs, fileIDs []uint) ( zipWriter := zip.NewWriter(zipFile) defer zipWriter.Close() - ctx, _ = context.WithCancel(context.Background()) - // ctx = context.WithValue(ctx, fsctx.UserCtx, *fs.User) + ctx = context.WithValue(ginCtx.Request.Context(), fsctx.UserCtx, *fs.User) + // 压缩各个目录及文件 for i := 0; i < len(folders); i++ { - fs.doCompress(ctx, nil, &folders[i], zipWriter, true) + select { + case <-ginCtx.Request.Context().Done(): + // 取消压缩请求 + fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath) + return "", ErrClientCanceled + default: + fs.doCompress(ctx, nil, &folders[i], zipWriter, true) + } + } for i := 0; i < len(files); i++ { - fs.doCompress(ctx, &files[i], nil, zipWriter, true) + select { + case <-ginCtx.Request.Context().Done(): + // 取消压缩请求 + fs.cancelCompress(ctx, zipWriter, zipFile, zipFilePath) + return "", ErrClientCanceled + default: + fs.doCompress(ctx, &files[i], nil, zipWriter, true) + } } return zipFilePath, nil } +// cancelCompress 取消压缩进程 +// TODO 测试 +func (fs *FileSystem) cancelCompress(ctx context.Context, zipWriter *zip.Writer, file *os.File, path string) { + util.Log().Debug("客户端取消压缩请求") + zipWriter.Close() + file.Close() + _ = os.Remove(path) +} + func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *model.Folder, zipWriter *zip.Writer, isArchive bool) { // 如果对象是文件 if file != nil { diff --git a/pkg/filesystem/errors.go b/pkg/filesystem/errors.go index 2cc3527..0e204e0 100644 --- a/pkg/filesystem/errors.go +++ b/pkg/filesystem/errors.go @@ -11,6 +11,7 @@ var ( ErrFileExtensionNotAllowed = errors.New("不允许上传此类型的文件") ErrInsufficientCapacity = errors.New("容量空间不足") ErrIllegalObjectName = errors.New("目标名称非法") + ErrClientCanceled = errors.New("客户端取消操作") ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "无法插入文件记录", nil) ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "同名文件已存在", nil) ErrFolderExisted = serializer.NewError(serializer.CodeObjectExist, "同名目录已存在", nil) diff --git a/pkg/filesystem/remote/handler.go b/pkg/filesystem/remote/handler.go index ded6f1a..27e772b 100644 --- a/pkg/filesystem/remote/handler.go +++ b/pkg/filesystem/remote/handler.go @@ -64,6 +64,7 @@ func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, "GET", downloadURL, nil, + request.WithContext(ctx), ).GetRSCloser() if err != nil { diff --git a/pkg/request/request.go b/pkg/request/request.go index b523103..ba3b810 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -1,6 +1,7 @@ package request import ( + "context" "errors" "fmt" "github.com/HFO4/cloudreve/pkg/auth" @@ -38,6 +39,7 @@ type options struct { header http.Header sign auth.Auth signTTL int64 + ctx context.Context } type optionFunc func(*options) @@ -60,6 +62,14 @@ func WithTimeout(t time.Duration) Option { }) } +// WithContext 设置请求上下文 +// TODO 测试 +func WithContext(c context.Context) Option { + return optionFunc(func(o *options) { + o.ctx = c + }) +} + // WithCredential 对请求进行签名 func WithCredential(instance auth.Auth, ttl int64) Option { return optionFunc(func(o *options) { @@ -87,7 +97,15 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio client := &http.Client{Timeout: options.timeout} // 创建请求 - req, err := http.NewRequest(method, target, body) + var ( + req *http.Request + err error + ) + if options.ctx != nil { + req, err = http.NewRequestWithContext(options.ctx, method, target, body) + } else { + req, err = http.NewRequest(method, target, body) + } if err != nil { return Response{Err: err} }