diff --git a/pkg/aria2/caller.go b/pkg/aria2/caller.go index f3d58e3..70e0bea 100644 --- a/pkg/aria2/caller.go +++ b/pkg/aria2/caller.go @@ -8,8 +8,8 @@ import ( "time" model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -34,7 +34,7 @@ func (client *RPCService) Init(server, secret string, timeout int, options map[s Options: options, } caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second, - common.EventNotifier) + mq.GlobalMQ) client.Caller = caller return err } diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go index 3eaa746..211ccc5 100644 --- a/pkg/aria2/common/common.go +++ b/pkg/aria2/common/common.go @@ -105,6 +105,3 @@ func GetStatus(status string) int { return Unknown } } - -// EventNotifier 任务状态更新通知处理器 -var EventNotifier = &Notifier{} diff --git a/pkg/aria2/common/notification.go b/pkg/aria2/common/notification.go deleted file mode 100644 index 3804bea..0000000 --- a/pkg/aria2/common/notification.go +++ /dev/null @@ -1,84 +0,0 @@ -package common - -import ( - "sync" - - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" -) - -// Notifier aria2事件通知处理 -type Notifier struct { - Subscribes sync.Map -} - -type CallbackFunc func(StatusEvent) - -// Subscribe 订阅事件通知 -func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) { - notifier.Subscribes.Store(gid, target) -} - -// Subscribe 订阅事件通知回调 -func (notifier *Notifier) SubscribeCallback(callback CallbackFunc, gid string) { - notifier.Subscribes.Store(gid, callback) -} - -// Unsubscribe 取消订阅事件通知 -func (notifier *Notifier) Unsubscribe(gid string) { - notifier.Subscribes.Delete(gid) -} - -// Notify 发送通知 -func (notifier *Notifier) Notify(events []rpc.Event, status int) { - for _, event := range events { - if target, ok := notifier.Subscribes.Load(event.Gid); ok { - msg := StatusEvent{ - GID: event.Gid, - Status: status, - } - - if callback, ok := target.(CallbackFunc); ok { - go callback(msg) - } else { - target.(chan StatusEvent) <- msg - } - - } - } -} - -// OnDownloadStart 下载开始 -func (notifier *Notifier) OnDownloadStart(events []rpc.Event) { - notifier.Notify(events, Downloading) -} - -// OnDownloadPause 下载暂停 -func (notifier *Notifier) OnDownloadPause(events []rpc.Event) { - notifier.Notify(events, Paused) -} - -// OnDownloadStop 下载停止 -func (notifier *Notifier) OnDownloadStop(events []rpc.Event) { - notifier.Notify(events, Canceled) -} - -// OnDownloadComplete 下载完成 -func (notifier *Notifier) OnDownloadComplete(events []rpc.Event) { - notifier.Notify(events, Complete) -} - -// OnDownloadError 下载出错 -func (notifier *Notifier) OnDownloadError(events []rpc.Event) { - notifier.Notify(events, Error) -} - -// OnBtDownloadComplete BT下载完成 -func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) { - notifier.Notify(events, Complete) -} - -// StatusEvent 状态改变事件 -type StatusEvent struct { - GID string - Status int -} diff --git a/pkg/aria2/monitor/monitor.go b/pkg/aria2/monitor/monitor.go index 08d3d9f..115bdf1 100644 --- a/pkg/aria2/monitor/monitor.go +++ b/pkg/aria2/monitor/monitor.go @@ -16,6 +16,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -25,7 +26,7 @@ type Monitor struct { Task *model.Download Interval time.Duration - notifier chan common.StatusEvent + notifier <-chan mq.Message node cluster.Node retried int } @@ -36,20 +37,20 @@ var MAX_RETRY = 10 func NewMonitor(task *model.Download) { monitor := &Monitor{ Task: task, - notifier: make(chan common.StatusEvent), + notifier: make(chan mq.Message), node: cluster.Default.GetNodeByID(task.GetNodeID()), } if monitor.node != nil { monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second go monitor.Loop() - common.EventNotifier.Subscribe(monitor.notifier, monitor.Task.GID) + monitor.notifier = mq.GlobalMQ.Subscribe(monitor.Task.GID, 0) } } // Loop 开启监控循环 func (monitor *Monitor) Loop() { - defer common.EventNotifier.Unsubscribe(monitor.Task.GID) + defer mq.GlobalMQ.Unsubscribe(monitor.Task.GID, monitor.notifier) fmt.Println(cluster.Default) // 首次循环立即更新 diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go index 28f1a3f..f28ba11 100644 --- a/pkg/cluster/master.go +++ b/pkg/cluster/master.go @@ -7,6 +7,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "net/url" @@ -168,7 +169,7 @@ func (r *rpcService) Init() error { Options: globalOptions, } timeout := r.parent.Model.Aria2OptionsSerialized.Timeout - caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, common.EventNotifier) + caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ) r.Caller = caller r.Initialized = err == nil diff --git a/pkg/mq/mq.go b/pkg/mq/mq.go index 0aa8f36..e7a8a34 100644 --- a/pkg/mq/mq.go +++ b/pkg/mq/mq.go @@ -51,7 +51,7 @@ func NewMQ() MQ { func init() { gob.Register(Message{}) - gob.Register(rpc.Event{}) + gob.Register([]rpc.Event{}) } type inMemoryMQ struct { diff --git a/pkg/slave/slave.go b/pkg/slave/slave.go index 8f7f7b9..e062ef0 100644 --- a/pkg/slave/slave.go +++ b/pkg/slave/slave.go @@ -1,12 +1,14 @@ package slave import ( + "bytes" "encoding/gob" "fmt" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/task" @@ -26,7 +28,7 @@ type Controller interface { GetAria2Instance(string) (common.Aria2, error) // Send event change message to master node - SendAria2Notification(string, common.StatusEvent) error + SendNotification(string, string, mq.Message) error // Submit async task into task pool SubmitTask(string, task.Job, string) error @@ -105,18 +107,23 @@ func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) { return nil, ErrMasterNotFound } -func (c *slaveController) SendAria2Notification(id string, msg common.StatusEvent) error { +func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error { c.lock.RLock() if node, ok := c.masters[id]; ok { c.lock.RUnlock() - apiPath, _ := url.Parse(fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status)) + apiPath, _ := url.Parse(fmt.Sprintf("/api/v3/slave/notification/%s", subject)) + body := bytes.Buffer{} + enc := gob.NewEncoder(&body) + if err := enc.Encode(&msg); err != nil { + return err + } res, err := c.client.Request( - "PATCH", + "PUT", node.url.ResolveReference(apiPath).String(), - nil, + &body, request.WithHeader(http.Header{"X-Node-Id": []string{fmt.Sprintf("%d", node.slaveID)}}), request.WithCredential(node.instance.MasterAuthInstance(), int64(node.ttl)), ).CheckHTTPResponse(200).DecodeResponse() diff --git a/routers/controllers/aria2.go b/routers/controllers/aria2.go index 0259e4a..25a8fb0 100644 --- a/routers/controllers/aria2.go +++ b/routers/controllers/aria2.go @@ -95,14 +95,3 @@ func ListFinished(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } - -// TaskUpdate 被动更新任务状态 -func TaskUpdate(c *gin.Context) { - var service aria2.DownloadTaskService - if err := c.ShouldBindUri(&service); err == nil { - res := service.Notify() - c.JSON(200, res) - } else { - c.JSON(200, ErrorResponse(err)) - } -} diff --git a/routers/controllers/slave.go b/routers/controllers/slave.go index c766fa2..cc64bf6 100644 --- a/routers/controllers/slave.go +++ b/routers/controllers/slave.go @@ -243,3 +243,14 @@ func SlaveCreateTransferTask(c *gin.Context) { c.JSON(200, ErrorResponse(err)) } } + +// SlaveNotificationPush 处理从机发送的消息推送 +func SlaveNotificationPush(c *gin.Context) { + var service node.SlaveNotificationService + if err := c.ShouldBindUri(&service); err == nil { + res := service.HandleSlaveNotificationPush(c) + c.JSON(200, res) + } else { + c.JSON(200, ErrorResponse(err)) + } +} diff --git a/routers/router.go b/routers/router.go index f1ed77b..92e4133 100644 --- a/routers/router.go +++ b/routers/router.go @@ -203,7 +203,7 @@ func InitMasterRouter() *gin.Engine { slave := v3.Group("slave") slave.Use(middleware.SlaveRPCSignRequired()) { - slave.PATCH("aria2/:gid/:status", controllers.TaskUpdate) + slave.PUT("notification/:subject", controllers.SlaveNotificationPush) } // 回调接口 diff --git a/service/aria2/add.go b/service/aria2/add.go index bada929..26b6baa 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -7,6 +7,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/util" @@ -89,11 +90,11 @@ func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response // 创建事件通知回调 siteID, _ := c.Get("MasterSiteID") - common.EventNotifier.SubscribeCallback(func(event common.StatusEvent) { - if err := slave.DefaultController.SendAria2Notification(siteID.(string), event); err != nil { + mq.GlobalMQ.SubscribeCallback(gid, func(message mq.Message) { + if err := slave.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil { util.Log().Warning("无法发送离线下载任务状态变更通知, %s", err) } - }, gid) + }) return serializer.Response{Data: gid} } diff --git a/service/aria2/manage.go b/service/aria2/manage.go index 020053e..b8877a1 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -3,7 +3,6 @@ package aria2 import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" - "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/gin-gonic/gin" @@ -16,8 +15,7 @@ type SelectFileService struct { // DownloadTaskService 下载任务管理服务 type DownloadTaskService struct { - GID string `uri:"gid" binding:"required"` - Status int `uri:"status"` + GID string `uri:"gid" binding:"required"` } // DownloadListService 下载列表服务 @@ -67,12 +65,6 @@ func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { return serializer.Response{} } -// Notify 转发通知任务更新 -func (service *DownloadTaskService) Notify() serializer.Response { - common.EventNotifier.Notify([]rpc.Event{{service.GID}}, service.Status) - return serializer.Response{} -} - // Select 选取要下载的文件 func (service *SelectFileService) Select(c *gin.Context) serializer.Response { userCtx, _ := c.Get("user") diff --git a/service/node/fabric.go b/service/node/fabric.go index 959e046..6904bb0 100644 --- a/service/node/fabric.go +++ b/service/node/fabric.go @@ -1,10 +1,17 @@ package node import ( + "encoding/gob" + "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/slave" + "github.com/gin-gonic/gin" ) +type SlaveNotificationService struct { + Subject string `uri:"subject" binding:"required"` +} + func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response { res, err := slave.DefaultController.HandleHeartBeat(req) if err != nil { @@ -16,3 +23,15 @@ func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response { Data: res, } } + +// HandleSlaveNotificationPush 转发从机的消息通知到本机消息队列 +func (s *SlaveNotificationService) HandleSlaveNotificationPush(c *gin.Context) serializer.Response { + var msg mq.Message + dec := gob.NewDecoder(c.Request.Body) + if err := dec.Decode(&msg); err != nil { + return serializer.ParamErr("无法解析通知消息", err) + } + + mq.GlobalMQ.Publish(s.Subject, msg) + return serializer.Response{} +}