Feat: slave aria2 status event callback / salve RPC auth
This commit is contained in:
parent
cf2960a092
commit
870df708bf
14 changed files with 92 additions and 31 deletions
|
@ -22,16 +22,14 @@ import (
|
|||
)
|
||||
|
||||
// SignRequired 验证请求签名
|
||||
func SignRequired() gin.HandlerFunc {
|
||||
func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var err error
|
||||
switch c.Request.Method {
|
||||
case "PUT", "POST":
|
||||
err = auth.CheckRequest(auth.General, c.Request)
|
||||
// TODO 生产环境去掉下一行
|
||||
//err = nil
|
||||
case "PUT", "POST", "PATCH":
|
||||
err = auth.CheckRequest(authInstance, c.Request)
|
||||
default:
|
||||
err = auth.CheckURI(auth.General, c.Request.URL)
|
||||
err = auth.CheckURI(authInstance, c.Request.URL)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -87,11 +87,10 @@ func TestAuthRequired(t *testing.T) {
|
|||
|
||||
func TestSignRequired(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
auth.General = auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
SignRequiredFunc := SignRequired()
|
||||
SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))})
|
||||
|
||||
// 鉴权失败
|
||||
SignRequiredFunc(c)
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
|
||||
func MasterMetadata() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set("MasterSiteID", c.GetHeader("X-Site-ID"))
|
||||
c.Set("MasterSiteURL", c.GetHeader("X-Site-Ur"))
|
||||
c.Set("MasterSiteID", c.GetHeader("X-Site-Id"))
|
||||
c.Set("MasterSiteURL", c.GetHeader("X-Site-Url"))
|
||||
c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version"))
|
||||
c.Next()
|
||||
}
|
||||
|
@ -37,3 +39,24 @@ func UseSlaveAria2Instance() gin.HandlerFunc {
|
|||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func SlaveRPCSignRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
slaveNode := cluster.Default.GetNodeByID(uint(nodeID))
|
||||
if slaveNode == nil {
|
||||
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
SignRequired(slaveNode.GetAuthInstance())(c)
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,5 +3,6 @@ package balancer
|
|||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInputNotSlice = errors.New("Input value is not silice")
|
||||
ErrInputNotSlice = errors.New("Input value is not silice")
|
||||
ErrNoAvaliableNode = errors.New("No nodes avaliable")
|
||||
)
|
||||
|
|
|
@ -16,6 +16,10 @@ func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) {
|
|||
return ErrInputNotSlice, nil
|
||||
}
|
||||
|
||||
if v.Len() == 0 {
|
||||
return ErrNoAvaliableNode, nil
|
||||
}
|
||||
|
||||
next := r.NextIndex(v.Len())
|
||||
return nil, v.Index(next).Interface()
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"net/url"
|
||||
|
@ -75,6 +76,13 @@ func (node *MasterNode) IsFeatureEnabled(feature string) bool {
|
|||
}
|
||||
}
|
||||
|
||||
func (node *MasterNode) GetAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||
}
|
||||
|
||||
// SubscribeStatusChange 订阅节点状态更改
|
||||
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package cluster
|
|||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
|
@ -33,6 +34,9 @@ type Node interface {
|
|||
|
||||
// Returns if current node is master node
|
||||
IsMater() bool
|
||||
|
||||
// Get auth instance used to check RPC call from slave to master
|
||||
GetAuthInstance() auth.Auth
|
||||
}
|
||||
|
||||
// Create new node from DB model
|
||||
|
|
|
@ -119,7 +119,11 @@ func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer)
|
|||
defer pool.lock.RUnlock()
|
||||
if nodes, ok := pool.featureMap[feature]; ok {
|
||||
err, res := lb.NextPeer(nodes)
|
||||
return err, res.(Node)
|
||||
if err == nil {
|
||||
return nil, res.(Node)
|
||||
}
|
||||
|
||||
return err, nil
|
||||
}
|
||||
|
||||
return ErrFeatureNotExist, nil
|
||||
|
|
|
@ -187,6 +187,7 @@ func (node *SlaveNode) StartPingLoop() {
|
|||
util.Log().Debug("从机节点 [%s] 启动心跳循环", node.Model.Name)
|
||||
retry := 0
|
||||
recoverMode := false
|
||||
isFirstLoop := true
|
||||
|
||||
loop:
|
||||
for {
|
||||
|
@ -197,7 +198,9 @@ loop:
|
|||
}
|
||||
|
||||
util.Log().Debug("从机节点 [%s] 发送Ping", node.Model.Name)
|
||||
res, err := node.Ping(node.getHeartbeatContent(false))
|
||||
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
|
||||
isFirstLoop = false
|
||||
|
||||
if err != nil {
|
||||
util.Log().Debug("Ping从机节点 [%s] 时发生错误: %s", node.Model.Name, err)
|
||||
retry++
|
||||
|
@ -217,6 +220,7 @@ loop:
|
|||
util.Log().Debug("从机节点 [%s] 复活", node.Model.Name)
|
||||
pingTicker = tickDuration
|
||||
recoverMode = false
|
||||
isFirstLoop = true
|
||||
}
|
||||
|
||||
util.Log().Debug("从机节点 [%s] 状态: %s", node.Model.Name, res)
|
||||
|
@ -234,6 +238,7 @@ loop:
|
|||
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
|
||||
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
|
||||
return &serializer.NodePingReq{
|
||||
SiteURL: model.GetSiteURL().String(),
|
||||
IsUpdate: isUpdate,
|
||||
SiteID: model.GetSettingByName("siteID"),
|
||||
Node: node.Model,
|
||||
|
@ -245,6 +250,13 @@ func (node *SlaveNode) IsMater() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (node *SlaveNode) GetAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -154,7 +154,7 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
|
|||
|
||||
if options.masterMeta {
|
||||
req.Header.Add("X-Site-Url", model.GetSiteURL().String())
|
||||
req.Header.Add("X-Site-ID", model.GetSettingByName("siteID"))
|
||||
req.Header.Add("X-Site-Id", model.GetSettingByName("siteID"))
|
||||
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
|
||||
}
|
||||
|
||||
|
|
|
@ -6,11 +6,11 @@ import (
|
|||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
@ -36,10 +36,10 @@ type slaveController struct {
|
|||
|
||||
// info of master node
|
||||
type masterInfo struct {
|
||||
slaveID uint
|
||||
id string
|
||||
authClient auth.Auth
|
||||
ttl int
|
||||
slaveID uint
|
||||
id string
|
||||
ttl int
|
||||
url *url.URL
|
||||
// used to invoke aria2 rpc calls
|
||||
instance cluster.Node
|
||||
}
|
||||
|
@ -66,14 +66,18 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
|
|||
origin.instance.Kill()
|
||||
}
|
||||
|
||||
masterUrl, err := url.Parse(req.SiteURL)
|
||||
if err != nil {
|
||||
return serializer.NodePingResp{}, err
|
||||
}
|
||||
|
||||
c.masters[req.SiteID] = masterInfo{
|
||||
slaveID: req.Node.ID,
|
||||
id: req.SiteID,
|
||||
authClient: auth.HMACAuth{
|
||||
SecretKey: []byte(req.Node.MasterKey),
|
||||
},
|
||||
ttl: req.CredentialTTL,
|
||||
url: masterUrl,
|
||||
ttl: req.CredentialTTL,
|
||||
instance: cluster.NewNodeFromDBModel(&model.Node{
|
||||
MasterKey: req.Node.MasterKey,
|
||||
Type: model.MasterNodeType,
|
||||
Aria2Enabled: req.Node.Aria2Enabled,
|
||||
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
|
||||
|
@ -101,12 +105,14 @@ func (c *slaveController) SendAria2Notification(id string, msg common.StatusEven
|
|||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
apiPath, _ := url.Parse(fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status))
|
||||
|
||||
res, err := c.client.Request(
|
||||
"PATCH",
|
||||
fmt.Sprintf("/api/v3/slave/aria2/%s/%d", msg.GID, msg.Status),
|
||||
node.url.ResolveReference(apiPath).String(),
|
||||
nil,
|
||||
request.WithHeader(http.Header{"X-Node-ID": []string{fmt.Sprintf("%d", node.slaveID)}}),
|
||||
request.WithCredential(node.authClient, int64(node.ttl)),
|
||||
request.WithHeader(http.Header{"X-Node-Id": []string{fmt.Sprintf("%d", node.slaveID)}}),
|
||||
request.WithCredential(node.instance.GetAuthInstance(), int64(node.ttl)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -99,7 +99,7 @@ func ListFinished(c *gin.Context) {
|
|||
// TaskUpdate 被动更新任务状态
|
||||
func TaskUpdate(c *gin.Context) {
|
||||
var service aria2.DownloadTaskService
|
||||
if err := c.ShouldBindQuery(&service); err == nil {
|
||||
if err := c.ShouldBindUri(&service); err == nil {
|
||||
res := service.Notify()
|
||||
c.JSON(200, res)
|
||||
} else {
|
||||
|
|
|
@ -2,6 +2,7 @@ package routers
|
|||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/middleware"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
|
@ -29,7 +30,7 @@ func InitSlaveRouter() *gin.Engine {
|
|||
InitCORS(r)
|
||||
v3 := r.Group("/api/v3/slave")
|
||||
// 鉴权中间件
|
||||
v3.Use(middleware.SignRequired())
|
||||
v3.Use(middleware.SignRequired(auth.General))
|
||||
// 主机信息解析
|
||||
v3.Use(middleware.MasterMetadata())
|
||||
|
||||
|
@ -149,7 +150,7 @@ func InitMasterRouter() *gin.Engine {
|
|||
user.PATCH("reset", controllers.UserReset)
|
||||
// 邮件激活
|
||||
user.GET("activate/:id",
|
||||
middleware.SignRequired(),
|
||||
middleware.SignRequired(auth.General),
|
||||
middleware.HashID(hashid.UserID),
|
||||
controllers.UserActivate,
|
||||
)
|
||||
|
@ -177,7 +178,7 @@ func InitMasterRouter() *gin.Engine {
|
|||
|
||||
// 需要携带签名验证的
|
||||
sign := v3.Group("")
|
||||
sign.Use(middleware.SignRequired())
|
||||
sign.Use(middleware.SignRequired(auth.General))
|
||||
{
|
||||
file := sign.Group("file")
|
||||
{
|
||||
|
@ -194,6 +195,7 @@ func InitMasterRouter() *gin.Engine {
|
|||
|
||||
// 从机的 RPC 通信
|
||||
slave := v3.Group("slave")
|
||||
slave.Use(middleware.SlaveRPCSignRequired())
|
||||
{
|
||||
slave.PATCH("aria2/:gid/:status", controllers.TaskUpdate)
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ type SelectFileService struct {
|
|||
// DownloadTaskService 下载任务管理服务
|
||||
type DownloadTaskService struct {
|
||||
GID string `uri:"gid" binding:"required"`
|
||||
Status int `uri:"gid"`
|
||||
Status int `uri:"status"`
|
||||
}
|
||||
|
||||
// DownloadListService 下载列表服务
|
||||
|
|
Loading…
Add table
Reference in a new issue