Feat: slave aria2 status event callback / salve RPC auth

This commit is contained in:
HFO4 2021-08-29 20:31:37 +08:00
parent cf2960a092
commit 870df708bf
14 changed files with 92 additions and 31 deletions

View file

@ -22,16 +22,14 @@ import (
)
// SignRequired 验证请求签名
func SignRequired() gin.HandlerFunc {
func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
return func(c *gin.Context) {
var err error
switch c.Request.Method {
case "PUT", "POST":
err = auth.CheckRequest(auth.General, c.Request)
// TODO 生产环境去掉下一行
//err = nil
case "PUT", "POST", "PATCH":
err = auth.CheckRequest(authInstance, c.Request)
default:
err = auth.CheckURI(auth.General, c.Request.URL)
err = auth.CheckURI(authInstance, c.Request.URL)
}
if err != nil {

View file

@ -87,11 +87,10 @@ func TestAuthRequired(t *testing.T) {
func TestSignRequired(t *testing.T) {
asserts := assert.New(t)
auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request, _ = http.NewRequest("GET", "/test", nil)
SignRequiredFunc := SignRequired()
SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))})
// 鉴权失败
SignRequiredFunc(c)

View file

@ -1,16 +1,18 @@
package middleware
import (
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/gin-gonic/gin"
"strconv"
)
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
func MasterMetadata() gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("MasterSiteID", c.GetHeader("X-Site-ID"))
c.Set("MasterSiteURL", c.GetHeader("X-Site-Ur"))
c.Set("MasterSiteID", c.GetHeader("X-Site-Id"))
c.Set("MasterSiteURL", c.GetHeader("X-Site-Url"))
c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version"))
c.Next()
}
@ -37,3 +39,24 @@ func UseSlaveAria2Instance() gin.HandlerFunc {
c.Abort()
}
}
func SlaveRPCSignRequired() gin.HandlerFunc {
return func(c *gin.Context) {
nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64)
if err != nil {
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()
return
}
slaveNode := cluster.Default.GetNodeByID(uint(nodeID))
if slaveNode == nil {
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
c.Abort()
return
}
SignRequired(slaveNode.GetAuthInstance())(c)
}
}

View file

@ -3,5 +3,6 @@ package balancer
import "errors"
var (
ErrInputNotSlice = errors.New("Input value is not silice")
ErrInputNotSlice = errors.New("Input value is not silice")
ErrNoAvaliableNode = errors.New("No nodes avaliable")
)

View file

@ -16,6 +16,10 @@ func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) {
return ErrInputNotSlice, nil
}
if v.Len() == 0 {
return ErrNoAvaliableNode, nil
}
next := r.NextIndex(v.Len())
return nil, v.Index(next).Interface()
}

View file

@ -6,6 +6,7 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"net/url"
@ -75,6 +76,13 @@ func (node *MasterNode) IsFeatureEnabled(feature string) bool {
}
}
func (node *MasterNode) GetAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
// SubscribeStatusChange 订阅节点状态更改
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
}

View file

@ -3,6 +3,7 @@ package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
@ -33,6 +34,9 @@ type Node interface {
// Returns if current node is master node
IsMater() bool
// Get auth instance used to check RPC call from slave to master
GetAuthInstance() auth.Auth
}
// Create new node from DB model

View file

@ -119,7 +119,11 @@ func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer)
defer pool.lock.RUnlock()
if nodes, ok := pool.featureMap[feature]; ok {
err, res := lb.NextPeer(nodes)
return err, res.(Node)
if err == nil {
return nil, res.(Node)
}
return err, nil
}
return ErrFeatureNotExist, nil

View file

@ -187,6 +187,7 @@ func (node *SlaveNode) StartPingLoop() {
util.Log().Debug("从机节点 [%s] 启动心跳循环", node.Model.Name)
retry := 0
recoverMode := false
isFirstLoop := true
loop:
for {
@ -197,7 +198,9 @@ loop:
}
util.Log().Debug("从机节点 [%s] 发送Ping", node.Model.Name)
res, err := node.Ping(node.getHeartbeatContent(false))
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
isFirstLoop = false
if err != nil {
util.Log().Debug("Ping从机节点 [%s] 时发生错误: %s", node.Model.Name, err)
retry++
@ -217,6 +220,7 @@ loop:
util.Log().Debug("从机节点 [%s] 复活", node.Model.Name)
pingTicker = tickDuration
recoverMode = false
isFirstLoop = true
}
util.Log().Debug("从机节点 [%s] 状态: %s", node.Model.Name, res)
@ -234,6 +238,7 @@ loop:
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
return &serializer.NodePingReq{
SiteURL: model.GetSiteURL().String(),
IsUpdate: isUpdate,
SiteID: model.GetSettingByName("siteID"),
Node: node.Model,
@ -245,6 +250,13 @@ func (node *SlaveNode) IsMater() bool {
return false
}
func (node *SlaveNode) GetAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
func (s *slaveCaller) Init() error {
return nil
}

View file

@ -154,7 +154,7 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
if options.masterMeta {
req.Header.Add("X-Site-Url", model.GetSiteURL().String())
req.Header.Add("X-Site-ID", model.GetSettingByName("siteID"))
req.Header.Add("X-Site-Id", model.GetSettingByName("siteID"))
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
}

View file

@ -6,11 +6,11 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"net/http"
"net/url"
"sync"
)
@ -36,10 +36,10 @@ type slaveController struct {
// info of master node
type masterInfo struct {
slaveID uint
id string
authClient auth.Auth
ttl int
slaveID uint
id string
ttl int
url *url.URL
// used to invoke aria2 rpc calls
instance cluster.Node
}
@ -66,14 +66,18 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
origin.instance.Kill()
}
masterUrl, err := url.Parse(req.SiteURL)
if err != nil {
return serializer.NodePingResp{}, err
}
c.masters[req.SiteID] = masterInfo{
slaveID: req.Node.ID,
id: req.SiteID,
authClient: auth.HMACAuth{
SecretKey: []byte(req.Node.MasterKey),
},
ttl: req.CredentialTTL,
url: masterUrl,
ttl: req.CredentialTTL,
instance: cluster.NewNodeFromDBModel(&model.Node{
MasterKey: req.Node.MasterKey,
Type: model.MasterNodeType,
Aria2Enabled: req.Node.Aria2Enabled,
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
@ -101,12 +105,14 @@ func (c *slaveController) SendAria2Notification(id string, msg common.StatusEven
if node, ok := c.masters[id]; ok {
c.lock.RUnlock()
apiPath, _ := url.Parse(fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status))
res, err := c.client.Request(
"PATCH",
fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status),
node.url.ResolveReference(apiPath).String(),
nil,
request.WithHeader(http.Header{"X-Node-ID": []string{fmt.Sprintf("%d", node.slaveID)}}),
request.WithCredential(node.authClient, int64(node.ttl)),
request.WithHeader(http.Header{"X-Node-Id": []string{fmt.Sprintf("%d", node.slaveID)}}),
request.WithCredential(node.instance.GetAuthInstance(), int64(node.ttl)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err

View file

@ -99,7 +99,7 @@ func ListFinished(c *gin.Context) {
// TaskUpdate 被动更新任务状态
func TaskUpdate(c *gin.Context) {
var service aria2.DownloadTaskService
if err := c.ShouldBindQuery(&service); err == nil {
if err := c.ShouldBindUri(&service); err == nil {
res := service.Notify()
c.JSON(200, res)
} else {

View file

@ -2,6 +2,7 @@ package routers
import (
"github.com/cloudreve/Cloudreve/v3/middleware"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
@ -29,7 +30,7 @@ func InitSlaveRouter() *gin.Engine {
InitCORS(r)
v3 := r.Group("/api/v3/slave")
// 鉴权中间件
v3.Use(middleware.SignRequired())
v3.Use(middleware.SignRequired(auth.General))
// 主机信息解析
v3.Use(middleware.MasterMetadata())
@ -149,7 +150,7 @@ func InitMasterRouter() *gin.Engine {
user.PATCH("reset", controllers.UserReset)
// 邮件激活
user.GET("activate/:id",
middleware.SignRequired(),
middleware.SignRequired(auth.General),
middleware.HashID(hashid.UserID),
controllers.UserActivate,
)
@ -177,7 +178,7 @@ func InitMasterRouter() *gin.Engine {
// 需要携带签名验证的
sign := v3.Group("")
sign.Use(middleware.SignRequired())
sign.Use(middleware.SignRequired(auth.General))
{
file := sign.Group("file")
{
@ -194,6 +195,7 @@ func InitMasterRouter() *gin.Engine {
// 从机的 RPC 通信
slave := v3.Group("slave")
slave.Use(middleware.SlaveRPCSignRequired())
{
slave.PATCH("aria2/:gid/:status", controllers.TaskUpdate)
}

View file

@ -17,7 +17,7 @@ type SelectFileService struct {
// DownloadTaskService 下载任务管理服务
type DownloadTaskService struct {
GID string `uri:"gid" binding:"required"`
Status int `uri:"gid"`
Status int `uri:"status"`
}
// DownloadListService 下载列表服务