From 5699c8a0f261f5b24a7926c759eee70050888e08 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Thu, 16 Sep 2021 21:16:24 +0800 Subject: [PATCH] Feat: slave transfer file in local policy --- pkg/filesystem/driver/remote/handler.go | 4 +++- pkg/filesystem/filesystem.go | 2 ++ pkg/request/options.go | 8 ++++++++ pkg/request/request.go | 6 +++++- pkg/slave/slave.go | 13 ++++++------- 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/pkg/filesystem/driver/remote/handler.go b/pkg/filesystem/driver/remote/handler.go index f17ed3d..3f77700 100644 --- a/pkg/filesystem/driver/remote/handler.go +++ b/pkg/filesystem/driver/remote/handler.go @@ -98,7 +98,7 @@ func (handler Driver) getAPIUrl(scope string, routes ...string) string { // Get 获取文件内容 func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { - // 尝试获取速度限制 TODO 是否需要在这里限制? + // 尝试获取速度限制 speedLimit := 0 if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok { speedLimit = user.Group.SpeedLimit @@ -177,6 +177,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s request.WithContentLength(int64(size)), request.WithTimeout(time.Duration(0)), request.WithMasterMeta(), + request.WithSlaveMeta(handler.Policy.AccessKey), request.WithCredential(handler.AuthInstance, int64(credentialTTL)), ).CheckHTTPResponse(200).DecodeResponse() if err != nil { @@ -210,6 +211,7 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err bodyReader, request.WithCredential(handler.AuthInstance, int64(signTTL)), request.WithMasterMeta(), + request.WithSlaveMeta(handler.Policy.AccessKey), ).CheckHTTPResponse(200).GetResponse() if err != nil { return files, err diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 98c21dd..073a2e4 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -2,6 +2,7 @@ package filesystem import ( "errors" + "fmt" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave" @@ -256,6 +257,7 @@ func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL strin } else if fs.Policy.Type == "local" { fs.Policy.Type = "remote" fs.Policy.Server = masterURL + fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID()) fs.Policy.SecretKey = master.DBModel().MasterKey fs.DispatchHandler() } diff --git a/pkg/request/options.go b/pkg/request/options.go index ed913e0..d495757 100644 --- a/pkg/request/options.go +++ b/pkg/request/options.go @@ -22,6 +22,7 @@ type options struct { contentLength int64 masterMeta bool endpoint *url.URL + slaveNodeID string } type optionFunc func(*options) @@ -93,6 +94,13 @@ func WithMasterMeta() Option { }) } +// WithSlaveMeta 请求时携带从机信息 +func WithSlaveMeta(s string) Option { + return optionFunc(func(o *options) { + o.slaveNodeID = s + }) +} + // Endpoint 使用同一的请求Endpoint func WithEndpoint(endpoint string) Option { endpointURL, _ := url.Parse(endpoint) diff --git a/pkg/request/request.go b/pkg/request/request.go index 7480790..dabf48b 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -96,12 +96,16 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio } } - if options.masterMeta { + if options.masterMeta && conf.SystemConfig.Mode == "master" { req.Header.Add("X-Site-Url", model.GetSiteURL().String()) req.Header.Add("X-Site-Id", model.GetSettingByName("siteID")) req.Header.Add("X-Cloudreve-Version", conf.BackendVersion) } + if options.slaveNodeID != "" && conf.SystemConfig.Mode == "slave" { + req.Header.Add("X-Node-Id", options.slaveNodeID) + } + if options.contentLength != -1 { req.ContentLength = options.contentLength } diff --git a/pkg/slave/slave.go b/pkg/slave/slave.go index 6fcae5c..b38a0e5 100644 --- a/pkg/slave/slave.go +++ b/pkg/slave/slave.go @@ -12,7 +12,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/cloudreve/Cloudreve/v3/pkg/task" - "net/http" + "github.com/jinzhu/gorm" "net/url" "sync" ) @@ -45,10 +45,9 @@ type slaveController struct { // info of master node type MasterInfo struct { - SlaveID uint - ID string - TTL int - URL *url.URL + ID string + TTL int + URL *url.URL // used to invoke aria2 rpc calls Instance cluster.Node @@ -83,12 +82,12 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ } c.masters[req.SiteID] = MasterInfo{ - SlaveID: req.Node.ID, ID: req.SiteID, URL: masterUrl, TTL: req.CredentialTTL, jobTracker: make(map[string]bool), Instance: cluster.NewNodeFromDBModel(&model.Node{ + Model: gorm.Model{ID: req.Node.ID}, MasterKey: req.Node.MasterKey, Type: model.MasterNodeType, Aria2Enabled: req.Node.Aria2Enabled, @@ -128,7 +127,7 @@ func (c *slaveController) SendNotification(id, subject string, msg mq.Message) e "PUT", node.URL.ResolveReference(apiPath).String(), &body, - request.WithHeader(http.Header{"X-Node-Id": []string{fmt.Sprintf("%d", node.SlaveID)}}), + request.WithSlaveMeta(fmt.Sprintf("%d", node.Instance.ID())), request.WithCredential(node.Instance.MasterAuthInstance(), int64(node.TTL)), ).CheckHTTPResponse(200).DecodeResponse() if err != nil {