feat: tps limit for OneDrive policy

This commit is contained in:
HFO4 2022-06-09 16:11:36 +08:00
parent 4859ea6ee5
commit f083d52e17
9 changed files with 170 additions and 13 deletions

2
assets

@ -1 +1 @@
Subproject commit c0f8a7ef6ddd335b697347dce56271c3d3d8c215 Subproject commit 41f585a6f8c8f99ed4b2e279555d6b4dcdf957bc

View file

@ -61,6 +61,10 @@ type PolicyOption struct {
ChunkSize uint64 `json:"chunk_size,omitempty"` ChunkSize uint64 `json:"chunk_size,omitempty"`
// 分片上传时是否需要预留空间 // 分片上传时是否需要预留空间
PlaceholderWithSize bool `json:"placeholder_with_size,omitempty"` PlaceholderWithSize bool `json:"placeholder_with_size,omitempty"`
// 每秒对存储端的 API 请求上限
TPSLimit float64 `json:"tps_limit,omitempty"`
// 每秒 API 请求爆发上限
TPSLimitBurst int `json:"tps_limit_burst,omitempty"`
} }
// thumbSuffix 支持缩略图处理的文件扩展名 // thumbSuffix 支持缩略图处理的文件扩展名

View file

@ -52,6 +52,9 @@ func TestUserStorageCalibration_Run(t *testing.T) {
mock.ExpectQuery("SELECT(.+)files(.+)"). mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs(1). WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10)) WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
script.Run(context.Background()) script.Run(context.Background())
asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(mock.ExpectationsWereMet())
} }

View file

@ -544,6 +544,11 @@ func (client *Client) request(ctx context.Context, method string, url string, bo
"Content-Type": {"application/json"}, "Content-Type": {"application/json"},
}), }),
request.WithContext(ctx), request.WithContext(ctx),
request.WithTPSLimit(
fmt.Sprintf("policy_%d", client.Policy.ID),
client.Policy.OptionsSerialized.TPSLimit,
client.Policy.OptionsSerialized.TPSLimitBurst,
),
) )
// 发送请求 // 发送请求

View file

@ -15,15 +15,18 @@ type Option interface {
} }
type options struct { type options struct {
timeout time.Duration timeout time.Duration
header http.Header header http.Header
sign auth.Auth sign auth.Auth
signTTL int64 signTTL int64
ctx context.Context ctx context.Context
contentLength int64 contentLength int64
masterMeta bool masterMeta bool
endpoint *url.URL endpoint *url.URL
slaveNodeID string slaveNodeID string
tpsLimiterToken string
tps float64
tpsBurst int
} }
type optionFunc func(*options) type optionFunc func(*options)
@ -37,6 +40,7 @@ func newDefaultOption() *options {
header: http.Header{}, header: http.Header{},
timeout: time.Duration(30) * time.Second, timeout: time.Duration(30) * time.Second,
contentLength: -1, contentLength: -1,
ctx: context.Background(),
} }
} }
@ -113,3 +117,15 @@ func WithEndpoint(endpoint string) Option {
o.endpoint = endpointURL o.endpoint = endpointURL
}) })
} }
// WithTPSLimit 请求时使用全局流量限制
func WithTPSLimit(token string, tps float64, burst int) Option {
return optionFunc(func(o *options) {
o.tpsLimiterToken = token
o.tps = tps
if burst < 1 {
burst = 1
}
o.tpsBurst = burst
})
}

View file

@ -34,13 +34,15 @@ type Client interface {
// HTTPClient 实现 Client 接口 // HTTPClient 实现 Client 接口
type HTTPClient struct { type HTTPClient struct {
mu sync.Mutex mu sync.Mutex
options *options options *options
tpsLimiter TPSLimiter
} }
func NewClient(opts ...Option) Client { func NewClient(opts ...Option) Client {
client := &HTTPClient{ client := &HTTPClient{
options: newDefaultOption(), options: newDefaultOption(),
tpsLimiter: globalTPSLimiter,
} }
for _, o := range opts { for _, o := range opts {
@ -126,6 +128,10 @@ func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Opti
} }
} }
if options.tps > 0 {
c.tpsLimiter.Limit(options.ctx, options.tpsLimiterToken, options.tps, options.tpsBurst)
}
// 发送请求 // 发送请求
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {

View file

@ -238,3 +238,41 @@ func TestBlackHole(t *testing.T) {
BlackHole(strings.NewReader("TestBlackHole")) BlackHole(strings.NewReader("TestBlackHole"))
}) })
} }
func TestHTTPClient_TPSLimit(t *testing.T) {
a := assert.New(t)
client := NewClient()
finished := make(chan struct{})
go func() {
client.Request(
"POST",
"/test",
strings.NewReader(""),
WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1),
)
close(finished)
}()
select {
case <-finished:
case <-time.After(10 * time.Second):
a.Fail("Request should be finished instantly.")
}
finished = make(chan struct{})
go func() {
client.Request(
"POST",
"/test",
strings.NewReader(""),
WithTPSLimit("TestHTTPClient_TPSLimit", 1, 1),
)
close(finished)
}()
select {
case <-finished:
case <-time.After(2 * time.Second):
a.Fail("Request should be finished in 1 second.")
}
}

39
pkg/request/tpslimiter.go Normal file
View file

@ -0,0 +1,39 @@
package request
import (
"context"
"golang.org/x/time/rate"
"sync"
)
var globalTPSLimiter = NewTPSLimiter()
type TPSLimiter interface {
Limit(ctx context.Context, token string, tps float64, burst int)
}
func NewTPSLimiter() TPSLimiter {
return &multipleBucketLimiter{
buckets: make(map[string]*rate.Limiter),
}
}
// multipleBucketLimiter implements TPSLimiter with multiple bucket support.
type multipleBucketLimiter struct {
mu sync.Mutex
buckets map[string]*rate.Limiter
}
// Limit finds the given bucket, if bucket not exist or limit is changed,
// a new bucket will be generated.
func (m *multipleBucketLimiter) Limit(ctx context.Context, token string, tps float64, burst int) {
m.mu.Lock()
bucket, ok := m.buckets[token]
if !ok || float64(bucket.Limit()) != tps || bucket.Burst() != burst {
bucket = rate.NewLimiter(rate.Limit(tps), burst)
m.buckets[token] = bucket
}
m.mu.Unlock()
bucket.Wait(ctx)
}

View file

@ -0,0 +1,46 @@
package request
import (
"context"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestLimit(t *testing.T) {
a := assert.New(t)
l := NewTPSLimiter()
finished := make(chan struct{})
go func() {
l.Limit(context.Background(), "token", 1, 1)
close(finished)
}()
select {
case <-finished:
case <-time.After(10 * time.Second):
a.Fail("Limit should be finished instantly.")
}
finished = make(chan struct{})
go func() {
l.Limit(context.Background(), "token", 1, 1)
close(finished)
}()
select {
case <-finished:
case <-time.After(2 * time.Second):
a.Fail("Limit should be finished in 1 second.")
}
finished = make(chan struct{})
go func() {
l.Limit(context.Background(), "token", 10, 1)
close(finished)
}()
select {
case <-finished:
case <-time.After(1 * time.Second):
a.Fail("Limit should be finished instantly.")
}
}