Feat: slave transfer file in local policy

This commit is contained in:
HFO4 2021-09-16 21:16:24 +08:00
parent f73abd021b
commit 5699c8a0f2
5 changed files with 24 additions and 9 deletions

View file

@ -98,7 +98,7 @@ func (handler Driver) getAPIUrl(scope string, routes ...string) string {
// Get 获取文件内容
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 尝试获取速度限制 TODO 是否需要在这里限制?
// 尝试获取速度限制
speedLimit := 0
if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok {
speedLimit = user.Group.SpeedLimit
@ -177,6 +177,7 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s
request.WithContentLength(int64(size)),
request.WithTimeout(time.Duration(0)),
request.WithMasterMeta(),
request.WithSlaveMeta(handler.Policy.AccessKey),
request.WithCredential(handler.AuthInstance, int64(credentialTTL)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
@ -210,6 +211,7 @@ func (handler Driver) Delete(ctx context.Context, files []string) ([]string, err
bodyReader,
request.WithCredential(handler.AuthInstance, int64(signTTL)),
request.WithMasterMeta(),
request.WithSlaveMeta(handler.Policy.AccessKey),
).CheckHTTPResponse(200).GetResponse()
if err != nil {
return files, err

View file

@ -2,6 +2,7 @@ package filesystem
import (
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave"
@ -256,6 +257,7 @@ func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL strin
} else if fs.Policy.Type == "local" {
fs.Policy.Type = "remote"
fs.Policy.Server = masterURL
fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID())
fs.Policy.SecretKey = master.DBModel().MasterKey
fs.DispatchHandler()
}

View file

@ -22,6 +22,7 @@ type options struct {
contentLength int64
masterMeta bool
endpoint *url.URL
slaveNodeID string
}
type optionFunc func(*options)
@ -93,6 +94,13 @@ func WithMasterMeta() Option {
})
}
// WithSlaveMeta 请求时携带从机信息
func WithSlaveMeta(s string) Option {
return optionFunc(func(o *options) {
o.slaveNodeID = s
})
}
// Endpoint 使用同一的请求Endpoint
func WithEndpoint(endpoint string) Option {
endpointURL, _ := url.Parse(endpoint)

View file

@ -96,12 +96,16 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
}
}
if options.masterMeta {
if options.masterMeta && conf.SystemConfig.Mode == "master" {
req.Header.Add("X-Site-Url", model.GetSiteURL().String())
req.Header.Add("X-Site-Id", model.GetSettingByName("siteID"))
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
}
if options.slaveNodeID != "" && conf.SystemConfig.Mode == "slave" {
req.Header.Add("X-Node-Id", options.slaveNodeID)
}
if options.contentLength != -1 {
req.ContentLength = options.contentLength
}

View file

@ -12,7 +12,7 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"net/http"
"github.com/jinzhu/gorm"
"net/url"
"sync"
)
@ -45,10 +45,9 @@ type slaveController struct {
// info of master node
type MasterInfo struct {
SlaveID uint
ID string
TTL int
URL *url.URL
ID string
TTL int
URL *url.URL
// used to invoke aria2 rpc calls
Instance cluster.Node
@ -83,12 +82,12 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ
}
c.masters[req.SiteID] = MasterInfo{
SlaveID: req.Node.ID,
ID: req.SiteID,
URL: masterUrl,
TTL: req.CredentialTTL,
jobTracker: make(map[string]bool),
Instance: cluster.NewNodeFromDBModel(&model.Node{
Model: gorm.Model{ID: req.Node.ID},
MasterKey: req.Node.MasterKey,
Type: model.MasterNodeType,
Aria2Enabled: req.Node.Aria2Enabled,
@ -128,7 +127,7 @@ func (c *slaveController) SendNotification(id, subject string, msg mq.Message) e
"PUT",
node.URL.ResolveReference(apiPath).String(),
&body,
request.WithHeader(http.Header{"X-Node-Id": []string{fmt.Sprintf("%d", node.SlaveID)}}),
request.WithSlaveMeta(fmt.Sprintf("%d", node.Instance.ID())),
request.WithCredential(node.Instance.MasterAuthInstance(), int64(node.TTL)),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {