Feat: init request client with global options
This commit is contained in:
parent
3b47e314e9
commit
23d1839b29
13 changed files with 136 additions and 110 deletions
|
@ -38,7 +38,7 @@ func (node *SlaveNode) Init(nodeModel *model.Node) {
|
|||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}
|
||||
node.caller.Client = request.HTTPClient{}
|
||||
node.caller.Client = request.NewClient()
|
||||
node.caller.parent = node
|
||||
node.Active = true
|
||||
if node.close != nil {
|
||||
|
|
|
@ -55,7 +55,7 @@ func NewClient(policy *model.Policy) (*Client, error) {
|
|||
ClientID: policy.BucketName,
|
||||
ClientSecret: policy.SecretKey,
|
||||
Redirect: policy.OptionsSerialized.OdRedirect,
|
||||
Request: request.HTTPClient{},
|
||||
Request: request.NewClient(),
|
||||
}
|
||||
|
||||
if client.Endpoints.DriverResource == "" {
|
||||
|
|
|
@ -42,7 +42,7 @@ func GetPublicKey(r *http.Request) ([]byte, error) {
|
|||
}
|
||||
|
||||
// 获取公钥
|
||||
client := request.HTTPClient{}
|
||||
client := request.NewClient()
|
||||
body, err := client.Request("GET", string(pubURL), nil).
|
||||
CheckHTTPResponse(200).
|
||||
GetResponse()
|
||||
|
|
|
@ -292,7 +292,7 @@ func TestDriver_Get(t *testing.T) {
|
|||
BucketName: "test",
|
||||
Server: "oss-cn-shanghai.aliyuncs.com",
|
||||
},
|
||||
HTTPClient: request.HTTPClient{},
|
||||
HTTPClient: request.NewClient(),
|
||||
}
|
||||
cache.Set("setting_preview_timeout", "3600", 0)
|
||||
|
||||
|
|
|
@ -172,7 +172,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
|
|||
}
|
||||
|
||||
// 获取文件数据流
|
||||
client := request.HTTPClient{}
|
||||
client := request.NewClient()
|
||||
resp, err := client.Request(
|
||||
"GET",
|
||||
downloadURL,
|
||||
|
|
|
@ -3,5 +3,6 @@ package slave
|
|||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNotImplemented = errors.New("This method of shadowed policy is not implemented")
|
||||
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
|
||||
ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node")
|
||||
)
|
||||
|
|
|
@ -5,7 +5,9 @@ import (
|
|||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"io"
|
||||
"net/url"
|
||||
|
@ -16,6 +18,7 @@ type Driver struct {
|
|||
node cluster.Node
|
||||
handler driver.Handler
|
||||
policy *model.Policy
|
||||
client request.Client
|
||||
}
|
||||
|
||||
// NewDriver 返回新的从机指派处理器
|
||||
|
@ -24,12 +27,16 @@ func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy)
|
|||
node: node,
|
||||
handler: handler,
|
||||
policy: policy,
|
||||
client: request.NewClient(request.WithMasterMeta()),
|
||||
}
|
||||
}
|
||||
|
||||
func (d Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
|
||||
realBase, ok := ctx.Value(fsctx.SlaveSrcPath).(string)
|
||||
if !ok {
|
||||
return ErrSlaveSrcPathNotExist
|
||||
}
|
||||
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (d Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
|
|
|
@ -153,7 +153,7 @@ func (fs *FileSystem) DispatchHandler() error {
|
|||
case "remote":
|
||||
fs.Handler = remote.Driver{
|
||||
Policy: currentPolicy,
|
||||
Client: request.HTTPClient{},
|
||||
Client: request.NewClient(),
|
||||
AuthInstance: auth.HMACAuth{[]byte(currentPolicy.SecretKey)},
|
||||
}
|
||||
return nil
|
||||
|
@ -165,7 +165,7 @@ func (fs *FileSystem) DispatchHandler() error {
|
|||
case "oss":
|
||||
fs.Handler = oss.Driver{
|
||||
Policy: currentPolicy,
|
||||
HTTPClient: request.HTTPClient{},
|
||||
HTTPClient: request.NewClient(),
|
||||
}
|
||||
return nil
|
||||
case "upyun":
|
||||
|
@ -178,7 +178,7 @@ func (fs *FileSystem) DispatchHandler() error {
|
|||
fs.Handler = onedrive.Driver{
|
||||
Policy: currentPolicy,
|
||||
Client: client,
|
||||
HTTPClient: request.HTTPClient{},
|
||||
HTTPClient: request.NewClient(),
|
||||
}
|
||||
return err
|
||||
case "cos":
|
||||
|
@ -192,7 +192,7 @@ func (fs *FileSystem) DispatchHandler() error {
|
|||
SecretKey: currentPolicy.SecretKey,
|
||||
},
|
||||
}),
|
||||
HTTPClient: request.HTTPClient{},
|
||||
HTTPClient: request.NewClient(),
|
||||
}
|
||||
return nil
|
||||
case "s3":
|
||||
|
|
92
pkg/request/options.go
Normal file
92
pkg/request/options.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package request
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Option 发送请求的额外设置
|
||||
type Option interface {
|
||||
apply(*options)
|
||||
}
|
||||
|
||||
type options struct {
|
||||
timeout time.Duration
|
||||
header http.Header
|
||||
sign auth.Auth
|
||||
signTTL int64
|
||||
ctx context.Context
|
||||
contentLength int64
|
||||
masterMeta bool
|
||||
}
|
||||
|
||||
type optionFunc func(*options)
|
||||
|
||||
func (f optionFunc) apply(o *options) {
|
||||
f(o)
|
||||
}
|
||||
|
||||
func newDefaultOption() *options {
|
||||
return &options{
|
||||
header: http.Header{},
|
||||
timeout: time.Duration(30) * time.Second,
|
||||
contentLength: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout 设置请求超时
|
||||
func WithTimeout(t time.Duration) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.timeout = t
|
||||
})
|
||||
}
|
||||
|
||||
// WithContext 设置请求上下文
|
||||
func WithContext(c context.Context) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.ctx = c
|
||||
})
|
||||
}
|
||||
|
||||
// WithCredential 对请求进行签名
|
||||
func WithCredential(instance auth.Auth, ttl int64) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.sign = instance
|
||||
o.signTTL = ttl
|
||||
})
|
||||
}
|
||||
|
||||
// WithHeader 设置请求Header
|
||||
func WithHeader(header http.Header) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
for k, v := range header {
|
||||
o.header[k] = v
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// WithoutHeader 设置清除请求Header
|
||||
func WithoutHeader(header []string) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
for _, v := range header {
|
||||
delete(o.header, v)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
// WithContentLength 设置请求大小
|
||||
func WithContentLength(s int64) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.contentLength = s
|
||||
})
|
||||
}
|
||||
|
||||
// WithMasterMeta 请求时携带主机信息
|
||||
func WithMasterMeta() Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.masterMeta = true
|
||||
})
|
||||
}
|
|
@ -1,14 +1,12 @@
|
|||
package request
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
|
@ -33,105 +31,33 @@ type Client interface {
|
|||
|
||||
// HTTPClient 实现 Client 接口
|
||||
type HTTPClient struct {
|
||||
options *options
|
||||
}
|
||||
|
||||
// Option 发送请求的额外设置
|
||||
type Option interface {
|
||||
apply(*options)
|
||||
}
|
||||
|
||||
type options struct {
|
||||
timeout time.Duration
|
||||
header http.Header
|
||||
sign auth.Auth
|
||||
signTTL int64
|
||||
ctx context.Context
|
||||
contentLength int64
|
||||
masterMeta bool
|
||||
}
|
||||
|
||||
type optionFunc func(*options)
|
||||
|
||||
func (f optionFunc) apply(o *options) {
|
||||
f(o)
|
||||
}
|
||||
|
||||
func newDefaultOption() *options {
|
||||
return &options{
|
||||
header: http.Header{},
|
||||
timeout: time.Duration(30) * time.Second,
|
||||
contentLength: -1,
|
||||
func NewClient(opts ...Option) Client {
|
||||
client := &HTTPClient{
|
||||
options: newDefaultOption(),
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout 设置请求超时
|
||||
func WithTimeout(t time.Duration) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.timeout = t
|
||||
})
|
||||
}
|
||||
for _, o := range opts {
|
||||
o.apply(client.options)
|
||||
}
|
||||
|
||||
// WithContext 设置请求上下文
|
||||
func WithContext(c context.Context) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.ctx = c
|
||||
})
|
||||
}
|
||||
|
||||
// WithCredential 对请求进行签名
|
||||
func WithCredential(instance auth.Auth, ttl int64) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.sign = instance
|
||||
o.signTTL = ttl
|
||||
})
|
||||
}
|
||||
|
||||
// WithHeader 设置请求Header
|
||||
func WithHeader(header http.Header) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
for k, v := range header {
|
||||
o.header[k] = v
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// WithoutHeader 设置清除请求Header
|
||||
func WithoutHeader(header []string) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
for _, v := range header {
|
||||
delete(o.header, v)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
// WithContentLength 设置请求大小
|
||||
func WithContentLength(s int64) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.contentLength = s
|
||||
})
|
||||
}
|
||||
|
||||
// WithMasterMeta 请求时携带主机信息
|
||||
func WithMasterMeta() Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.masterMeta = true
|
||||
})
|
||||
return client
|
||||
}
|
||||
|
||||
// Request 发送HTTP请求
|
||||
func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
|
||||
// 应用额外设置
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
o.apply(c.options)
|
||||
}
|
||||
|
||||
// 创建请求客户端
|
||||
client := &http.Client{Timeout: options.timeout}
|
||||
client := &http.Client{Timeout: c.options.timeout}
|
||||
|
||||
// size为0时将body设为nil
|
||||
if options.contentLength == 0 {
|
||||
if c.options.contentLength == 0 {
|
||||
body = nil
|
||||
}
|
||||
|
||||
|
@ -140,8 +66,8 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
|
|||
req *http.Request
|
||||
err error
|
||||
)
|
||||
if options.ctx != nil {
|
||||
req, err = http.NewRequestWithContext(options.ctx, method, target, body)
|
||||
if c.options.ctx != nil {
|
||||
req, err = http.NewRequestWithContext(c.options.ctx, method, target, body)
|
||||
} else {
|
||||
req, err = http.NewRequest(method, target, body)
|
||||
}
|
||||
|
@ -150,21 +76,21 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
|
|||
}
|
||||
|
||||
// 添加请求相关设置
|
||||
req.Header = options.header
|
||||
req.Header = c.options.header
|
||||
|
||||
if options.masterMeta {
|
||||
if c.options.masterMeta {
|
||||
req.Header.Add("X-Site-Url", model.GetSiteURL().String())
|
||||
req.Header.Add("X-Site-Id", model.GetSettingByName("siteID"))
|
||||
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
|
||||
}
|
||||
|
||||
if options.contentLength != -1 {
|
||||
req.ContentLength = options.contentLength
|
||||
if c.options.contentLength != -1 {
|
||||
req.ContentLength = c.options.contentLength
|
||||
}
|
||||
|
||||
// 签名请求
|
||||
if options.sign != nil {
|
||||
auth.SignRequest(options.sign, req, options.signTTL)
|
||||
if c.options.sign != nil {
|
||||
auth.SignRequest(c.options.sign, req, c.options.signTTL)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
|
|
|
@ -47,7 +47,7 @@ type masterInfo struct {
|
|||
func Init() {
|
||||
DefaultController = &slaveController{
|
||||
masters: make(map[string]masterInfo),
|
||||
client: request.HTTPClient{},
|
||||
client: request.NewClient(),
|
||||
}
|
||||
gob.Register(rpc.StatusInfo{})
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ func AdminSummary(c *gin.Context) {
|
|||
|
||||
// AdminNews 获取社区新闻
|
||||
func AdminNews(c *gin.Context) {
|
||||
r := request.HTTPClient{}
|
||||
r := request.NewClient()
|
||||
res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&page%5Blimit%5D=10", nil)
|
||||
if res.Err == nil {
|
||||
io.Copy(c.Writer, res.Response.Body)
|
||||
|
|
|
@ -151,7 +151,7 @@ func (service *PolicyService) AddCORS() serializer.Response {
|
|||
case "oss":
|
||||
handler := oss.Driver{
|
||||
Policy: &policy,
|
||||
HTTPClient: request.HTTPClient{},
|
||||
HTTPClient: request.NewClient(),
|
||||
}
|
||||
if err := handler.CORS(); err != nil {
|
||||
return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err)
|
||||
|
@ -161,7 +161,7 @@ func (service *PolicyService) AddCORS() serializer.Response {
|
|||
b := &cossdk.BaseURL{BucketURL: u}
|
||||
handler := cos.Driver{
|
||||
Policy: &policy,
|
||||
HTTPClient: request.HTTPClient{},
|
||||
HTTPClient: request.NewClient(),
|
||||
Client: cossdk.NewClient(b, &http.Client{
|
||||
Transport: &cossdk.AuthorizationTransport{
|
||||
SecretID: policy.AccessKey,
|
||||
|
@ -195,7 +195,7 @@ func (service *SlavePingService) Test() serializer.Response {
|
|||
|
||||
controller, _ := url.Parse("/api/v3/site/ping")
|
||||
|
||||
r := request.HTTPClient{}
|
||||
r := request.NewClient()
|
||||
res, err := r.Request(
|
||||
"GET",
|
||||
master.ResolveReference(controller).String(),
|
||||
|
@ -229,7 +229,7 @@ func (service *SlaveTestService) Test() serializer.Response {
|
|||
}
|
||||
bodyByte, _ := json.Marshal(body)
|
||||
|
||||
r := request.HTTPClient{}
|
||||
r := request.NewClient()
|
||||
res, err := r.Request(
|
||||
"POST",
|
||||
slave.ResolveReference(controller).String(),
|
||||
|
|
Loading…
Add table
Reference in a new issue