Cloudreve/pkg/filesystem/driver/onedrive/handler_test.go

420 lines
11 KiB
Go

package onedrive
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/assert"
testMock "github.com/stretchr/testify/mock"
)
func TestDriver_Token(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
// 无法获取文件路径
{
ctx := context.WithValue(context.Background(), fsctx.FileSizeCtx, uint64(10))
res, err := handler.Token(ctx, 10, "key", nil)
asserts.Error(err)
asserts.Equal(serializer.UploadCredential{}, res)
}
// 无法获取文件大小
{
ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123")
res, err := handler.Token(ctx, 10, "key", nil)
asserts.Error(err)
asserts.Equal(serializer.UploadCredential{}, res)
}
// 小文件成功
{
ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123")
ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(10))
res, err := handler.Token(ctx, 10, "key", nil)
asserts.NoError(err)
asserts.Equal(serializer.UploadCredential{}, res)
}
// 分片上传 失败
{
cache.Set("setting_siteURL", "http://test.cloudreve.org", 0)
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 400,
Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
},
})
handler.Client.Request = clientMock
ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123")
ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(20*1024*1024))
res, err := handler.Token(ctx, 10, "key", nil)
asserts.Error(err)
asserts.Equal(serializer.UploadCredential{}, res)
}
// 分片上传 成功
{
cache.Set("setting_siteURL", "http://test.cloudreve.org", 0)
cache.Set("setting_onedrive_monitor_timeout", "600", 0)
cache.Set("setting_onedrive_callback_check", "20", 0)
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1"
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
},
})
handler.Client.Request = clientMock
ctx := context.WithValue(context.Background(), fsctx.SavePathCtx, "/123")
ctx = context.WithValue(ctx, fsctx.FileSizeCtx, uint64(20*1024*1024))
go func() {
time.Sleep(time.Duration(1) * time.Second)
FinishCallback("key")
}()
res, err := handler.Token(ctx, 10, "key", nil)
asserts.NoError(err)
asserts.Equal("123321", res.Policy)
}
}
func TestDriver_Source(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
cache.Set("setting_onedrive_source_timeout", "1800", 0)
// 失败
{
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0)
asserts.Error(err)
asserts.Empty(res)
}
// 命中缓存 成功
{
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1"
cache.Set("onedrive_source_0_123.jpg", "res", 1)
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 0, true, 0)
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
asserts.NoError(err)
asserts.Equal("res", res)
}
// 命中缓存 上下文存在文件 成功
{
file := model.File{}
file.ID = 1
file.UpdatedAt = time.Now()
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
handler.Client.Credential.AccessToken = "1"
cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0)
res, err := handler.Source(ctx, "123.jpg", url.URL{}, 1, true, 0)
cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
asserts.NoError(err)
asserts.Equal("res", res)
}
// 成功
{
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)),
},
})
handler.Client.Request = clientMock
handler.Client.Credential.AccessToken = "1"
res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0)
asserts.NoError(err)
asserts.Equal("123321", res)
}
// 成功 永久直链
{
file := model.File{}
file.ID = 1
file.Name = "123.jpg"
file.UpdatedAt = time.Now()
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
auth.General = auth.HMACAuth{}
handler.Client.Credential.AccessToken = "1"
res, err := handler.Source(ctx, "123.jpg", url.URL{}, 0, true, 0)
asserts.NoError(err)
asserts.Contains(res, "/api/v3/file/source/1/123.jpg?sign")
}
}
func TestDriver_List(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.AccessToken = "AccessToken"
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 非递归
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{}]}`)),
},
})
handler.Client.Request = clientMock
res, err := handler.List(context.Background(), "/", false)
asserts.NoError(err)
asserts.Len(res, 1)
}
// 递归一次
{
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
"me/drive/root/children?$top=999999999",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"1","folder":{}}]}`)),
},
})
clientMock.On(
"Request",
"GET",
"me/drive/root:/1:/children?$top=999999999",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"2"}]}`)),
},
})
handler.Client.Request = clientMock
res, err := handler.List(context.Background(), "/", true)
asserts.NoError(err)
asserts.Len(res, 2)
}
}
func TestDriver_Thumb(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 失败
{
ctx := context.WithValue(context.Background(), fsctx.ThumbSizeCtx, [2]uint{10, 20})
ctx = context.WithValue(ctx, fsctx.FileModelCtx, model.File{})
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
res, err := handler.Thumb(ctx, "123.jpg")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Empty(res.URL)
}
// 上下文错误
{
_, err := handler.Thumb(context.Background(), "123.jpg")
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
}
func TestDriver_Delete(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 失败
{
_, err := handler.Delete(context.Background(), []string{"1"})
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
}
}
func TestDriver_Put(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 失败
{
err := handler.Put(context.Background(), ioutil.NopCloser(strings.NewReader("")), "dst", 0)
asserts.Error(err)
}
}
func TestDriver_Get(t *testing.T) {
asserts := assert.New(t)
handler := Driver{
Policy: &model.Policy{
AccessKey: "ak",
SecretKey: "sk",
BucketName: "test",
Server: "test.com",
},
}
handler.Client, _ = NewClient(&model.Policy{})
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
// 无法获取source
{
res, err := handler.Get(context.Background(), "123.txt")
asserts.Error(err)
asserts.Nil(res)
}
// 成功
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
clientMock := ClientMock{}
clientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)),
},
})
handler.Client.Request = clientMock
handler.Client.Credential.AccessToken = "1"
driverClientMock := ClientMock{}
driverClientMock.On(
"Request",
"GET",
testMock.Anything,
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`123`)),
},
})
handler.HTTPClient = driverClientMock
res, err := handler.Get(context.Background(), "123.txt")
clientMock.AssertExpectations(t)
asserts.NoError(err)
_, err = res.Seek(0, io.SeekEnd)
asserts.NoError(err)
content, err := ioutil.ReadAll(res)
asserts.NoError(err)
asserts.Equal("123", string(content))
}