Feat: aria2 callback to master node / cancel or select task to slave node
This commit is contained in:
parent
4c2505032f
commit
34003a36d0
14 changed files with 263 additions and 63 deletions
|
@ -1,6 +1,8 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
@ -13,3 +15,25 @@ func MasterMetadata() gin.HandlerFunc {
|
|||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
|
||||
func UseSlaveAria2Instance() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if siteID, exist := c.Get("MasterSiteID"); exist {
|
||||
// 获取对应主机节点的从机Aria2实例
|
||||
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("MasterAria2Instance", caller)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, serializer.ParamErr("未知的主机节点ID", nil))
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,11 +11,18 @@ type Notifier struct {
|
|||
Subscribes sync.Map
|
||||
}
|
||||
|
||||
type CallbackFunc func(StatusEvent)
|
||||
|
||||
// Subscribe 订阅事件通知
|
||||
func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) {
|
||||
notifier.Subscribes.Store(gid, target)
|
||||
}
|
||||
|
||||
// Subscribe 订阅事件通知回调
|
||||
func (notifier *Notifier) SubscribeCallback(callback CallbackFunc, gid string) {
|
||||
notifier.Subscribes.Store(gid, callback)
|
||||
}
|
||||
|
||||
// Unsubscribe 取消订阅事件通知
|
||||
func (notifier *Notifier) Unsubscribe(gid string) {
|
||||
notifier.Subscribes.Delete(gid)
|
||||
|
@ -25,10 +32,17 @@ func (notifier *Notifier) Unsubscribe(gid string) {
|
|||
func (notifier *Notifier) Notify(events []rpc.Event, status int) {
|
||||
for _, event := range events {
|
||||
if target, ok := notifier.Subscribes.Load(event.Gid); ok {
|
||||
target.(chan StatusEvent) <- StatusEvent{
|
||||
msg := StatusEvent{
|
||||
GID: event.Gid,
|
||||
Status: status,
|
||||
}
|
||||
|
||||
if callback, ok := target.(CallbackFunc); ok {
|
||||
go callback(msg)
|
||||
} else {
|
||||
target.(chan StatusEvent) <- msg
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,6 +42,7 @@ func NewMonitor(task *model.Download) {
|
|||
if monitor.node != nil {
|
||||
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
|
||||
go monitor.Loop()
|
||||
|
||||
common.EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ func TestMonitor_Update(t *testing.T) {
|
|||
{
|
||||
MAX_RETRY = 1
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
|
||||
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
|
||||
file.Close()
|
||||
aria2.Instance = testInstance
|
||||
|
@ -91,7 +91,7 @@ func TestMonitor_Update(t *testing.T) {
|
|||
// 磁力链下载重定向
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{
|
||||
FollowedBy: []string{"1"},
|
||||
}, nil)
|
||||
monitor.Task.ID = 1
|
||||
|
@ -108,7 +108,7 @@ func TestMonitor_Update(t *testing.T) {
|
|||
// 无法更新任务信息
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, nil)
|
||||
monitor.Task.ID = 1
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
|
@ -122,7 +122,7 @@ func TestMonitor_Update(t *testing.T) {
|
|||
// 返回未知状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
|
@ -135,7 +135,7 @@ func TestMonitor_Update(t *testing.T) {
|
|||
// 返回被取消状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
|
@ -151,7 +151,7 @@ func TestMonitor_Update(t *testing.T) {
|
|||
// 返回活跃状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
|
@ -164,7 +164,7 @@ func TestMonitor_Update(t *testing.T) {
|
|||
// 返回错误状态
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
|
@ -177,7 +177,7 @@ func TestMonitor_Update(t *testing.T) {
|
|||
// 返回完成
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
|
||||
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
aria2.mock.ExpectCommit()
|
||||
|
@ -221,7 +221,7 @@ func TestMonitor_UpdateTaskInfo(t *testing.T) {
|
|||
// 更新成功,大小改变,需要校验,校验失败
|
||||
{
|
||||
testInstance := new(InstanceMock)
|
||||
testInstance.On("Cancel", testMock.Anything).Return(nil)
|
||||
testInstance.On("SlaveCancel", testMock.Anything).Return(nil)
|
||||
aria2.Instance = testInstance
|
||||
aria2.mock.ExpectBegin()
|
||||
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
|
|
@ -107,6 +107,10 @@ func (node *MasterNode) GetAria2Instance() common.Aria2 {
|
|||
return &node.aria2RPC
|
||||
}
|
||||
|
||||
func (node *MasterNode) IsMater() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *rpcService) Init() error {
|
||||
r.parent.lock.Lock()
|
||||
defer r.parent.lock.Unlock()
|
||||
|
|
|
@ -9,20 +9,30 @@ import (
|
|||
type Node interface {
|
||||
// Init a node from database model
|
||||
Init(node *model.Node)
|
||||
|
||||
// Check if given feature is enabled
|
||||
IsFeatureEnabled(feature string) bool
|
||||
|
||||
// Subscribe node status change to a callback function
|
||||
SubscribeStatusChange(callback func(isActive bool, id uint))
|
||||
|
||||
// Ping the node
|
||||
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
|
||||
|
||||
// Returns if the node is active
|
||||
IsActive() bool
|
||||
|
||||
// Get instances for aria2 calls
|
||||
GetAria2Instance() common.Aria2
|
||||
|
||||
// Returns unique id of this node
|
||||
ID() uint
|
||||
|
||||
// Kill node and recycle resources
|
||||
Kill()
|
||||
|
||||
// Returns if current node is master node
|
||||
IsMater() bool
|
||||
}
|
||||
|
||||
// Create new node from DB model
|
||||
|
|
|
@ -234,12 +234,17 @@ loop:
|
|||
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
|
||||
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
|
||||
return &serializer.NodePingReq{
|
||||
IsUpdate: isUpdate,
|
||||
SiteID: model.GetSettingByName("siteID"),
|
||||
Node: node.Model,
|
||||
IsUpdate: isUpdate,
|
||||
SiteID: model.GetSettingByName("siteID"),
|
||||
Node: node.Model,
|
||||
CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
|
||||
}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) IsMater() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
@ -307,11 +312,44 @@ func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
|
|||
}
|
||||
|
||||
func (s *slaveCaller) Cancel(task *model.Download) error {
|
||||
panic("implement me")
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "cancel")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Select(task *model.Download, files []int) error {
|
||||
panic("implement me")
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
Files: files,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "select")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveCaller) GetConfig() model.Aria2Option {
|
||||
|
|
|
@ -15,10 +15,11 @@ type ListRequest struct {
|
|||
|
||||
// NodePingReq 从机节点Ping请求
|
||||
type NodePingReq struct {
|
||||
SiteURL string `json:"site_url"`
|
||||
SiteID string `json:"site_id"`
|
||||
IsUpdate bool `json:"is_update"`
|
||||
Node *model.Node `json:"node"`
|
||||
SiteURL string `json:"site_url"`
|
||||
SiteID string `json:"site_id"`
|
||||
IsUpdate bool `json:"is_update"`
|
||||
CredentialTTL int `json:"credential_ttl"`
|
||||
Node *model.Node `json:"node"`
|
||||
}
|
||||
|
||||
// NodePingResp 从机节点Ping响应
|
||||
|
@ -29,5 +30,5 @@ type NodePingResp struct {
|
|||
type SlaveAria2Call struct {
|
||||
Task *model.Download `json:"task"`
|
||||
GroupOptions map[string]interface{} `json:"group_options"`
|
||||
Files []uint `json:"files"`
|
||||
Files []int `json:"files"`
|
||||
}
|
||||
|
|
|
@ -2,12 +2,15 @@ package slave
|
|||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
@ -20,10 +23,14 @@ type Controller interface {
|
|||
|
||||
// Get Aria2 instance by master node id
|
||||
GetAria2Instance(string) (common.Aria2, error)
|
||||
|
||||
// Send event change message to master node
|
||||
SendAria2Notification(string, common.StatusEvent) error
|
||||
}
|
||||
|
||||
type slaveController struct {
|
||||
masters map[string]masterInfo
|
||||
client request.Client
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
|
@ -32,6 +39,7 @@ type masterInfo struct {
|
|||
slaveID uint
|
||||
id string
|
||||
authClient auth.Auth
|
||||
ttl int
|
||||
// used to invoke aria2 rpc calls
|
||||
instance cluster.Node
|
||||
}
|
||||
|
@ -39,6 +47,7 @@ type masterInfo struct {
|
|||
func Init() {
|
||||
DefaultController = &slaveController{
|
||||
masters: make(map[string]masterInfo),
|
||||
client: request.HTTPClient{},
|
||||
}
|
||||
gob.Register(rpc.StatusInfo{})
|
||||
}
|
||||
|
@ -63,6 +72,7 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
|
|||
authClient: auth.HMACAuth{
|
||||
SecretKey: []byte(req.Node.MasterKey),
|
||||
},
|
||||
ttl: req.CredentialTTL,
|
||||
instance: cluster.NewNodeFromDBModel(&model.Node{
|
||||
Type: model.MasterNodeType,
|
||||
Aria2Enabled: req.Node.Aria2Enabled,
|
||||
|
@ -84,3 +94,31 @@ func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
|
|||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
||||
func (c *slaveController) SendAria2Notification(id string, msg common.StatusEvent) error {
|
||||
c.lock.RLock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
res, err := c.client.Request(
|
||||
"PATCH",
|
||||
fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status),
|
||||
nil,
|
||||
request.WithHeader(http.Header{"X-Node-ID": []string{fmt.Sprintf("%d", node.slaveID)}}),
|
||||
request.WithCredential(node.authClient, int64(node.ttl)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
c.lock.RUnlock()
|
||||
return ErrMasterNotFound
|
||||
}
|
||||
|
|
|
@ -95,3 +95,14 @@ func ListFinished(c *gin.Context) {
|
|||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
||||
// TaskUpdate 被动更新任务状态
|
||||
func TaskUpdate(c *gin.Context) {
|
||||
var service aria2.DownloadTaskService
|
||||
if err := c.ShouldBindQuery(&service); err == nil {
|
||||
res := service.Notify()
|
||||
c.JSON(200, res)
|
||||
} else {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -200,11 +200,33 @@ func SlaveAria2Create(c *gin.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
// SlaveAria2Status 查询 Aria2 任务状态
|
||||
// SlaveAria2Status 查询从机 Aria2 任务状态
|
||||
func SlaveAria2Status(c *gin.Context) {
|
||||
var service serializer.SlaveAria2Call
|
||||
if err := c.ShouldBindJSON(&service); err == nil {
|
||||
res := aria2.Status(c, &service)
|
||||
res := aria2.SlaveStatus(c, &service)
|
||||
c.JSON(200, res)
|
||||
} else {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
||||
// SlaveCancelAria2Task 取消从机离线下载任务
|
||||
func SlaveCancelAria2Task(c *gin.Context) {
|
||||
var service serializer.SlaveAria2Call
|
||||
if err := c.ShouldBindJSON(&service); err == nil {
|
||||
res := aria2.SlaveCancel(c, &service)
|
||||
c.JSON(200, res)
|
||||
} else {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
}
|
||||
}
|
||||
|
||||
// SlaveSelectTask 从机选取离线下载文件
|
||||
func SlaveSelectTask(c *gin.Context) {
|
||||
var service serializer.SlaveAria2Call
|
||||
if err := c.ShouldBindJSON(&service); err == nil {
|
||||
res := aria2.SlaveSelect(c, &service)
|
||||
c.JSON(200, res)
|
||||
} else {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
|
|
|
@ -56,11 +56,16 @@ func InitSlaveRouter() *gin.Engine {
|
|||
|
||||
// 离线下载
|
||||
aria2 := v3.Group("aria2")
|
||||
aria2.Use(middleware.UseSlaveAria2Instance())
|
||||
{
|
||||
// 创建离线下载任务
|
||||
aria2.POST("task", controllers.SlaveAria2Create)
|
||||
// 创建离线下载任务
|
||||
// 获取任务状态
|
||||
aria2.POST("status", controllers.SlaveAria2Status)
|
||||
// 取消离线下载任务
|
||||
aria2.POST("cancel", controllers.SlaveCancelAria2Task)
|
||||
// 选取任务文件
|
||||
aria2.POST("select", controllers.SlaveSelectTask)
|
||||
}
|
||||
}
|
||||
return r
|
||||
|
@ -187,6 +192,12 @@ func InitMasterRouter() *gin.Engine {
|
|||
}
|
||||
}
|
||||
|
||||
// 从机的 RPC 通信
|
||||
slave := v3.Group("slave")
|
||||
{
|
||||
slave.PATCH("aria2/:gid/:status", controllers.TaskUpdate)
|
||||
}
|
||||
|
||||
// 回调接口
|
||||
callback := v3.Group("callback")
|
||||
{
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
@ -78,22 +79,21 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
|
|||
|
||||
// Add 从机创建新的链接离线下载任务
|
||||
func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||
if siteID, exist := c.Get("MasterSiteID"); exist {
|
||||
// 获取对应主机节点的从机Aria2实例
|
||||
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err)
|
||||
}
|
||||
caller, _ := c.Get("MasterAria2Instance")
|
||||
|
||||
// 创建任务
|
||||
gid, err := caller.CreateTask(service.Task, service.GroupOptions)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法创建离线下载任务", err)
|
||||
}
|
||||
|
||||
// TODO: 创建监控
|
||||
return serializer.Response{Data: gid}
|
||||
// 创建任务
|
||||
gid, err := caller.(common.Aria2).CreateTask(service.Task, service.GroupOptions)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法创建离线下载任务", err)
|
||||
}
|
||||
|
||||
return serializer.ParamErr("未知的主机节点ID", nil)
|
||||
// 创建事件通知回调
|
||||
siteID, _ := c.Get("MasterSiteID")
|
||||
common.EventNotifier.SubscribeCallback(func(event common.StatusEvent) {
|
||||
if err := slave.DefaultController.SendAria2Notification(siteID.(string), event); err != nil {
|
||||
util.Log().Warning("无法发送离线下载任务状态变更通知, %s", err)
|
||||
}
|
||||
}, gid)
|
||||
|
||||
return serializer.Response{Data: gid}
|
||||
}
|
||||
|
|
|
@ -2,10 +2,10 @@ package aria2
|
|||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
@ -16,7 +16,8 @@ type SelectFileService struct {
|
|||
|
||||
// DownloadTaskService 下载任务管理服务
|
||||
type DownloadTaskService struct {
|
||||
GID string `uri:"gid" binding:"required"`
|
||||
GID string `uri:"gid" binding:"required"`
|
||||
Status int `uri:"gid"`
|
||||
}
|
||||
|
||||
// DownloadListService 下载列表服务
|
||||
|
@ -58,15 +59,20 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response {
|
|||
}
|
||||
|
||||
// 取消任务
|
||||
aria2.Lock.RLock()
|
||||
defer aria2.Lock.RUnlock()
|
||||
if err := aria2.Instance.Cancel(download); err != nil {
|
||||
node := cluster.Default.GetNodeByID(download.GetNodeID())
|
||||
if err := node.GetAria2Instance().Cancel(download); err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, "操作失败", err)
|
||||
}
|
||||
|
||||
return serializer.Response{}
|
||||
}
|
||||
|
||||
// Notify 转发通知任务更新
|
||||
func (service *DownloadTaskService) Notify() serializer.Response {
|
||||
common.EventNotifier.Notify([]rpc.Event{{service.GID}}, service.Status)
|
||||
return serializer.Response{}
|
||||
}
|
||||
|
||||
// Select 选取要下载的文件
|
||||
func (service *SelectFileService) Select(c *gin.Context) serializer.Response {
|
||||
userCtx, _ := c.Get("user")
|
||||
|
@ -83,9 +89,8 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response {
|
|||
}
|
||||
|
||||
// 选取下载
|
||||
aria2.Lock.RLock()
|
||||
defer aria2.Lock.RUnlock()
|
||||
if err := aria2.Instance.Select(download, service.Indexes); err != nil {
|
||||
node := cluster.Default.GetNodeByID(download.GetNodeID())
|
||||
if err := node.GetAria2Instance().Select(download, service.Indexes); err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, "操作失败", err)
|
||||
}
|
||||
|
||||
|
@ -93,23 +98,44 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response {
|
|||
|
||||
}
|
||||
|
||||
// Status 从机查询离线任务状态
|
||||
func Status(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||
if siteID, exist := c.Get("MasterSiteID"); exist {
|
||||
// 获取对应主机节点的从机Aria2实例
|
||||
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err)
|
||||
}
|
||||
// SlaveStatus 从机查询离线任务状态
|
||||
func SlaveStatus(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||
caller, _ := c.Get("MasterAria2Instance")
|
||||
|
||||
// 查询任务
|
||||
status, err := caller.Status(service.Task)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "离线下载任务查询失败", err)
|
||||
}
|
||||
|
||||
return serializer.NewResponseWithGobData(status)
|
||||
// 查询任务
|
||||
status, err := caller.(common.Aria2).Status(service.Task)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "离线下载任务查询失败", err)
|
||||
}
|
||||
|
||||
return serializer.ParamErr("未知的主机节点ID", nil)
|
||||
return serializer.NewResponseWithGobData(status)
|
||||
|
||||
}
|
||||
|
||||
// SlaveCancel 取消从机离线下载任务
|
||||
func SlaveCancel(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||
caller, _ := c.Get("MasterAria2Instance")
|
||||
|
||||
// 查询任务
|
||||
err := caller.(common.Aria2).Cancel(service.Task)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "任务取消失败", err)
|
||||
}
|
||||
|
||||
return serializer.Response{}
|
||||
|
||||
}
|
||||
|
||||
// SlaveSelect 从机选取离线下载任务文件
|
||||
func SlaveSelect(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||
caller, _ := c.Get("MasterAria2Instance")
|
||||
|
||||
// 查询任务
|
||||
err := caller.(common.Aria2).Select(service.Task, service.Files)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "任务选取失败", err)
|
||||
}
|
||||
|
||||
return serializer.Response{}
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue