Feat: call slave aria2 rpc method from master
This commit is contained in:
parent
8c2affaa12
commit
32b88e989d
13 changed files with 207 additions and 32 deletions
15
middleware/cluster.go
Normal file
15
middleware/cluster.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MasterMetadata 解析主机节点发来请求的包含主机节点信息的元数据
|
||||
func MasterMetadata() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set("MasterSiteID", c.GetHeader("X-Site-ID"))
|
||||
c.Set("MasterSiteURL", c.GetHeader("X-Site-Ur"))
|
||||
c.Set("MasterVersion", c.GetHeader("X-Cloudreve-Version"))
|
||||
c.Next()
|
||||
}
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/fatih/color"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
|
@ -74,6 +75,8 @@ func addDefaultPolicy() {
|
|||
}
|
||||
|
||||
func addDefaultSettings() {
|
||||
siteID, _ := uuid.NewV4()
|
||||
|
||||
defaultSettings := []Setting{
|
||||
{Name: "siteURL", Value: `http://localhost`, Type: "basic"},
|
||||
{Name: "siteName", Value: `Cloudreve`, Type: "basic"},
|
||||
|
@ -84,6 +87,7 @@ func addDefaultSettings() {
|
|||
{Name: "siteDes", Value: `Cloudreve`, Type: "basic"},
|
||||
{Name: "siteTitle", Value: `平步云端`, Type: "basic"},
|
||||
{Name: "siteScript", Value: ``, Type: "basic"},
|
||||
{Name: "siteID", Value: siteID.String(), Type: "basic"},
|
||||
{Name: "fromName", Value: `Cloudreve`, Type: "mail"},
|
||||
{Name: "mail_keepalive", Value: `30`, Type: "mail"},
|
||||
{Name: "fromAdress", Value: `no-reply@acg.blue`, Type: "mail"},
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -81,6 +82,7 @@ func getSignContent(r *http.Request) (rawSignString string) {
|
|||
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
|
||||
}
|
||||
}
|
||||
sort.Strings(signedHeader)
|
||||
|
||||
// 读取所有待签名Header
|
||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))
|
||||
|
|
|
@ -93,12 +93,13 @@ func (node *MasterNode) Kill() {
|
|||
|
||||
// GetAria2Instance 获取主机Aria2实例
|
||||
func (node *MasterNode) GetAria2Instance() common.Aria2 {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
if !node.Model.Aria2Enabled {
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
if !node.aria2RPC.Initialized {
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
|
|
@ -2,13 +2,14 @@ package cluster
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"io"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
@ -19,20 +20,26 @@ import (
|
|||
type SlaveNode struct {
|
||||
Model *model.Node
|
||||
AuthInstance auth.Auth
|
||||
Client request.Client
|
||||
Active bool
|
||||
|
||||
caller slaveCaller
|
||||
callback func(bool, uint)
|
||||
close chan bool
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
type slaveCaller struct {
|
||||
parent *SlaveNode
|
||||
Client request.Client
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *SlaveNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}
|
||||
node.Client = request.HTTPClient{}
|
||||
node.caller.Client = request.HTTPClient{}
|
||||
node.caller.parent = node
|
||||
node.Active = true
|
||||
if node.close != nil {
|
||||
node.close <- true
|
||||
|
@ -44,7 +51,12 @@ func (node *SlaveNode) Init(nodeModel *model.Node) {
|
|||
|
||||
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
switch feature {
|
||||
case "aria2":
|
||||
return node.Model.Aria2Enabled
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
@ -67,10 +79,12 @@ func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingRe
|
|||
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
|
||||
resp, err := node.Client.Request(
|
||||
resp, err := node.caller.Client.Request(
|
||||
"POST",
|
||||
node.getAPIUrl("heartbeat"),
|
||||
bodyReader,
|
||||
request.WithMasterMeta(),
|
||||
request.WithTimeout(time.Duration(signTTL)*time.Second),
|
||||
request.WithCredential(node.AuthInstance, int64(signTTL)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
|
@ -79,7 +93,7 @@ func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingRe
|
|||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return nil, errors.New(resp.Error)
|
||||
return nil, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
var res serializer.NodePingResp
|
||||
|
@ -96,6 +110,9 @@ func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingRe
|
|||
|
||||
// IsActive 返回节点是否在线
|
||||
func (node *SlaveNode) IsActive() bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Active
|
||||
}
|
||||
|
||||
|
@ -111,7 +128,14 @@ func (node *SlaveNode) Kill() {
|
|||
|
||||
// GetAria2Instance 获取从机Aria2实例
|
||||
func (node *SlaveNode) GetAria2Instance() common.Aria2 {
|
||||
return nil
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
if !node.Model.Aria2Enabled {
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
return &node.caller
|
||||
}
|
||||
|
||||
func (node *SlaveNode) ID() uint {
|
||||
|
@ -210,8 +234,79 @@ loop:
|
|||
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
|
||||
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
|
||||
return &serializer.NodePingReq{
|
||||
IsUpdate: isUpdate,
|
||||
MasterURL: model.GetSiteURL().String(),
|
||||
Node: node.Model,
|
||||
IsUpdate: isUpdate,
|
||||
SiteID: model.GetSettingByName("siteID"),
|
||||
Node: node.Model,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendAria2Call send remote aria2 call to slave node
|
||||
func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
|
||||
reqReader, err := getAria2RequestBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
return s.Client.Request(
|
||||
"POST",
|
||||
s.parent.getAPIUrl("aria2/"+scope),
|
||||
reqReader,
|
||||
request.WithMasterMeta(),
|
||||
request.WithTimeout(time.Duration(signTTL)*time.Second),
|
||||
request.WithCredential(s.parent.AuthInstance, int64(signTTL)),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
}
|
||||
|
||||
func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
GroupOptions: options,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "task")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return "", serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return res.Data.(string), err
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Cancel(task *model.Download) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Select(task *model.Download, files []int) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (s *slaveCaller) GetConfig() model.Aria2Option {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
return s.parent.Model.Aria2OptionsSerialized
|
||||
}
|
||||
|
||||
func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
|
||||
reqBodyEncoded, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return strings.NewReader(string(reqBodyEncoded)), nil
|
||||
}
|
||||
|
|
|
@ -154,6 +154,7 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
|
|||
|
||||
if options.masterMeta {
|
||||
req.Header.Add("X-Site-Url", model.GetSiteURL().String())
|
||||
req.Header.Add("X-Site-ID", model.GetSettingByName("siteID"))
|
||||
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package serializer
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 基础序列化器
|
||||
type Response struct {
|
||||
|
@ -17,7 +20,7 @@ type AppError struct {
|
|||
RawError error
|
||||
}
|
||||
|
||||
// NewError 返回新的错误对象 todo:测试 还有下面的
|
||||
// NewError 返回新的错误对象
|
||||
func NewError(code int, msg string, err error) AppError {
|
||||
return AppError{
|
||||
Code: code,
|
||||
|
@ -26,6 +29,15 @@ func NewError(code int, msg string, err error) AppError {
|
|||
}
|
||||
}
|
||||
|
||||
// NewErrorFromResponse 从 serializer.Response 构建错误
|
||||
func NewErrorFromResponse(resp *Response) AppError {
|
||||
return AppError{
|
||||
Code: resp.Code,
|
||||
Msg: resp.Msg,
|
||||
RawError: errors.New(resp.Error),
|
||||
}
|
||||
}
|
||||
|
||||
// WithError 将应用error携带标准库中的error
|
||||
func (err *AppError) WithError(raw error) AppError {
|
||||
err.RawError = raw
|
||||
|
@ -66,6 +78,8 @@ const (
|
|||
CodeGroupNotAllowed = 40007
|
||||
// CodeAdminRequired 非管理用户组
|
||||
CodeAdminRequired = 40008
|
||||
// CodeMasterNotFound 主机节点未注册
|
||||
CodeMasterNotFound = 40009
|
||||
// CodeDBError 数据库操作失败
|
||||
CodeDBError = 50001
|
||||
// CodeEncryptError 加密失败
|
||||
|
|
|
@ -15,11 +15,19 @@ type ListRequest struct {
|
|||
|
||||
// NodePingReq 从机节点Ping请求
|
||||
type NodePingReq struct {
|
||||
MasterURL string `json:"master_url"`
|
||||
IsUpdate bool `json:"is_update"`
|
||||
Node *model.Node `json:"node"`
|
||||
SiteURL string `json:"site_url"`
|
||||
SiteID string `json:"site_id"`
|
||||
IsUpdate bool `json:"is_update"`
|
||||
Node *model.Node `json:"node"`
|
||||
}
|
||||
|
||||
// NodePingResp 从机节点Ping响应
|
||||
type NodePingResp struct {
|
||||
}
|
||||
|
||||
// SlaveAria2Call 从机有关Aria2的请求正文
|
||||
type SlaveAria2Call struct {
|
||||
Task *model.Download `json:"task"`
|
||||
GroupOptions map[string]interface{} `json:"group_options"`
|
||||
Files []uint `json:"files"`
|
||||
}
|
||||
|
|
7
pkg/slave/errors.go
Normal file
7
pkg/slave/errors.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package slave
|
||||
|
||||
import "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
|
||||
var (
|
||||
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
|
||||
)
|
|
@ -2,6 +2,7 @@ package slave
|
|||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
|
@ -14,6 +15,9 @@ var DefaultController Controller
|
|||
type Controller interface {
|
||||
// Handle heartbeat sent from master
|
||||
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
|
||||
|
||||
// Get Aria2 instance by master node id
|
||||
GetAria2Instance(string) (common.Aria2, error)
|
||||
}
|
||||
|
||||
type slaveController struct {
|
||||
|
@ -24,7 +28,7 @@ type slaveController struct {
|
|||
// info of master node
|
||||
type masterInfo struct {
|
||||
slaveID uint
|
||||
url string
|
||||
id string
|
||||
authClient auth.Auth
|
||||
// used to invoke aria2 rpc calls
|
||||
instance cluster.Node
|
||||
|
@ -43,16 +47,16 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
|
|||
req.Node.AfterFind()
|
||||
|
||||
// close old node if exist
|
||||
origin, ok := c.masters[req.MasterURL]
|
||||
origin, ok := c.masters[req.SiteID]
|
||||
|
||||
if (ok && req.IsUpdate) || !ok {
|
||||
if ok {
|
||||
origin.instance.Kill()
|
||||
}
|
||||
|
||||
c.masters[req.MasterURL] = masterInfo{
|
||||
c.masters[req.SiteID] = masterInfo{
|
||||
slaveID: req.Node.ID,
|
||||
url: req.MasterURL,
|
||||
id: req.SiteID,
|
||||
authClient: auth.HMACAuth{
|
||||
SecretKey: []byte(req.Node.MasterKey),
|
||||
},
|
||||
|
@ -66,3 +70,14 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
|
|||
|
||||
return serializer.NodePingResp{}, nil
|
||||
}
|
||||
|
||||
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
return node.instance.GetAria2Instance(), nil
|
||||
}
|
||||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
|
|
@ -191,9 +191,9 @@ func SlaveHeartbeat(c *gin.Context) {
|
|||
|
||||
// SlaveAria2Create 创建 Aria2 任务
|
||||
func SlaveAria2Create(c *gin.Context) {
|
||||
var service aria2.SlaveAria2Call
|
||||
var service serializer.SlaveAria2Call
|
||||
if err := c.ShouldBindJSON(&service); err == nil {
|
||||
res := service.Add(c)
|
||||
res := aria2.Add(c, &service)
|
||||
c.JSON(200, res)
|
||||
} else {
|
||||
c.JSON(200, ErrorResponse(err))
|
||||
|
|
|
@ -30,6 +30,8 @@ func InitSlaveRouter() *gin.Engine {
|
|||
v3 := r.Group("/api/v3/slave")
|
||||
// 鉴权中间件
|
||||
v3.Use(middleware.SignRequired())
|
||||
// 主机信息解析
|
||||
v3.Use(middleware.MasterMetadata())
|
||||
|
||||
/*
|
||||
路由
|
||||
|
@ -55,7 +57,7 @@ func InitSlaveRouter() *gin.Engine {
|
|||
// 离线下载
|
||||
aria2 := v3.Group("aria2")
|
||||
{
|
||||
aria2.POST("task", controllers.SlaveList)
|
||||
aria2.POST("task", controllers.SlaveAria2Create)
|
||||
}
|
||||
}
|
||||
return r
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
@ -17,13 +18,6 @@ type AddURLService struct {
|
|||
Dst string `json:"dst" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
// SlaveAria2Call 从机有关Aria2的请求正文
|
||||
type SlaveAria2Call struct {
|
||||
Task *model.Download `json:"task"`
|
||||
GroupOptions map[string]interface{} `json:"group_options"`
|
||||
Files []uint `json:"files"`
|
||||
}
|
||||
|
||||
// Add 主机创建新的链接离线下载任务
|
||||
func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Response {
|
||||
// 创建文件系统
|
||||
|
@ -83,6 +77,23 @@ func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Respo
|
|||
}
|
||||
|
||||
// Add 从机创建新的链接离线下载任务
|
||||
func (service *SlaveAria2Call) Add(c *gin.Context) serializer.Response {
|
||||
return serializer.Response{}
|
||||
func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response {
|
||||
if siteID, exist := c.Get("MasterSiteID"); exist {
|
||||
// 获取对应主机节点的从机Aria2实例
|
||||
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err)
|
||||
}
|
||||
|
||||
// 创建任务
|
||||
gid, err := caller.CreateTask(service.Task, service.GroupOptions)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "无法创建离线下载任务", err)
|
||||
}
|
||||
|
||||
// TODO: 创建监控
|
||||
return serializer.Response{Data: gid}
|
||||
}
|
||||
|
||||
return serializer.ParamErr("未知的主机节点ID", nil)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue