Feat: call slave aria2 rpc method from master

This commit is contained in:
HFO4 2021-08-21 11:06:53 +08:00
parent 8c2affaa12
commit 32b88e989d
13 changed files with 207 additions and 32 deletions

15
middleware/cluster.go Normal file
View file

@ -0,0 +1,15 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
func MasterMetadata() gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("MasterSiteID", c.GetHeader("X-Site-ID"))
c.Set("MasterSiteURL", c.GetHeader("X-Site-Ur"))
c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version"))
c.Next()
}
}

View file

@ -5,6 +5,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/fatih/color"
"github.com/gofrs/uuid"
"github.com/jinzhu/gorm"
)
@ -74,6 +75,8 @@ func addDefaultPolicy() {
}
func addDefaultSettings() {
siteID, _ := uuid.NewV4()
defaultSettings := []Setting{
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
{Name: "siteName", Value: `Cloudreve`, Type: "basic"},
@ -84,6 +87,7 @@ func addDefaultSettings() {
{Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
{Name: "siteTitle", Value: `平步云端`, Type: "basic"},
{Name: "siteScript", Value: ``, Type: "basic"},
{Name: "siteID", Value: siteID.String(), Type: "basic"},
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
{Name: "mail_keepalive", Value: `30`, Type: "mail"},
{Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"},

View file

@ -6,6 +6,7 @@ import (
"io/ioutil"
"net/http"
"net/url"
"sort"
"strings"
"time"
@ -81,6 +82,7 @@ func getSignContent(r *http.Request) (rawSignString string) {
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
}
}
sort.Strings(signedHeader)
// 读取所有待签名Header
rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))

View file

@ -93,12 +93,13 @@ func (node *MasterNode) Kill() {
// GetAria2Instance 获取主机Aria2实例
func (node *MasterNode) GetAria2Instance() common.Aria2 {
node.lock.RLock()
defer node.lock.RUnlock()
if !node.Model.Aria2Enabled {
return &common.DummyAria2{}
}
node.lock.RLock()
defer node.lock.RUnlock()
if !node.aria2RPC.Initialized {
return &common.DummyAria2{}
}

View file

@ -2,13 +2,14 @@ package cluster
import (
"encoding/json"
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"net/url"
"path"
"strings"
@ -19,20 +20,26 @@ import (
type SlaveNode struct {
Model *model.Node
AuthInstance auth.Auth
Client request.Client
Active bool
caller slaveCaller
callback func(bool, uint)
close chan bool
lock sync.RWMutex
}
type slaveCaller struct {
parent *SlaveNode
Client request.Client
}
// Init 初始化节点
func (node *SlaveNode) Init(nodeModel *model.Node) {
node.lock.Lock()
node.Model = nodeModel
node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}
node.Client = request.HTTPClient{}
node.caller.Client = request.HTTPClient{}
node.caller.parent = node
node.Active = true
if node.close != nil {
node.close <- true
@ -44,7 +51,12 @@ func (node *SlaveNode) Init(nodeModel *model.Node) {
// IsFeatureEnabled 查询节点的某项功能是否启用
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
node.lock.RLock()
defer node.lock.RUnlock()
switch feature {
case "aria2":
return node.Model.Aria2Enabled
default:
return false
}
@ -67,10 +79,12 @@ func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingRe
bodyReader := strings.NewReader(string(reqBodyEncoded))
signTTL := model.GetIntSetting("slave_api_timeout", 60)
resp, err := node.Client.Request(
resp, err := node.caller.Client.Request(
"POST",
node.getAPIUrl("heartbeat"),
bodyReader,
request.WithMasterMeta(),
request.WithTimeout(time.Duration(signTTL)*time.Second),
request.WithCredential(node.AuthInstance, int64(signTTL)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
@ -79,7 +93,7 @@ func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingRe
// 处理列取结果
if resp.Code != 0 {
return nil, errors.New(resp.Error)
return nil, serializer.NewErrorFromResponse(resp)
}
var res serializer.NodePingResp
@ -96,6 +110,9 @@ func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingRe
// IsActive 返回节点是否在线
func (node *SlaveNode) IsActive() bool {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Active
}
@ -111,7 +128,14 @@ func (node *SlaveNode) Kill() {
// GetAria2Instance 获取从机Aria2实例
func (node *SlaveNode) GetAria2Instance() common.Aria2 {
return nil
node.lock.RLock()
defer node.lock.RUnlock()
if !node.Model.Aria2Enabled {
return &common.DummyAria2{}
}
return &node.caller
}
func (node *SlaveNode) ID() uint {
@ -210,8 +234,79 @@ loop:
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
return &serializer.NodePingReq{
IsUpdate: isUpdate,
MasterURL: model.GetSiteURL().String(),
Node: node.Model,
IsUpdate: isUpdate,
SiteID: model.GetSettingByName("siteID"),
Node: node.Model,
}
}
func (s *slaveCaller) Init() error {
return nil
}
// SendAria2Call send remote aria2 call to slave node
func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
reqReader, err := getAria2RequestBody(body)
if err != nil {
return nil, err
}
signTTL := model.GetIntSetting("slave_api_timeout", 60)
return s.Client.Request(
"POST",
s.parent.getAPIUrl("aria2/"+scope),
reqReader,
request.WithMasterMeta(),
request.WithTimeout(time.Duration(signTTL)*time.Second),
request.WithCredential(s.parent.AuthInstance, int64(signTTL)),
).CheckHTTPResponse(200).DecodeResponse()
}
func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
GroupOptions: options,
}
res, err := s.SendAria2Call(req, "task")
if err != nil {
return "", err
}
if res.Code != 0 {
return "", serializer.NewErrorFromResponse(res)
}
return res.Data.(string), err
}
func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
panic("implement me")
}
func (s *slaveCaller) Cancel(task *model.Download) error {
panic("implement me")
}
func (s *slaveCaller) Select(task *model.Download, files []int) error {
panic("implement me")
}
func (s *slaveCaller) GetConfig() model.Aria2Option {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
return s.parent.Model.Aria2OptionsSerialized
}
func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
reqBodyEncoded, err := json.Marshal(body)
if err != nil {
return nil, err
}
return strings.NewReader(string(reqBodyEncoded)), nil
}

View file

@ -154,6 +154,7 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
if options.masterMeta {
req.Header.Add("X-Site-Url", model.GetSiteURL().String())
req.Header.Add("X-Site-ID", model.GetSettingByName("siteID"))
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
}

View file

@ -1,6 +1,9 @@
package serializer
import "github.com/gin-gonic/gin"
import (
"errors"
"github.com/gin-gonic/gin"
)
// Response 基础序列化器
type Response struct {
@ -17,7 +20,7 @@ type AppError struct {
RawError error
}
// NewError 返回新的错误对象 todo:测试 还有下面的
// NewError 返回新的错误对象
func NewError(code int, msg string, err error) AppError {
return AppError{
Code: code,
@ -26,6 +29,15 @@ func NewError(code int, msg string, err error) AppError {
}
}
// NewErrorFromResponse 从 serializer.Response 构建错误
func NewErrorFromResponse(resp *Response) AppError {
return AppError{
Code: resp.Code,
Msg: resp.Msg,
RawError: errors.New(resp.Error),
}
}
// WithError 将应用error携带标准库中的error
func (err *AppError) WithError(raw error) AppError {
err.RawError = raw
@ -66,6 +78,8 @@ const (
CodeGroupNotAllowed = 40007
// CodeAdminRequired 非管理用户组
CodeAdminRequired = 40008
// CodeMasterNotFound 主机节点未注册
CodeMasterNotFound = 40009
// CodeDBError 数据库操作失败
CodeDBError = 50001
// CodeEncryptError 加密失败

View file

@ -15,11 +15,19 @@ type ListRequest struct {
// NodePingReq 从机节点Ping请求
type NodePingReq struct {
MasterURL string `json:"master_url"`
IsUpdate bool `json:"is_update"`
Node *model.Node `json:"node"`
SiteURL string `json:"site_url"`
SiteID string `json:"site_id"`
IsUpdate bool `json:"is_update"`
Node *model.Node `json:"node"`
}
// NodePingResp 从机节点Ping响应
type NodePingResp struct {
}
// SlaveAria2Call 从机有关Aria2的请求正文
type SlaveAria2Call struct {
Task *model.Download `json:"task"`
GroupOptions map[string]interface{} `json:"group_options"`
Files []uint `json:"files"`
}

7
pkg/slave/errors.go Normal file
View file

@ -0,0 +1,7 @@
package slave
import "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
var (
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
)

View file

@ -2,6 +2,7 @@ package slave
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
@ -14,6 +15,9 @@ var DefaultController Controller
type Controller interface {
// Handle heartbeat sent from master
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
// Get Aria2 instance by master node id
GetAria2Instance(string) (common.Aria2, error)
}
type slaveController struct {
@ -24,7 +28,7 @@ type slaveController struct {
// info of master node
type masterInfo struct {
slaveID uint
url string
id string
authClient auth.Auth
// used to invoke aria2 rpc calls
instance cluster.Node
@ -43,16 +47,16 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
req.Node.AfterFind()
// close old node if exist
origin, ok := c.masters[req.MasterURL]
origin, ok := c.masters[req.SiteID]
if (ok && req.IsUpdate) || !ok {
if ok {
origin.instance.Kill()
}
c.masters[req.MasterURL] = masterInfo{
c.masters[req.SiteID] = masterInfo{
slaveID: req.Node.ID,
url: req.MasterURL,
id: req.SiteID,
authClient: auth.HMACAuth{
SecretKey: []byte(req.Node.MasterKey),
},
@ -66,3 +70,14 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
return serializer.NodePingResp{}, nil
}
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
return node.instance.GetAria2Instance(), nil
}
return nil, ErrMasterNotFound
}

View file

@ -191,9 +191,9 @@ func SlaveHeartbeat(c *gin.Context) {
// SlaveAria2Create 创建 Aria2 任务
func SlaveAria2Create(c *gin.Context) {
var service aria2.SlaveAria2Call
var service serializer.SlaveAria2Call
if err := c.ShouldBindJSON(&service); err == nil {
res := service.Add(c)
res := aria2.Add(c, &service)
c.JSON(200, res)
} else {
c.JSON(200, ErrorResponse(err))

View file

@ -30,6 +30,8 @@ func InitSlaveRouter() *gin.Engine {
v3 := r.Group("/api/v3/slave")
// 鉴权中间件
v3.Use(middleware.SignRequired())
// 主机信息解析
v3.Use(middleware.MasterMetadata())
/*
路由
@ -55,7 +57,7 @@ func InitSlaveRouter() *gin.Engine {
// 离线下载
aria2 := v3.Group("aria2")
{
aria2.POST("task", controllers.SlaveList)
aria2.POST("task", controllers.SlaveAria2Create)
}
}
return r

View file

@ -8,6 +8,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
"github.com/gin-gonic/gin"
)
@ -17,13 +18,6 @@ type AddURLService struct {
Dst string `json:"dst" binding:"required,min=1"`
}
// SlaveAria2Call 从机有关Aria2的请求正文
type SlaveAria2Call struct {
Task *model.Download `json:"task"`
GroupOptions map[string]interface{} `json:"group_options"`
Files []uint `json:"files"`
}
// Add 主机创建新的链接离线下载任务
func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Response {
// 创建文件系统
@ -83,6 +77,23 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
}
// Add 从机创建新的链接离线下载任务
func (service *SlaveAria2Call) Add(c *gin.Context) serializer.Response {
return serializer.Response{}
func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
if siteID, exist := c.Get("MasterSiteID"); exist {
// 获取对应主机节点的从机Aria2实例
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
if err != nil {
return serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err)
}
// 创建任务
gid, err := caller.CreateTask(service.Task, service.GroupOptions)
if err != nil {
return serializer.Err(serializer.CodeInternalSetting, "无法创建离线下载任务", err)
}
// TODO: 创建监控
return serializer.Response{Data: gid}
}
return serializer.ParamErr("未知的主机节点ID", nil)
}