From 23d1839b29925e323a5ae6250dcf51c19721ab88 Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Sun, 29 Aug 2021 20:36:40 +0800 Subject: [PATCH] Feat: init request client with global options --- pkg/cluster/slave.go | 2 +- pkg/filesystem/driver/onedrive/client.go | 2 +- pkg/filesystem/driver/oss/callback.go | 2 +- pkg/filesystem/driver/oss/handler_test.go | 2 +- pkg/filesystem/driver/s3/handler.go | 2 +- pkg/filesystem/driver/shadow/slave/errors.go | 3 +- pkg/filesystem/driver/shadow/slave/handler.go | 9 +- pkg/filesystem/filesystem.go | 8 +- pkg/request/options.go | 92 ++++++++++++++ pkg/request/request.go | 112 +++--------------- pkg/slave/slave.go | 2 +- routers/controllers/admin.go | 2 +- service/admin/policy.go | 8 +- 13 files changed, 136 insertions(+), 110 deletions(-) create mode 100644 pkg/request/options.go diff --git a/pkg/cluster/slave.go b/pkg/cluster/slave.go index 496d741..ca78f65 100644 --- a/pkg/cluster/slave.go +++ b/pkg/cluster/slave.go @@ -38,7 +38,7 @@ func (node *SlaveNode) Init(nodeModel *model.Node) { node.lock.Lock() node.Model = nodeModel node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)} - node.caller.Client = request.HTTPClient{} + node.caller.Client = request.NewClient() node.caller.parent = node node.Active = true if node.close != nil { diff --git a/pkg/filesystem/driver/onedrive/client.go b/pkg/filesystem/driver/onedrive/client.go index 101f9c3..dbbca3c 100644 --- a/pkg/filesystem/driver/onedrive/client.go +++ b/pkg/filesystem/driver/onedrive/client.go @@ -55,7 +55,7 @@ func NewClient(policy *model.Policy) (*Client, error) { ClientID: policy.BucketName, ClientSecret: policy.SecretKey, Redirect: policy.OptionsSerialized.OdRedirect, - Request: request.HTTPClient{}, + Request: request.NewClient(), } if client.Endpoints.DriverResource == "" { diff --git a/pkg/filesystem/driver/oss/callback.go b/pkg/filesystem/driver/oss/callback.go index 7ca1e23..e5b41bb 100644 --- a/pkg/filesystem/driver/oss/callback.go +++ b/pkg/filesystem/driver/oss/callback.go @@ -42,7 +42,7 @@ func GetPublicKey(r *http.Request) ([]byte, error) { } // 获取公钥 - client := request.HTTPClient{} + client := request.NewClient() body, err := client.Request("GET", string(pubURL), nil). CheckHTTPResponse(200). GetResponse() diff --git a/pkg/filesystem/driver/oss/handler_test.go b/pkg/filesystem/driver/oss/handler_test.go index 5be01f2..58401f3 100644 --- a/pkg/filesystem/driver/oss/handler_test.go +++ b/pkg/filesystem/driver/oss/handler_test.go @@ -292,7 +292,7 @@ func TestDriver_Get(t *testing.T) { BucketName: "test", Server: "oss-cn-shanghai.aliyuncs.com", }, - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), } cache.Set("setting_preview_timeout", "3600", 0) diff --git a/pkg/filesystem/driver/s3/handler.go b/pkg/filesystem/driver/s3/handler.go index b338d8f..4502196 100644 --- a/pkg/filesystem/driver/s3/handler.go +++ b/pkg/filesystem/driver/s3/handler.go @@ -172,7 +172,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, } // 获取文件数据流 - client := request.HTTPClient{} + client := request.NewClient() resp, err := client.Request( "GET", downloadURL, diff --git a/pkg/filesystem/driver/shadow/slave/errors.go b/pkg/filesystem/driver/shadow/slave/errors.go index 4a42c0f..f1dd4f1 100644 --- a/pkg/filesystem/driver/shadow/slave/errors.go +++ b/pkg/filesystem/driver/shadow/slave/errors.go @@ -3,5 +3,6 @@ package slave import "errors" var ( - ErrNotImplemented = errors.New("This method of shadowed policy is not implemented") + ErrNotImplemented = errors.New("this method of shadowed policy is not implemented") + ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node") ) diff --git a/pkg/filesystem/driver/shadow/slave/handler.go b/pkg/filesystem/driver/shadow/slave/handler.go index cd3596f..dbd8e07 100644 --- a/pkg/filesystem/driver/shadow/slave/handler.go +++ b/pkg/filesystem/driver/shadow/slave/handler.go @@ -5,7 +5,9 @@ import ( model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/cluster" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" + "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" + "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "io" "net/url" @@ -16,6 +18,7 @@ type Driver struct { node cluster.Node handler driver.Handler policy *model.Policy + client request.Client } // NewDriver 返回新的从机指派处理器 @@ -24,12 +27,16 @@ func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) node: node, handler: handler, policy: policy, + client: request.NewClient(request.WithMasterMeta()), } } func (d Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { + realBase, ok := ctx.Value(fsctx.SlaveSrcPath).(string) + if !ok { + return ErrSlaveSrcPathNotExist + } - panic("implement me") } func (d Driver) Delete(ctx context.Context, files []string) ([]string, error) { diff --git a/pkg/filesystem/filesystem.go b/pkg/filesystem/filesystem.go index 52871fb..16eff6c 100644 --- a/pkg/filesystem/filesystem.go +++ b/pkg/filesystem/filesystem.go @@ -153,7 +153,7 @@ func (fs *FileSystem) DispatchHandler() error { case "remote": fs.Handler = remote.Driver{ Policy: currentPolicy, - Client: request.HTTPClient{}, + Client: request.NewClient(), AuthInstance: auth.HMACAuth{[]byte(currentPolicy.SecretKey)}, } return nil @@ -165,7 +165,7 @@ func (fs *FileSystem) DispatchHandler() error { case "oss": fs.Handler = oss.Driver{ Policy: currentPolicy, - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), } return nil case "upyun": @@ -178,7 +178,7 @@ func (fs *FileSystem) DispatchHandler() error { fs.Handler = onedrive.Driver{ Policy: currentPolicy, Client: client, - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), } return err case "cos": @@ -192,7 +192,7 @@ func (fs *FileSystem) DispatchHandler() error { SecretKey: currentPolicy.SecretKey, }, }), - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), } return nil case "s3": diff --git a/pkg/request/options.go b/pkg/request/options.go new file mode 100644 index 0000000..0e3511e --- /dev/null +++ b/pkg/request/options.go @@ -0,0 +1,92 @@ +package request + +import ( + "context" + "github.com/cloudreve/Cloudreve/v3/pkg/auth" + "net/http" + "time" +) + +// Option 发送请求的额外设置 +type Option interface { + apply(*options) +} + +type options struct { + timeout time.Duration + header http.Header + sign auth.Auth + signTTL int64 + ctx context.Context + contentLength int64 + masterMeta bool +} + +type optionFunc func(*options) + +func (f optionFunc) apply(o *options) { + f(o) +} + +func newDefaultOption() *options { + return &options{ + header: http.Header{}, + timeout: time.Duration(30) * time.Second, + contentLength: -1, + } +} + +// WithTimeout 设置请求超时 +func WithTimeout(t time.Duration) Option { + return optionFunc(func(o *options) { + o.timeout = t + }) +} + +// WithContext 设置请求上下文 +func WithContext(c context.Context) Option { + return optionFunc(func(o *options) { + o.ctx = c + }) +} + +// WithCredential 对请求进行签名 +func WithCredential(instance auth.Auth, ttl int64) Option { + return optionFunc(func(o *options) { + o.sign = instance + o.signTTL = ttl + }) +} + +// WithHeader 设置请求Header +func WithHeader(header http.Header) Option { + return optionFunc(func(o *options) { + for k, v := range header { + o.header[k] = v + } + }) +} + +// WithoutHeader 设置清除请求Header +func WithoutHeader(header []string) Option { + return optionFunc(func(o *options) { + for _, v := range header { + delete(o.header, v) + } + + }) +} + +// WithContentLength 设置请求大小 +func WithContentLength(s int64) Option { + return optionFunc(func(o *options) { + o.contentLength = s + }) +} + +// WithMasterMeta 请求时携带主机信息 +func WithMasterMeta() Option { + return optionFunc(func(o *options) { + o.masterMeta = true + }) +} diff --git a/pkg/request/request.go b/pkg/request/request.go index 2908fa6..eef19ee 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -1,14 +1,12 @@ package request import ( - "context" "encoding/json" "errors" "fmt" "io" "io/ioutil" "net/http" - "time" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" @@ -33,105 +31,33 @@ type Client interface { // HTTPClient 实现 Client 接口 type HTTPClient struct { + options *options } -// Option 发送请求的额外设置 -type Option interface { - apply(*options) -} - -type options struct { - timeout time.Duration - header http.Header - sign auth.Auth - signTTL int64 - ctx context.Context - contentLength int64 - masterMeta bool -} - -type optionFunc func(*options) - -func (f optionFunc) apply(o *options) { - f(o) -} - -func newDefaultOption() *options { - return &options{ - header: http.Header{}, - timeout: time.Duration(30) * time.Second, - contentLength: -1, +func NewClient(opts ...Option) Client { + client := &HTTPClient{ + options: newDefaultOption(), } -} -// WithTimeout 设置请求超时 -func WithTimeout(t time.Duration) Option { - return optionFunc(func(o *options) { - o.timeout = t - }) -} + for _, o := range opts { + o.apply(client.options) + } -// WithContext 设置请求上下文 -func WithContext(c context.Context) Option { - return optionFunc(func(o *options) { - o.ctx = c - }) -} - -// WithCredential 对请求进行签名 -func WithCredential(instance auth.Auth, ttl int64) Option { - return optionFunc(func(o *options) { - o.sign = instance - o.signTTL = ttl - }) -} - -// WithHeader 设置请求Header -func WithHeader(header http.Header) Option { - return optionFunc(func(o *options) { - for k, v := range header { - o.header[k] = v - } - }) -} - -// WithoutHeader 设置清除请求Header -func WithoutHeader(header []string) Option { - return optionFunc(func(o *options) { - for _, v := range header { - delete(o.header, v) - } - - }) -} - -// WithContentLength 设置请求大小 -func WithContentLength(s int64) Option { - return optionFunc(func(o *options) { - o.contentLength = s - }) -} - -// WithMasterMeta 请求时携带主机信息 -func WithMasterMeta() Option { - return optionFunc(func(o *options) { - o.masterMeta = true - }) + return client } // Request 发送HTTP请求 func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { // 应用额外设置 - options := newDefaultOption() for _, o := range opts { - o.apply(options) + o.apply(c.options) } // 创建请求客户端 - client := &http.Client{Timeout: options.timeout} + client := &http.Client{Timeout: c.options.timeout} // size为0时将body设为nil - if options.contentLength == 0 { + if c.options.contentLength == 0 { body = nil } @@ -140,8 +66,8 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio req *http.Request err error ) - if options.ctx != nil { - req, err = http.NewRequestWithContext(options.ctx, method, target, body) + if c.options.ctx != nil { + req, err = http.NewRequestWithContext(c.options.ctx, method, target, body) } else { req, err = http.NewRequest(method, target, body) } @@ -150,21 +76,21 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio } // 添加请求相关设置 - req.Header = options.header + req.Header = c.options.header - if options.masterMeta { + if c.options.masterMeta { req.Header.Add("X-Site-Url", model.GetSiteURL().String()) req.Header.Add("X-Site-Id", model.GetSettingByName("siteID")) req.Header.Add("X-Cloudreve-Version", conf.BackendVersion) } - if options.contentLength != -1 { - req.ContentLength = options.contentLength + if c.options.contentLength != -1 { + req.ContentLength = c.options.contentLength } // 签名请求 - if options.sign != nil { - auth.SignRequest(options.sign, req, options.signTTL) + if c.options.sign != nil { + auth.SignRequest(c.options.sign, req, c.options.signTTL) } // 发送请求 diff --git a/pkg/slave/slave.go b/pkg/slave/slave.go index 4348158..f3cd9d6 100644 --- a/pkg/slave/slave.go +++ b/pkg/slave/slave.go @@ -47,7 +47,7 @@ type masterInfo struct { func Init() { DefaultController = &slaveController{ masters: make(map[string]masterInfo), - client: request.HTTPClient{}, + client: request.NewClient(), } gob.Register(rpc.StatusInfo{}) } diff --git a/routers/controllers/admin.go b/routers/controllers/admin.go index c0dc4c9..9d10367 100644 --- a/routers/controllers/admin.go +++ b/routers/controllers/admin.go @@ -24,7 +24,7 @@ func AdminSummary(c *gin.Context) { // AdminNews 获取社区新闻 func AdminNews(c *gin.Context) { - r := request.HTTPClient{} + r := request.NewClient() res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&page%5Blimit%5D=10", nil) if res.Err == nil { io.Copy(c.Writer, res.Response.Body) diff --git a/service/admin/policy.go b/service/admin/policy.go index 7cd55fd..657d752 100644 --- a/service/admin/policy.go +++ b/service/admin/policy.go @@ -151,7 +151,7 @@ func (service *PolicyService) AddCORS() serializer.Response { case "oss": handler := oss.Driver{ Policy: &policy, - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), } if err := handler.CORS(); err != nil { return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err) @@ -161,7 +161,7 @@ func (service *PolicyService) AddCORS() serializer.Response { b := &cossdk.BaseURL{BucketURL: u} handler := cos.Driver{ Policy: &policy, - HTTPClient: request.HTTPClient{}, + HTTPClient: request.NewClient(), Client: cossdk.NewClient(b, &http.Client{ Transport: &cossdk.AuthorizationTransport{ SecretID: policy.AccessKey, @@ -195,7 +195,7 @@ func (service *SlavePingService) Test() serializer.Response { controller, _ := url.Parse("/api/v3/site/ping") - r := request.HTTPClient{} + r := request.NewClient() res, err := r.Request( "GET", master.ResolveReference(controller).String(), @@ -229,7 +229,7 @@ func (service *SlaveTestService) Test() serializer.Response { } bodyByte, _ := json.Marshal(body) - r := request.HTTPClient{} + r := request.NewClient() res, err := r.Request( "POST", slave.ResolveReference(controller).String(),