From 6ada16d25d8326904d5592f52f5a18dcef0571ab Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Fri, 29 Oct 2021 20:25:29 +0800 Subject: [PATCH] Fix: temp file cannot be removed when aria2 task fails --- models/node.go | 19 +++++++++++++++++-- pkg/aria2/common/common.go | 9 ++++++++- pkg/aria2/monitor/monitor.go | 11 +++++------ pkg/cluster/errors.go | 1 + pkg/cluster/master.go | 19 +++++++++++++++++++ pkg/cluster/pool.go | 2 ++ pkg/cluster/slave.go | 20 ++++++++++++++++++++ pkg/task/slavetask/transfer.go | 11 +++++++++++ pkg/task/tranfer.go | 9 +++++---- routers/controllers/slave.go | 11 +++++++++++ routers/router.go | 2 ++ service/admin/node.go | 18 ++++++++++++++++-- service/aria2/manage.go | 18 ++++++++++++++++++ 13 files changed, 135 insertions(+), 15 deletions(-) diff --git a/models/node.go b/models/node.go index 977c8b5..992a828 100644 --- a/models/node.go +++ b/models/node.go @@ -42,15 +42,22 @@ type NodeStatus int type ModelType int const ( - NodeActive = iota + NodeActive NodeStatus = iota NodeSuspend ) const ( - SlaveNodeType = iota + SlaveNodeType ModelType = iota MasterNodeType ) +// GetNodeByID 用ID获取节点 +func GetNodeByID(ID interface{}) (Node, error) { + var node Node + result := DB.First(&node, ID) + return node, result.Error +} + // GetNodesByStatus 根据给定状态获取节点 func GetNodesByStatus(status ...NodeStatus) ([]Node, error) { var nodes []Node @@ -74,3 +81,11 @@ func (node *Node) BeforeSave() (err error) { node.Aria2Options = string(optionsValue) return err } + +// SetStatus 设置节点启用状态 +func (node *Node) SetStatus(status NodeStatus) error { + node.Status = status + return DB.Model(node).Updates(map[string]interface{}{ + "status": status, + }).Error +} diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go index 211ccc5..8f281d8 100644 --- a/pkg/aria2/common/common.go +++ b/pkg/aria2/common/common.go @@ -18,8 +18,10 @@ type Aria2 interface { Cancel(task *model.Download) error // 选择要下载的文件 Select(task *model.Download, files []int) error - // GetConfig 获取离线下载配置 + // 获取离线下载配置 GetConfig() model.Aria2Option + // 删除临时下载文件 + DeleteTempFile(*model.Download) error } const ( @@ -86,6 +88,11 @@ func (instance *DummyAria2) GetConfig() model.Aria2Option { return model.Aria2Option{} } +// GetConfig 返回空的 +func (instance *DummyAria2) DeleteTempFile(src *model.Download) error { + return ErrNotEnabled +} + // GetStatus 将给定的状态字符串转换为状态标识数字 func GetStatus(status string) int { switch status { diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index 6c5d53e..7a04411 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "os" "path/filepath" "strconv" "time" @@ -40,11 +39,14 @@ func NewMonitor(task *model.Download) { notifier: make(chan mq.Message), node: cluster.Default.GetNodeByID(task.GetNodeID()), } + if monitor.node != nil { monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second go monitor.Loop() monitor.notifier = mq.GlobalMQ.Subscribe(monitor.Task.GID, 0) + } else { + monitor.setErrorStatus(errors.New("节点不可用")) } } @@ -102,6 +104,7 @@ func (monitor *Monitor) Update() bool { if err := monitor.UpdateTaskInfo(status); err != nil { util.Log().Warning("无法更新下载任务[%s]的任务信息[%s],", monitor.Task.GID, err) monitor.setErrorStatus(err) + monitor.RemoveTempFolder() return true } @@ -228,11 +231,7 @@ func (monitor *Monitor) Error(status rpc.StatusInfo) bool { // RemoveTempFolder 清理下载临时目录 func (monitor *Monitor) RemoveTempFolder() { - err := os.RemoveAll(monitor.Task.Parent) - if err != nil { - util.Log().Warning("无法删除离线下载临时目录[%s], %s", monitor.Task.Parent, err) - } - + monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task) } // Complete 完成下载,返回是否中断监控 diff --git a/pkg/cluster/errors.go b/pkg/cluster/errors.go index 0a19f6e..9afdbef 100644 --- a/pkg/cluster/errors.go +++ b/pkg/cluster/errors.go @@ -4,4 +4,5 @@ import "errors" var ( ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed") + ErrIlegalPath = errors.New("path out of boundary of setting temp folder") ) diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go index f28ba11..1f2f140 100644 --- a/pkg/cluster/master.go +++ b/pkg/cluster/master.go @@ -11,6 +11,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "net/url" + "os" "path/filepath" "strconv" "strings" @@ -18,6 +19,8 @@ import ( "time" ) +const deleteTempFileDuration = 60 * time.Second + type MasterNode struct { Model *model.Node aria2RPC rpcService @@ -242,3 +245,19 @@ func (r *rpcService) GetConfig() model.Aria2Option { return r.parent.Model.Aria2OptionsSerialized } + +func (s *rpcService) DeleteTempFile(task *model.Download) error { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + // 避免被aria2占用,异步执行删除 + go func(src string) { + time.Sleep(deleteTempFileDuration) + err := os.RemoveAll(src) + if err != nil { + util.Log().Warning("无法删除离线下载临时目录[%s], %s", src, err) + } + }(task.Parent) + + return nil +} diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go index 131fcf1..4526f4a 100644 --- a/pkg/cluster/pool.go +++ b/pkg/cluster/pool.go @@ -125,6 +125,7 @@ func (pool *NodePool) add(node *model.Node) { func (pool *NodePool) Add(node *model.Node) { pool.lock.Lock() + defer pool.buildIndexMap() defer pool.lock.Unlock() if _, ok := pool.active[node.ID]; ok { @@ -141,6 +142,7 @@ func (pool *NodePool) Add(node *model.Node) { func (pool *NodePool) Delete(id uint) { pool.lock.Lock() + defer pool.buildIndexMap() defer pool.lock.Unlock() if node, ok := pool.active[id]; ok { diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index be92ba5..a76f238 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -375,6 +375,26 @@ func (s *slaveCaller) GetConfig() model.Aria2Option { return s.parent.Model.Aria2OptionsSerialized } +func (s *slaveCaller) DeleteTempFile(task *model.Download) error { + s.parent.lock.RLock() + defer s.parent.lock.RUnlock() + + req := &serializer.SlaveAria2Call{ + Task: task, + } + + res, err := s.SendAria2Call(req, "delete") + if err != nil { + return err + } + + if res.Code != 0 { + return serializer.NewErrorFromResponse(res) + } + + return nil +} + func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) { reqBodyEncoded, err := json.Marshal(body) if err != nil { diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index 7aecd85..c312742 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -11,6 +11,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/util" "os" + "path/filepath" ) // TransferTask 文件中转任务 @@ -79,6 +80,8 @@ func (job *TransferTask) GetError() *task.JobError { // Do 开始执行任务 func (job *TransferTask) Do() { + defer job.Recycle() + fs, err := filesystem.NewAnonymousFileSystem() if err != nil { job.SetErrorMsg("无法初始化匿名文件系统", err) @@ -132,3 +135,11 @@ func (job *TransferTask) Do() { util.Log().Warning("无法发送转存成功通知到从机, ", err) } } + +// Recycle 回收临时文件 +func (job *TransferTask) Recycle() { + err := os.RemoveAll(filepath.Dir(job.Req.Src)) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", job.Req.Src, err) + } +} diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index aee96fb..5db638d 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -135,11 +135,12 @@ func (job *TransferTask) Do() { // Recycle 回收临时文件 func (job *TransferTask) Recycle() { - err := os.RemoveAll(job.TaskProps.Parent) - if err != nil { - util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) + if job.TaskProps.NodeID == 1 { + err := os.RemoveAll(job.TaskProps.Parent) + if err != nil { + util.Log().Warning("无法删除中转临时目录[%s], %s", job.TaskProps.Parent, err) + } } - } // NewTransferTask 新建中转任务 diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index 267ae40..10c46ff 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -265,3 +265,14 @@ func SlaveGetOneDriveCredential(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// SlaveSelectTask 从机删除离线下载临时文件 +func SlaveDeleteTempFile(c *gin.Context) { + var service serializer.SlaveAria2Call + if err := c.ShouldBindJSON(&service); err == nil { + res := aria2.SlaveDeleteTemp(c, &service) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index c620b4d..230b009 100644 --- a/routers/router.go +++ b/routers/router.go @@ -69,6 +69,8 @@ func InitSlaveRouter() *gin.Engine { aria2.POST("cancel", controllers.SlaveCancelAria2Task) // 选取任务文件 aria2.POST("select", controllers.SlaveSelectTask) + // 删除任务临时文件 + aria2.POST("delete", controllers.SlaveDeleteTempFile) } // 异步任务 diff --git a/service/admin/node.go b/service/admin/node.go index 1e7b216..1561c98 100644 --- a/service/admin/node.go +++ b/service/admin/node.go @@ -72,12 +72,26 @@ func (service *AdminListService) Nodes() serializer.Response { // ToggleNodeService 开关节点服务 type ToggleNodeService struct { - ID uint `uri:"id"` - Desired int `uri:"desired"` + ID uint `uri:"id"` + Desired model.NodeStatus `uri:"desired"` } // Toggle 开关节点 func (service *ToggleNodeService) Toggle() serializer.Response { + node, err := model.GetNodeByID(service.ID) + if err != nil { + return serializer.DBErr("找不到节点", err) + } + + if err = node.SetStatus(service.Desired); err != nil { + return serializer.DBErr("无法更改节点状态", err) + } + + if service.Desired == model.NodeActive { + cluster.Default.Add(&node) + } else { + cluster.Default.Delete(node.ID) + } return serializer.Response{} } diff --git a/service/aria2/manage.go b/service/aria2/manage.go index b8877a1..f3ed47d 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -58,6 +58,10 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { // 取消任务 node := cluster.Default.GetNodeByID(download.GetNodeID()) + if node == nil { + return serializer.Err(serializer.CodeInternalSetting, "目标节点不可用", err) + } + if err := node.GetAria2Instance().Cancel(download); err != nil { return serializer.Err(serializer.CodeNotSet, "操作失败", err) } @@ -131,3 +135,17 @@ func SlaveSelect(c *gin.Context, service *serializer.SlaveAria2Call) serializer. return serializer.Response{} } + +// SlaveSelect 从机选取离线下载任务文件 +func SlaveDeleteTemp(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response { + caller, _ := c.Get("MasterAria2Instance") + + // 查询任务 + err := caller.(common.Aria2).DeleteTempFile(service.Task) + if err != nil { + return serializer.Err(serializer.CodeInternalSetting, "临时文件删除失败", err) + } + + return serializer.Response{} + +}