diff --git a/middleware/auth.go b/middleware/auth.go index 69233ee..135dd8c 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -22,16 +22,14 @@ import ( ) // SignRequired 验证请求签名 -func SignRequired() gin.HandlerFunc { +func SignRequired(authInstance auth.Auth) gin.HandlerFunc { return func(c *gin.Context) { var err error switch c.Request.Method { - case "PUT", "POST": - err = auth.CheckRequest(auth.General, c.Request) - // TODO 生产环境去掉下一行 - //err = nil + case "PUT", "POST", "PATCH": + err = auth.CheckRequest(authInstance, c.Request) default: - err = auth.CheckURI(auth.General, c.Request.URL) + err = auth.CheckURI(authInstance, c.Request.URL) } if err != nil { diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 95ab75c..a17602d 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -87,11 +87,10 @@ func TestAuthRequired(t *testing.T) { func TestSignRequired(t *testing.T) { asserts := assert.New(t) - auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request, _ = http.NewRequest("GET", "/test", nil) - SignRequiredFunc := SignRequired() + SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}) // 鉴权失败 SignRequiredFunc(c) diff --git a/middleware/cluster.go b/middleware/cluster.go index 192fa38..8a59c36 100644 --- a/middleware/cluster.go +++ b/middleware/cluster.go @@ -1,16 +1,18 @@ package middleware import ( + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/gin-gonic/gin" + "strconv" ) // MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据 func MasterMetadata() gin.HandlerFunc { return func(c *gin.Context) { - c.Set("MasterSiteID", c.GetHeader("X-Site-ID")) - c.Set("MasterSiteURL", c.GetHeader("X-Site-Ur")) + c.Set("MasterSiteID", c.GetHeader("X-Site-Id")) + c.Set("MasterSiteURL", c.GetHeader("X-Site-Url")) c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version")) c.Next() } @@ -37,3 +39,24 @@ func UseSlaveAria2Instance() gin.HandlerFunc { c.Abort() } } + +func SlaveRPCSignRequired() gin.HandlerFunc { + return func(c *gin.Context) { + nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64) + if err != nil { + c.JSON(200, serializer.ParamErr("未知的主机节点ID", err)) + c.Abort() + return + } + + slaveNode := cluster.Default.GetNodeByID(uint(nodeID)) + if slaveNode == nil { + c.JSON(200, serializer.ParamErr("未知的主机节点ID", err)) + c.Abort() + return + } + + SignRequired(slaveNode.GetAuthInstance())(c) + + } +} diff --git a/pkg/balancer/errors.go b/pkg/balancer/errors.go index 5285478..aef7b1f 100644 --- a/pkg/balancer/errors.go +++ b/pkg/balancer/errors.go @@ -3,5 +3,6 @@ package balancer import "errors" var ( - ErrInputNotSlice = errors.New("Input value is not silice") + ErrInputNotSlice = errors.New("Input value is not silice") + ErrNoAvaliableNode = errors.New("No nodes avaliable") ) diff --git a/pkg/balancer/roundrobin.go b/pkg/balancer/roundrobin.go index 26f4ccc..cf300f5 100644 --- a/pkg/balancer/roundrobin.go +++ b/pkg/balancer/roundrobin.go @@ -16,6 +16,10 @@ func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) { return ErrInputNotSlice, nil } + if v.Len() == 0 { + return ErrNoAvaliableNode, nil + } + next := r.NextIndex(v.Len()) return nil, v.Index(next).Interface() } diff --git a/pkg/cluster/master.go b/pkg/cluster/master.go index 491fb63..3d6fff1 100644 --- a/pkg/cluster/master.go +++ b/pkg/cluster/master.go @@ -6,6 +6,7 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/util" "net/url" @@ -75,6 +76,13 @@ func (node *MasterNode) IsFeatureEnabled(feature string) bool { } } +func (node *MasterNode) GetAuthInstance() auth.Auth { + node.lock.RLock() + defer node.lock.RUnlock() + + return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)} +} + // SubscribeStatusChange 订阅节点状态更改 func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) { } diff --git a/pkg/cluster/node.go b/pkg/cluster/node.go index 31e637f..9f34f50 100644 --- a/pkg/cluster/node.go +++ b/pkg/cluster/node.go @@ -3,6 +3,7 @@ package cluster import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" ) @@ -33,6 +34,9 @@ type Node interface { // Returns if current node is master node IsMater() bool + + // Get auth instance used to check RPC call from slave to master + GetAuthInstance() auth.Auth } // Create new node from DB model diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go index bc965dc..9893dcf 100644 --- a/pkg/cluster/pool.go +++ b/pkg/cluster/pool.go @@ -119,7 +119,11 @@ func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer) defer pool.lock.RUnlock() if nodes, ok := pool.featureMap[feature]; ok { err, res := lb.NextPeer(nodes) - return err, res.(Node) + if err == nil { + return nil, res.(Node) + } + + return err, nil } return ErrFeatureNotExist, nil diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index 47c52b5..496d741 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -187,6 +187,7 @@ func (node *SlaveNode) StartPingLoop() { util.Log().Debug("从机节点 [%s] 启动心跳循环", node.Model.Name) retry := 0 recoverMode := false + isFirstLoop := true loop: for { @@ -197,7 +198,9 @@ loop: } util.Log().Debug("从机节点 [%s] 发送Ping", node.Model.Name) - res, err := node.Ping(node.getHeartbeatContent(false)) + res, err := node.Ping(node.getHeartbeatContent(isFirstLoop)) + isFirstLoop = false + if err != nil { util.Log().Debug("Ping从机节点 [%s] 时发生错误: %s", node.Model.Name, err) retry++ @@ -217,6 +220,7 @@ loop: util.Log().Debug("从机节点 [%s] 复活", node.Model.Name) pingTicker = tickDuration recoverMode = false + isFirstLoop = true } util.Log().Debug("从机节点 [%s] 状态: %s", node.Model.Name, res) @@ -234,6 +238,7 @@ loop: // getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq { return &serializer.NodePingReq{ + SiteURL: model.GetSiteURL().String(), IsUpdate: isUpdate, SiteID: model.GetSettingByName("siteID"), Node: node.Model, @@ -245,6 +250,13 @@ func (node *SlaveNode) IsMater() bool { return false } +func (node *SlaveNode) GetAuthInstance() auth.Auth { + node.lock.RLock() + defer node.lock.RUnlock() + + return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)} +} + func (s *slaveCaller) Init() error { return nil } diff --git a/pkg/request/request.go b/pkg/request/request.go index 98a1baa..2908fa6 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -154,7 +154,7 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio if options.masterMeta { req.Header.Add("X-Site-Url", model.GetSiteURL().String()) - req.Header.Add("X-Site-ID", model.GetSettingByName("siteID")) + req.Header.Add("X-Site-Id", model.GetSettingByName("siteID")) req.Header.Add("X-Cloudreve-Version", conf.BackendVersion) } diff --git a/pkg/slave/slave.go b/pkg/slave/slave.go index eac8076..4348158 100644 --- a/pkg/slave/slave.go +++ b/pkg/slave/slave.go @@ -6,11 +6,11 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "net/http" + "net/url" "sync" ) @@ -36,10 +36,10 @@ type slaveController struct { // info of master node type masterInfo struct { - slaveID uint - id string - authClient auth.Auth - ttl int + slaveID uint + id string + ttl int + url *url.URL // used to invoke aria2 rpc calls instance cluster.Node } @@ -66,14 +66,18 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ origin.instance.Kill() } + masterUrl, err := url.Parse(req.SiteURL) + if err != nil { + return serializer.NodePingResp{}, err + } + c.masters[req.SiteID] = masterInfo{ slaveID: req.Node.ID, id: req.SiteID, - authClient: auth.HMACAuth{ - SecretKey: []byte(req.Node.MasterKey), - }, - ttl: req.CredentialTTL, + url: masterUrl, + ttl: req.CredentialTTL, instance: cluster.NewNodeFromDBModel(&model.Node{ + MasterKey: req.Node.MasterKey, Type: model.MasterNodeType, Aria2Enabled: req.Node.Aria2Enabled, Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized, @@ -101,12 +105,14 @@ func (c *slaveController) SendAria2Notification(id string, msg common.StatusEven if node, ok := c.masters[id]; ok { c.lock.RUnlock() + apiPath, _ := url.Parse(fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status)) + res, err := c.client.Request( "PATCH", - fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status), + node.url.ResolveReference(apiPath).String(), nil, - request.WithHeader(http.Header{"X-Node-ID": []string{fmt.Sprintf("%d", node.slaveID)}}), - request.WithCredential(node.authClient, int64(node.ttl)), + request.WithHeader(http.Header{"X-Node-Id": []string{fmt.Sprintf("%d", node.slaveID)}}), + request.WithCredential(node.instance.GetAuthInstance(), int64(node.ttl)), ).CheckHTTPResponse(200).DecodeResponse() if err != nil { return err diff --git a/routers/controllers/aria2.go b/routers/controllers/aria2.go index 0de7b8e..0259e4a 100644 --- a/routers/controllers/aria2.go +++ b/routers/controllers/aria2.go @@ -99,7 +99,7 @@ func ListFinished(c *gin.Context) { // TaskUpdate 被动更新任务状态 func TaskUpdate(c *gin.Context) { var service aria2.DownloadTaskService - if err := c.ShouldBindQuery(&service); err == nil { + if err := c.ShouldBindUri(&service); err == nil { res := service.Notify() c.JSON(200, res) } else { diff --git a/routers/router.go b/routers/router.go index e61e729..0735f4c 100644 --- a/routers/router.go +++ b/routers/router.go @@ -2,6 +2,7 @@ package routers import ( "github.com/cloudreve/Cloudreve/v3/middleware" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/util" @@ -29,7 +30,7 @@ func InitSlaveRouter() *gin.Engine { InitCORS(r) v3 := r.Group("/api/v3/slave") // 鉴权中间件 - v3.Use(middleware.SignRequired()) + v3.Use(middleware.SignRequired(auth.General)) // 主机信息解析 v3.Use(middleware.MasterMetadata()) @@ -149,7 +150,7 @@ func InitMasterRouter() *gin.Engine { user.PATCH("reset", controllers.UserReset) // 邮件激活 user.GET("activate/:id", - middleware.SignRequired(), + middleware.SignRequired(auth.General), middleware.HashID(hashid.UserID), controllers.UserActivate, ) @@ -177,7 +178,7 @@ func InitMasterRouter() *gin.Engine { // 需要携带签名验证的 sign := v3.Group("") - sign.Use(middleware.SignRequired()) + sign.Use(middleware.SignRequired(auth.General)) { file := sign.Group("file") { @@ -194,6 +195,7 @@ func InitMasterRouter() *gin.Engine { // 从机的 RPC 通信 slave := v3.Group("slave") + slave.Use(middleware.SlaveRPCSignRequired()) { slave.PATCH("aria2/:gid/:status", controllers.TaskUpdate) } diff --git a/service/aria2/manage.go b/service/aria2/manage.go index 508607f..020053e 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -17,7 +17,7 @@ type SelectFileService struct { // DownloadTaskService 下载任务管理服务 type DownloadTaskService struct { GID string `uri:"gid" binding:"required"` - Status int `uri:"gid"` + Status int `uri:"status"` } // DownloadListService 下载列表服务