Refactor: move slave pkg inside of cluster
Test: middleware for node communication
This commit is contained in:
parent
eaa0f6be91
commit
e41ec9defa
16 changed files with 135 additions and 43 deletions
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/crontab"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
@ -78,7 +77,7 @@ func Init(path string) {
|
|||
{
|
||||
"slave",
|
||||
func() {
|
||||
slave.Init()
|
||||
cluster.InitController()
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
@ -37,6 +37,7 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc {
|
|||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,15 +90,27 @@ func TestSignRequired(t *testing.T) {
|
|||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("GET", "/test", nil)
|
||||
SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))})
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}
|
||||
SignRequiredFunc := SignRequired(authInstance)
|
||||
|
||||
// 鉴权失败
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("PUT", "/test", nil)
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
asserts.True(c.IsAborted())
|
||||
|
||||
// Sign verify success
|
||||
c, _ = gin.CreateTestContext(rec)
|
||||
c.Request, _ = http.NewRequest("PUT", "/test", nil)
|
||||
c.Request = auth.SignRequest(authInstance, c.Request, 0)
|
||||
SignRequiredFunc(c)
|
||||
asserts.NotNil(c)
|
||||
asserts.False(c.IsAborted())
|
||||
}
|
||||
|
||||
func TestWebDAVAuth(t *testing.T) {
|
||||
|
@ -780,8 +792,6 @@ func TestS3CallbackAuth(t *testing.T) {
|
|||
WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1))
|
||||
mock.ExpectQuery("SELECT(.+)groups(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]"))
|
||||
mock.ExpectQuery("SELECT(.+)policies(.+)").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123"))
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Params = []gin.Param{
|
||||
{"key", "testCallBackUpyun"},
|
||||
|
@ -789,5 +799,6 @@ func TestS3CallbackAuth(t *testing.T) {
|
|||
c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1")))
|
||||
AuthFunc(c)
|
||||
asserts.False(c.IsAborted())
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package middleware
|
|||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
)
|
||||
|
@ -19,11 +18,11 @@ func MasterMetadata() gin.HandlerFunc {
|
|||
}
|
||||
|
||||
// UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例
|
||||
func UseSlaveAria2Instance() gin.HandlerFunc {
|
||||
func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if siteID, exist := c.Get("MasterSiteID"); exist {
|
||||
// 获取对应主机节点的从机Aria2实例
|
||||
caller, err := slave.DefaultController.GetAria2Instance(siteID.(string))
|
||||
caller, err := clusterController.GetAria2Instance(siteID.(string))
|
||||
if err != nil {
|
||||
c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err))
|
||||
c.Abort()
|
||||
|
@ -40,7 +39,7 @@ func UseSlaveAria2Instance() gin.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func SlaveRPCSignRequired() gin.HandlerFunc {
|
||||
func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64)
|
||||
if err != nil {
|
||||
|
@ -49,7 +48,7 @@ func SlaveRPCSignRequired() gin.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
slaveNode := cluster.Default.GetNodeByID(uint(nodeID))
|
||||
slaveNode := nodePool.GetNodeByID(uint(nodeID))
|
||||
if slaveNode == nil {
|
||||
c.JSON(200, serializer.ParamErr("未知的主机节点ID", err))
|
||||
c.Abort()
|
||||
|
|
80
middleware/cluster_test.go
Normal file
80
middleware/cluster_test.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMasterMetadata(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
masterMetaDataFunc := MasterMetadata()
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
c.Request.Header = map[string][]string{
|
||||
"X-Site-Id": {"expectedSiteID"},
|
||||
"X-Site-Url": {"expectedSiteURL"},
|
||||
"X-Cloudreve-Version": {"expectedMasterVersion"},
|
||||
}
|
||||
masterMetaDataFunc(c)
|
||||
siteID, _ := c.Get("MasterSiteID")
|
||||
siteURL, _ := c.Get("MasterSiteURL")
|
||||
siteVersion, _ := c.Get("MasterVersion")
|
||||
|
||||
a.Equal("expectedSiteID", siteID.(string))
|
||||
a.Equal("expectedSiteURL", siteURL.(string))
|
||||
a.Equal("expectedMasterVersion", siteVersion.(string))
|
||||
}
|
||||
|
||||
func TestSlaveRPCSignRequired(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
np := &cluster.NodePool{}
|
||||
np.Init()
|
||||
slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// id parse failed
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Request.Header.Set("X-Node-Id", "unknown")
|
||||
slaveRPCSignRequiredFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// node id not exist
|
||||
{
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Request.Header.Set("X-Node-Id", "38")
|
||||
slaveRPCSignRequiredFunc(c)
|
||||
a.True(c.IsAborted())
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
authInstance := auth.HMACAuth{SecretKey: []byte("")}
|
||||
np.Add(&model.Node{Model: gorm.Model{
|
||||
ID: 38,
|
||||
}})
|
||||
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
c.Request.Header.Set("X-Node-Id", "38")
|
||||
c.Request = auth.SignRequest(authInstance, c.Request, 0)
|
||||
slaveRPCSignRequiredFunc(c)
|
||||
a.False(c.IsAborted())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseSlaveAria2Instance(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
}
|
|
@ -35,7 +35,7 @@ type User struct {
|
|||
Storage uint64
|
||||
TwoFactor string
|
||||
Avatar string
|
||||
Options string `json:"-",gorm:"type:text"`
|
||||
Options string `json:"-" gorm:"type:text"`
|
||||
Authn string `gorm:"type:text"`
|
||||
|
||||
// 关联模型
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package slave
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -8,7 +8,6 @@ import (
|
|||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
|
@ -51,13 +50,13 @@ type MasterInfo struct {
|
|||
TTL int
|
||||
URL *url.URL
|
||||
// used to invoke aria2 rpc calls
|
||||
Instance cluster.Node
|
||||
Instance Node
|
||||
Client request.Client
|
||||
|
||||
jobTracker map[string]bool
|
||||
}
|
||||
|
||||
func Init() {
|
||||
func InitController() {
|
||||
DefaultController = &slaveController{
|
||||
masters: make(map[string]MasterInfo),
|
||||
}
|
||||
|
@ -95,7 +94,7 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
|
|||
}, int64(req.CredentialTTL)),
|
||||
),
|
||||
jobTracker: make(map[string]bool),
|
||||
Instance: cluster.NewNodeFromDBModel(&model.Node{
|
||||
Instance: NewNodeFromDBModel(&model.Node{
|
||||
Model: gorm.Model{ID: req.Node.ID},
|
||||
MasterKey: req.Node.MasterKey,
|
||||
Type: model.MasterNodeType,
|
|
@ -1,8 +1,12 @@
|
|||
package cluster
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
|
||||
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
|
||||
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
|
||||
)
|
||||
|
|
|
@ -39,14 +39,22 @@ type NodePool struct {
|
|||
|
||||
// Init 初始化从机节点池
|
||||
func Init() {
|
||||
Default = &NodePool{
|
||||
featureMap: make(map[string][]Node),
|
||||
}
|
||||
Default = &NodePool{}
|
||||
Default.Init()
|
||||
if err := Default.initFromDB(); err != nil {
|
||||
util.Log().Warning("节点池初始化失败, %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (pool *NodePool) Init() {
|
||||
pool.lock.Lock()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
pool.featureMap = make(map[string][]Node)
|
||||
pool.active = make(map[uint]Node)
|
||||
pool.inactive = make(map[uint]Node)
|
||||
}
|
||||
|
||||
func (pool *NodePool) buildIndexMap() {
|
||||
pool.lock.Lock()
|
||||
for _, feature := range featureGroup {
|
||||
|
@ -98,8 +106,6 @@ func (pool *NodePool) initFromDB() error {
|
|||
}
|
||||
|
||||
pool.lock.Lock()
|
||||
pool.active = make(map[uint]Node)
|
||||
pool.inactive = make(map[uint]Node)
|
||||
for i := 0; i < len(nodes); i++ {
|
||||
pool.add(&nodes[i])
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package onedrive
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -12,7 +13,6 @@ import (
|
|||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
|
@ -179,7 +179,7 @@ func (client *Client) UpdateCredential(ctx context.Context) error {
|
|||
|
||||
// UpdateCredential 更新凭证,并检查有效期
|
||||
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
|
||||
res, err := slave.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID)
|
||||
res, err := cluster.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
package slave
|
||||
|
||||
import "github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
|
||||
var (
|
||||
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil)
|
||||
)
|
|
@ -3,11 +3,11 @@ package slavetask
|
|||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"os"
|
||||
|
@ -68,7 +68,7 @@ func (job *TransferTask) SetErrorMsg(msg string, err error) {
|
|||
},
|
||||
}
|
||||
|
||||
if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
|
||||
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
|
||||
util.Log().Warning("无法发送转存失败通知到从机, ", err)
|
||||
}
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ func (job *TransferTask) Do() {
|
|||
return
|
||||
}
|
||||
|
||||
master, err := slave.DefaultController.GetMasterInfo(job.MasterID)
|
||||
master, err := cluster.DefaultController.GetMasterInfo(job.MasterID)
|
||||
if err != nil {
|
||||
job.SetErrorMsg("找不到主机节点", err)
|
||||
return
|
||||
|
@ -131,7 +131,7 @@ func (job *TransferTask) Do() {
|
|||
Content: serializer.SlaveTransferResult{},
|
||||
}
|
||||
|
||||
if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
|
||||
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
|
||||
util.Log().Warning("无法发送转存成功通知到从机, ", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package routers
|
|||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/middleware"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
|
@ -59,7 +60,7 @@ func InitSlaveRouter() *gin.Engine {
|
|||
|
||||
// 离线下载
|
||||
aria2 := v3.Group("aria2")
|
||||
aria2.Use(middleware.UseSlaveAria2Instance())
|
||||
aria2.Use(middleware.UseSlaveAria2Instance(cluster.DefaultController))
|
||||
{
|
||||
// 创建离线下载任务
|
||||
aria2.POST("task", controllers.SlaveAria2Create)
|
||||
|
@ -205,7 +206,7 @@ func InitMasterRouter() *gin.Engine {
|
|||
|
||||
// 从机的 RPC 通信
|
||||
slave := v3.Group("slave")
|
||||
slave.Use(middleware.SlaveRPCSignRequired())
|
||||
slave.Use(middleware.SlaveRPCSignRequired(cluster.Default))
|
||||
{
|
||||
// 事件通知
|
||||
slave.PUT("notification/:subject", controllers.SlaveNotificationPush)
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
@ -91,7 +90,7 @@ func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response
|
|||
// 创建事件通知回调
|
||||
siteID, _ := c.Get("MasterSiteID")
|
||||
mq.GlobalMQ.SubscribeCallback(gid, func(message mq.Message) {
|
||||
if err := slave.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil {
|
||||
if err := cluster.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil {
|
||||
util.Log().Warning("无法发送离线下载任务状态变更通知, %s", err)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -6,10 +6,10 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
@ -153,7 +153,7 @@ func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serial
|
|||
MasterID: id.(string),
|
||||
}
|
||||
|
||||
if err := slave.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) {
|
||||
if err := cluster.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) {
|
||||
task.TaskPoll.Submit(job.(task.Job))
|
||||
}); err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "任务创建失败", err)
|
||||
|
|
|
@ -3,10 +3,10 @@ package node
|
|||
import (
|
||||
"encoding/gob"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
@ -19,7 +19,7 @@ type OneDriveCredentialService struct {
|
|||
}
|
||||
|
||||
func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response {
|
||||
res, err := slave.DefaultController.HandleHeartBeat(req)
|
||||
res, err := cluster.DefaultController.HandleHeartBeat(req)
|
||||
if err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize slave controller", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue