diff --git a/bootstrap/init.go b/bootstrap/init.go index 6b43adb..60ebcb8 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -9,7 +9,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/crontab" "github.com/cloudreve/Cloudreve/v3/pkg/email" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/gin-gonic/gin" ) @@ -78,7 +77,7 @@ func Init(path string) { { "slave", func() { - slave.Init() + cluster.InitController() }, }, { diff --git a/middleware/auth.go b/middleware/auth.go index 135dd8c..3e7cbe7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -37,6 +37,7 @@ func SignRequired(authInstance auth.Auth) gin.HandlerFunc { c.Abort() return } + c.Next() } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index a17602d..84d229e 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -90,15 +90,27 @@ func TestSignRequired(t *testing.T) { rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request, _ = http.NewRequest("GET", "/test", nil) - SignRequiredFunc := SignRequired(auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))}) + authInstance := auth.HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} + SignRequiredFunc := SignRequired(authInstance) // 鉴权失败 SignRequiredFunc(c) asserts.NotNil(c) + asserts.True(c.IsAborted()) + c, _ = gin.CreateTestContext(rec) c.Request, _ = http.NewRequest("PUT", "/test", nil) SignRequiredFunc(c) asserts.NotNil(c) + asserts.True(c.IsAborted()) + + // Sign verify success + c, _ = gin.CreateTestContext(rec) + c.Request, _ = http.NewRequest("PUT", "/test", nil) + c.Request = auth.SignRequest(authInstance, c.Request, 0) + SignRequiredFunc(c) + asserts.NotNil(c) + asserts.False(c.IsAborted()) } func TestWebDAVAuth(t *testing.T) { @@ -780,8 +792,6 @@ func TestS3CallbackAuth(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"id", "group_id"}).AddRow(1, 1)) mock.ExpectQuery("SELECT(.+)groups(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "policies"}).AddRow(1, "[702]")) - mock.ExpectQuery("SELECT(.+)policies(.+)"). - WillReturnRows(sqlmock.NewRows([]string{"id", "access_key", "secret_key"}).AddRow(2, "123", "123")) c, _ := gin.CreateTestContext(rec) c.Params = []gin.Param{ {"key", "testCallBackUpyun"}, @@ -789,5 +799,6 @@ func TestS3CallbackAuth(t *testing.T) { c.Request, _ = http.NewRequest("POST", "/api/v3/callback/upyun/testCallBackUpyun", ioutil.NopCloser(strings.NewReader("1"))) AuthFunc(c) asserts.False(c.IsAborted()) + asserts.NoError(mock.ExpectationsWereMet()) } } diff --git a/middleware/cluster.go b/middleware/cluster.go index 079a4f4..d8bf979 100644 --- a/middleware/cluster.go +++ b/middleware/cluster.go @@ -3,7 +3,6 @@ package middleware import ( "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/gin-gonic/gin" "strconv" ) @@ -19,11 +18,11 @@ func MasterMetadata() gin.HandlerFunc { } // UseSlaveAria2Instance 从机用于获取对应主机节点的Aria2实例 -func UseSlaveAria2Instance() gin.HandlerFunc { +func UseSlaveAria2Instance(clusterController cluster.Controller) gin.HandlerFunc { return func(c *gin.Context) { if siteID, exist := c.Get("MasterSiteID"); exist { // 获取对应主机节点的从机Aria2实例 - caller, err := slave.DefaultController.GetAria2Instance(siteID.(string)) + caller, err := clusterController.GetAria2Instance(siteID.(string)) if err != nil { c.JSON(200, serializer.Err(serializer.CodeNotSet, "无法获取 Aria2 实例", err)) c.Abort() @@ -40,7 +39,7 @@ func UseSlaveAria2Instance() gin.HandlerFunc { } } -func SlaveRPCSignRequired() gin.HandlerFunc { +func SlaveRPCSignRequired(nodePool cluster.Pool) gin.HandlerFunc { return func(c *gin.Context) { nodeID, err := strconv.ParseUint(c.GetHeader("X-Node-Id"), 10, 64) if err != nil { @@ -49,7 +48,7 @@ func SlaveRPCSignRequired() gin.HandlerFunc { return } - slaveNode := cluster.Default.GetNodeByID(uint(nodeID)) + slaveNode := nodePool.GetNodeByID(uint(nodeID)) if slaveNode == nil { c.JSON(200, serializer.ParamErr("未知的主机节点ID", err)) c.Abort() diff --git a/middleware/cluster_test.go b/middleware/cluster_test.go new file mode 100644 index 0000000..2c25e29 --- /dev/null +++ b/middleware/cluster_test.go @@ -0,0 +1,80 @@ +package middleware + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" + "github.com/gin-gonic/gin" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "net/http/httptest" + "testing" +) + +func TestMasterMetadata(t *testing.T) { + a := assert.New(t) + masterMetaDataFunc := MasterMetadata() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest("GET", "/", nil) + + c.Request.Header = map[string][]string{ + "X-Site-Id": {"expectedSiteID"}, + "X-Site-Url": {"expectedSiteURL"}, + "X-Cloudreve-Version": {"expectedMasterVersion"}, + } + masterMetaDataFunc(c) + siteID, _ := c.Get("MasterSiteID") + siteURL, _ := c.Get("MasterSiteURL") + siteVersion, _ := c.Get("MasterVersion") + + a.Equal("expectedSiteID", siteID.(string)) + a.Equal("expectedSiteURL", siteURL.(string)) + a.Equal("expectedMasterVersion", siteVersion.(string)) +} + +func TestSlaveRPCSignRequired(t *testing.T) { + a := assert.New(t) + np := &cluster.NodePool{} + np.Init() + slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np) + rec := httptest.NewRecorder() + + // id parse failed + { + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Request.Header.Set("X-Node-Id", "unknown") + slaveRPCSignRequiredFunc(c) + a.True(c.IsAborted()) + } + + // node id not exist + { + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Request.Header.Set("X-Node-Id", "38") + slaveRPCSignRequiredFunc(c) + a.True(c.IsAborted()) + } + + // success + { + authInstance := auth.HMACAuth{SecretKey: []byte("")} + np.Add(&model.Node{Model: gorm.Model{ + ID: 38, + }}) + + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest("POST", "/", nil) + c.Request.Header.Set("X-Node-Id", "38") + c.Request = auth.SignRequest(authInstance, c.Request, 0) + slaveRPCSignRequiredFunc(c) + a.False(c.IsAborted()) + } +} + +func TestUseSlaveAria2Instance(t *testing.T) { + a := assert.New(t) + +} diff --git a/models/user.go b/models/user.go index ecd091b..8045d27 100644 --- a/models/user.go +++ b/models/user.go @@ -35,7 +35,7 @@ type User struct { Storage uint64 TwoFactor string Avatar string - Options string `json:"-",gorm:"type:text"` + Options string `json:"-" gorm:"type:text"` Authn string `gorm:"type:text"` // 关联模型 diff --git a/pkg/slave/slave.go b/pkg/cluster/controller.go similarity index 96% rename from pkg/slave/slave.go rename to pkg/cluster/controller.go index aa457b5..d5352ee 100644 --- a/pkg/slave/slave.go +++ b/pkg/cluster/controller.go @@ -1,4 +1,4 @@ -package slave +package cluster import ( "bytes" @@ -8,7 +8,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" "github.com/cloudreve/Cloudreve/v3/pkg/auth" - "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -51,13 +50,13 @@ type MasterInfo struct { TTL int URL *url.URL // used to invoke aria2 rpc calls - Instance cluster.Node + Instance Node Client request.Client jobTracker map[string]bool } -func Init() { +func InitController() { DefaultController = &slaveController{ masters: make(map[string]MasterInfo), } @@ -95,7 +94,7 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ }, int64(req.CredentialTTL)), ), jobTracker: make(map[string]bool), - Instance: cluster.NewNodeFromDBModel(&model.Node{ + Instance: NewNodeFromDBModel(&model.Node{ Model: gorm.Model{ID: req.Node.ID}, MasterKey: req.Node.MasterKey, Type: model.MasterNodeType, diff --git a/pkg/cluster/errors.go b/pkg/cluster/errors.go index 9afdbef..84b2ad8 100644 --- a/pkg/cluster/errors.go +++ b/pkg/cluster/errors.go @@ -1,8 +1,12 @@ package cluster -import "errors" +import ( + "errors" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" +) var ( ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed") ErrIlegalPath = errors.New("path out of boundary of setting temp folder") + ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil) ) diff --git a/pkg/cluster/pool.go b/pkg/cluster/pool.go index 2181df4..a297649 100644 --- a/pkg/cluster/pool.go +++ b/pkg/cluster/pool.go @@ -39,14 +39,22 @@ type NodePool struct { // Init 初始化从机节点池 func Init() { - Default = &NodePool{ - featureMap: make(map[string][]Node), - } + Default = &NodePool{} + Default.Init() if err := Default.initFromDB(); err != nil { util.Log().Warning("节点池初始化失败, %s", err) } } +func (pool *NodePool) Init() { + pool.lock.Lock() + defer pool.lock.Unlock() + + pool.featureMap = make(map[string][]Node) + pool.active = make(map[uint]Node) + pool.inactive = make(map[uint]Node) +} + func (pool *NodePool) buildIndexMap() { pool.lock.Lock() for _, feature := range featureGroup { @@ -98,8 +106,6 @@ func (pool *NodePool) initFromDB() error { } pool.lock.Lock() - pool.active = make(map[uint]Node) - pool.inactive = make(map[uint]Node) for i := 0; i < len(nodes); i++ { pool.add(&nodes[i]) } diff --git a/pkg/filesystem/driver/onedrive/oauth.go b/pkg/filesystem/driver/onedrive/oauth.go index 49170fe..0d5865b 100644 --- a/pkg/filesystem/driver/onedrive/oauth.go +++ b/pkg/filesystem/driver/onedrive/oauth.go @@ -3,6 +3,7 @@ package onedrive import ( "context" "encoding/json" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "io/ioutil" "net/http" "net/url" @@ -12,7 +13,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/request" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) @@ -179,7 +179,7 @@ func (client *Client) UpdateCredential(ctx context.Context) error { // UpdateCredential 更新凭证,并检查有效期 func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { - res, err := slave.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID) + res, err := cluster.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID) if err != nil { return err } diff --git a/pkg/slave/errors.go b/pkg/slave/errors.go deleted file mode 100644 index 2af6e13..0000000 --- a/pkg/slave/errors.go +++ /dev/null @@ -1,7 +0,0 @@ -package slave - -import "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - -var ( - ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "未知的主机节点", nil) -) diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index c312742..9231092 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -3,11 +3,11 @@ package slavetask import ( "context" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/util" "os" @@ -68,7 +68,7 @@ func (job *TransferTask) SetErrorMsg(msg string, err error) { }, } - if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { + if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { util.Log().Warning("无法发送转存失败通知到从机, ", err) } } @@ -94,7 +94,7 @@ func (job *TransferTask) Do() { return } - master, err := slave.DefaultController.GetMasterInfo(job.MasterID) + master, err := cluster.DefaultController.GetMasterInfo(job.MasterID) if err != nil { job.SetErrorMsg("找不到主机节点", err) return @@ -131,7 +131,7 @@ func (job *TransferTask) Do() { Content: serializer.SlaveTransferResult{}, } - if err := slave.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { + if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { util.Log().Warning("无法发送转存成功通知到从机, ", err) } } diff --git a/routers/router.go b/routers/router.go index a7204c4..8f335c3 100644 --- a/routers/router.go +++ b/routers/router.go @@ -3,6 +3,7 @@ package routers import ( "github.com/cloudreve/Cloudreve/v3/middleware" "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/util" @@ -59,7 +60,7 @@ func InitSlaveRouter() *gin.Engine { // 离线下载 aria2 := v3.Group("aria2") - aria2.Use(middleware.UseSlaveAria2Instance()) + aria2.Use(middleware.UseSlaveAria2Instance(cluster.DefaultController)) { // 创建离线下载任务 aria2.POST("task", controllers.SlaveAria2Create) @@ -205,7 +206,7 @@ func InitMasterRouter() *gin.Engine { // 从机的 RPC 通信 slave := v3.Group("slave") - slave.Use(middleware.SlaveRPCSignRequired()) + slave.Use(middleware.SlaveRPCSignRequired(cluster.Default)) { // 事件通知 slave.PUT("notification/:subject", controllers.SlaveNotificationPush) diff --git a/service/aria2/add.go b/service/aria2/add.go index 26b6baa..73446b4 100644 --- a/service/aria2/add.go +++ b/service/aria2/add.go @@ -9,7 +9,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-gonic/gin" ) @@ -91,7 +90,7 @@ func Add(c *gin.Context, service *serializer.SlaveAria2Call) serializer.Response // 创建事件通知回调 siteID, _ := c.Get("MasterSiteID") mq.GlobalMQ.SubscribeCallback(gid, func(message mq.Message) { - if err := slave.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil { + if err := cluster.DefaultController.SendNotification(siteID.(string), message.TriggeredBy, message); err != nil { util.Log().Warning("无法发送离线下载任务状态变更通知, %s", err) } }) diff --git a/service/explorer/slave.go b/service/explorer/slave.go index 8beb15b..54638ee 100644 --- a/service/explorer/slave.go +++ b/service/explorer/slave.go @@ -6,10 +6,10 @@ import ( "encoding/json" "fmt" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/cloudreve/Cloudreve/v3/pkg/task" "github.com/cloudreve/Cloudreve/v3/pkg/task/slavetask" "github.com/gin-gonic/gin" @@ -153,7 +153,7 @@ func CreateTransferTask(c *gin.Context, req *serializer.SlaveTransferReq) serial MasterID: id.(string), } - if err := slave.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { + if err := cluster.DefaultController.SubmitTask(job.MasterID, job, req.Hash(job.MasterID), func(job interface{}) { task.TaskPoll.Submit(job.(task.Job)) }); err != nil { return serializer.Err(serializer.CodeInternalSetting, "任务创建失败", err) diff --git a/service/node/fabric.go b/service/node/fabric.go index 79dfb29..63b5ecf 100644 --- a/service/node/fabric.go +++ b/service/node/fabric.go @@ -3,10 +3,10 @@ package node import ( "encoding/gob" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" - "github.com/cloudreve/Cloudreve/v3/pkg/slave" "github.com/gin-gonic/gin" ) @@ -19,7 +19,7 @@ type OneDriveCredentialService struct { } func HandleMasterHeartbeat(req *serializer.NodePingReq) serializer.Response { - res, err := slave.DefaultController.HandleHeartBeat(req) + res, err := cluster.DefaultController.HandleHeartBeat(req) if err != nil { return serializer.Err(serializer.CodeInternalSetting, "Cannot initialize slave controller", err) }