Merge branch 'master' into patch-samesite

This commit is contained in:
AaronLiu 2022-12-16 13:58:54 +08:00 committed by GitHub
commit 4ad9649300
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
74 changed files with 1320 additions and 291 deletions

View file

@ -7,10 +7,10 @@ jobs:
name: Build
runs-on: ubuntu-18.04
steps:
- name: Set up Go 1.17
- name: Set up Go 1.18
uses: actions/setup-go@v2
with:
go-version: "1.17"
go-version: "1.18"
id: go
- name: Check out code into the Go module directory

View file

@ -12,10 +12,10 @@ jobs:
name: Test
runs-on: ubuntu-18.04
steps:
- name: Set up Go 1.17
- name: Set up Go 1.18
uses: actions/setup-go@v2
with:
go-version: "1.17"
go-version: "1.18"
id: go
- name: Check out code into the Go module directory

View file

@ -1,6 +1,6 @@
language: go
go:
- 1.17.x
- 1.18.x
node_js: "12.16.3"
git:
depth: 1

View file

@ -1,42 +1,47 @@
FROM golang:1.17-alpine as cloudreve_builder
# the frontend builder
# cloudreve need node.js 16* to build frontend,
# separate build step and custom image tag will resolve this
FROM node:16-alpine as cloudreve_frontend_builder
RUN apk update \
&& apk add --no-cache wget curl git yarn zip bash \
&& git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git /cloudreve_frontend
# build frontend assets using build script, make sure all the steps just follow the regular release
WORKDIR /cloudreve_frontend
ENV GENERATE_SOURCEMAP false
RUN chmod +x ./build.sh && ./build.sh -a
# the backend builder
# cloudreve backend needs golang 1.18* to build
FROM golang:1.18-alpine as cloudreve_backend_builder
# install dependencies and build tools
RUN apk update && apk add --no-cache wget curl git yarn build-base gcc abuild binutils binutils-doc gcc-doc zip
RUN apk update \
# install dependencies and build tools
&& apk add --no-cache wget curl git build-base gcc abuild binutils binutils-doc gcc-doc zip bash \
&& git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git /cloudreve_backend
WORKDIR /cloudreve_builder
RUN git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git
# build frontend
WORKDIR /cloudreve_builder/Cloudreve/assets
ENV GENERATE_SOURCEMAP false
RUN yarn install --network-timeout 1000000
RUN yarn run build
# build backend
WORKDIR /cloudreve_builder/Cloudreve
RUN zip -r - assets/build >assets.zip
RUN tag_name=$(git describe --tags) \
&& export COMMIT_SHA=$(git rev-parse --short HEAD) \
&& go build -a -o cloudreve -ldflags " -X 'github.com/HFO4/cloudreve/pkg/conf.BackendVersion=$tag_name' -X 'github.com/HFO4/cloudreve/pkg/conf.LastCommit=$COMMIT_SHA'"
WORKDIR /cloudreve_backend
COPY --from=cloudreve_frontend_builder /cloudreve_frontend/assets.zip ./
RUN chmod +x ./build.sh && ./build.sh -c
# build final image
# TODO: merge the frontend build and backend build into a single one image
# the final published image
FROM alpine:latest
WORKDIR /cloudreve
COPY --from=cloudreve_backend_builder /cloudreve_backend/cloudreve ./cloudreve
RUN apk update && apk add --no-cache tzdata
# we using the `Asia/Shanghai` timezone by default, you can do modification at your will
RUN cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& echo "Asia/Shanghai" > /etc/timezone
COPY --from=cloudreve_builder /cloudreve_builder/Cloudreve/cloudreve ./
# prepare permissions and aria2 dir
RUN chmod +x ./cloudreve && mkdir -p /data/aria2 && chmod -R 766 /data/aria2
RUN apk update \
&& apk add --no-cache tzdata \
&& cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& echo "Asia/Shanghai" > /etc/timezone \
&& chmod +x ./cloudreve \
&& mkdir -p /data/aria2 \
&& chmod -R 766 /data/aria2
EXPOSE 5212
VOLUME ["/cloudreve/uploads", "/cloudreve/avatar", "/data"]

View file

@ -71,7 +71,7 @@ chmod +x ./cloudreve
## :gear: 构建
自行构建前需要拥有 `Go >= 1.17`、`node.js``yarn``zip` 等必要依赖。
自行构建前需要拥有 `Go >= 1.18`、`node.js``yarn``zip` 等必要依赖。
#### 克隆代码

2
assets

@ -1 +1 @@
Subproject commit 02d93206cc5b943c34b5f5ac86c23dd96f5ef603
Subproject commit 2bf915a33d58fc78c9c13ffc64685219c28a4732

Binary file not shown.

432
bootstrap/embed.go Normal file
View file

@ -0,0 +1,432 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package embed provides access to files embedded in the running Go program.
//
// Go source files that import "embed" can use the //go:embed directive
// to initialize a variable of type string, []byte, or FS with the contents of
// files read from the package directory or subdirectories at compile time.
//
// For example, here are three ways to embed a file named hello.txt
// and then print its contents at run time.
//
// Embedding one file into a string:
//
// import _ "embed"
//
// //go:embed hello.txt
// var s string
// print(s)
//
// Embedding one file into a slice of bytes:
//
// import _ "embed"
//
// //go:embed hello.txt
// var b []byte
// print(string(b))
//
// Embedded one or more files into a file system:
//
// import "embed"
//
// //go:embed hello.txt
// var f embed.FS
// data, _ := f.ReadFile("hello.txt")
// print(string(data))
//
// # Directives
//
// A //go:embed directive above a variable declaration specifies which files to embed,
// using one or more path.Match patterns.
//
// The directive must immediately precede a line containing the declaration of a single variable.
// Only blank lines and // line comments are permitted between the directive and the declaration.
//
// The type of the variable must be a string type, or a slice of a byte type,
// or FS (or an alias of FS).
//
// For example:
//
// package server
//
// import "embed"
//
// // content holds our static web server content.
// //go:embed image/* template/*
// //go:embed html/index.html
// var content embed.FS
//
// The Go build system will recognize the directives and arrange for the declared variable
// (in the example above, content) to be populated with the matching files from the file system.
//
// The //go:embed directive accepts multiple space-separated patterns for
// brevity, but it can also be repeated, to avoid very long lines when there are
// many patterns. The patterns are interpreted relative to the package directory
// containing the source file. The path separator is a forward slash, even on
// Windows systems. Patterns may not contain . or .. or empty path elements,
// nor may they begin or end with a slash. To match everything in the current
// directory, use * instead of .. To allow for naming files with spaces in
// their names, patterns can be written as Go double-quoted or back-quoted
// string literals.
//
// If a pattern names a directory, all files in the subtree rooted at that directory are
// embedded (recursively), except that files with names beginning with . or _
// are excluded. So the variable in the above example is almost equivalent to:
//
// // content is our static web server content.
// //go:embed image template html/index.html
// var content embed.FS
//
// The difference is that image/* embeds image/.tempfile while image does not.
// Neither embeds image/dir/.tempfile.
//
// If a pattern begins with the prefix all:, then the rule for walking directories is changed
// to include those files beginning with . or _. For example, all:image embeds
// both image/.tempfile and image/dir/.tempfile.
//
// The //go:embed directive can be used with both exported and unexported variables,
// depending on whether the package wants to make the data available to other packages.
// It can only be used with variables at package scope, not with local variables.
//
// Patterns must not match files outside the package's module, such as .git/* or symbolic links.
// Patterns must not match files whose names include the special punctuation characters " * < > ? ` ' | / \ and :.
// Matches for empty directories are ignored. After that, each pattern in a //go:embed line
// must match at least one file or non-empty directory.
//
// If any patterns are invalid or have invalid matches, the build will fail.
//
// # Strings and Bytes
//
// The //go:embed line for a variable of type string or []byte can have only a single pattern,
// and that pattern can match only a single file. The string or []byte is initialized with
// the contents of that file.
//
// The //go:embed directive requires importing "embed", even when using a string or []byte.
// In source files that don't refer to embed.FS, use a blank import (import _ "embed").
//
// # File Systems
//
// For embedding a single file, a variable of type string or []byte is often best.
// The FS type enables embedding a tree of files, such as a directory of static
// web server content, as in the example above.
//
// FS implements the io/fs package's FS interface, so it can be used with any package that
// understands file systems, including net/http, text/template, and html/template.
//
// For example, given the content variable in the example above, we can write:
//
// http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(content))))
//
// template.ParseFS(content, "*.tmpl")
//
// # Tools
//
// To support tools that analyze Go packages, the patterns found in //go:embed lines
// are available in “go list” output. See the EmbedPatterns, TestEmbedPatterns,
// and XTestEmbedPatterns fields in the “go help list” output.
package bootstrap
import (
"errors"
"io"
"io/fs"
"time"
)
// An FS is a read-only collection of files, usually initialized with a //go:embed directive.
// When declared without a //go:embed directive, an FS is an empty file system.
//
// An FS is a read-only value, so it is safe to use from multiple goroutines
// simultaneously and also safe to assign values of type FS to each other.
//
// FS implements fs.FS, so it can be used with any package that understands
// file system interfaces, including net/http, text/template, and html/template.
//
// See the package documentation for more details about initializing an FS.
type FS struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
//
// The files list is sorted by name but not by simple string comparison.
// Instead, each file's name takes the form "dir/elem" or "dir/elem/".
// The optional trailing slash indicates that the file is itself a directory.
// The files list is sorted first by dir (if dir is missing, it is taken to be ".")
// and then by base, so this list of files:
//
// p
// q/
// q/r
// q/s/
// q/s/t
// q/s/u
// q/v
// w
//
// is actually sorted as:
//
// p # dir=. elem=p
// q/ # dir=. elem=q
// w/ # dir=. elem=w
// q/r # dir=q elem=r
// q/s/ # dir=q elem=s
// q/v # dir=q elem=v
// q/s/t # dir=q/s elem=t
// q/s/u # dir=q/s elem=u
//
// This order brings directory contents together in contiguous sections
// of the list, allowing a directory read to use binary search to find
// the relevant sequence of entries.
files *[]file
}
// split splits the name into dir and elem as described in the
// comment in the FS struct above. isDir reports whether the
// final trailing slash was present, indicating that name is a directory.
func split(name string) (dir, elem string, isDir bool) {
if name[len(name)-1] == '/' {
isDir = true
name = name[:len(name)-1]
}
i := len(name) - 1
for i >= 0 && name[i] != '/' {
i--
}
if i < 0 {
return ".", name, isDir
}
return name[:i], name[i+1:], isDir
}
// trimSlash trims a trailing slash from name, if present,
// returning the possibly shortened name.
func trimSlash(name string) string {
if len(name) > 0 && name[len(name)-1] == '/' {
return name[:len(name)-1]
}
return name
}
var (
_ fs.ReadDirFS = FS{}
_ fs.ReadFileFS = FS{}
)
// A file is a single file in the FS.
// It implements fs.FileInfo and fs.DirEntry.
type file struct {
// The compiler knows the layout of this struct.
// See cmd/compile/internal/staticdata's WriteEmbed.
name string
data string
hash [16]byte // truncated SHA256 hash
}
var (
_ fs.FileInfo = (*file)(nil)
_ fs.DirEntry = (*file)(nil)
)
func (f *file) Name() string { _, elem, _ := split(f.name); return elem }
func (f *file) Size() int64 { return int64(len(f.data)) }
func (f *file) ModTime() time.Time { return time.Time{} }
func (f *file) IsDir() bool { _, _, isDir := split(f.name); return isDir }
func (f *file) Sys() any { return nil }
func (f *file) Type() fs.FileMode { return f.Mode().Type() }
func (f *file) Info() (fs.FileInfo, error) { return f, nil }
func (f *file) Mode() fs.FileMode {
if f.IsDir() {
return fs.ModeDir | 0555
}
return 0444
}
// dotFile is a file for the root directory,
// which is omitted from the files list in a FS.
var dotFile = &file{name: "./"}
// lookup returns the named file, or nil if it is not present.
func (f FS) lookup(name string) *file {
if !fs.ValidPath(name) {
// The compiler should never emit a file with an invalid name,
// so this check is not strictly necessary (if name is invalid,
// we shouldn't find a match below), but it's a good backstop anyway.
return nil
}
if name == "." {
return dotFile
}
if f.files == nil {
return nil
}
// Binary search to find where name would be in the list,
// and then check if name is at that position.
dir, elem, _ := split(name)
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, ielem, _ := split(files[i].name)
return idir > dir || idir == dir && ielem >= elem
})
if i < len(files) && trimSlash(files[i].name) == name {
return &files[i]
}
return nil
}
// readDir returns the list of files corresponding to the directory dir.
func (f FS) readDir(dir string) []file {
if f.files == nil {
return nil
}
// Binary search to find where dir starts and ends in the list
// and then return that slice of the list.
files := *f.files
i := sortSearch(len(files), func(i int) bool {
idir, _, _ := split(files[i].name)
return idir >= dir
})
j := sortSearch(len(files), func(j int) bool {
jdir, _, _ := split(files[j].name)
return jdir > dir
})
return files[i:j]
}
// Open opens the named file for reading and returns it as an fs.File.
//
// The returned file implements io.Seeker when the file is not a directory.
func (f FS) Open(name string) (fs.File, error) {
file := f.lookup(name)
if file == nil {
return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist}
}
if file.IsDir() {
return &openDir{file, f.readDir(name), 0}, nil
}
return &openFile{file, 0}, nil
}
// ReadDir reads and returns the entire named directory.
func (f FS) ReadDir(name string) ([]fs.DirEntry, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
dir, ok := file.(*openDir)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("not a directory")}
}
list := make([]fs.DirEntry, len(dir.files))
for i := range list {
list[i] = &dir.files[i]
}
return list, nil
}
// ReadFile reads and returns the content of the named file.
func (f FS) ReadFile(name string) ([]byte, error) {
file, err := f.Open(name)
if err != nil {
return nil, err
}
ofile, ok := file.(*openFile)
if !ok {
return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("is a directory")}
}
return []byte(ofile.f.data), nil
}
// An openFile is a regular file open for reading.
type openFile struct {
f *file // the file itself
offset int64 // current read offset
}
var (
_ io.Seeker = (*openFile)(nil)
)
func (f *openFile) Close() error { return nil }
func (f *openFile) Stat() (fs.FileInfo, error) { return f.f, nil }
func (f *openFile) Read(b []byte) (int, error) {
if f.offset >= int64(len(f.f.data)) {
return 0, io.EOF
}
if f.offset < 0 {
return 0, &fs.PathError{Op: "read", Path: f.f.name, Err: fs.ErrInvalid}
}
n := copy(b, f.f.data[f.offset:])
f.offset += int64(n)
return n, nil
}
func (f *openFile) Seek(offset int64, whence int) (int64, error) {
switch whence {
case 0:
// offset += 0
case 1:
offset += f.offset
case 2:
offset += int64(len(f.f.data))
}
if offset < 0 || offset > int64(len(f.f.data)) {
return 0, &fs.PathError{Op: "seek", Path: f.f.name, Err: fs.ErrInvalid}
}
f.offset = offset
return offset, nil
}
// An openDir is a directory open for reading.
type openDir struct {
f *file // the directory file itself
files []file // the directory contents
offset int // the read offset, an index into the files slice
}
func (d *openDir) Close() error { return nil }
func (d *openDir) Stat() (fs.FileInfo, error) { return d.f, nil }
func (d *openDir) Read([]byte) (int, error) {
return 0, &fs.PathError{Op: "read", Path: d.f.name, Err: errors.New("is a directory")}
}
func (d *openDir) ReadDir(count int) ([]fs.DirEntry, error) {
n := len(d.files) - d.offset
if n == 0 {
if count <= 0 {
return nil, nil
}
return nil, io.EOF
}
if count > 0 && n > count {
n = count
}
list := make([]fs.DirEntry, n)
for i := range list {
list[i] = &d.files[d.offset+i]
}
d.offset += n
return list, nil
}
// sortSearch is like sort.Search, avoiding an import.
func sortSearch(n int, f func(int) bool) int {
// Define f(-1) == false and f(n) == true.
// Invariant: f(i-1) == false, f(j) == true.
i, j := 0, n
for i < j {
h := int(uint(i+j) >> 1) // avoid overflow when computing h
// i ≤ h < j
if !f(h) {
i = h + 1 // preserves f(i-1) == false
} else {
j = h // preserves f(j) == true
}
}
// i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i.
return i
}

75
bootstrap/fs.go Normal file
View file

@ -0,0 +1,75 @@
package bootstrap
import (
"archive/zip"
"crypto/sha256"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/pkg/errors"
"io"
"io/fs"
"sort"
"strings"
)
func NewFS(zipContent string) fs.FS {
zipReader, err := zip.NewReader(strings.NewReader(zipContent), int64(len(zipContent)))
if err != nil {
util.Log().Panic("Static resource is not a valid zip file: %s", err)
}
var files []file
err = fs.WalkDir(zipReader, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return errors.Errorf("无法获取[%s]的信息, %s, 跳过...", path, err)
}
if path == "." {
return nil
}
var f file
if d.IsDir() {
f.name = path + "/"
} else {
f.name = path
rc, err := zipReader.Open(path)
if err != nil {
return errors.Errorf("无法打开文件[%s], %s, 跳过...", path, err)
}
defer rc.Close()
data, err := io.ReadAll(rc)
if err != nil {
return errors.Errorf("无法读取文件[%s], %s, 跳过...", path, err)
}
f.data = string(data)
hash := sha256.Sum256(data)
for i := range f.hash {
f.hash[i] = ^hash[i]
}
}
files = append(files, f)
return nil
})
if err != nil {
util.Log().Panic("初始化静态资源失败: %s", err)
}
sort.Slice(files, func(i, j int) bool {
fi, fj := files[i], files[j]
di, ei, _ := split(fi.name)
dj, ej, _ := split(fj.name)
if di != dj {
return di < dj
}
return ei < ej
})
var embedFS FS
embedFS.files = &files
return embedFS
}

View file

@ -32,11 +32,15 @@ buildAssets() {
yarn run build
cd build
cd $REPO
# please keep in mind that if this final output binary `assets.zip` name changed, please go and update the `Dockerfile` as well
zip -r - assets/build >assets.zip
}
buildBinary() {
cd $REPO
# same as assets, if this final output binary `cloudreve` name changed, please go and update the `Dockerfile`
go build -a -o cloudreve -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'"
}

3
go.mod
View file

@ -1,6 +1,6 @@
module github.com/cloudreve/Cloudreve/v3
go 1.17
go 1.18
require (
github.com/DATA-DOG/go-sqlmock v1.3.3
@ -100,6 +100,7 @@ require (
github.com/mattn/go-colorable v0.1.4 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mattn/go-runewidth v0.0.12 // indirect
github.com/mattn/go-sqlite3 v1.14.7 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect

46
main.go
View file

@ -4,13 +4,11 @@ import (
"context"
_ "embed"
"flag"
"io"
"io/fs"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
@ -19,8 +17,6 @@ import (
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/cloudreve/Cloudreve/v3/routers"
"github.com/mholt/archiver/v4"
)
var (
@ -35,15 +31,12 @@ var staticZip string
var staticFS fs.FS
func init() {
flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "配置文件路径")
flag.BoolVar(&isEject, "eject", false, "导出内置静态资源")
flag.StringVar(&scriptName, "database-script", "", "运行内置数据库助手脚本")
flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "Path to the config file.")
flag.BoolVar(&isEject, "eject", false, "Eject all embedded static files.")
flag.StringVar(&scriptName, "database-script", "", "Name of database util script.")
flag.Parse()
staticFS = archiver.ArchiveFS{
Stream: io.NewSectionReader(strings.NewReader(staticZip), 0, int64(len(staticZip))),
Format: archiver.Zip{},
}
staticFS = bootstrap.NewFS(staticZip)
bootstrap.Init(confPath, staticFS)
}
@ -71,7 +64,7 @@ func main() {
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
go func() {
sig := <-sigChan
util.Log().Info("收到信号 %s开始关闭 server", sig)
util.Log().Info("Signal %s received, shutting down server...", sig)
ctx := context.Background()
if conf.SystemConfig.GracePeriod != 0 {
var cancel context.CancelFunc
@ -81,16 +74,16 @@ func main() {
err := server.Shutdown(ctx)
if err != nil {
util.Log().Error("关闭 server 错误, %s", err)
util.Log().Error("Failed to shutdown server: %s", err)
}
}()
// 如果启用了SSL
if conf.SSLConfig.CertPath != "" {
util.Log().Info("开始监听 %s", conf.SSLConfig.Listen)
util.Log().Info("Listening to %q", conf.SSLConfig.Listen)
server.Addr = conf.SSLConfig.Listen
if err := server.ListenAndServeTLS(conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil {
util.Log().Error("无法监听[%s]%s", conf.SSLConfig.Listen, err)
util.Log().Error("Failed to listen to %q: %s", conf.SSLConfig.Listen, err)
return
}
}
@ -100,23 +93,23 @@ func main() {
// delete socket file before listening
if _, err := os.Stat(conf.UnixConfig.Listen); err == nil {
if err = os.Remove(conf.UnixConfig.Listen); err != nil {
util.Log().Error("删除 socket 文件错误, %s", err)
util.Log().Error("Failed to delete socket file: %s", err)
return
}
}
api.TrustedPlatform = conf.UnixConfig.ProxyHeader
util.Log().Info("开始监听 %s", conf.UnixConfig.Listen)
util.Log().Info("Listening to %q", conf.UnixConfig.Listen)
if err := RunUnix(server); err != nil {
util.Log().Error("无法监听[%s]%s", conf.UnixConfig.Listen, err)
util.Log().Error("Failed to listen to %q: %s", conf.UnixConfig.Listen, err)
}
return
}
util.Log().Info("开始监听 %s", conf.SystemConfig.Listen)
util.Log().Info("Listening to %q", conf.SystemConfig.Listen)
server.Addr = conf.SystemConfig.Listen
if err := server.ListenAndServe(); err != nil {
util.Log().Error("无法监听[%s]%s", conf.SystemConfig.Listen, err)
util.Log().Error("Failed to listen to %q: %s", conf.SystemConfig.Listen, err)
}
}
@ -125,8 +118,21 @@ func RunUnix(server *http.Server) error {
if err != nil {
return err
}
defer listener.Close()
defer os.Remove(conf.UnixConfig.Listen)
if conf.UnixConfig.Perm > 0 {
err = os.Chmod(conf.UnixConfig.Listen, os.FileMode(conf.UnixConfig.Perm))
if err != nil {
util.Log().Warning(
"Failed to set permission to %q for socket file %q: %s",
conf.UnixConfig.Perm,
conf.UnixConfig.Listen,
err,
)
}
}
return server.Serve(listener)
}

View file

@ -1,6 +1,7 @@
package middleware
import (
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
@ -45,3 +46,17 @@ func CacheControl() gin.HandlerFunc {
c.Header("Cache-Control", "private, no-cache")
}
}
func Sandbox() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Content-Security-Policy", "sandbox")
}
}
// StaticResourceCache 使用静态资源缓存策略
func StaticResourceCache() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", model.GetIntSetting("public_resource_maxage", 86400)))
}
}

View file

@ -85,3 +85,21 @@ func TestCacheControl(t *testing.T) {
TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache")
}
func TestSandbox(t *testing.T) {
a := assert.New(t)
TestFunc := Sandbox()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox")
}
func TestStaticResourceCache(t *testing.T) {
a := assert.New(t)
TestFunc := StaticResourceCache()
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
TestFunc(c)
a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age")
}

30
middleware/file.go Normal file
View file

@ -0,0 +1,30 @@
package middleware
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
)
// ValidateSourceLink validates if the perm source link is a valid redirect link
func ValidateSourceLink() gin.HandlerFunc {
return func(c *gin.Context) {
linkID, ok := c.Get("object_id")
if !ok {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
c.Abort()
return
}
sourceLink, err := model.GetSourceLinkByID(linkID)
if err != nil || sourceLink.File.ID == 0 || sourceLink.File.Name != c.Param("name") {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
c.Abort()
return
}
sourceLink.Downloaded()
c.Set("source_link", sourceLink)
c.Next()
}
}

57
middleware/file_test.go Normal file
View file

@ -0,0 +1,57 @@
package middleware
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"testing"
)
func TestValidateSourceLink(t *testing.T) {
a := assert.New(t)
rec := httptest.NewRecorder()
testFunc := ValidateSourceLink()
// ID 不存在
{
c, _ := gin.CreateTestContext(rec)
testFunc(c)
a.True(c.IsAborted())
}
// SourceLink 不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 原文件不存在
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"}))
testFunc(c)
a.True(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
c, _ := gin.CreateTestContext(rec)
c.Set("object_id", 1)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1))
testFunc(c)
a.False(c.IsAborted())
a.NoError(mock.ExpectationsWereMet())
}
}

View file

@ -39,7 +39,11 @@ func FrontendFileHandler() gin.HandlerFunc {
path := c.Request.URL.Path
// API 跳过
if strings.HasPrefix(path, "/api") || strings.HasPrefix(path, "/custom") || strings.HasPrefix(path, "/dav") || path == "/manifest.json" {
if strings.HasPrefix(path, "/api") ||
strings.HasPrefix(path, "/custom") ||
strings.HasPrefix(path, "/dav") ||
strings.HasPrefix(path, "/f") ||
path == "/manifest.json" {
c.Next()
return
}

View file

@ -46,7 +46,7 @@ func Session(secret string) gin.HandlerFunc {
// Also set Secure: true if using SSL, you should though
Store.Options(sessions.Options{
HttpOnly: true,
MaxAge: 7 * 86400,
MaxAge: 60 * 86400,
Path: "/",
SameSite: sameSiteMode,
Secure: conf.CORSConfig.Secure,

View file

@ -113,4 +113,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
{Name: "pwa_theme_color", Value: "#000000", Type: "pwa"},
{Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"},
{Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"},
{Name: "show_app_promotion", Value: "1", Type: "mobile"},
{Name: "public_resource_maxage", Value: "86400", Type: "timeout"},
}

View file

@ -32,6 +32,7 @@ type Download struct {
// 数据库忽略字段
StatusInfo rpc.StatusInfo `gorm:"-"`
Task *Task `gorm:"-"`
NodeName string `gorm:"-"`
}
// AfterFind 找到下载任务后的钩子处理Status结构

View file

@ -4,6 +4,7 @@ import (
"encoding/gob"
"encoding/json"
"errors"
"fmt"
"path"
"time"
@ -191,14 +192,15 @@ func RemoveFilesWithSoftLinks(files []File) ([]File, error) {
}
// 查询软链接的文件
var filesWithSoftLinks []File
tx := DB
for _, value := range files {
tx = tx.Or("source_name = ? and policy_id = ? and id != ?", value.SourceName, value.PolicyID, value.ID)
}
result := tx.Find(&filesWithSoftLinks)
if result.Error != nil {
return nil, result.Error
filesWithSoftLinks := make([]File, 0)
for _, file := range files {
var softLinkFile File
res := DB.
Where("source_name = ? and policy_id = ? and id != ?", file.SourceName, file.PolicyID, file.ID).
First(&softLinkFile)
if res.Error == nil {
filesWithSoftLinks = append(filesWithSoftLinks, softLinkFile)
}
}
// 过滤具有软连接的文件
@ -338,6 +340,25 @@ func (file *File) CanCopy() bool {
return file.UploadSessionID == nil
}
// CreateOrGetSourceLink creates a SourceLink model. If the given model exists, the existing
// model will be returned.
func (file *File) CreateOrGetSourceLink() (*SourceLink, error) {
res := &SourceLink{}
err := DB.Set("gorm:auto_preload", true).Where("file_id = ?", file.ID).Find(&res).Error
if err == nil && res.ID > 0 {
return res, nil
}
res.FileID = file.ID
res.Name = file.Name
if err := DB.Save(res).Error; err != nil {
return nil, fmt.Errorf("failed to insert SourceLink: %w", err)
}
res.File = *file
return res, nil
}
/*
实现 webdav.FileInfo 接口
*/

View file

@ -285,30 +285,34 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
},
}
// 传入空文件列表
{
file, err := RemoveFilesWithSoftLinks([]File{})
asserts.NoError(err)
asserts.Empty(file)
}
// 全都没有
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WithArgs("1.txt", 23, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
asserts.Equal(files, file)
}
// 查询出错
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WillReturnError(errors.New("error"))
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.Error(err)
asserts.Nil(file)
}
// 第二个是软链
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WithArgs("1.txt", 23, 1).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 24, "2.txt"),
@ -318,14 +322,18 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
asserts.NoError(err)
asserts.Equal(files[:1], file)
}
// 第一个是软链
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WithArgs("1.txt", 23, 1).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 23, "1.txt"),
)
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
asserts.NoError(err)
@ -334,11 +342,16 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) {
// 全部是软链
{
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("1.txt", 23, 1, "2.txt", 24, 2).
WithArgs("1.txt", 23, 1).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 24, "2.txt").
AddRow(4, 23, "1.txt"),
AddRow(3, 23, "1.txt"),
)
mock.ExpectQuery("SELECT(.+)files(.+)").
WithArgs("2.txt", 24, 2).
WillReturnRows(
sqlmock.NewRows([]string{"id", "policy_id", "source_name"}).
AddRow(3, 24, "2.txt"),
)
file, err := RemoveFilesWithSoftLinks(files)
asserts.NoError(mock.ExpectationsWereMet())
@ -598,3 +611,44 @@ func TestGetFilesByKeywords(t *testing.T) {
asserts.Len(res, 1)
}
}
func TestFile_CreateOrGetSourceLink(t *testing.T) {
a := assert.New(t)
file := &File{}
file.ID = 1
// 已存在,返回老的 SourceLink
{
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
res, err := file.CreateOrGetSourceLink()
a.NoError(err)
a.EqualValues(2, res.ID)
a.NoError(mock.ExpectationsWereMet())
}
// 不存在,插入失败
{
expectedErr := errors.New("error")
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnError(expectedErr)
mock.ExpectRollback()
res, err := file.CreateOrGetSourceLink()
a.Nil(res)
a.ErrorIs(err, expectedErr)
a.NoError(mock.ExpectationsWereMet())
}
// 成功
{
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}))
mock.ExpectBegin()
mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(2, 1))
mock.ExpectCommit()
res, err := file.CreateOrGetSourceLink()
a.NoError(err)
a.EqualValues(2, res.ID)
a.EqualValues(file.ID, res.File.ID)
a.NoError(mock.ExpectationsWereMet())
}
}

View file

@ -224,7 +224,7 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6
} else if IDCache, ok := newIDCache[*folder.ParentID]; ok {
newID = IDCache
} else {
util.Log().Warning("Failed to get parent folder %q", folder.ParentID)
util.Log().Warning("Failed to get parent folder %q", *folder.ParentID)
return size, errors.New("Failed to get parent folder")
}

View file

@ -23,16 +23,17 @@ type Group struct {
// GroupOption 用户组其他配置
type GroupOption struct {
ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载
ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩
CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小
DecompressSize uint64 `json:"decompress_size,omitempty"`
OneTimeDownload bool `json:"one_time_download,omitempty"`
ShareDownload bool `json:"share_download,omitempty"`
Aria2 bool `json:"aria2,omitempty"` // 离线下载
Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置
SourceBatchSize int `json:"source_batch,omitempty"`
Aria2BatchSize int `json:"aria2_batch,omitempty"`
ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载
ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩
CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小
DecompressSize uint64 `json:"decompress_size,omitempty"`
OneTimeDownload bool `json:"one_time_download,omitempty"`
ShareDownload bool `json:"share_download,omitempty"`
Aria2 bool `json:"aria2,omitempty"` // 离线下载
Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置
SourceBatchSize int `json:"source_batch,omitempty"`
RedirectedSource bool `json:"redirected_source,omitempty"`
Aria2BatchSize int `json:"aria2_batch,omitempty"`
}
// GetGroupByID 用ID获取用户组
@ -66,7 +67,7 @@ func (group *Group) BeforeSave() (err error) {
return err
}
//SerializePolicyList 将序列后的可选策略列表、配置写入数据库字段
// SerializePolicyList 将序列后的可选策略列表、配置写入数据库字段
// TODO 完善测试
func (group *Group) SerializePolicyList() (err error) {
policies, err := json.Marshal(&group.PolicyList)

View file

@ -19,7 +19,7 @@ func needMigration() bool {
return DB.Where("name = ?", "db_version_"+conf.RequiredDBVersion).First(&setting).Error != nil
}
//执行数据迁移
// 执行数据迁移
func migration() {
// 确认是否需要执行迁移
if !needMigration() {
@ -41,7 +41,7 @@ func migration() {
}
DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{},
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{})
&Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}, &SourceLink{})
// 创建初始存储策略
addDefaultPolicy()
@ -104,12 +104,13 @@ func addDefaultGroups() {
ShareEnabled: true,
WebDAVEnabled: true,
OptionsSerialized: GroupOption{
ArchiveDownload: true,
ArchiveTask: true,
ShareDownload: true,
Aria2: true,
SourceBatchSize: 1000,
Aria2BatchSize: 50,
ArchiveDownload: true,
ArchiveTask: true,
ShareDownload: true,
Aria2: true,
SourceBatchSize: 1000,
Aria2BatchSize: 50,
RedirectedSource: true,
},
}
if err := DB.Create(&defaultAdminGroup).Error; err != nil {
@ -128,9 +129,10 @@ func addDefaultGroups() {
ShareEnabled: true,
WebDAVEnabled: true,
OptionsSerialized: GroupOption{
ShareDownload: true,
SourceBatchSize: 10,
Aria2BatchSize: 1,
ShareDownload: true,
SourceBatchSize: 10,
Aria2BatchSize: 1,
RedirectedSource: true,
},
}
if err := DB.Create(&defaultAdminGroup).Error; err != nil {

View file

@ -15,12 +15,12 @@ var availableScripts = make(map[string]DBScript)
func RunDBScript(name string, ctx context.Context) error {
if script, ok := availableScripts[name]; ok {
util.Log().Info("开始执行数据库脚本 [%s]", name)
util.Log().Info("Start executing database script %q.", name)
script.Run(ctx)
return nil
}
return fmt.Errorf("数据库脚本 [%s] 不存在", name)
return fmt.Errorf("Database script %q not exist.", name)
}
func Register(name string, script DBScript) {

View file

@ -14,7 +14,7 @@ func (script ResetAdminPassword) Run(ctx context.Context) {
// 查找用户
user, err := model.GetUserByID(1)
if err != nil {
util.Log().Panic("初始管理员用户不存在, %s", err)
util.Log().Panic("Initial admin user not exist: %s", err)
}
// 生成密码
@ -23,9 +23,9 @@ func (script ResetAdminPassword) Run(ctx context.Context) {
// 更改为新密码
user.SetPassword(password)
if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil {
util.Log().Panic("密码更改失败, %s", err)
util.Log().Panic("Failed to update password: %s", err)
}
c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold)
util.Log().Info("初始管理员密码已更改为:" + c.Sprint(password))
util.Log().Info("Initial admin user password changed to:" + c.Sprint(password))
}

View file

@ -25,7 +25,7 @@ func (script UserStorageCalibration) Run(ctx context.Context) {
model.DB.Model(&model.File{}).Where("user_id = ?", user.ID).Select("sum(size) as total").Scan(&total)
// 更新用户的容量
if user.Storage != total.Total {
util.Log().Info("将用户 [%s] 的容量由 %d 校准为 %d", user.Email,
util.Log().Info("Calibrate used storage for user %q, from %d to %d.", user.Email,
user.Storage, total.Total)
}
model.DB.Model(&user).Update("storage", total.Total)

47
models/source_link.go Normal file
View file

@ -0,0 +1,47 @@
package model
import (
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/jinzhu/gorm"
"net/url"
)
// SourceLink represent a shared file source link
type SourceLink struct {
gorm.Model
FileID uint // corresponding file ID
Name string // name of the file while creating the source link, for annotation
Downloads int // 下载数
// 关联模型
File File `gorm:"save_associations:false:false"`
}
// Link gets the URL of a SourceLink
func (s *SourceLink) Link() (string, error) {
baseURL := GetSiteURL()
linkPath, err := url.Parse(fmt.Sprintf("/f/%s/%s", hashid.HashID(s.ID, hashid.SourceLinkID), s.File.Name))
if err != nil {
return "", err
}
return baseURL.ResolveReference(linkPath).String(), nil
}
// GetTasksByID queries source link based on ID
func GetSourceLinkByID(id interface{}) (*SourceLink, error) {
link := &SourceLink{}
result := DB.Where("id = ?", id).First(link)
files, _ := GetFilesByIDs([]uint{link.FileID}, 0)
if len(files) > 0 {
link.File = files[0]
}
return link, result.Error
}
// Viewed 增加访问次数
func (s *SourceLink) Downloaded() {
s.Downloads++
DB.Model(s).UpdateColumn("downloads", gorm.Expr("downloads + ?", 1))
}

View file

@ -0,0 +1,52 @@
package model
import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestSourceLink_Link(t *testing.T) {
a := assert.New(t)
s := &SourceLink{}
s.ID = 1
// 失败
{
s.File.Name = string([]byte{0x7f})
res, err := s.Link()
a.Error(err)
a.Empty(res)
}
// 成功
{
s.File.Name = "filename"
res, err := s.Link()
a.NoError(err)
a.Contains(res, s.Name)
}
}
func TestGetSourceLinkByID(t *testing.T) {
a := assert.New(t)
mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2))
mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
res, err := GetSourceLinkByID(1)
a.NoError(err)
a.NotNil(res)
a.EqualValues(2, res.File.ID)
a.NoError(mock.ExpectationsWereMet())
}
func TestSourceLink_Downloaded(t *testing.T) {
a := assert.New(t)
s := &SourceLink{}
s.ID = 1
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
s.Downloaded()
a.NoError(mock.ExpectationsWereMet())
}

View file

@ -52,7 +52,7 @@ const (
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "", nil)
ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "not enabled", nil)
// ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeUserNotFound, "", nil)
)

View file

@ -38,6 +38,7 @@ type ssl struct {
type unix struct {
Listen string
ProxyHeader string `validate:"required_with=Listen"`
Perm uint32
}
// slave 作为slave存储端配置

View file

@ -1,13 +1,13 @@
package conf
// BackendVersion 当前后端版本号
var BackendVersion = "3.5.3"
var BackendVersion = "3.6.0"
// RequiredDBVersion 与当前版本匹配的数据库版本
var RequiredDBVersion = "3.5.2"
var RequiredDBVersion = "3.6.0"
// RequiredStaticVersion 与当前版本匹配的静态资源版本
var RequiredStaticVersion = "3.5.3"
var RequiredStaticVersion = "3.6.0"
// IsPro 是否为Pro版本
var IsPro = "false"

View file

@ -1,14 +1,22 @@
package backoff
import "time"
import (
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"net/http"
"strconv"
"time"
)
// Backoff used for retry sleep backoff
type Backoff interface {
Next() bool
Next(err error) bool
Reset()
}
// ConstantBackoff implements Backoff interface with constant sleep time
// ConstantBackoff implements Backoff interface with constant sleep time. If the error
// is retryable and with `RetryAfter` defined, the `RetryAfter` will be used as sleep duration.
type ConstantBackoff struct {
Sleep time.Duration
Max int
@ -16,16 +24,51 @@ type ConstantBackoff struct {
tried int
}
func (c *ConstantBackoff) Next() bool {
func (c *ConstantBackoff) Next(err error) bool {
c.tried++
if c.tried > c.Max {
return false
}
time.Sleep(c.Sleep)
var e *RetryableError
if errors.As(err, &e) && e.RetryAfter > 0 {
util.Log().Warning("Retryable error %q occurs in backoff, will sleep after %s.", e, e.RetryAfter)
time.Sleep(e.RetryAfter)
} else {
time.Sleep(c.Sleep)
}
return true
}
func (c *ConstantBackoff) Reset() {
c.tried = 0
}
type RetryableError struct {
Err error
RetryAfter time.Duration
}
// NewRetryableErrorFromHeader constructs a new RetryableError from http response header
// and existing error.
func NewRetryableErrorFromHeader(err error, header http.Header) *RetryableError {
retryAfter := header.Get("retry-after")
if retryAfter == "" {
retryAfter = "0"
}
res := &RetryableError{
Err: err,
}
if retryAfterSecond, err := strconv.ParseInt(retryAfter, 10, 64); err == nil {
res.RetryAfter = time.Duration(retryAfterSecond) * time.Second
}
return res
}
func (e *RetryableError) Error() string {
return fmt.Sprintf("retryable error with retry-after=%s: %s", e.RetryAfter, e.Err)
}

View file

@ -1,7 +1,9 @@
package backoff
import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
"time"
)
@ -9,14 +11,51 @@ import (
func TestConstantBackoff_Next(t *testing.T) {
a := assert.New(t)
b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
a.True(b.Next())
a.True(b.Next())
a.True(b.Next())
a.False(b.Next())
b.Reset()
a.True(b.Next())
a.True(b.Next())
a.True(b.Next())
a.False(b.Next())
// General error
{
err := errors.New("error")
b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
b.Reset()
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
}
// Retryable error
{
err := &RetryableError{RetryAfter: time.Duration(1)}
b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3}
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
b.Reset()
a.True(b.Next(err))
a.True(b.Next(err))
a.True(b.Next(err))
a.False(b.Next(err))
}
}
func TestNewRetryableErrorFromHeader(t *testing.T) {
a := assert.New(t)
// no retry-after header
{
err := NewRetryableErrorFromHeader(nil, http.Header{})
a.Empty(err.RetryAfter)
}
// with retry-after header
{
header := http.Header{}
header.Add("retry-after", "120")
err := NewRetryableErrorFromHeader(nil, header)
a.EqualValues(time.Duration(120)*time.Second, err.RetryAfter)
}
}

View file

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"os"
@ -66,7 +67,7 @@ func (c *ChunkGroup) TempAvailable() bool {
// Process a chunk with retry logic
func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
reader := io.LimitReader(c.file, int64(c.chunkSize))
reader := io.LimitReader(c.file, c.Length())
// If useBuffer is enabled, tee the reader to a temp file
if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() {
@ -90,13 +91,17 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
}
util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name())
reader = c.bufferTemp
reader = io.NopCloser(c.bufferTemp)
}
}
err := processor(c, reader)
if err != nil {
if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next() {
if c.enableRetryBuffer {
request.BlackHole(reader)
}
if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next(err) {
if c.file.Seekable() {
if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil {
return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err)

View file

@ -36,7 +36,7 @@ func TestHandler_Put(t *testing.T) {
{&fsctx.FileStream{
SavePath: "TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("")),
}, "物理同名文件已存在或不可用"},
}, "file with the same name existed or unavailable"},
{&fsctx.FileStream{
SavePath: "inner/TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("")),
@ -51,7 +51,7 @@ func TestHandler_Put(t *testing.T) {
Mode: fsctx.Append | fsctx.Overwrite,
SavePath: "inner/TestHandler_Put.txt",
File: io.NopCloser(strings.NewReader("123")),
}, "未上传完成的文件分片与预期大小不一致"},
}, "size of unfinished uploaded chunks is not as expected"},
{&fsctx.FileStream{
Mode: fsctx.Append | fsctx.Overwrite,
SavePath: "inner/TestHandler_Put.txt",

View file

@ -7,7 +7,6 @@ import (
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
@ -37,24 +36,18 @@ const (
// GetSourcePath 获取文件的绝对路径
func (info *FileInfo) GetSourcePath() string {
res, err := url.PathUnescape(
strings.TrimPrefix(
path.Join(
strings.TrimPrefix(info.ParentReference.Path, "/drive/root:"),
info.Name,
),
"/",
),
)
res, err := url.PathUnescape(info.ParentReference.Path)
if err != nil {
return ""
}
return res
}
// Error 实现error接口
func (err RespError) Error() string {
return err.APIError.Message
return strings.TrimPrefix(
path.Join(
strings.TrimPrefix(res, "/drive/root:"),
info.Name,
),
"/",
)
}
func (client *Client) getRequestURL(api string, opts ...Option) string {
@ -531,7 +524,7 @@ func sysError(err error) *RespError {
}}
}
func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, *RespError) {
func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) {
// 获取凭证
err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave")
if err != nil {
@ -580,15 +573,21 @@ func (client *Client) request(ctx context.Context, method string, url string, bo
util.Log().Debug("Onedrive returns unknown response: %s", respBody)
return "", sysError(decodeErr)
}
if res.Response.StatusCode == 429 {
util.Log().Warning("OneDrive request is throttled.")
return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header)
}
return "", &errResp
}
return respBody, nil
}
func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, *RespError) {
func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) {
// 发送请求
bodyReader := ioutil.NopCloser(strings.NewReader(body))
bodyReader := io.NopCloser(strings.NewReader(body))
return client.request(ctx, method, url, bodyReader,
request.WithContentLength(int64(len(body))),
)

View file

@ -112,6 +112,35 @@ func TestRequest(t *testing.T) {
asserts.Equal("error msg", err.Error())
}
// OneDrive返回429错误
{
header := http.Header{}
header.Add("retry-after", "120")
clientMock := ClientMock{}
clientMock.On(
"Request",
"POST",
"http://dev.com",
testMock.Anything,
testMock.Anything,
).Return(&request.Response{
Err: nil,
Response: &http.Response{
StatusCode: 429,
Header: header,
Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)),
},
})
client.Request = clientMock
res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader(""))
clientMock.AssertExpectations(t)
asserts.Error(err)
asserts.Empty(res)
var retryErr *backoff.RetryableError
asserts.ErrorAs(err, &retryErr)
asserts.EqualValues(time.Duration(120)*time.Second, retryErr.RetryAfter)
}
// OneDrive返回未知响应
{
clientMock := ClientMock{}
@ -144,18 +173,18 @@ func TestFileInfo_GetSourcePath(t *testing.T) {
fileInfo := FileInfo{
Name: "%e6%96%87%e4%bb%b6%e5%90%8d.jpg",
ParentReference: parentReference{
Path: "/drive/root:/123/321",
Path: "/drive/root:/123/32%201",
},
}
asserts.Equal("123/321/文件名.jpg", fileInfo.GetSourcePath())
asserts.Equal("123/32 1/%e6%96%87%e4%bb%b6%e5%90%8d.jpg", fileInfo.GetSourcePath())
}
// 失败
{
fileInfo := FileInfo{
Name: "%e6%96%87%e4%bb%b6%e5%90%8g.jpg",
Name: "123.jpg",
ParentReference: parentReference{
Path: "/drive/root:/123/321",
Path: "/drive/root:/123/%e6%96%87%e4%bb%b6%e5%90%8g",
},
}
asserts.Equal("", fileInfo.GetSourcePath())

View file

@ -11,7 +11,6 @@ import (
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
@ -171,19 +170,6 @@ func (handler Driver) Source(
cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path)
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID)
// 如果是永久链接,则返回签名后的中转外链
if ttl == 0 {
signedURI, err := auth.SignURI(
auth.General,
fmt.Sprintf("/api/v3/file/source/%d/%s", file.ID, file.Name),
ttl,
)
if err != nil {
return "", err
}
return baseURL.ResolveReference(signedURI).String(), nil
}
}
// 尝试从缓存中查找

View file

@ -3,7 +3,6 @@ package onedrive
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/jinzhu/gorm"
@ -161,21 +160,6 @@ func TestDriver_Source(t *testing.T) {
asserts.NoError(err)
asserts.Equal("123321", res)
}
// 成功 永久直链
{
file := model.File{}
file.ID = 1
file.Name = "123.jpg"
file.UpdatedAt = time.Now()
ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
auth.General = auth.HMACAuth{}
handler.Client.Credential.AccessToken = "1"
res, err := handler.Source(ctx, "123.jpg", url.URL{}, 0, true, 0)
asserts.NoError(err)
asserts.Contains(res, "/api/v3/file/source/1/123.jpg?sign")
}
}
func TestDriver_List(t *testing.T) {

View file

@ -133,3 +133,8 @@ type Site struct {
func init() {
gob.Register(Credential{})
}
// Error 实现error接口
func (err RespError) Error() string {
return err.APIError.Message
}

View file

@ -398,8 +398,8 @@ func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *seri
// Meta 获取文件信息
func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
res, err := handler.svc.GetObject(
&s3.GetObjectInput{
res, err := handler.svc.HeadObject(
&s3.HeadObjectInput{
Bucket: &handler.Policy.BucketName,
Key: &path,
})

View file

@ -8,17 +8,17 @@ import (
var (
ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil)
ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "", nil)
ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "", nil)
ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "", nil)
ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "", nil)
ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "File is too large", nil)
ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "File type not allowed", nil)
ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "Insufficient capacity", nil)
ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "Invalid object name", nil)
ErrClientCanceled = errors.New("Client canceled operation")
ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "", nil)
ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "Root protected", nil)
ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "Failed to create file record", nil)
ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "", nil)
ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "", nil)
ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "", nil)
ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "", nil)
ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "Object existed", nil)
ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "Upload session existed", nil)
ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "Path not exist", nil)
ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "Object not exist", nil)
ErrIO = serializer.NewError(serializer.CodeIOFailed, "Failed to read file data", nil)
ErrDBListObjects = serializer.NewError(serializer.CodeDBError, "Failed to list object records", nil)
ErrDBDeleteObjects = serializer.NewError(serializer.CodeDBError, "Failed to delete object records", nil)

View file

@ -472,6 +472,9 @@ func TestFileSystem_Delete(t *testing.T) {
AddRow(4, "1.txt", "1.txt", 365, 1),
)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 365, 2))
// 两次查询软连接
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
// 查询上传策略
@ -527,6 +530,9 @@ func TestFileSystem_Delete(t *testing.T) {
AddRow(4, "1.txt", "1.txt", 602, 1),
)
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 602, 2))
// 两次查询软连接
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
mock.ExpectQuery("SELECT(.+)files(.+)").
WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"}))
// 查询上传策略

View file

@ -15,11 +15,12 @@ const (
FolderID // 目录ID
TagID // 标签ID
PolicyID // 存储策略ID
SourceLinkID
)
var (
// ErrTypeNotMatch ID类型不匹配
ErrTypeNotMatch = errors.New("ID类型不匹配")
ErrTypeNotMatch = errors.New("mismatched ID type.")
)
// HashEncode 对给定数据计算HashID

View file

@ -44,6 +44,12 @@ func newDefaultOption() *options {
}
}
func (o *options) clone() options {
newOptions := *o
newOptions.header = o.header.Clone()
return newOptions
}
// WithTimeout 设置请求超时
func WithTimeout(t time.Duration) Option {
return optionFunc(func(o *options) {

View file

@ -56,7 +56,7 @@ func NewClient(opts ...Option) Client {
func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response {
// 应用额外设置
c.mu.Lock()
options := *c.options
options := c.options.clone()
c.mu.Unlock()
for _, o := range opts {
o.apply(&options)
@ -179,7 +179,7 @@ func (resp *Response) DecodeResponse() (*serializer.Response, error) {
var res serializer.Response
err = json.Unmarshal([]byte(respString), &res)
if err != nil {
util.Log().Debug("无法解析回调服务端响应:%s", string(respString))
util.Log().Debug("Failed to parse response: %s", string(respString))
return nil, err
}
return &res, nil
@ -251,7 +251,7 @@ func (instance NopRSCloser) Seek(offset int64, whence int) (int64, error) {
return instance.status.Size, nil
}
}
return 0, errors.New("未实现")
return 0, errors.New("not implemented")
}

View file

@ -19,6 +19,7 @@ type DownloadListResponse struct {
Downloaded uint64 `json:"downloaded"`
Speed int `json:"speed"`
Info rpc.StatusInfo `json:"info"`
NodeName string `json:"node"`
}
// FinishedListResponse 已完成任务条目
@ -34,6 +35,7 @@ type FinishedListResponse struct {
TaskError string `json:"task_error"`
CreateTime time.Time `json:"create"`
UpdateTime time.Time `json:"update"`
NodeName string `json:"node"`
}
// BuildFinishedListResponse 构建已完成任务条目
@ -62,6 +64,7 @@ func BuildFinishedListResponse(tasks []model.Download) Response {
TaskStatus: -1,
UpdateTime: tasks[i].UpdatedAt,
CreateTime: tasks[i].CreatedAt,
NodeName: tasks[i].NodeName,
}
if tasks[i].Task != nil {
@ -106,6 +109,7 @@ func BuildDownloadingResponse(tasks []model.Download, intervals map[uint]int) Re
Downloaded: tasks[i].DownloadedSize,
Speed: tasks[i].Speed,
Info: tasks[i].StatusInfo,
NodeName: tasks[i].NodeName,
})
}

View file

@ -221,7 +221,7 @@ const (
// DBErr 数据库操作失败
func DBErr(msg string, err error) Response {
if msg == "" {
msg = "数据库操作失败"
msg = "Database operation failed."
}
return Err(CodeDBError, msg, err)
}
@ -229,7 +229,7 @@ func DBErr(msg string, err error) Response {
// ParamErr 各种参数错误
func ParamErr(msg string, err error) Response {
if msg == "" {
msg = "参数错误"
msg = "Invalid parameters."
}
return Err(CodeParamErr, msg, err)
}

View file

@ -19,7 +19,7 @@ func NewResponseWithGobData(data interface{}) Response {
var w bytes.Buffer
encoder := gob.NewEncoder(&w)
if err := encoder.Encode(data); err != nil {
return Err(CodeInternalSetting, "无法编码返回结果", err)
return Err(CodeInternalSetting, "Failed to encode response content", err)
}
return Response{Data: w.Bytes()}

View file

@ -22,6 +22,7 @@ type SiteConfig struct {
CaptchaType string `json:"captcha_type"`
TCaptchaCaptchaAppId string `json:"tcaptcha_captcha_app_id"`
RegisterEnabled bool `json:"registerEnabled"`
AppPromotion bool `json:"app_promotion"`
}
type task struct {
@ -83,6 +84,7 @@ func BuildSiteConfig(settings map[string]string, user *model.User) Response {
CaptchaType: checkSettingValue(settings, "captcha_type"),
TCaptchaCaptchaAppId: checkSettingValue(settings, "captcha_TCaptcha_CaptchaAppId"),
RegisterEnabled: model.IsTrueVal(checkSettingValue(settings, "register_enabled")),
AppPromotion: model.IsTrueVal(checkSettingValue(settings, "show_app_promotion")),
}}
return res
}

View file

@ -19,14 +19,3 @@ func TestSlaveTransferReq_Hash(t *testing.T) {
}
a.NotEqual(s1.Hash("1"), s2.Hash("1"))
}
func TestSlaveRecycleReq_Hash(t *testing.T) {
a := assert.New(t)
s1 := &SlaveRecycleReq{
Path: "1",
}
s2 := &SlaveRecycleReq{
Path: "2",
}
a.NotEqual(s1.Hash("1"), s2.Hash("1"))
}

View file

@ -13,7 +13,7 @@ import (
func CheckLogin() Response {
return Response{
Code: CodeCheckLogin,
Msg: "未登录",
Msg: "Login required",
}
}

View file

@ -69,7 +69,7 @@ func (job *CompressTask) SetError(err *JobError) {
func (job *CompressTask) removeZipFile() {
if job.zipPath != "" {
if err := os.Remove(job.zipPath); err != nil {
util.Log().Warning("无法删除临时压缩文件 %s , %s", job.zipPath, err)
util.Log().Warning("Failed to delete temp zip file %q: %s", job.zipPath, err)
}
}
}
@ -93,7 +93,7 @@ func (job *CompressTask) Do() {
return
}
util.Log().Debug("开始压缩文件")
util.Log().Debug("Starting compress file...")
job.TaskModel.SetProgress(CompressingProgress)
// 创建临时压缩文件
@ -122,7 +122,7 @@ func (job *CompressTask) Do() {
job.zipPath = zipFilePath
zipFile.Close()
util.Log().Debug("压缩文件存放至%s开始上传", zipFilePath)
util.Log().Debug("Compressed file saved to %q, start uploading it...", zipFilePath)
job.TaskModel.SetProgress(TransferringProgress)
// 上传文件

View file

@ -77,7 +77,7 @@ func (job *DecompressTask) Do() {
// 创建文件系统
fs, err := filesystem.NewFileSystem(job.User)
if err != nil {
job.SetErrorMsg("无法创建文件系统", err)
job.SetErrorMsg("Failed to create filesystem.", err)
return
}
@ -85,7 +85,7 @@ func (job *DecompressTask) Do() {
err = fs.Decompress(context.Background(), job.TaskProps.Src, job.TaskProps.Dst, job.TaskProps.Encoding)
if err != nil {
job.SetErrorMsg("解压缩失败", err)
job.SetErrorMsg("Failed to decompress file.", err)
return
}

View file

@ -4,5 +4,5 @@ import "errors"
var (
// ErrUnknownTaskType 未知任务类型
ErrUnknownTaskType = errors.New("未知任务类型")
ErrUnknownTaskType = errors.New("unknown task type")
)

View file

@ -81,7 +81,7 @@ func (job *ImportTask) Do() {
// 查找存储策略
policy, err := model.GetPolicyByID(job.TaskProps.PolicyID)
if err != nil {
job.SetErrorMsg("找不到存储策略", err)
job.SetErrorMsg("Policy not exist.", err)
return
}
@ -96,7 +96,7 @@ func (job *ImportTask) Do() {
fs.Policy = &policy
if err := fs.DispatchHandler(); err != nil {
job.SetErrorMsg("无法分发存储策略", err)
job.SetErrorMsg("Failed to dispatch policy.", err)
return
}
@ -110,7 +110,7 @@ func (job *ImportTask) Do() {
true)
objects, err := fs.Handler.List(ctx, job.TaskProps.Src, job.TaskProps.Recursive)
if err != nil {
job.SetErrorMsg("无法列取文件", err)
job.SetErrorMsg("Failed to list files.", err)
return
}
@ -126,7 +126,7 @@ func (job *ImportTask) Do() {
virtualPath := path.Join(job.TaskProps.Dst, object.RelativePath)
folder, err := fs.CreateDirectory(coxIgnoreConflict, virtualPath)
if err != nil {
util.Log().Warning("导入任务无法创建用户目录[%s], %s", virtualPath, err)
util.Log().Warning("Importing task cannot create user directory %q: %s", virtualPath, err)
} else if folder.ID > 0 {
pathCache[virtualPath] = folder
}
@ -152,7 +152,7 @@ func (job *ImportTask) Do() {
} else {
folder, err := fs.CreateDirectory(context.Background(), virtualPath)
if err != nil {
util.Log().Warning("导入任务无法创建用户目录[%s], %s",
util.Log().Warning("Importing task cannot create user directory %q: %s",
virtualPath, err)
continue
}
@ -163,10 +163,10 @@ func (job *ImportTask) Do() {
// 插入文件记录
_, err := fs.AddFile(context.Background(), parentFolder, &fileHeader)
if err != nil {
util.Log().Warning("导入任务无法创插入文件[%s], %s",
util.Log().Warning("Importing task cannot insert user file %q: %s",
object.RelativePath, err)
if err == filesystem.ErrInsufficientCapacity {
job.SetErrorMsg("容量不足", err)
job.SetErrorMsg("Insufficient storage capacity.", err)
return
}
}

View file

@ -89,12 +89,12 @@ func Resume(p Pool) {
if len(tasks) == 0 {
return
}
util.Log().Info("从数据库中恢复 %d 个未完成任务", len(tasks))
util.Log().Info("Resume %d unfinished task(s) from database.", len(tasks))
for i := 0; i < len(tasks); i++ {
job, err := GetJobFromModel(&tasks[i])
if err != nil {
util.Log().Warning("无法恢复任务,%s", err)
util.Log().Warning("Failed to resume task: %s", err)
continue
}

View file

@ -44,11 +44,11 @@ func (pool *AsyncPool) freeWorker() {
// Submit 开始提交任务
func (pool *AsyncPool) Submit(job Job) {
go func() {
util.Log().Debug("等待获取Worker")
util.Log().Debug("Waiting for Worker.")
worker := pool.obtainWorker()
util.Log().Debug("获取到Worker")
util.Log().Debug("Worker obtained.")
worker.Do(job)
util.Log().Debug("释放Worker")
util.Log().Debug("Worker released.")
pool.freeWorker()
}()
}
@ -60,7 +60,7 @@ func Init() {
idleWorker: make(chan int, maxWorker),
}
TaskPoll.Add(maxWorker)
util.Log().Info("初始化任务队列,WorkerNum = %d", maxWorker)
util.Log().Info("Initialize task queue with WorkerNum = %d", maxWorker)
if conf.SystemConfig.Mode == "master" {
Resume(TaskPoll)

View file

@ -73,21 +73,21 @@ func (job *RecycleTask) GetError() *JobError {
func (job *RecycleTask) Do() {
download, err := model.GetDownloadByGid(job.TaskProps.DownloadGID, job.User.ID)
if err != nil {
util.Log().Warning("回收任务 %d 找不到下载记录", job.TaskModel.ID)
job.SetErrorMsg("无法找到下载任务", err)
util.Log().Warning("Recycle task %d cannot found download record.", job.TaskModel.ID)
job.SetErrorMsg("Cannot found download task.", err)
return
}
nodeID := download.GetNodeID()
node := cluster.Default.GetNodeByID(nodeID)
if node == nil {
util.Log().Warning("回收任务 %d 找不到节点", job.TaskModel.ID)
job.SetErrorMsg("从机节点不可用", nil)
util.Log().Warning("Recycle task %d cannot found node.", job.TaskModel.ID)
job.SetErrorMsg("Invalid slave node.", nil)
return
}
err = node.GetAria2Instance().DeleteTempFile(download)
if err != nil {
util.Log().Warning("无法删除中转临时目录[%s], %s", download.Parent, err)
job.SetErrorMsg("文件回收失败", err)
util.Log().Warning("Failed to delete transfer temp folder %q: %s", download.Parent, err)
job.SetErrorMsg("Failed to recycle files.", err)
return
}
}

View file

@ -69,7 +69,7 @@ func (job *TransferTask) SetErrorMsg(msg string, err error) {
}
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil {
util.Log().Warning("无法发送转存失败通知到从机, %s", err)
util.Log().Warning("Failed to send transfer failure notification to master node: %s", err)
}
}
@ -82,26 +82,26 @@ func (job *TransferTask) GetError() *task.JobError {
func (job *TransferTask) Do() {
fs, err := filesystem.NewAnonymousFileSystem()
if err != nil {
job.SetErrorMsg("无法初始化匿名文件系统", err)
job.SetErrorMsg("Failed to initialize anonymous filesystem.", err)
return
}
fs.Policy = job.Req.Policy
if err := fs.DispatchHandler(); err != nil {
job.SetErrorMsg("无法分发存储策略", err)
job.SetErrorMsg("Failed to dispatch policy.", err)
return
}
master, err := cluster.DefaultController.GetMasterInfo(job.MasterID)
if err != nil {
job.SetErrorMsg("找不到主机节点", err)
job.SetErrorMsg("Cannot found master node ID.", err)
return
}
fs.SwitchToShadowHandler(master.Instance, master.URL.String(), master.ID)
file, err := os.Open(util.RelativePath(job.Req.Src))
if err != nil {
job.SetErrorMsg("无法读取源文件", err)
job.SetErrorMsg("Failed to read source file.", err)
return
}
@ -110,7 +110,7 @@ func (job *TransferTask) Do() {
// 获取源文件大小
fi, err := file.Stat()
if err != nil {
job.SetErrorMsg("无法获取源文件大小", err)
job.SetErrorMsg("Failed to get source file size.", err)
return
}
@ -122,7 +122,7 @@ func (job *TransferTask) Do() {
Size: uint64(size),
})
if err != nil {
job.SetErrorMsg("文件上传失败", err)
job.SetErrorMsg("Upload failed.", err)
return
}
@ -133,6 +133,6 @@ func (job *TransferTask) Do() {
}
if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil {
util.Log().Warning("无法发送转存成功通知到从机, %s", err)
util.Log().Warning("Failed to send transfer success notification to master node: %s", err)
}
}

View file

@ -3,6 +3,7 @@ package task
import (
"context"
"encoding/json"
"fmt"
"path"
"path/filepath"
"strings"
@ -94,6 +95,7 @@ func (job *TransferTask) Do() {
}
successCount := 0
errorList := make([]string, 0, len(job.TaskProps.Src))
for _, file := range job.TaskProps.Src {
dst := path.Join(job.TaskProps.Dst, filepath.Base(file))
if job.TaskProps.TrimPath {
@ -109,7 +111,7 @@ func (job *TransferTask) Do() {
// 获取从机节点
node := cluster.Default.GetNodeByID(job.TaskProps.NodeID)
if node == nil {
job.SetErrorMsg("从机节点不可用", nil)
job.SetErrorMsg("Invalid slave node.", nil)
}
// 切换为从机节点处理上传
@ -127,13 +129,17 @@ func (job *TransferTask) Do() {
}
if err != nil {
job.SetErrorMsg("文件转存失败", err)
errorList = append(errorList, err.Error())
} else {
successCount++
job.TaskModel.SetProgress(successCount)
}
}
if len(errorList) > 0 {
job.SetErrorMsg("Failed to transfer one or more file(s).", fmt.Errorf(strings.Join(errorList, "\n")))
}
}
// NewTransferTask 新建中转任务

View file

@ -16,14 +16,14 @@ type GeneralWorker struct {
// Do 执行任务
func (worker *GeneralWorker) Do(job Job) {
util.Log().Debug("开始执行任务")
util.Log().Debug("Start executing task.")
job.SetStatus(Processing)
defer func() {
// 致命错误捕获
if err := recover(); err != nil {
util.Log().Debug("任务执行出错,%s", err)
job.SetError(&JobError{Msg: "致命错误", Error: fmt.Sprintf("%s", err)})
util.Log().Debug("Failed to execute task: %s", err)
job.SetError(&JobError{Msg: "Fatal error.", Error: fmt.Sprintf("%s", err)})
job.SetStatus(Error)
}
}()
@ -33,12 +33,12 @@ func (worker *GeneralWorker) Do(job Job) {
// 任务执行失败
if err := job.GetError(); err != nil {
util.Log().Debug("任务执行出错")
util.Log().Debug("Failed to execute task.")
job.SetStatus(Error)
return
}
util.Log().Debug("任务执行完成")
util.Log().Debug("Task finished.")
// 执行完成
job.SetStatus(Complete)
}

View file

@ -45,7 +45,7 @@ func NewThumbFromFile(file io.Reader, name string) (*Thumb, error) {
case "png":
img, err = png.Decode(file)
default:
return nil, errors.New("未知的图像类型")
return nil, errors.New("unknown image format")
}
if err != nil {
return nil, err

View file

@ -22,7 +22,7 @@ func CreatNestedFile(path string) (*os.File, error) {
if !Exists(basePath) {
err := os.MkdirAll(basePath, 0700)
if err != nil {
Log().Warning("无法创建目录,%s", err)
Log().Warning("Failed to create directory: %s", err)
return nil, err
}
}

View file

@ -79,8 +79,8 @@ func AnonymousGetContent(c *gin.Context) {
}
}
// AnonymousPermLink 文件签名后的永久链接
func AnonymousPermLink(c *gin.Context) {
// AnonymousPermLink Deprecated 文件签名后的永久链接
func AnonymousPermLinkDeprecated(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -102,6 +102,39 @@ func AnonymousPermLink(c *gin.Context) {
}
}
// AnonymousPermLink 文件中转后的永久直链接
func AnonymousPermLink(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sourceLinkRaw, ok := c.Get("source_link")
if !ok {
c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil))
return
}
sourceLink := sourceLinkRaw.(*model.SourceLink)
service := &explorer.FileAnonymousGetService{
ID: sourceLink.FileID,
Name: sourceLink.File.Name,
}
res := service.Source(ctx, c)
// 是否需要重定向
if res.Code == -302 {
c.Redirect(302, res.Data.(string))
return
}
// 是否有错误发生
if res.Code != 0 {
c.JSON(200, res)
}
}
func GetSource(c *gin.Context) {
// 创建上下文
ctx, cancel := context.WithCancel(context.Background())

View file

@ -27,6 +27,7 @@ func SiteConfig(c *gin.Context) {
"captcha_type",
"captcha_TCaptcha_CaptchaAppId",
"register_enabled",
"show_app_promotion",
)
// 如果已登录,则同时返回用户信息和标签

View file

@ -16,10 +16,10 @@ import (
// InitRouter 初始化路由
func InitRouter() *gin.Engine {
if conf.SystemConfig.Mode == "master" {
util.Log().Info("当前运行模式Master")
util.Log().Info("Current running mode: Master.")
return InitMasterRouter()
}
util.Log().Info("当前运行模式Slave")
util.Log().Info("Current running mode: Slave.")
return InitSlaveRouter()
}
@ -108,7 +108,7 @@ func InitCORS(router *gin.Engine) {
// slave模式下未启动跨域的警告
if conf.SystemConfig.Mode == "slave" {
util.Log().Warning("当前作为存储端Slave运行但未启用跨域配置可能会导致 Master 端无法正常上传文件")
util.Log().Warning("You are running Cloudreve as slave node, if you are using slave storage policy, please enable CORS feature in config file, otherwise file cannot be uploaded from Master site.")
}
}
@ -145,6 +145,15 @@ func InitMasterRouter() *gin.Engine {
路由
*/
{
// Redirect file source link
source := r.Group("f")
{
source.GET(":id/:name",
middleware.HashID(hashid.SourceLinkID),
middleware.ValidateSourceLink(),
controllers.AnonymousPermLink)
}
// 全局设置相关
site := v3.Group("site")
{
@ -197,6 +206,7 @@ func InitMasterRouter() *gin.Engine {
// 获取用户头像
user.GET("avatar/:id/:size",
middleware.HashID(hashid.UserID),
middleware.StaticResourceCache(),
controllers.GetUserAvatar,
)
}
@ -208,11 +218,18 @@ func InitMasterRouter() *gin.Engine {
file := sign.Group("file")
{
// 文件外链(直接输出文件数据)
file.GET("get/:id/:name", controllers.AnonymousGetContent)
file.GET("get/:id/:name",
middleware.Sandbox(),
middleware.StaticResourceCache(),
controllers.AnonymousGetContent,
)
// 文件外链(301跳转)
file.GET("source/:id/:name", controllers.AnonymousPermLink)
file.GET("source/:id/:name", controllers.AnonymousPermLinkDeprecated)
// 下载文件
file.GET("download/:id", controllers.Download)
file.GET("download/:id",
middleware.StaticResourceCache(),
controllers.Download,
)
// 打包并下载文件
file.GET("archive/:sessionID/archive.zip", controllers.DownloadArchive)
}
@ -445,7 +462,7 @@ func InitMasterRouter() *gin.Engine {
// 列出文件
file.POST("list", controllers.AdminListFile)
// 预览文件
file.GET("preview/:id", controllers.AdminGetFile)
file.GET("preview/:id", middleware.Sandbox(), controllers.AdminGetFile)
// 删除
file.POST("delete", controllers.AdminDeleteFile)
// 列出用户或外部文件系统目录
@ -555,9 +572,9 @@ func InitMasterRouter() *gin.Engine {
// 创建文件下载会话
file.PUT("download/:id", controllers.CreateDownloadSession)
// 预览文件
file.GET("preview/:id", controllers.Preview)
file.GET("preview/:id", middleware.Sandbox(), controllers.Preview)
// 获取文本文件内容
file.GET("content/:id", controllers.PreviewText)
file.GET("content/:id", middleware.Sandbox(), controllers.PreviewText)
// 取得Office文档预览地址
file.GET("doc/:id", controllers.GetDocPreview)
// 获取缩略图

View file

@ -318,12 +318,20 @@ func (service *AdminListService) Policies() serializer.Response {
// 统计每个策略的文件使用
statics := make(map[uint][2]int, len(res))
policyIds := make([]uint, 0, len(res))
for i := 0; i < len(res); i++ {
policyIds = append(policyIds, res[i].ID)
}
rows, _ := model.DB.Model(&model.File{}).Where("policy_id in (?)", policyIds).
Select("policy_id,count(id),sum(size)").Group("policy_id").Rows()
for rows.Next() {
policyId := uint(0)
total := [2]int{}
row := model.DB.Model(&model.File{}).Where("policy_id = ?", res[i].ID).
Select("count(id),sum(size)").Row()
row.Scan(&total[0], &total[1])
statics[res[i].ID] = total
rows.Scan(&policyId, &total[0], &total[1])
statics[policyId] = total
}
return serializer.Response{Data: map[string]interface{}{

View file

@ -109,6 +109,7 @@ func (service *AddUserService) Add() serializer.Response {
user.Email = service.User.Email
user.GroupID = service.User.GroupID
user.Status = service.User.Status
user.TwoFactor = service.User.TwoFactor
// 检查愚蠢操作
if user.ID == 1 && user.GroupID != 1 {

View file

@ -27,6 +27,13 @@ type DownloadListService struct {
func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response {
// 查找下载记录
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Error, common.Complete, common.Canceled, common.Unknown)
for key, download := range downloads {
node := cluster.Default.GetNodeByID(download.GetNodeID())
if node != nil {
downloads[key].NodeName = node.DBModel().Name
}
}
return serializer.BuildFinishedListResponse(downloads)
}
@ -35,12 +42,17 @@ func (service *DownloadListService) Downloading(c *gin.Context, user *model.User
// 查找下载记录
downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready)
intervals := make(map[uint]int)
for _, download := range downloads {
for key, download := range downloads {
if _, ok := intervals[download.ID]; !ok {
if node := cluster.Default.GetNodeByID(download.GetNodeID()); node != nil {
intervals[download.ID] = node.DBModel().Aria2OptionsSerialized.Interval
}
}
node := cluster.Default.GetNodeByID(download.GetNodeID())
if node != nil {
downloads[key].NodeName = node.DBModel().Name
}
}
return serializer.BuildDownloadingResponse(downloads, intervals)

View file

@ -175,7 +175,7 @@ func (service *OneDriveCallback) PreProcess(c *gin.Context) serializer.Response
// SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 1 MB 宽容
// See: https://github.com/OneDrive/onedrive-api-docs/issues/935
if strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.com") && isSizeCheckFailed && (info.Size > uploadSession.Size) && (info.Size-uploadSession.Size <= 1048576) {
if (strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.com") || strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.cn")) && isSizeCheckFailed && (info.Size > uploadSession.Size) && (info.Size-uploadSession.Size <= 1048576) {
isSizeCheckFailed = false
}
@ -239,7 +239,7 @@ func (service *S3Callback) PreProcess(c *gin.Context) serializer.Response {
return ProcessCallback(service, c)
}
// PreProcess 对OneDrive客户端回调进行预处理验证
// PreProcess 对从机客户端回调进行预处理验证
func (service *UploadCallbackService) PreProcess(c *gin.Context) serializer.Response {
// 创建文件系统
fs, err := filesystem.NewFileSystemFromCallback(c)

View file

@ -178,12 +178,13 @@ func (service *FileAnonymousGetService) Source(ctx context.Context, c *gin.Conte
}
// 获取文件流
res, err := fs.SignURL(ctx, &fs.FileTarget[0],
int64(model.GetIntSetting("preview_timeout", 60)), false)
ttl := int64(model.GetIntSetting("preview_timeout", 60))
res, err := fs.SignURL(ctx, &fs.FileTarget[0], ttl, false)
if err != nil {
return serializer.Err(serializer.CodeNotSet, err.Error(), err)
}
c.Header("Cache-Control", fmt.Sprintf("max-age=%d", ttl))
return serializer.Response{
Code: -302,
Data: res,
@ -442,24 +443,48 @@ func (s *ItemIDService) Sources(ctx context.Context, c *gin.Context) serializer.
}
res := make([]serializer.Sources, 0, len(s.Raw().Items))
for _, id := range s.Raw().Items {
fs.FileTarget = []model.File{}
sourceURL, err := fs.GetSource(ctx, id)
if len(fs.FileTarget) > 0 {
current := serializer.Sources{
URL: sourceURL,
Name: fs.FileTarget[0].Name,
Parent: fs.FileTarget[0].FolderID,
}
files, err := model.GetFilesByIDs(s.Raw().Items, fs.User.ID)
if err != nil || len(files) == 0 {
return serializer.Err(serializer.CodeFileNotFound, "", err)
}
getSourceFunc := func(file model.File) (string, error) {
fs.FileTarget = []model.File{file}
return fs.GetSource(ctx, file.ID)
}
// Create redirected source link if needed
if fs.User.Group.OptionsSerialized.RedirectedSource {
getSourceFunc = func(file model.File) (string, error) {
source, err := file.CreateOrGetSourceLink()
if err != nil {
current.Error = err.Error()
return "", err
}
res = append(res, current)
sourceLinkURL, err := source.Link()
if err != nil {
return "", err
}
return sourceLinkURL, nil
}
}
for _, file := range files {
sourceURL, err := getSourceFunc(file)
current := serializer.Sources{
URL: sourceURL,
Name: file.Name,
Parent: file.FolderID,
}
if err != nil {
current.Error = err.Error()
}
res = append(res, current)
}
return serializer.Response{
Code: 0,
Data: res,