Feat: init request client with global options

This commit is contained in:
HFO4 2021-08-29 20:36:40 +08:00
parent 3b47e314e9
commit 23d1839b29
13 changed files with 136 additions and 110 deletions
pkg
routers/controllers
service/admin

View file

@ -38,7 +38,7 @@ func (node *SlaveNode) Init(nodeModel *model.Node) {
node.lock.Lock()
node.Model = nodeModel
node.AuthInstance = auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}
node.caller.Client = request.HTTPClient{}
node.caller.Client = request.NewClient()
node.caller.parent = node
node.Active = true
if node.close != nil {

View file

@ -55,7 +55,7 @@ func NewClient(policy *model.Policy) (*Client, error) {
ClientID: policy.BucketName,
ClientSecret: policy.SecretKey,
Redirect: policy.OptionsSerialized.OdRedirect,
Request: request.HTTPClient{},
Request: request.NewClient(),
}
if client.Endpoints.DriverResource == "" {

View file

@ -42,7 +42,7 @@ func GetPublicKey(r *http.Request) ([]byte, error) {
}
// 获取公钥
client := request.HTTPClient{}
client := request.NewClient()
body, err := client.Request("GET", string(pubURL), nil).
CheckHTTPResponse(200).
GetResponse()

View file

@ -292,7 +292,7 @@ func TestDriver_Get(t *testing.T) {
BucketName: "test",
Server: "oss-cn-shanghai.aliyuncs.com",
},
HTTPClient: request.HTTPClient{},
HTTPClient: request.NewClient(),
}
cache.Set("setting_preview_timeout", "3600", 0)

View file

@ -172,7 +172,7 @@ func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser,
}
// 获取文件数据流
client := request.HTTPClient{}
client := request.NewClient()
resp, err := client.Request(
"GET",
downloadURL,

View file

@ -3,5 +3,6 @@ package slave
import "errors"
var (
ErrNotImplemented = errors.New("This method of shadowed policy is not implemented")
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node")
)

View file

@ -5,7 +5,9 @@ import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"io"
"net/url"
@ -16,6 +18,7 @@ type Driver struct {
node cluster.Node
handler driver.Handler
policy *model.Policy
client request.Client
}
// NewDriver 返回新的从机指派处理器
@ -24,12 +27,16 @@ func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy)
node: node,
handler: handler,
policy: policy,
client: request.NewClient(request.WithMasterMeta()),
}
}
func (d Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
realBase, ok := ctx.Value(fsctx.SlaveSrcPath).(string)
if !ok {
return ErrSlaveSrcPathNotExist
}
panic("implement me")
}
func (d Driver) Delete(ctx context.Context, files []string) ([]string, error) {

View file

@ -153,7 +153,7 @@ func (fs *FileSystem) DispatchHandler() error {
case "remote":
fs.Handler = remote.Driver{
Policy: currentPolicy,
Client: request.HTTPClient{},
Client: request.NewClient(),
AuthInstance: auth.HMACAuth{[]byte(currentPolicy.SecretKey)},
}
return nil
@ -165,7 +165,7 @@ func (fs *FileSystem) DispatchHandler() error {
case "oss":
fs.Handler = oss.Driver{
Policy: currentPolicy,
HTTPClient: request.HTTPClient{},
HTTPClient: request.NewClient(),
}
return nil
case "upyun":
@ -178,7 +178,7 @@ func (fs *FileSystem) DispatchHandler() error {
fs.Handler = onedrive.Driver{
Policy: currentPolicy,
Client: client,
HTTPClient: request.HTTPClient{},
HTTPClient: request.NewClient(),
}
return err
case "cos":
@ -192,7 +192,7 @@ func (fs *FileSystem) DispatchHandler() error {
SecretKey: currentPolicy.SecretKey,
},
}),
HTTPClient: request.HTTPClient{},
HTTPClient: request.NewClient(),
}
return nil
case "s3":

92
pkg/request/options.go Normal file
View file

@ -0,0 +1,92 @@
package request
import (
"context"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"net/http"
"time"
)
// Option 发送请求的额外设置
type Option interface {
apply(*options)
}
type options struct {
timeout time.Duration
header http.Header
sign auth.Auth
signTTL int64
ctx context.Context
contentLength int64
masterMeta bool
}
type optionFunc func(*options)
func (f optionFunc) apply(o *options) {
f(o)
}
func newDefaultOption() *options {
return &options{
header: http.Header{},
timeout: time.Duration(30) * time.Second,
contentLength: -1,
}
}
// WithTimeout 设置请求超时
func WithTimeout(t time.Duration) Option {
return optionFunc(func(o *options) {
o.timeout = t
})
}
// WithContext 设置请求上下文
func WithContext(c context.Context) Option {
return optionFunc(func(o *options) {
o.ctx = c
})
}
// WithCredential 对请求进行签名
func WithCredential(instance auth.Auth, ttl int64) Option {
return optionFunc(func(o *options) {
o.sign = instance
o.signTTL = ttl
})
}
// WithHeader 设置请求Header
func WithHeader(header http.Header) Option {
return optionFunc(func(o *options) {
for k, v := range header {
o.header[k] = v
}
})
}
// WithoutHeader 设置清除请求Header
func WithoutHeader(header []string) Option {
return optionFunc(func(o *options) {
for _, v := range header {
delete(o.header, v)
}
})
}
// WithContentLength 设置请求大小
func WithContentLength(s int64) Option {
return optionFunc(func(o *options) {
o.contentLength = s
})
}
// WithMasterMeta 请求时携带主机信息
func WithMasterMeta() Option {
return optionFunc(func(o *options) {
o.masterMeta = true
})
}

View file

@ -1,14 +1,12 @@
package request
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
@ -33,105 +31,33 @@ type Client interface {
// HTTPClient 实现 Client 接口
type HTTPClient struct {
options *options
}
// Option 发送请求的额外设置
type Option interface {
apply(*options)
}
type options struct {
timeout time.Duration
header http.Header
sign auth.Auth
signTTL int64
ctx context.Context
contentLength int64
masterMeta bool
}
type optionFunc func(*options)
func (f optionFunc) apply(o *options) {
f(o)
}
func newDefaultOption() *options {
return &options{
header: http.Header{},
timeout: time.Duration(30) * time.Second,
contentLength: -1,
func NewClient(opts ...Option) Client {
client := &HTTPClient{
options: newDefaultOption(),
}
}
// WithTimeout 设置请求超时
func WithTimeout(t time.Duration) Option {
return optionFunc(func(o *options) {
o.timeout = t
})
}
for _, o := range opts {
o.apply(client.options)
}
// WithContext 设置请求上下文
func WithContext(c context.Context) Option {
return optionFunc(func(o *options) {
o.ctx = c
})
}
// WithCredential 对请求进行签名
func WithCredential(instance auth.Auth, ttl int64) Option {
return optionFunc(func(o *options) {
o.sign = instance
o.signTTL = ttl
})
}
// WithHeader 设置请求Header
func WithHeader(header http.Header) Option {
return optionFunc(func(o *options) {
for k, v := range header {
o.header[k] = v
}
})
}
// WithoutHeader 设置清除请求Header
func WithoutHeader(header []string) Option {
return optionFunc(func(o *options) {
for _, v := range header {
delete(o.header, v)
}
})
}
// WithContentLength 设置请求大小
func WithContentLength(s int64) Option {
return optionFunc(func(o *options) {
o.contentLength = s
})
}
// WithMasterMeta 请求时携带主机信息
func WithMasterMeta() Option {
return optionFunc(func(o *options) {
o.masterMeta = true
})
return client
}
// Request 发送HTTP请求
func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
// 应用额外设置
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
o.apply(c.options)
}
// 创建请求客户端
client := &http.Client{Timeout: options.timeout}
client := &http.Client{Timeout: c.options.timeout}
// size为0时将body设为nil
if options.contentLength == 0 {
if c.options.contentLength == 0 {
body = nil
}
@ -140,8 +66,8 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
req *http.Request
err error
)
if options.ctx != nil {
req, err = http.NewRequestWithContext(options.ctx, method, target, body)
if c.options.ctx != nil {
req, err = http.NewRequestWithContext(c.options.ctx, method, target, body)
} else {
req, err = http.NewRequest(method, target, body)
}
@ -150,21 +76,21 @@ func (c HTTPClient) Request(method, target string, body io.Reader, opts ...Optio
}
// 添加请求相关设置
req.Header = options.header
req.Header = c.options.header
if options.masterMeta {
if c.options.masterMeta {
req.Header.Add("X-Site-Url", model.GetSiteURL().String())
req.Header.Add("X-Site-Id", model.GetSettingByName("siteID"))
req.Header.Add("X-Cloudreve-Version", conf.BackendVersion)
}
if options.contentLength != -1 {
req.ContentLength = options.contentLength
if c.options.contentLength != -1 {
req.ContentLength = c.options.contentLength
}
// 签名请求
if options.sign != nil {
auth.SignRequest(options.sign, req, options.signTTL)
if c.options.sign != nil {
auth.SignRequest(c.options.sign, req, c.options.signTTL)
}
// 发送请求

View file

@ -47,7 +47,7 @@ type masterInfo struct {
func Init() {
DefaultController = &slaveController{
masters: make(map[string]masterInfo),
client: request.HTTPClient{},
client: request.NewClient(),
}
gob.Register(rpc.StatusInfo{})
}

View file

@ -24,7 +24,7 @@ func AdminSummary(c *gin.Context) {
// AdminNews 获取社区新闻
func AdminNews(c *gin.Context) {
r := request.HTTPClient{}
r := request.NewClient()
res := r.Request("GET", "https://forum.cloudreve.org/api/discussions?include=startUser%2ClastUser%2CstartPost%2Ctags&filter%5Bq%5D=%20tag%3Anotice&sort=-startTime&page%5Blimit%5D=10", nil)
if res.Err == nil {
io.Copy(c.Writer, res.Response.Body)

View file

@ -151,7 +151,7 @@ func (service *PolicyService) AddCORS() serializer.Response {
case "oss":
handler := oss.Driver{
Policy: &policy,
HTTPClient: request.HTTPClient{},
HTTPClient: request.NewClient(),
}
if err := handler.CORS(); err != nil {
return serializer.Err(serializer.CodeInternalSetting, "跨域策略添加失败", err)
@ -161,7 +161,7 @@ func (service *PolicyService) AddCORS() serializer.Response {
b := &cossdk.BaseURL{BucketURL: u}
handler := cos.Driver{
Policy: &policy,
HTTPClient: request.HTTPClient{},
HTTPClient: request.NewClient(),
Client: cossdk.NewClient(b, &http.Client{
Transport: &cossdk.AuthorizationTransport{
SecretID: policy.AccessKey,
@ -195,7 +195,7 @@ func (service *SlavePingService) Test() serializer.Response {
controller, _ := url.Parse("/api/v3/site/ping")
r := request.HTTPClient{}
r := request.NewClient()
res, err := r.Request(
"GET",
master.ResolveReference(controller).String(),
@ -229,7 +229,7 @@ func (service *SlaveTestService) Test() serializer.Response {
}
bodyByte, _ := json.Marshal(body)
r := request.HTTPClient{}
r := request.NewClient()
res, err := r.Request(
"POST",
slave.ResolveReference(controller).String(),