Feat: aria2 callback to master node / cancel or select task to slave node

This commit is contained in:
HFO4 2021-08-21 11:08:29 +08:00
parent 4c2505032f
commit 34003a36d0
14 changed files with 263 additions and 63 deletions

View file

@ -1,6 +1,8 @@
package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/gin-gonic/gin"
)
@ -13,3 +15,25 @@ func MasterMetadata() gin.HandlerFunc {
c.Next()
}
}
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
func UseSlaveAria2Instance() gin.HandlerFunc {
return func(c *gin.Context) {
if siteID, exist := c.Get("MasterSiteID"); exist {
// 获取对应主机节点的从机Aria2实例
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
if err != nil {
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
c.Abort()
return
}
c.Set("MasterAria2Instance", caller)
c.Next()
return
}
c.JSON(200, serializer.ParamErr("未知的主机节点ID", nil))
c.Abort()
}
}

View file

@ -11,11 +11,18 @@ type Notifier struct {
Subscribes sync.Map
}
type CallbackFunc func(StatusEvent)
// Subscribe 订阅事件通知
func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) {
notifier.Subscribes.Store(gid, target)
}
// Subscribe 订阅事件通知回调
func (notifier *Notifier) SubscribeCallback(callback CallbackFunc, gid string) {
notifier.Subscribes.Store(gid, callback)
}
// Unsubscribe 取消订阅事件通知
func (notifier *Notifier) Unsubscribe(gid string) {
notifier.Subscribes.Delete(gid)
@ -25,10 +32,17 @@ func (notifier *Notifier) Unsubscribe(gid string) {
func (notifier *Notifier) Notify(events []rpc.Event, status int) {
for _, event := range events {
if target, ok := notifier.Subscribes.Load(event.Gid); ok {
target.(chan StatusEvent) <- StatusEvent{
msg := StatusEvent{
GID: event.Gid,
Status: status,
}
if callback, ok := target.(CallbackFunc); ok {
go callback(msg)
} else {
target.(chan StatusEvent) <- msg
}
}
}
}

View file

@ -42,6 +42,7 @@ func NewMonitor(task *model.Download) {
if monitor.node != nil {
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
go monitor.Loop()
common.EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID)
}
}

View file

@ -78,7 +78,7 @@ func TestMonitor_Update(t *testing.T) {
{
MAX_RETRY = 1
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, errors.New("error"))
file, _ := util.CreatNestedFile("TestMonitor_Update/1")
file.Close()
aria2.Instance = testInstance
@ -91,7 +91,7 @@ func TestMonitor_Update(t *testing.T) {
// 磁力链下载重定向
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{
FollowedBy: []string{"1"},
}, nil)
monitor.Task.ID = 1
@ -108,7 +108,7 @@ func TestMonitor_Update(t *testing.T) {
// 无法更新任务信息
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{}, nil)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{}, nil)
monitor.Task.ID = 1
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
@ -122,7 +122,7 @@ func TestMonitor_Update(t *testing.T) {
// 返回未知状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "?"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
@ -135,7 +135,7 @@ func TestMonitor_Update(t *testing.T) {
// 返回被取消状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "removed"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
@ -151,7 +151,7 @@ func TestMonitor_Update(t *testing.T) {
// 返回活跃状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "active"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
@ -164,7 +164,7 @@ func TestMonitor_Update(t *testing.T) {
// 返回错误状态
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "error"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
@ -177,7 +177,7 @@ func TestMonitor_Update(t *testing.T) {
// 返回完成
{
testInstance := new(InstanceMock)
testInstance.On("Status", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
testInstance.On("SlaveStatus", testMock.Anything).Return(rpc.StatusInfo{Status: "complete"}, nil)
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
aria2.mock.ExpectCommit()
@ -221,7 +221,7 @@ func TestMonitor_UpdateTaskInfo(t *testing.T) {
// 更新成功,大小改变,需要校验,校验失败
{
testInstance := new(InstanceMock)
testInstance.On("Cancel", testMock.Anything).Return(nil)
testInstance.On("SlaveCancel", testMock.Anything).Return(nil)
aria2.Instance = testInstance
aria2.mock.ExpectBegin()
aria2.mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))

View file

@ -107,6 +107,10 @@ func (node *MasterNode) GetAria2Instance() common.Aria2 {
return &node.aria2RPC
}
func (node *MasterNode) IsMater() bool {
return true
}
func (r *rpcService) Init() error {
r.parent.lock.Lock()
defer r.parent.lock.Unlock()

View file

@ -9,20 +9,30 @@ import (
type Node interface {
// Init a node from database model
Init(node *model.Node)
// Check if given feature is enabled
IsFeatureEnabled(feature string) bool
// Subscribe node status change to a callback function
SubscribeStatusChange(callback func(isActive bool, id uint))
// Ping the node
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
// Returns if the node is active
IsActive() bool
// Get instances for aria2 calls
GetAria2Instance() common.Aria2
// Returns unique id of this node
ID() uint
// Kill node and recycle resources
Kill()
// Returns if current node is master node
IsMater() bool
}
// Create new node from DB model

View file

@ -234,12 +234,17 @@ loop:
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
return &serializer.NodePingReq{
IsUpdate: isUpdate,
SiteID: model.GetSettingByName("siteID"),
Node: node.Model,
IsUpdate: isUpdate,
SiteID: model.GetSettingByName("siteID"),
Node: node.Model,
CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
}
}
func (node *SlaveNode) IsMater() bool {
return false
}
func (s *slaveCaller) Init() error {
return nil
}
@ -307,11 +312,44 @@ func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
}
func (s *slaveCaller) Cancel(task *model.Download) error {
panic("implement me")
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "cancel")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func (s *slaveCaller) Select(task *model.Download, files []int) error {
panic("implement me")
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
Files: files,
}
res, err := s.SendAria2Call(req, "select")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func (s *slaveCaller) GetConfig() model.Aria2Option {

View file

@ -15,10 +15,11 @@ type ListRequest struct {
// NodePingReq 从机节点Ping请求
type NodePingReq struct {
SiteURL string `json:"site_url"`
SiteID string `json:"site_id"`
IsUpdate bool `json:"is_update"`
Node *model.Node `json:"node"`
SiteURL string `json:"site_url"`
SiteID string `json:"site_id"`
IsUpdate bool `json:"is_update"`
CredentialTTL int `json:"credential_ttl"`
Node *model.Node `json:"node"`
}
// NodePingResp 从机节点Ping响应
@ -29,5 +30,5 @@ type NodePingResp struct {
type SlaveAria2Call struct {
Task *model.Download `json:"task"`
GroupOptions map[string]interface{} `json:"group_options"`
Files []uint `json:"files"`
Files []int `json:"files"`
}

View file

@ -2,12 +2,15 @@ package slave
import (
"encoding/gob"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/http"
"sync"
)
@ -20,10 +23,14 @@ type Controller interface {
// Get Aria2 instance by master node id
GetAria2Instance(string) (common.Aria2, error)
// Send event change message to master node
SendAria2Notification(string, common.StatusEvent) error
}
type slaveController struct {
masters map[string]masterInfo
client request.Client
lock sync.RWMutex
}
@ -32,6 +39,7 @@ type masterInfo struct {
slaveID uint
id string
authClient auth.Auth
ttl int
// used to invoke aria2 rpc calls
instance cluster.Node
}
@ -39,6 +47,7 @@ type masterInfo struct {
func Init() {
DefaultController = &slaveController{
masters: make(map[string]masterInfo),
client: request.HTTPClient{},
}
gob.Register(rpc.StatusInfo{})
}
@ -63,6 +72,7 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
authClient: auth.HMACAuth{
SecretKey: []byte(req.Node.MasterKey),
},
ttl: req.CredentialTTL,
instance: cluster.NewNodeFromDBModel(&model.Node{
Type: model.MasterNodeType,
Aria2Enabled: req.Node.Aria2Enabled,
@ -84,3 +94,31 @@ func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
return nil, ErrMasterNotFound
}
func (c *slaveController) SendAria2Notification(id string, msg common.StatusEvent) error {
c.lock.RLock()
if node, ok := c.masters[id]; ok {
c.lock.RUnlock()
res, err := c.client.Request(
"PATCH",
fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status),
nil,
request.WithHeader(http.Header{"X-Node-ID": []string{fmt.Sprintf("%d", node.slaveID)}}),
request.WithCredential(node.authClient, int64(node.ttl)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
c.lock.RUnlock()
return ErrMasterNotFound
}

View file

@ -95,3 +95,14 @@ func ListFinished(c *gin.Context) {
c.JSON(200, ErrorResponse(err))
}
}
// TaskUpdate 被动更新任务状态
func TaskUpdate(c *gin.Context) {
var service aria2.DownloadTaskService
if err := c.ShouldBindQuery(&service); err == nil {
res := service.Notify()
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}

View file

@ -200,11 +200,33 @@ func SlaveAria2Create(c *gin.Context) {
}
}
// SlaveAria2Status 查询 Aria2 任务状态
// SlaveAria2Status 查询从机 Aria2 任务状态
func SlaveAria2Status(c *gin.Context) {
var service serializer.SlaveAria2Call
if err := c.ShouldBindJSON(&service); err == nil {
res := aria2.Status(c, &service)
res := aria2.SlaveStatus(c, &service)
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// SlaveCancelAria2Task 取消从机离线下载任务
func SlaveCancelAria2Task(c *gin.Context) {
var service serializer.SlaveAria2Call
if err := c.ShouldBindJSON(&service); err == nil {
res := aria2.SlaveCancel(c, &service)
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))
}
}
// SlaveSelectTask 从机选取离线下载文件
func SlaveSelectTask(c *gin.Context) {
var service serializer.SlaveAria2Call
if err := c.ShouldBindJSON(&service); err == nil {
res := aria2.SlaveSelect(c, &service)
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))

View file

@ -56,11 +56,16 @@ func InitSlaveRouter() *gin.Engine {
// 离线下载
aria2 := v3.Group("aria2")
aria2.Use(middleware.UseSlaveAria2Instance())
{
// 创建离线下载任务
aria2.POST("task", controllers.SlaveAria2Create)
// 创建离线下载任务
// 获取任务状态
aria2.POST("status", controllers.SlaveAria2Status)
// 取消离线下载任务
aria2.POST("cancel", controllers.SlaveCancelAria2Task)
// 选取任务文件
aria2.POST("select", controllers.SlaveSelectTask)
}
}
return r
@ -187,6 +192,12 @@ func InitMasterRouter() *gin.Engine {
}
}
// 从机的 RPC 通信
slave := v3.Group("slave")
{
slave.PATCH("aria2/:gid/:status", controllers.TaskUpdate)
}
// 回调接口
callback := v3.Group("callback")
{

View file

@ -9,6 +9,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
)
@ -78,22 +79,21 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
// Add 从机创建新的链接离线下载任务
func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
if siteID, exist := c.Get("MasterSiteID"); exist {
// 获取对应主机节点的从机Aria2实例
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
if err != nil {
return serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err)
}
caller, _ := c.Get("MasterAria2Instance")
// 创建任务
gid, err := caller.CreateTask(service.Task, service.GroupOptions)
if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "无法创建离线下载任务", err)
}
// TODO: 创建监控
return serializer.Response{Data: gid}
// 创建任务
gid, err := caller.(common.Aria2).CreateTask(service.Task, service.GroupOptions)
if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "无法创建离线下载任务", err)
}
return serializer.ParamErr("未知的主机节点ID", nil)
// 创建事件通知回调
siteID, _ := c.Get("MasterSiteID")
common.EventNotifier.SubscribeCallback(func(event common.StatusEvent) {
if err := slave.DefaultController.SendAria2Notification(siteID.(string), event); err != nil {
util.Log().Warning("无法发送离线下载任务状态变更通知, %s", err)
}
}, gid)
return serializer.Response{Data: gid}
}

View file

@ -2,10 +2,10 @@ package aria2
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/gin-gonic/gin"
)
@ -16,7 +16,8 @@ type SelectFileService struct {
// DownloadTaskService 下载任务管理服务
type DownloadTaskService struct {
GID string `uri:"gid" binding:"required"`
GID string `uri:"gid" binding:"required"`
Status int `uri:"gid"`
}
// DownloadListService 下载列表服务
@ -58,15 +59,20 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response {
}
// 取消任务
aria2.Lock.RLock()
defer aria2.Lock.RUnlock()
if err := aria2.Instance.Cancel(download); err != nil {
node := cluster.Default.GetNodeByID(download.GetNodeID())
if err := node.GetAria2Instance().Cancel(download); err != nil {
return serializer.Err(serializer.CodeNotSet, "操作失败", err)
}
return serializer.Response{}
}
// Notify 转发通知任务更新
func (service *DownloadTaskService) Notify() serializer.Response {
common.EventNotifier.Notify([]rpc.Event{{service.GID}}, service.Status)
return serializer.Response{}
}
// Select 选取要下载的文件
func (service *SelectFileService) Select(c *gin.Context) serializer.Response {
userCtx, _ := c.Get("user")
@ -83,9 +89,8 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response {
}
// 选取下载
aria2.Lock.RLock()
defer aria2.Lock.RUnlock()
if err := aria2.Instance.Select(download, service.Indexes); err != nil {
node := cluster.Default.GetNodeByID(download.GetNodeID())
if err := node.GetAria2Instance().Select(download, service.Indexes); err != nil {
return serializer.Err(serializer.CodeNotSet, "操作失败", err)
}
@ -93,23 +98,44 @@ func (service *SelectFileService) Select(c *gin.Context) serializer.Response {
}
// Status 从机查询离线任务状态
func Status(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
if siteID, exist := c.Get("MasterSiteID"); exist {
// 获取对应主机节点的从机Aria2实例
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
if err != nil {
return serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err)
}
// SlaveStatus 从机查询离线任务状态
func SlaveStatus(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
caller, _ := c.Get("MasterAria2Instance")
// 查询任务
status, err := caller.Status(service.Task)
if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "离线下载任务查询失败", err)
}
return serializer.NewResponseWithGobData(status)
// 查询任务
status, err := caller.(common.Aria2).Status(service.Task)
if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "离线下载任务查询失败", err)
}
return serializer.ParamErr("未知的主机节点ID", nil)
return serializer.NewResponseWithGobData(status)
}
// SlaveCancel 取消从机离线下载任务
func SlaveCancel(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
caller, _ := c.Get("MasterAria2Instance")
// 查询任务
err := caller.(common.Aria2).Cancel(service.Task)
if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "任务取消失败", err)
}
return serializer.Response{}
}
// SlaveSelect 从机选取离线下载任务文件
func SlaveSelect(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
caller, _ := c.Get("MasterAria2Instance")
// 查询任务
err := caller.(common.Aria2).Select(service.Task, service.Files)
if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "任务选取失败", err)
}
return serializer.Response{}
}