diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 358e391..d8250e8 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -78,7 +78,7 @@ func getSignContent(r *http.Request) (rawSignString string) { // 决定要签名的header var signedHeader []string for k, _ := range r.Header { - if strings.HasPrefix(k, "X-") { + if strings.HasPrefix(k, "X-") && k != "X-Filename" { signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k))) } } diff --git a/pkg/filesystem/driver/remote/handler.go b/pkg/filesystem/driver/remote/handler.go index 0d3ebb4..f17ed3d 100644 --- a/pkg/filesystem/driver/remote/handler.go +++ b/pkg/filesystem/driver/remote/handler.go @@ -170,14 +170,14 @@ func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, s handler.Policy.GetUploadURL(), file, request.WithHeader(map[string][]string{ - "Authorization": {credential.Token}, - "X-Policy": {credential.Policy}, - "X-FileName": {fileName}, - "X-Overwrite": {overwrite}, + "X-Policy": {credential.Policy}, + "X-FileName": {fileName}, + "X-Overwrite": {overwrite}, }), request.WithContentLength(int64(size)), request.WithTimeout(time.Duration(0)), request.WithMasterMeta(), + request.WithCredential(handler.AuthInstance, int64(credentialTTL)), ).CheckHTTPResponse(200).DecodeResponse() if err != nil { return err diff --git a/pkg/filesystem/driver/shadow/masterinslave/handler.go b/pkg/filesystem/driver/shadow/masterinslave/handler.go index 00b119c..485a9b2 100644 --- a/pkg/filesystem/driver/shadow/masterinslave/handler.go +++ b/pkg/filesystem/driver/shadow/masterinslave/handler.go @@ -3,6 +3,7 @@ package masterinslave import ( "context" model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -12,17 +13,17 @@ import ( // Driver 影子存储策略,用于在从机端上传文件 type Driver struct { - masterID string - handler driver.Handler - policy *model.Policy + master cluster.Node + handler driver.Handler + policy *model.Policy } // NewDriver 返回新的处理器 -func NewDriver(masterID string, handler driver.Handler, policy *model.Policy) driver.Handler { +func NewDriver(master cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler { return &Driver{ - masterID: masterID, - handler: handler, - policy: policy, + master: master, + handler: handler, + policy: policy, } } diff --git a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go index e9f18ad..9d13247 100644 --- a/pkg/filesystem/driver/shadow/slaveinmaster/handler.go +++ b/pkg/filesystem/driver/shadow/slaveinmaster/handler.go @@ -97,7 +97,7 @@ func (d *Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size u } func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) { - return nil, ErrNotImplemented + return d.handler.Delete(ctx, files) } func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 2546d24..98c21dd 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -248,8 +248,19 @@ func (fs *FileSystem) SwitchToSlaveHandler(node cluster.Node) { } // SwitchToShadowHandler 将负责上传的 Handler 切换为从机节点转存使用的影子处理器 -func (fs *FileSystem) SwitchToShadowHandler(masterID string) { - fs.Handler = masterinslave.NewDriver(masterID, fs.Handler, fs.Policy) +func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL string) { + // 交换主从存储策略 + if fs.Policy.Type == "remote" { + fs.Policy.Type = "local" + fs.DispatchHandler() + } else if fs.Policy.Type == "local" { + fs.Policy.Type = "remote" + fs.Policy.Server = masterURL + fs.Policy.SecretKey = master.DBModel().MasterKey + fs.DispatchHandler() + } + + fs.Handler = masterinslave.NewDriver(master, fs.Handler, fs.Policy) } // SetTargetFile 设置当前处理的目标文件 diff --git a/pkg/request/request.go b/pkg/request/request.go index 77d84b6..7480790 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -19,7 +19,7 @@ import ( ) // GeneralClient 通用 HTTP Client -var GeneralClient Client = HTTPClient{} +var GeneralClient Client = NewClient() // Response 请求的响应或错误信息 type Response struct { diff --git a/pkg/request/slave.go b/pkg/request/slave.go index 0bd1ca3..2948250 100644 --- a/pkg/request/slave.go +++ b/pkg/request/slave.go @@ -11,6 +11,7 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/serializer" ) +// TODO: move to slave pkg // RemoteCallback 发送远程存储策略上传回调请求 func RemoteCallback(url string, body serializer.UploadCallback) error { callbackBody, err := json.Marshal(struct { diff --git a/pkg/slave/slave.go b/pkg/slave/slave.go index e062ef0..6fcae5c 100644 --- a/pkg/slave/slave.go +++ b/pkg/slave/slave.go @@ -24,7 +24,7 @@ type Controller interface { // Handle heartbeat sent from master HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error) - // Get Aria2 instance by master node id + // Get Aria2 Instance by master node ID GetAria2Instance(string) (common.Aria2, error) // Send event change message to master node @@ -32,28 +32,32 @@ type Controller interface { // Submit async task into task pool SubmitTask(string, task.Job, string) error + + // Get master node info + GetMasterInfo(string) (*MasterInfo, error) } type slaveController struct { - masters map[string]masterInfo + masters map[string]MasterInfo client request.Client lock sync.RWMutex } // info of master node -type masterInfo struct { - slaveID uint - id string - ttl int - url *url.URL - jobTracker map[string]bool +type MasterInfo struct { + SlaveID uint + ID string + TTL int + URL *url.URL // used to invoke aria2 rpc calls - instance cluster.Node + Instance cluster.Node + + jobTracker map[string]bool } func Init() { DefaultController = &slaveController{ - masters: make(map[string]masterInfo), + masters: make(map[string]MasterInfo), client: request.NewClient(), } gob.Register(rpc.StatusInfo{}) @@ -70,7 +74,7 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ if (ok && req.IsUpdate) || !ok { if ok { - origin.instance.Kill() + origin.Instance.Kill() } masterUrl, err := url.Parse(req.SiteURL) @@ -78,13 +82,13 @@ func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializ return serializer.NodePingResp{}, err } - c.masters[req.SiteID] = masterInfo{ - slaveID: req.Node.ID, - id: req.SiteID, - url: masterUrl, - ttl: req.CredentialTTL, + c.masters[req.SiteID] = MasterInfo{ + SlaveID: req.Node.ID, + ID: req.SiteID, + URL: masterUrl, + TTL: req.CredentialTTL, jobTracker: make(map[string]bool), - instance: cluster.NewNodeFromDBModel(&model.Node{ + Instance: cluster.NewNodeFromDBModel(&model.Node{ MasterKey: req.Node.MasterKey, Type: model.MasterNodeType, Aria2Enabled: req.Node.Aria2Enabled, @@ -101,7 +105,7 @@ func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) { defer c.lock.RUnlock() if node, ok := c.masters[id]; ok { - return node.instance.GetAria2Instance(), nil + return node.Instance.GetAria2Instance(), nil } return nil, ErrMasterNotFound @@ -122,10 +126,10 @@ func (c *slaveController) SendNotification(id, subject string, msg mq.Message) e res, err := c.client.Request( "PUT", - node.url.ResolveReference(apiPath).String(), + node.URL.ResolveReference(apiPath).String(), &body, - request.WithHeader(http.Header{"X-Node-Id": []string{fmt.Sprintf("%d", node.slaveID)}}), - request.WithCredential(node.instance.MasterAuthInstance(), int64(node.ttl)), + request.WithHeader(http.Header{"X-Node-Id": []string{fmt.Sprintf("%d", node.SlaveID)}}), + request.WithCredential(node.Instance.MasterAuthInstance(), int64(node.TTL)), ).CheckHTTPResponse(200).DecodeResponse() if err != nil { return err @@ -159,3 +163,15 @@ func (c *slaveController) SubmitTask(id string, job task.Job, hash string) error return ErrMasterNotFound } + +// GetMasterInfo 获取主机节点信息 +func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + if node, ok := c.masters[id]; ok { + return &node, nil + } + + return nil, ErrMasterNotFound +} diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index c7d6da6..8db10f9 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -91,7 +91,13 @@ func (job *TransferTask) Do() { return } - fs.SwitchToShadowHandler(job.MasterID) + master, err := slave.DefaultController.GetMasterInfo(job.MasterID) + if err != nil { + job.SetErrorMsg("找不到主机节点", err) + return + } + + fs.SwitchToShadowHandler(master.Instance, master.URL.String()) ctx := context.WithValue(context.Background(), fsctx.DisableOverwrite, true) file, err := os.Open(util.RelativePath(job.Req.Src)) if err != nil { diff --git a/routers/router.go b/routers/router.go index 92e4133..9110bdb 100644 --- a/routers/router.go +++ b/routers/router.go @@ -204,6 +204,8 @@ func InitMasterRouter() *gin.Engine { slave.Use(middleware.SlaveRPCSignRequired()) { slave.PUT("notification/:subject", controllers.SlaveNotificationPush) + // 上传 + slave.POST("upload", controllers.SlaveUpload) } // 回调接口