diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d6d4f47..ddc7f09 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,10 +7,10 @@ jobs: name: Build runs-on: ubuntu-18.04 steps: - - name: Set up Go 1.17 + - name: Set up Go 1.18 uses: actions/setup-go@v2 with: - go-version: "1.17" + go-version: "1.18" id: go - name: Check out code into the Go module directory diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 10337c6..7217dd9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,10 +12,10 @@ jobs: name: Test runs-on: ubuntu-18.04 steps: - - name: Set up Go 1.17 + - name: Set up Go 1.18 uses: actions/setup-go@v2 with: - go-version: "1.17" + go-version: "1.18" id: go - name: Check out code into the Go module directory diff --git a/.travis.yml b/.travis.yml index 546dc3e..6f3891e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: go go: - - 1.17.x + - 1.18.x node_js: "12.16.3" git: depth: 1 diff --git a/Dockerfile b/Dockerfile index c657018..d3d5caa 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,42 +1,47 @@ -FROM golang:1.17-alpine as cloudreve_builder +# the frontend builder +# cloudreve need node.js 16* to build frontend, +# separate build step and custom image tag will resolve this +FROM node:16-alpine as cloudreve_frontend_builder +RUN apk update \ + && apk add --no-cache wget curl git yarn zip bash \ + && git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git /cloudreve_frontend + +# build frontend assets using build script, make sure all the steps just follow the regular release +WORKDIR /cloudreve_frontend +ENV GENERATE_SOURCEMAP false +RUN chmod +x ./build.sh && ./build.sh -a + + +# the backend builder +# cloudreve backend needs golang 1.18* to build +FROM golang:1.18-alpine as cloudreve_backend_builder # install dependencies and build tools -RUN apk update && apk add --no-cache wget curl git yarn build-base gcc abuild binutils binutils-doc gcc-doc zip +RUN apk update \ + # install dependencies and build tools + && apk add --no-cache wget curl git build-base gcc abuild binutils binutils-doc gcc-doc zip bash \ + && git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git /cloudreve_backend -WORKDIR /cloudreve_builder -RUN git clone --recurse-submodules https://github.com/cloudreve/Cloudreve.git - -# build frontend -WORKDIR /cloudreve_builder/Cloudreve/assets -ENV GENERATE_SOURCEMAP false - -RUN yarn install --network-timeout 1000000 -RUN yarn run build - -# build backend -WORKDIR /cloudreve_builder/Cloudreve -RUN zip -r - assets/build >assets.zip -RUN tag_name=$(git describe --tags) \ - && export COMMIT_SHA=$(git rev-parse --short HEAD) \ - && go build -a -o cloudreve -ldflags " -X 'github.com/HFO4/cloudreve/pkg/conf.BackendVersion=$tag_name' -X 'github.com/HFO4/cloudreve/pkg/conf.LastCommit=$COMMIT_SHA'" +WORKDIR /cloudreve_backend +COPY --from=cloudreve_frontend_builder /cloudreve_frontend/assets.zip ./ +RUN chmod +x ./build.sh && ./build.sh -c -# build final image +# TODO: merge the frontend build and backend build into a single one image +# the final published image FROM alpine:latest WORKDIR /cloudreve +COPY --from=cloudreve_backend_builder /cloudreve_backend/cloudreve ./cloudreve -RUN apk update && apk add --no-cache tzdata - -# we using the `Asia/Shanghai` timezone by default, you can do modification at your will -RUN cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ - && echo "Asia/Shanghai" > /etc/timezone - -COPY --from=cloudreve_builder /cloudreve_builder/Cloudreve/cloudreve ./ - -# prepare permissions and aria2 dir -RUN chmod +x ./cloudreve && mkdir -p /data/aria2 && chmod -R 766 /data/aria2 +RUN apk update \ + && apk add --no-cache tzdata \ + && cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ + && echo "Asia/Shanghai" > /etc/timezone \ + && chmod +x ./cloudreve \ + && mkdir -p /data/aria2 \ + && chmod -R 766 /data/aria2 EXPOSE 5212 VOLUME ["/cloudreve/uploads", "/cloudreve/avatar", "/data"] diff --git a/README.md b/README.md index b951750..d23c9ec 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ chmod +x ./cloudreve ## :gear: 构建 -自行构建前需要拥有 `Go >= 1.17`、`node.js`、`yarn`、`zip` 等必要依赖。 +自行构建前需要拥有 `Go >= 1.18`、`node.js`、`yarn`、`zip` 等必要依赖。 #### 克隆代码 diff --git a/assets b/assets index 02d9320..2bf915a 160000 --- a/assets +++ b/assets @@ -1 +1 @@ -Subproject commit 02d93206cc5b943c34b5f5ac86c23dd96f5ef603 +Subproject commit 2bf915a33d58fc78c9c13ffc64685219c28a4732 diff --git a/assets.zip b/assets.zip index e69de29..15cb0ec 100644 Binary files a/assets.zip and b/assets.zip differ diff --git a/bootstrap/embed.go b/bootstrap/embed.go new file mode 100644 index 0000000..71f7567 --- /dev/null +++ b/bootstrap/embed.go @@ -0,0 +1,432 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package embed provides access to files embedded in the running Go program. +// +// Go source files that import "embed" can use the //go:embed directive +// to initialize a variable of type string, []byte, or FS with the contents of +// files read from the package directory or subdirectories at compile time. +// +// For example, here are three ways to embed a file named hello.txt +// and then print its contents at run time. +// +// Embedding one file into a string: +// +// import _ "embed" +// +// //go:embed hello.txt +// var s string +// print(s) +// +// Embedding one file into a slice of bytes: +// +// import _ "embed" +// +// //go:embed hello.txt +// var b []byte +// print(string(b)) +// +// Embedded one or more files into a file system: +// +// import "embed" +// +// //go:embed hello.txt +// var f embed.FS +// data, _ := f.ReadFile("hello.txt") +// print(string(data)) +// +// # Directives +// +// A //go:embed directive above a variable declaration specifies which files to embed, +// using one or more path.Match patterns. +// +// The directive must immediately precede a line containing the declaration of a single variable. +// Only blank lines and ‘//’ line comments are permitted between the directive and the declaration. +// +// The type of the variable must be a string type, or a slice of a byte type, +// or FS (or an alias of FS). +// +// For example: +// +// package server +// +// import "embed" +// +// // content holds our static web server content. +// //go:embed image/* template/* +// //go:embed html/index.html +// var content embed.FS +// +// The Go build system will recognize the directives and arrange for the declared variable +// (in the example above, content) to be populated with the matching files from the file system. +// +// The //go:embed directive accepts multiple space-separated patterns for +// brevity, but it can also be repeated, to avoid very long lines when there are +// many patterns. The patterns are interpreted relative to the package directory +// containing the source file. The path separator is a forward slash, even on +// Windows systems. Patterns may not contain ‘.’ or ‘..’ or empty path elements, +// nor may they begin or end with a slash. To match everything in the current +// directory, use ‘*’ instead of ‘.’. To allow for naming files with spaces in +// their names, patterns can be written as Go double-quoted or back-quoted +// string literals. +// +// If a pattern names a directory, all files in the subtree rooted at that directory are +// embedded (recursively), except that files with names beginning with ‘.’ or ‘_’ +// are excluded. So the variable in the above example is almost equivalent to: +// +// // content is our static web server content. +// //go:embed image template html/index.html +// var content embed.FS +// +// The difference is that ‘image/*’ embeds ‘image/.tempfile’ while ‘image’ does not. +// Neither embeds ‘image/dir/.tempfile’. +// +// If a pattern begins with the prefix ‘all:’, then the rule for walking directories is changed +// to include those files beginning with ‘.’ or ‘_’. For example, ‘all:image’ embeds +// both ‘image/.tempfile’ and ‘image/dir/.tempfile’. +// +// The //go:embed directive can be used with both exported and unexported variables, +// depending on whether the package wants to make the data available to other packages. +// It can only be used with variables at package scope, not with local variables. +// +// Patterns must not match files outside the package's module, such as ‘.git/*’ or symbolic links. +// Patterns must not match files whose names include the special punctuation characters " * < > ? ` ' | / \ and :. +// Matches for empty directories are ignored. After that, each pattern in a //go:embed line +// must match at least one file or non-empty directory. +// +// If any patterns are invalid or have invalid matches, the build will fail. +// +// # Strings and Bytes +// +// The //go:embed line for a variable of type string or []byte can have only a single pattern, +// and that pattern can match only a single file. The string or []byte is initialized with +// the contents of that file. +// +// The //go:embed directive requires importing "embed", even when using a string or []byte. +// In source files that don't refer to embed.FS, use a blank import (import _ "embed"). +// +// # File Systems +// +// For embedding a single file, a variable of type string or []byte is often best. +// The FS type enables embedding a tree of files, such as a directory of static +// web server content, as in the example above. +// +// FS implements the io/fs package's FS interface, so it can be used with any package that +// understands file systems, including net/http, text/template, and html/template. +// +// For example, given the content variable in the example above, we can write: +// +// http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(content)))) +// +// template.ParseFS(content, "*.tmpl") +// +// # Tools +// +// To support tools that analyze Go packages, the patterns found in //go:embed lines +// are available in “go list” output. See the EmbedPatterns, TestEmbedPatterns, +// and XTestEmbedPatterns fields in the “go help list” output. +package bootstrap + +import ( + "errors" + "io" + "io/fs" + "time" +) + +// An FS is a read-only collection of files, usually initialized with a //go:embed directive. +// When declared without a //go:embed directive, an FS is an empty file system. +// +// An FS is a read-only value, so it is safe to use from multiple goroutines +// simultaneously and also safe to assign values of type FS to each other. +// +// FS implements fs.FS, so it can be used with any package that understands +// file system interfaces, including net/http, text/template, and html/template. +// +// See the package documentation for more details about initializing an FS. +type FS struct { + // The compiler knows the layout of this struct. + // See cmd/compile/internal/staticdata's WriteEmbed. + // + // The files list is sorted by name but not by simple string comparison. + // Instead, each file's name takes the form "dir/elem" or "dir/elem/". + // The optional trailing slash indicates that the file is itself a directory. + // The files list is sorted first by dir (if dir is missing, it is taken to be ".") + // and then by base, so this list of files: + // + // p + // q/ + // q/r + // q/s/ + // q/s/t + // q/s/u + // q/v + // w + // + // is actually sorted as: + // + // p # dir=. elem=p + // q/ # dir=. elem=q + // w/ # dir=. elem=w + // q/r # dir=q elem=r + // q/s/ # dir=q elem=s + // q/v # dir=q elem=v + // q/s/t # dir=q/s elem=t + // q/s/u # dir=q/s elem=u + // + // This order brings directory contents together in contiguous sections + // of the list, allowing a directory read to use binary search to find + // the relevant sequence of entries. + files *[]file +} + +// split splits the name into dir and elem as described in the +// comment in the FS struct above. isDir reports whether the +// final trailing slash was present, indicating that name is a directory. +func split(name string) (dir, elem string, isDir bool) { + if name[len(name)-1] == '/' { + isDir = true + name = name[:len(name)-1] + } + i := len(name) - 1 + for i >= 0 && name[i] != '/' { + i-- + } + if i < 0 { + return ".", name, isDir + } + return name[:i], name[i+1:], isDir +} + +// trimSlash trims a trailing slash from name, if present, +// returning the possibly shortened name. +func trimSlash(name string) string { + if len(name) > 0 && name[len(name)-1] == '/' { + return name[:len(name)-1] + } + return name +} + +var ( + _ fs.ReadDirFS = FS{} + _ fs.ReadFileFS = FS{} +) + +// A file is a single file in the FS. +// It implements fs.FileInfo and fs.DirEntry. +type file struct { + // The compiler knows the layout of this struct. + // See cmd/compile/internal/staticdata's WriteEmbed. + name string + data string + hash [16]byte // truncated SHA256 hash +} + +var ( + _ fs.FileInfo = (*file)(nil) + _ fs.DirEntry = (*file)(nil) +) + +func (f *file) Name() string { _, elem, _ := split(f.name); return elem } +func (f *file) Size() int64 { return int64(len(f.data)) } +func (f *file) ModTime() time.Time { return time.Time{} } +func (f *file) IsDir() bool { _, _, isDir := split(f.name); return isDir } +func (f *file) Sys() any { return nil } +func (f *file) Type() fs.FileMode { return f.Mode().Type() } +func (f *file) Info() (fs.FileInfo, error) { return f, nil } + +func (f *file) Mode() fs.FileMode { + if f.IsDir() { + return fs.ModeDir | 0555 + } + return 0444 +} + +// dotFile is a file for the root directory, +// which is omitted from the files list in a FS. +var dotFile = &file{name: "./"} + +// lookup returns the named file, or nil if it is not present. +func (f FS) lookup(name string) *file { + if !fs.ValidPath(name) { + // The compiler should never emit a file with an invalid name, + // so this check is not strictly necessary (if name is invalid, + // we shouldn't find a match below), but it's a good backstop anyway. + return nil + } + if name == "." { + return dotFile + } + if f.files == nil { + return nil + } + + // Binary search to find where name would be in the list, + // and then check if name is at that position. + dir, elem, _ := split(name) + files := *f.files + i := sortSearch(len(files), func(i int) bool { + idir, ielem, _ := split(files[i].name) + return idir > dir || idir == dir && ielem >= elem + }) + if i < len(files) && trimSlash(files[i].name) == name { + return &files[i] + } + return nil +} + +// readDir returns the list of files corresponding to the directory dir. +func (f FS) readDir(dir string) []file { + if f.files == nil { + return nil + } + // Binary search to find where dir starts and ends in the list + // and then return that slice of the list. + files := *f.files + i := sortSearch(len(files), func(i int) bool { + idir, _, _ := split(files[i].name) + return idir >= dir + }) + j := sortSearch(len(files), func(j int) bool { + jdir, _, _ := split(files[j].name) + return jdir > dir + }) + return files[i:j] +} + +// Open opens the named file for reading and returns it as an fs.File. +// +// The returned file implements io.Seeker when the file is not a directory. +func (f FS) Open(name string) (fs.File, error) { + file := f.lookup(name) + if file == nil { + return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist} + } + if file.IsDir() { + return &openDir{file, f.readDir(name), 0}, nil + } + return &openFile{file, 0}, nil +} + +// ReadDir reads and returns the entire named directory. +func (f FS) ReadDir(name string) ([]fs.DirEntry, error) { + file, err := f.Open(name) + if err != nil { + return nil, err + } + dir, ok := file.(*openDir) + if !ok { + return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("not a directory")} + } + list := make([]fs.DirEntry, len(dir.files)) + for i := range list { + list[i] = &dir.files[i] + } + return list, nil +} + +// ReadFile reads and returns the content of the named file. +func (f FS) ReadFile(name string) ([]byte, error) { + file, err := f.Open(name) + if err != nil { + return nil, err + } + ofile, ok := file.(*openFile) + if !ok { + return nil, &fs.PathError{Op: "read", Path: name, Err: errors.New("is a directory")} + } + return []byte(ofile.f.data), nil +} + +// An openFile is a regular file open for reading. +type openFile struct { + f *file // the file itself + offset int64 // current read offset +} + +var ( + _ io.Seeker = (*openFile)(nil) +) + +func (f *openFile) Close() error { return nil } +func (f *openFile) Stat() (fs.FileInfo, error) { return f.f, nil } + +func (f *openFile) Read(b []byte) (int, error) { + if f.offset >= int64(len(f.f.data)) { + return 0, io.EOF + } + if f.offset < 0 { + return 0, &fs.PathError{Op: "read", Path: f.f.name, Err: fs.ErrInvalid} + } + n := copy(b, f.f.data[f.offset:]) + f.offset += int64(n) + return n, nil +} + +func (f *openFile) Seek(offset int64, whence int) (int64, error) { + switch whence { + case 0: + // offset += 0 + case 1: + offset += f.offset + case 2: + offset += int64(len(f.f.data)) + } + if offset < 0 || offset > int64(len(f.f.data)) { + return 0, &fs.PathError{Op: "seek", Path: f.f.name, Err: fs.ErrInvalid} + } + f.offset = offset + return offset, nil +} + +// An openDir is a directory open for reading. +type openDir struct { + f *file // the directory file itself + files []file // the directory contents + offset int // the read offset, an index into the files slice +} + +func (d *openDir) Close() error { return nil } +func (d *openDir) Stat() (fs.FileInfo, error) { return d.f, nil } + +func (d *openDir) Read([]byte) (int, error) { + return 0, &fs.PathError{Op: "read", Path: d.f.name, Err: errors.New("is a directory")} +} + +func (d *openDir) ReadDir(count int) ([]fs.DirEntry, error) { + n := len(d.files) - d.offset + if n == 0 { + if count <= 0 { + return nil, nil + } + return nil, io.EOF + } + if count > 0 && n > count { + n = count + } + list := make([]fs.DirEntry, n) + for i := range list { + list[i] = &d.files[d.offset+i] + } + d.offset += n + return list, nil +} + +// sortSearch is like sort.Search, avoiding an import. +func sortSearch(n int, f func(int) bool) int { + // Define f(-1) == false and f(n) == true. + // Invariant: f(i-1) == false, f(j) == true. + i, j := 0, n + for i < j { + h := int(uint(i+j) >> 1) // avoid overflow when computing h + // i ≤ h < j + if !f(h) { + i = h + 1 // preserves f(i-1) == false + } else { + j = h // preserves f(j) == true + } + } + // i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i. + return i +} diff --git a/bootstrap/fs.go b/bootstrap/fs.go new file mode 100644 index 0000000..a82396c --- /dev/null +++ b/bootstrap/fs.go @@ -0,0 +1,75 @@ +package bootstrap + +import ( + "archive/zip" + "crypto/sha256" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "github.com/pkg/errors" + "io" + "io/fs" + "sort" + "strings" +) + +func NewFS(zipContent string) fs.FS { + zipReader, err := zip.NewReader(strings.NewReader(zipContent), int64(len(zipContent))) + if err != nil { + util.Log().Panic("Static resource is not a valid zip file: %s", err) + } + + var files []file + err = fs.WalkDir(zipReader, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return errors.Errorf("无法获取[%s]的信息, %s, 跳过...", path, err) + } + + if path == "." { + return nil + } + + var f file + if d.IsDir() { + f.name = path + "/" + } else { + f.name = path + + rc, err := zipReader.Open(path) + if err != nil { + return errors.Errorf("无法打开文件[%s], %s, 跳过...", path, err) + } + defer rc.Close() + + data, err := io.ReadAll(rc) + if err != nil { + return errors.Errorf("无法读取文件[%s], %s, 跳过...", path, err) + } + + f.data = string(data) + + hash := sha256.Sum256(data) + for i := range f.hash { + f.hash[i] = ^hash[i] + } + } + files = append(files, f) + return nil + }) + if err != nil { + util.Log().Panic("初始化静态资源失败: %s", err) + } + + sort.Slice(files, func(i, j int) bool { + fi, fj := files[i], files[j] + di, ei, _ := split(fi.name) + dj, ej, _ := split(fj.name) + + if di != dj { + return di < dj + } + return ei < ej + }) + + var embedFS FS + embedFS.files = &files + return embedFS +} diff --git a/build.sh b/build.sh index 8acac1a..afd4b4b 100755 --- a/build.sh +++ b/build.sh @@ -32,11 +32,15 @@ buildAssets() { yarn run build cd build cd $REPO + + # please keep in mind that if this final output binary `assets.zip` name changed, please go and update the `Dockerfile` as well zip -r - assets/build >assets.zip } buildBinary() { cd $REPO + + # same as assets, if this final output binary `cloudreve` name changed, please go and update the `Dockerfile` go build -a -o cloudreve -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'" } diff --git a/go.mod b/go.mod index 4b23b29..068d0b3 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/cloudreve/Cloudreve/v3 -go 1.17 +go 1.18 require ( github.com/DATA-DOG/go-sqlmock v1.3.3 @@ -100,6 +100,7 @@ require ( github.com/mattn/go-colorable v0.1.4 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-runewidth v0.0.12 // indirect + github.com/mattn/go-sqlite3 v1.14.7 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/mitchellh/mapstructure v1.1.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect diff --git a/main.go b/main.go index b691f6e..d309b8f 100644 --- a/main.go +++ b/main.go @@ -4,13 +4,11 @@ import ( "context" _ "embed" "flag" - "io" "io/fs" "net" "net/http" "os" "os/signal" - "strings" "syscall" "time" @@ -19,8 +17,6 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/routers" - - "github.com/mholt/archiver/v4" ) var ( @@ -35,15 +31,12 @@ var staticZip string var staticFS fs.FS func init() { - flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "配置文件路径") - flag.BoolVar(&isEject, "eject", false, "导出内置静态资源") - flag.StringVar(&scriptName, "database-script", "", "运行内置数据库助手脚本") + flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "Path to the config file.") + flag.BoolVar(&isEject, "eject", false, "Eject all embedded static files.") + flag.StringVar(&scriptName, "database-script", "", "Name of database util script.") flag.Parse() - staticFS = archiver.ArchiveFS{ - Stream: io.NewSectionReader(strings.NewReader(staticZip), 0, int64(len(staticZip))), - Format: archiver.Zip{}, - } + staticFS = bootstrap.NewFS(staticZip) bootstrap.Init(confPath, staticFS) } @@ -71,7 +64,7 @@ func main() { signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) go func() { sig := <-sigChan - util.Log().Info("收到信号 %s,开始关闭 server", sig) + util.Log().Info("Signal %s received, shutting down server...", sig) ctx := context.Background() if conf.SystemConfig.GracePeriod != 0 { var cancel context.CancelFunc @@ -81,16 +74,16 @@ func main() { err := server.Shutdown(ctx) if err != nil { - util.Log().Error("关闭 server 错误, %s", err) + util.Log().Error("Failed to shutdown server: %s", err) } }() // 如果启用了SSL if conf.SSLConfig.CertPath != "" { - util.Log().Info("开始监听 %s", conf.SSLConfig.Listen) + util.Log().Info("Listening to %q", conf.SSLConfig.Listen) server.Addr = conf.SSLConfig.Listen if err := server.ListenAndServeTLS(conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil { - util.Log().Error("无法监听[%s],%s", conf.SSLConfig.Listen, err) + util.Log().Error("Failed to listen to %q: %s", conf.SSLConfig.Listen, err) return } } @@ -100,23 +93,23 @@ func main() { // delete socket file before listening if _, err := os.Stat(conf.UnixConfig.Listen); err == nil { if err = os.Remove(conf.UnixConfig.Listen); err != nil { - util.Log().Error("删除 socket 文件错误, %s", err) + util.Log().Error("Failed to delete socket file: %s", err) return } } api.TrustedPlatform = conf.UnixConfig.ProxyHeader - util.Log().Info("开始监听 %s", conf.UnixConfig.Listen) + util.Log().Info("Listening to %q", conf.UnixConfig.Listen) if err := RunUnix(server); err != nil { - util.Log().Error("无法监听[%s],%s", conf.UnixConfig.Listen, err) + util.Log().Error("Failed to listen to %q: %s", conf.UnixConfig.Listen, err) } return } - util.Log().Info("开始监听 %s", conf.SystemConfig.Listen) + util.Log().Info("Listening to %q", conf.SystemConfig.Listen) server.Addr = conf.SystemConfig.Listen if err := server.ListenAndServe(); err != nil { - util.Log().Error("无法监听[%s],%s", conf.SystemConfig.Listen, err) + util.Log().Error("Failed to listen to %q: %s", conf.SystemConfig.Listen, err) } } @@ -125,8 +118,21 @@ func RunUnix(server *http.Server) error { if err != nil { return err } + defer listener.Close() defer os.Remove(conf.UnixConfig.Listen) + if conf.UnixConfig.Perm > 0 { + err = os.Chmod(conf.UnixConfig.Listen, os.FileMode(conf.UnixConfig.Perm)) + if err != nil { + util.Log().Warning( + "Failed to set permission to %q for socket file %q: %s", + conf.UnixConfig.Perm, + conf.UnixConfig.Listen, + err, + ) + } + } + return server.Serve(listener) } diff --git a/middleware/common.go b/middleware/common.go index 812dccb..cfc6747 100644 --- a/middleware/common.go +++ b/middleware/common.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/hashid" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" @@ -45,3 +46,17 @@ func CacheControl() gin.HandlerFunc { c.Header("Cache-Control", "private, no-cache") } } + +func Sandbox() gin.HandlerFunc { + return func(c *gin.Context) { + c.Header("Content-Security-Policy", "sandbox") + } +} + +// StaticResourceCache 使用静态资源缓存策略 +func StaticResourceCache() gin.HandlerFunc { + return func(c *gin.Context) { + c.Header("Cache-Control", fmt.Sprintf("public, max-age=%d", model.GetIntSetting("public_resource_maxage", 86400))) + + } +} diff --git a/middleware/common_test.go b/middleware/common_test.go index 000687b..1ab839a 100644 --- a/middleware/common_test.go +++ b/middleware/common_test.go @@ -85,3 +85,21 @@ func TestCacheControl(t *testing.T) { TestFunc(c) a.Contains(c.Writer.Header().Get("Cache-Control"), "no-cache") } + +func TestSandbox(t *testing.T) { + a := assert.New(t) + TestFunc := Sandbox() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + TestFunc(c) + a.Contains(c.Writer.Header().Get("Content-Security-Policy"), "sandbox") +} + +func TestStaticResourceCache(t *testing.T) { + a := assert.New(t) + TestFunc := StaticResourceCache() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + TestFunc(c) + a.Contains(c.Writer.Header().Get("Cache-Control"), "public, max-age") +} diff --git a/middleware/file.go b/middleware/file.go new file mode 100644 index 0000000..995637e --- /dev/null +++ b/middleware/file.go @@ -0,0 +1,30 @@ +package middleware + +import ( + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/gin-gonic/gin" +) + +// ValidateSourceLink validates if the perm source link is a valid redirect link +func ValidateSourceLink() gin.HandlerFunc { + return func(c *gin.Context) { + linkID, ok := c.Get("object_id") + if !ok { + c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil)) + c.Abort() + return + } + + sourceLink, err := model.GetSourceLinkByID(linkID) + if err != nil || sourceLink.File.ID == 0 || sourceLink.File.Name != c.Param("name") { + c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil)) + c.Abort() + return + } + + sourceLink.Downloaded() + c.Set("source_link", sourceLink) + c.Next() + } +} diff --git a/middleware/file_test.go b/middleware/file_test.go new file mode 100644 index 0000000..5ca4014 --- /dev/null +++ b/middleware/file_test.go @@ -0,0 +1,57 @@ +package middleware + +import ( + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "net/http/httptest" + "testing" +) + +func TestValidateSourceLink(t *testing.T) { + a := assert.New(t) + rec := httptest.NewRecorder() + testFunc := ValidateSourceLink() + + // ID 不存在 + { + c, _ := gin.CreateTestContext(rec) + testFunc(c) + a.True(c.IsAborted()) + } + + // SourceLink 不存在 + { + c, _ := gin.CreateTestContext(rec) + c.Set("object_id", 1) + mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"})) + testFunc(c) + a.True(c.IsAborted()) + a.NoError(mock.ExpectationsWereMet()) + } + + // 原文件不存在 + { + c, _ := gin.CreateTestContext(rec) + c.Set("object_id", 1) + mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(0).WillReturnRows(sqlmock.NewRows([]string{"id"})) + testFunc(c) + a.True(c.IsAborted()) + a.NoError(mock.ExpectationsWereMet()) + } + + // 成功 + { + c, _ := gin.CreateTestContext(rec) + c.Set("object_id", 1) + mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2)) + mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)source_links").WillReturnResult(sqlmock.NewResult(1, 1)) + testFunc(c) + a.False(c.IsAborted()) + a.NoError(mock.ExpectationsWereMet()) + } + +} diff --git a/middleware/frontend.go b/middleware/frontend.go index 95e4609..f07d9b6 100644 --- a/middleware/frontend.go +++ b/middleware/frontend.go @@ -39,7 +39,11 @@ func FrontendFileHandler() gin.HandlerFunc { path := c.Request.URL.Path // API 跳过 - if strings.HasPrefix(path, "/api") || strings.HasPrefix(path, "/custom") || strings.HasPrefix(path, "/dav") || path == "/manifest.json" { + if strings.HasPrefix(path, "/api") || + strings.HasPrefix(path, "/custom") || + strings.HasPrefix(path, "/dav") || + strings.HasPrefix(path, "/f") || + path == "/manifest.json" { c.Next() return } diff --git a/middleware/session.go b/middleware/session.go index 18f0d15..77825ae 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -46,7 +46,7 @@ func Session(secret string) gin.HandlerFunc { // Also set Secure: true if using SSL, you should though Store.Options(sessions.Options{ HttpOnly: true, - MaxAge: 7 * 86400, + MaxAge: 60 * 86400, Path: "/", SameSite: sameSiteMode, Secure: conf.CORSConfig.Secure, diff --git a/models/defaults.go b/models/defaults.go index a37ecac..3090016 100644 --- a/models/defaults.go +++ b/models/defaults.go @@ -113,4 +113,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti {Name: "pwa_theme_color", Value: "#000000", Type: "pwa"}, {Name: "pwa_background_color", Value: "#ffffff", Type: "pwa"}, {Name: "office_preview_service", Value: "https://view.officeapps.live.com/op/view.aspx?src={$src}", Type: "preview"}, + {Name: "show_app_promotion", Value: "1", Type: "mobile"}, + {Name: "public_resource_maxage", Value: "86400", Type: "timeout"}, } diff --git a/models/download.go b/models/download.go index 87f4533..dce50f3 100644 --- a/models/download.go +++ b/models/download.go @@ -32,6 +32,7 @@ type Download struct { // 数据库忽略字段 StatusInfo rpc.StatusInfo `gorm:"-"` Task *Task `gorm:"-"` + NodeName string `gorm:"-"` } // AfterFind 找到下载任务后的钩子,处理Status结构 diff --git a/models/file.go b/models/file.go index 9cf9ab8..161bbbb 100644 --- a/models/file.go +++ b/models/file.go @@ -4,6 +4,7 @@ import ( "encoding/gob" "encoding/json" "errors" + "fmt" "path" "time" @@ -191,14 +192,15 @@ func RemoveFilesWithSoftLinks(files []File) ([]File, error) { } // 查询软链接的文件 - var filesWithSoftLinks []File - tx := DB - for _, value := range files { - tx = tx.Or("source_name = ? and policy_id = ? and id != ?", value.SourceName, value.PolicyID, value.ID) - } - result := tx.Find(&filesWithSoftLinks) - if result.Error != nil { - return nil, result.Error + filesWithSoftLinks := make([]File, 0) + for _, file := range files { + var softLinkFile File + res := DB. + Where("source_name = ? and policy_id = ? and id != ?", file.SourceName, file.PolicyID, file.ID). + First(&softLinkFile) + if res.Error == nil { + filesWithSoftLinks = append(filesWithSoftLinks, softLinkFile) + } } // 过滤具有软连接的文件 @@ -338,6 +340,25 @@ func (file *File) CanCopy() bool { return file.UploadSessionID == nil } +// CreateOrGetSourceLink creates a SourceLink model. If the given model exists, the existing +// model will be returned. +func (file *File) CreateOrGetSourceLink() (*SourceLink, error) { + res := &SourceLink{} + err := DB.Set("gorm:auto_preload", true).Where("file_id = ?", file.ID).Find(&res).Error + if err == nil && res.ID > 0 { + return res, nil + } + + res.FileID = file.ID + res.Name = file.Name + if err := DB.Save(res).Error; err != nil { + return nil, fmt.Errorf("failed to insert SourceLink: %w", err) + } + + res.File = *file + return res, nil +} + /* 实现 webdav.FileInfo 接口 */ diff --git a/models/file_test.go b/models/file_test.go index 9563521..5f6826c 100644 --- a/models/file_test.go +++ b/models/file_test.go @@ -285,30 +285,34 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) { }, } + // 传入空文件列表 + { + file, err := RemoveFilesWithSoftLinks([]File{}) + asserts.NoError(err) + asserts.Empty(file) + } + // 全都没有 { mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1, "2.txt", 24, 2). + WithArgs("1.txt", 23, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs("2.txt", 24, 2). WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) file, err := RemoveFilesWithSoftLinks(files) asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(err) asserts.Equal(files, file) } - // 查询出错 - { - mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1, "2.txt", 24, 2). - WillReturnError(errors.New("error")) - file, err := RemoveFilesWithSoftLinks(files) - asserts.NoError(mock.ExpectationsWereMet()) - asserts.Error(err) - asserts.Nil(file) - } + // 第二个是软链 { mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1, "2.txt", 24, 2). + WithArgs("1.txt", 23, 1). + WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs("2.txt", 24, 2). WillReturnRows( sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). AddRow(3, 24, "2.txt"), @@ -318,14 +322,18 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) { asserts.NoError(err) asserts.Equal(files[:1], file) } + // 第一个是软链 { mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1, "2.txt", 24, 2). + WithArgs("1.txt", 23, 1). WillReturnRows( sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). AddRow(3, 23, "1.txt"), ) + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs("2.txt", 24, 2). + WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) file, err := RemoveFilesWithSoftLinks(files) asserts.NoError(mock.ExpectationsWereMet()) asserts.NoError(err) @@ -334,11 +342,16 @@ func TestRemoveFilesWithSoftLinks(t *testing.T) { // 全部是软链 { mock.ExpectQuery("SELECT(.+)files(.+)"). - WithArgs("1.txt", 23, 1, "2.txt", 24, 2). + WithArgs("1.txt", 23, 1). WillReturnRows( sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). - AddRow(3, 24, "2.txt"). - AddRow(4, 23, "1.txt"), + AddRow(3, 23, "1.txt"), + ) + mock.ExpectQuery("SELECT(.+)files(.+)"). + WithArgs("2.txt", 24, 2). + WillReturnRows( + sqlmock.NewRows([]string{"id", "policy_id", "source_name"}). + AddRow(3, 24, "2.txt"), ) file, err := RemoveFilesWithSoftLinks(files) asserts.NoError(mock.ExpectationsWereMet()) @@ -598,3 +611,44 @@ func TestGetFilesByKeywords(t *testing.T) { asserts.Len(res, 1) } } + +func TestFile_CreateOrGetSourceLink(t *testing.T) { + a := assert.New(t) + file := &File{} + file.ID = 1 + + // 已存在,返回老的 SourceLink + { + mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) + res, err := file.CreateOrGetSourceLink() + a.NoError(err) + a.EqualValues(2, res.ID) + a.NoError(mock.ExpectationsWereMet()) + } + + // 不存在,插入失败 + { + expectedErr := errors.New("error") + mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"})) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnError(expectedErr) + mock.ExpectRollback() + res, err := file.CreateOrGetSourceLink() + a.Nil(res) + a.ErrorIs(err, expectedErr) + a.NoError(mock.ExpectationsWereMet()) + } + + // 成功 + { + mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id"})) + mock.ExpectBegin() + mock.ExpectExec("INSERT(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(2, 1)) + mock.ExpectCommit() + res, err := file.CreateOrGetSourceLink() + a.NoError(err) + a.EqualValues(2, res.ID) + a.EqualValues(file.ID, res.File.ID) + a.NoError(mock.ExpectationsWereMet()) + } +} diff --git a/models/folder.go b/models/folder.go index 16a0a34..ebc1069 100644 --- a/models/folder.go +++ b/models/folder.go @@ -224,7 +224,7 @@ func (folder *Folder) CopyFolderTo(folderID uint, dstFolder *Folder) (size uint6 } else if IDCache, ok := newIDCache[*folder.ParentID]; ok { newID = IDCache } else { - util.Log().Warning("Failed to get parent folder %q", folder.ParentID) + util.Log().Warning("Failed to get parent folder %q", *folder.ParentID) return size, errors.New("Failed to get parent folder") } diff --git a/models/group.go b/models/group.go index 78f7bfd..490fc38 100644 --- a/models/group.go +++ b/models/group.go @@ -23,16 +23,17 @@ type Group struct { // GroupOption 用户组其他配置 type GroupOption struct { - ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载 - ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩 - CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小 - DecompressSize uint64 `json:"decompress_size,omitempty"` - OneTimeDownload bool `json:"one_time_download,omitempty"` - ShareDownload bool `json:"share_download,omitempty"` - Aria2 bool `json:"aria2,omitempty"` // 离线下载 - Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置 - SourceBatchSize int `json:"source_batch,omitempty"` - Aria2BatchSize int `json:"aria2_batch,omitempty"` + ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载 + ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩 + CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小 + DecompressSize uint64 `json:"decompress_size,omitempty"` + OneTimeDownload bool `json:"one_time_download,omitempty"` + ShareDownload bool `json:"share_download,omitempty"` + Aria2 bool `json:"aria2,omitempty"` // 离线下载 + Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置 + SourceBatchSize int `json:"source_batch,omitempty"` + RedirectedSource bool `json:"redirected_source,omitempty"` + Aria2BatchSize int `json:"aria2_batch,omitempty"` } // GetGroupByID 用ID获取用户组 @@ -66,7 +67,7 @@ func (group *Group) BeforeSave() (err error) { return err } -//SerializePolicyList 将序列后的可选策略列表、配置写入数据库字段 +// SerializePolicyList 将序列后的可选策略列表、配置写入数据库字段 // TODO 完善测试 func (group *Group) SerializePolicyList() (err error) { policies, err := json.Marshal(&group.PolicyList) diff --git a/models/migration.go b/models/migration.go index 63ae60d..17a08ce 100644 --- a/models/migration.go +++ b/models/migration.go @@ -19,7 +19,7 @@ func needMigration() bool { return DB.Where("name = ?", "db_version_"+conf.RequiredDBVersion).First(&setting).Error != nil } -//执行数据迁移 +// 执行数据迁移 func migration() { // 确认是否需要执行迁移 if !needMigration() { @@ -41,7 +41,7 @@ func migration() { } DB.AutoMigrate(&User{}, &Setting{}, &Group{}, &Policy{}, &Folder{}, &File{}, &Share{}, - &Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}) + &Task{}, &Download{}, &Tag{}, &Webdav{}, &Node{}, &SourceLink{}) // 创建初始存储策略 addDefaultPolicy() @@ -104,12 +104,13 @@ func addDefaultGroups() { ShareEnabled: true, WebDAVEnabled: true, OptionsSerialized: GroupOption{ - ArchiveDownload: true, - ArchiveTask: true, - ShareDownload: true, - Aria2: true, - SourceBatchSize: 1000, - Aria2BatchSize: 50, + ArchiveDownload: true, + ArchiveTask: true, + ShareDownload: true, + Aria2: true, + SourceBatchSize: 1000, + Aria2BatchSize: 50, + RedirectedSource: true, }, } if err := DB.Create(&defaultAdminGroup).Error; err != nil { @@ -128,9 +129,10 @@ func addDefaultGroups() { ShareEnabled: true, WebDAVEnabled: true, OptionsSerialized: GroupOption{ - ShareDownload: true, - SourceBatchSize: 10, - Aria2BatchSize: 1, + ShareDownload: true, + SourceBatchSize: 10, + Aria2BatchSize: 1, + RedirectedSource: true, }, } if err := DB.Create(&defaultAdminGroup).Error; err != nil { diff --git a/models/scripts/invoker/invoker.go b/models/scripts/invoker/invoker.go index e9a0b05..b55b1e9 100644 --- a/models/scripts/invoker/invoker.go +++ b/models/scripts/invoker/invoker.go @@ -15,12 +15,12 @@ var availableScripts = make(map[string]DBScript) func RunDBScript(name string, ctx context.Context) error { if script, ok := availableScripts[name]; ok { - util.Log().Info("开始执行数据库脚本 [%s]", name) + util.Log().Info("Start executing database script %q.", name) script.Run(ctx) return nil } - return fmt.Errorf("数据库脚本 [%s] 不存在", name) + return fmt.Errorf("Database script %q not exist.", name) } func Register(name string, script DBScript) { diff --git a/models/scripts/reset.go b/models/scripts/reset.go index 88ee25d..1f6bf08 100644 --- a/models/scripts/reset.go +++ b/models/scripts/reset.go @@ -14,7 +14,7 @@ func (script ResetAdminPassword) Run(ctx context.Context) { // 查找用户 user, err := model.GetUserByID(1) if err != nil { - util.Log().Panic("初始管理员用户不存在, %s", err) + util.Log().Panic("Initial admin user not exist: %s", err) } // 生成密码 @@ -23,9 +23,9 @@ func (script ResetAdminPassword) Run(ctx context.Context) { // 更改为新密码 user.SetPassword(password) if err := user.Update(map[string]interface{}{"password": user.Password}); err != nil { - util.Log().Panic("密码更改失败, %s", err) + util.Log().Panic("Failed to update password: %s", err) } c := color.New(color.FgWhite).Add(color.BgBlack).Add(color.Bold) - util.Log().Info("初始管理员密码已更改为:" + c.Sprint(password)) + util.Log().Info("Initial admin user password changed to:" + c.Sprint(password)) } diff --git a/models/scripts/storage.go b/models/scripts/storage.go index 99dcb0b..0d436b9 100644 --- a/models/scripts/storage.go +++ b/models/scripts/storage.go @@ -25,7 +25,7 @@ func (script UserStorageCalibration) Run(ctx context.Context) { model.DB.Model(&model.File{}).Where("user_id = ?", user.ID).Select("sum(size) as total").Scan(&total) // 更新用户的容量 if user.Storage != total.Total { - util.Log().Info("将用户 [%s] 的容量由 %d 校准为 %d", user.Email, + util.Log().Info("Calibrate used storage for user %q, from %d to %d.", user.Email, user.Storage, total.Total) } model.DB.Model(&user).Update("storage", total.Total) diff --git a/models/source_link.go b/models/source_link.go new file mode 100644 index 0000000..49dfea2 --- /dev/null +++ b/models/source_link.go @@ -0,0 +1,47 @@ +package model + +import ( + "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/hashid" + "github.com/jinzhu/gorm" + "net/url" +) + +// SourceLink represent a shared file source link +type SourceLink struct { + gorm.Model + FileID uint // corresponding file ID + Name string // name of the file while creating the source link, for annotation + Downloads int // 下载数 + + // 关联模型 + File File `gorm:"save_associations:false:false"` +} + +// Link gets the URL of a SourceLink +func (s *SourceLink) Link() (string, error) { + baseURL := GetSiteURL() + linkPath, err := url.Parse(fmt.Sprintf("/f/%s/%s", hashid.HashID(s.ID, hashid.SourceLinkID), s.File.Name)) + if err != nil { + return "", err + } + return baseURL.ResolveReference(linkPath).String(), nil +} + +// GetTasksByID queries source link based on ID +func GetSourceLinkByID(id interface{}) (*SourceLink, error) { + link := &SourceLink{} + result := DB.Where("id = ?", id).First(link) + files, _ := GetFilesByIDs([]uint{link.FileID}, 0) + if len(files) > 0 { + link.File = files[0] + } + + return link, result.Error +} + +// Viewed 增加访问次数 +func (s *SourceLink) Downloaded() { + s.Downloads++ + DB.Model(s).UpdateColumn("downloads", gorm.Expr("downloads + ?", 1)) +} diff --git a/models/source_link_test.go b/models/source_link_test.go new file mode 100644 index 0000000..d84dc62 --- /dev/null +++ b/models/source_link_test.go @@ -0,0 +1,52 @@ +package model + +import ( + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSourceLink_Link(t *testing.T) { + a := assert.New(t) + s := &SourceLink{} + s.ID = 1 + + // 失败 + { + s.File.Name = string([]byte{0x7f}) + res, err := s.Link() + a.Error(err) + a.Empty(res) + } + + // 成功 + { + s.File.Name = "filename" + res, err := s.Link() + a.NoError(err) + a.Contains(res, s.Name) + } +} + +func TestGetSourceLinkByID(t *testing.T) { + a := assert.New(t) + mock.ExpectQuery("SELECT(.+)source_links(.+)").WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "file_id"}).AddRow(1, 2)) + mock.ExpectQuery("SELECT(.+)files(.+)").WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) + + res, err := GetSourceLinkByID(1) + a.NoError(err) + a.NotNil(res) + a.EqualValues(2, res.File.ID) + a.NoError(mock.ExpectationsWereMet()) +} + +func TestSourceLink_Downloaded(t *testing.T) { + a := assert.New(t) + s := &SourceLink{} + s.ID = 1 + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)source_links(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + s.Downloaded() + a.NoError(mock.ExpectationsWereMet()) +} diff --git a/pkg/aria2/common/common.go b/pkg/aria2/common/common.go index fd32615..ae5e6b0 100644 --- a/pkg/aria2/common/common.go +++ b/pkg/aria2/common/common.go @@ -52,7 +52,7 @@ const ( var ( // ErrNotEnabled 功能未开启错误 - ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "", nil) + ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "not enabled", nil) // ErrUserNotFound 未找到下载任务创建者 ErrUserNotFound = serializer.NewError(serializer.CodeUserNotFound, "", nil) ) diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 2ef84be..96cc4bb 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -38,6 +38,7 @@ type ssl struct { type unix struct { Listen string ProxyHeader string `validate:"required_with=Listen"` + Perm uint32 } // slave 作为slave存储端配置 diff --git a/pkg/conf/version.go b/pkg/conf/version.go index d92c524..fda173f 100644 --- a/pkg/conf/version.go +++ b/pkg/conf/version.go @@ -1,13 +1,13 @@ package conf // BackendVersion 当前后端版本号 -var BackendVersion = "3.5.3" +var BackendVersion = "3.6.0" // RequiredDBVersion 与当前版本匹配的数据库版本 -var RequiredDBVersion = "3.5.2" +var RequiredDBVersion = "3.6.0" // RequiredStaticVersion 与当前版本匹配的静态资源版本 -var RequiredStaticVersion = "3.5.3" +var RequiredStaticVersion = "3.6.0" // IsPro 是否为Pro版本 var IsPro = "false" diff --git a/pkg/filesystem/chunk/backoff/backoff.go b/pkg/filesystem/chunk/backoff/backoff.go index d15b975..95cb1b5 100644 --- a/pkg/filesystem/chunk/backoff/backoff.go +++ b/pkg/filesystem/chunk/backoff/backoff.go @@ -1,14 +1,22 @@ package backoff -import "time" +import ( + "errors" + "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "net/http" + "strconv" + "time" +) // Backoff used for retry sleep backoff type Backoff interface { - Next() bool + Next(err error) bool Reset() } -// ConstantBackoff implements Backoff interface with constant sleep time +// ConstantBackoff implements Backoff interface with constant sleep time. If the error +// is retryable and with `RetryAfter` defined, the `RetryAfter` will be used as sleep duration. type ConstantBackoff struct { Sleep time.Duration Max int @@ -16,16 +24,51 @@ type ConstantBackoff struct { tried int } -func (c *ConstantBackoff) Next() bool { +func (c *ConstantBackoff) Next(err error) bool { c.tried++ if c.tried > c.Max { return false } - time.Sleep(c.Sleep) + var e *RetryableError + if errors.As(err, &e) && e.RetryAfter > 0 { + util.Log().Warning("Retryable error %q occurs in backoff, will sleep after %s.", e, e.RetryAfter) + time.Sleep(e.RetryAfter) + } else { + time.Sleep(c.Sleep) + } + return true } func (c *ConstantBackoff) Reset() { c.tried = 0 } + +type RetryableError struct { + Err error + RetryAfter time.Duration +} + +// NewRetryableErrorFromHeader constructs a new RetryableError from http response header +// and existing error. +func NewRetryableErrorFromHeader(err error, header http.Header) *RetryableError { + retryAfter := header.Get("retry-after") + if retryAfter == "" { + retryAfter = "0" + } + + res := &RetryableError{ + Err: err, + } + + if retryAfterSecond, err := strconv.ParseInt(retryAfter, 10, 64); err == nil { + res.RetryAfter = time.Duration(retryAfterSecond) * time.Second + } + + return res +} + +func (e *RetryableError) Error() string { + return fmt.Sprintf("retryable error with retry-after=%s: %s", e.RetryAfter, e.Err) +} diff --git a/pkg/filesystem/chunk/backoff/backoff_test.go b/pkg/filesystem/chunk/backoff/backoff_test.go index 6419c71..0fda534 100644 --- a/pkg/filesystem/chunk/backoff/backoff_test.go +++ b/pkg/filesystem/chunk/backoff/backoff_test.go @@ -1,7 +1,9 @@ package backoff import ( + "errors" "github.com/stretchr/testify/assert" + "net/http" "testing" "time" ) @@ -9,14 +11,51 @@ import ( func TestConstantBackoff_Next(t *testing.T) { a := assert.New(t) - b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} - a.True(b.Next()) - a.True(b.Next()) - a.True(b.Next()) - a.False(b.Next()) - b.Reset() - a.True(b.Next()) - a.True(b.Next()) - a.True(b.Next()) - a.False(b.Next()) + // General error + { + err := errors.New("error") + b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} + a.True(b.Next(err)) + a.True(b.Next(err)) + a.True(b.Next(err)) + a.False(b.Next(err)) + b.Reset() + a.True(b.Next(err)) + a.True(b.Next(err)) + a.True(b.Next(err)) + a.False(b.Next(err)) + } + + // Retryable error + { + err := &RetryableError{RetryAfter: time.Duration(1)} + b := &ConstantBackoff{Sleep: time.Duration(0), Max: 3} + a.True(b.Next(err)) + a.True(b.Next(err)) + a.True(b.Next(err)) + a.False(b.Next(err)) + b.Reset() + a.True(b.Next(err)) + a.True(b.Next(err)) + a.True(b.Next(err)) + a.False(b.Next(err)) + } + +} + +func TestNewRetryableErrorFromHeader(t *testing.T) { + a := assert.New(t) + // no retry-after header + { + err := NewRetryableErrorFromHeader(nil, http.Header{}) + a.Empty(err.RetryAfter) + } + + // with retry-after header + { + header := http.Header{} + header.Add("retry-after", "120") + err := NewRetryableErrorFromHeader(nil, header) + a.EqualValues(time.Duration(120)*time.Second, err.RetryAfter) + } } diff --git a/pkg/filesystem/chunk/chunk.go b/pkg/filesystem/chunk/chunk.go index 24e50a1..cf790f6 100644 --- a/pkg/filesystem/chunk/chunk.go +++ b/pkg/filesystem/chunk/chunk.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" + "github.com/cloudreve/Cloudreve/v3/pkg/request" "github.com/cloudreve/Cloudreve/v3/pkg/util" "io" "os" @@ -66,7 +67,7 @@ func (c *ChunkGroup) TempAvailable() bool { // Process a chunk with retry logic func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { - reader := io.LimitReader(c.file, int64(c.chunkSize)) + reader := io.LimitReader(c.file, c.Length()) // If useBuffer is enabled, tee the reader to a temp file if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() { @@ -90,13 +91,17 @@ func (c *ChunkGroup) Process(processor ChunkProcessFunc) error { } util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name()) - reader = c.bufferTemp + reader = io.NopCloser(c.bufferTemp) } } err := processor(c, reader) if err != nil { - if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next() { + if c.enableRetryBuffer { + request.BlackHole(reader) + } + + if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next(err) { if c.file.Seekable() { if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil { return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err) diff --git a/pkg/filesystem/driver/local/handler_test.go b/pkg/filesystem/driver/local/handler_test.go index 9167e82..0dfe818 100644 --- a/pkg/filesystem/driver/local/handler_test.go +++ b/pkg/filesystem/driver/local/handler_test.go @@ -36,7 +36,7 @@ func TestHandler_Put(t *testing.T) { {&fsctx.FileStream{ SavePath: "TestHandler_Put.txt", File: io.NopCloser(strings.NewReader("")), - }, "物理同名文件已存在或不可用"}, + }, "file with the same name existed or unavailable"}, {&fsctx.FileStream{ SavePath: "inner/TestHandler_Put.txt", File: io.NopCloser(strings.NewReader("")), @@ -51,7 +51,7 @@ func TestHandler_Put(t *testing.T) { Mode: fsctx.Append | fsctx.Overwrite, SavePath: "inner/TestHandler_Put.txt", File: io.NopCloser(strings.NewReader("123")), - }, "未上传完成的文件分片与预期大小不一致"}, + }, "size of unfinished uploaded chunks is not as expected"}, {&fsctx.FileStream{ Mode: fsctx.Append | fsctx.Overwrite, SavePath: "inner/TestHandler_Put.txt", diff --git a/pkg/filesystem/driver/onedrive/api.go b/pkg/filesystem/driver/onedrive/api.go index 1e41b72..2ec1663 100644 --- a/pkg/filesystem/driver/onedrive/api.go +++ b/pkg/filesystem/driver/onedrive/api.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "io" - "io/ioutil" "net/http" "net/url" "path" @@ -37,24 +36,18 @@ const ( // GetSourcePath 获取文件的绝对路径 func (info *FileInfo) GetSourcePath() string { - res, err := url.PathUnescape( - strings.TrimPrefix( - path.Join( - strings.TrimPrefix(info.ParentReference.Path, "/drive/root:"), - info.Name, - ), - "/", - ), - ) + res, err := url.PathUnescape(info.ParentReference.Path) if err != nil { return "" } - return res -} -// Error 实现error接口 -func (err RespError) Error() string { - return err.APIError.Message + return strings.TrimPrefix( + path.Join( + strings.TrimPrefix(res, "/drive/root:"), + info.Name, + ), + "/", + ) } func (client *Client) getRequestURL(api string, opts ...Option) string { @@ -531,7 +524,7 @@ func sysError(err error) *RespError { }} } -func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, *RespError) { +func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) { // 获取凭证 err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave") if err != nil { @@ -580,15 +573,21 @@ func (client *Client) request(ctx context.Context, method string, url string, bo util.Log().Debug("Onedrive returns unknown response: %s", respBody) return "", sysError(decodeErr) } + + if res.Response.StatusCode == 429 { + util.Log().Warning("OneDrive request is throttled.") + return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header) + } + return "", &errResp } return respBody, nil } -func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, *RespError) { +func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) { // 发送请求 - bodyReader := ioutil.NopCloser(strings.NewReader(body)) + bodyReader := io.NopCloser(strings.NewReader(body)) return client.request(ctx, method, url, bodyReader, request.WithContentLength(int64(len(body))), ) diff --git a/pkg/filesystem/driver/onedrive/api_test.go b/pkg/filesystem/driver/onedrive/api_test.go index fb6393d..a675548 100644 --- a/pkg/filesystem/driver/onedrive/api_test.go +++ b/pkg/filesystem/driver/onedrive/api_test.go @@ -112,6 +112,35 @@ func TestRequest(t *testing.T) { asserts.Equal("error msg", err.Error()) } + // OneDrive返回429错误 + { + header := http.Header{} + header.Add("retry-after", "120") + clientMock := ClientMock{} + clientMock.On( + "Request", + "POST", + "http://dev.com", + testMock.Anything, + testMock.Anything, + ).Return(&request.Response{ + Err: nil, + Response: &http.Response{ + StatusCode: 429, + Header: header, + Body: ioutil.NopCloser(strings.NewReader(`{"error":{"message":"error msg"}}`)), + }, + }) + client.Request = clientMock + res, err := client.request(context.Background(), "POST", "http://dev.com", strings.NewReader("")) + clientMock.AssertExpectations(t) + asserts.Error(err) + asserts.Empty(res) + var retryErr *backoff.RetryableError + asserts.ErrorAs(err, &retryErr) + asserts.EqualValues(time.Duration(120)*time.Second, retryErr.RetryAfter) + } + // OneDrive返回未知响应 { clientMock := ClientMock{} @@ -144,18 +173,18 @@ func TestFileInfo_GetSourcePath(t *testing.T) { fileInfo := FileInfo{ Name: "%e6%96%87%e4%bb%b6%e5%90%8d.jpg", ParentReference: parentReference{ - Path: "/drive/root:/123/321", + Path: "/drive/root:/123/32%201", }, } - asserts.Equal("123/321/文件名.jpg", fileInfo.GetSourcePath()) + asserts.Equal("123/32 1/%e6%96%87%e4%bb%b6%e5%90%8d.jpg", fileInfo.GetSourcePath()) } // 失败 { fileInfo := FileInfo{ - Name: "%e6%96%87%e4%bb%b6%e5%90%8g.jpg", + Name: "123.jpg", ParentReference: parentReference{ - Path: "/drive/root:/123/321", + Path: "/drive/root:/123/%e6%96%87%e4%bb%b6%e5%90%8g", }, } asserts.Equal("", fileInfo.GetSourcePath()) diff --git a/pkg/filesystem/driver/onedrive/handler.go b/pkg/filesystem/driver/onedrive/handler.go index 98b2ba7..389ede2 100644 --- a/pkg/filesystem/driver/onedrive/handler.go +++ b/pkg/filesystem/driver/onedrive/handler.go @@ -11,7 +11,6 @@ import ( "time" model "github.com/cloudreve/Cloudreve/v3/models" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver" "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" @@ -171,19 +170,6 @@ func (handler Driver) Source( cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path) if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok { cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID) - // 如果是永久链接,则返回签名后的中转外链 - if ttl == 0 { - signedURI, err := auth.SignURI( - auth.General, - fmt.Sprintf("/api/v3/file/source/%d/%s", file.ID, file.Name), - ttl, - ) - if err != nil { - return "", err - } - return baseURL.ResolveReference(signedURI).String(), nil - } - } // 尝试从缓存中查找 diff --git a/pkg/filesystem/driver/onedrive/handler_test.go b/pkg/filesystem/driver/onedrive/handler_test.go index 7700e7a..c63be86 100644 --- a/pkg/filesystem/driver/onedrive/handler_test.go +++ b/pkg/filesystem/driver/onedrive/handler_test.go @@ -3,7 +3,6 @@ package onedrive import ( "context" "fmt" - "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/mq" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" "github.com/jinzhu/gorm" @@ -161,21 +160,6 @@ func TestDriver_Source(t *testing.T) { asserts.NoError(err) asserts.Equal("123321", res) } - - // 成功 永久直链 - { - file := model.File{} - file.ID = 1 - file.Name = "123.jpg" - file.UpdatedAt = time.Now() - ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file) - handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix() - auth.General = auth.HMACAuth{} - handler.Client.Credential.AccessToken = "1" - res, err := handler.Source(ctx, "123.jpg", url.URL{}, 0, true, 0) - asserts.NoError(err) - asserts.Contains(res, "/api/v3/file/source/1/123.jpg?sign") - } } func TestDriver_List(t *testing.T) { diff --git a/pkg/filesystem/driver/onedrive/types.go b/pkg/filesystem/driver/onedrive/types.go index 2a4307f..2a2ea4c 100644 --- a/pkg/filesystem/driver/onedrive/types.go +++ b/pkg/filesystem/driver/onedrive/types.go @@ -133,3 +133,8 @@ type Site struct { func init() { gob.Register(Credential{}) } + +// Error 实现error接口 +func (err RespError) Error() string { + return err.APIError.Message +} diff --git a/pkg/filesystem/driver/s3/handler.go b/pkg/filesystem/driver/s3/handler.go index a6d17bf..cc2d1de 100644 --- a/pkg/filesystem/driver/s3/handler.go +++ b/pkg/filesystem/driver/s3/handler.go @@ -398,8 +398,8 @@ func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *seri // Meta 获取文件信息 func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) { - res, err := handler.svc.GetObject( - &s3.GetObjectInput{ + res, err := handler.svc.HeadObject( + &s3.HeadObjectInput{ Bucket: &handler.Policy.BucketName, Key: &path, }) diff --git a/pkg/filesystem/errors.go b/pkg/filesystem/errors.go index 303d0d5..d267038 100644 --- a/pkg/filesystem/errors.go +++ b/pkg/filesystem/errors.go @@ -8,17 +8,17 @@ import ( var ( ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil) - ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "", nil) - ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "", nil) - ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "", nil) - ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "", nil) + ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "File is too large", nil) + ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "File type not allowed", nil) + ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "Insufficient capacity", nil) + ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "Invalid object name", nil) ErrClientCanceled = errors.New("Client canceled operation") - ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "", nil) + ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "Root protected", nil) ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "Failed to create file record", nil) - ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "", nil) - ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "", nil) - ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "", nil) - ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "", nil) + ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "Object existed", nil) + ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "Upload session existed", nil) + ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "Path not exist", nil) + ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "Object not exist", nil) ErrIO = serializer.NewError(serializer.CodeIOFailed, "Failed to read file data", nil) ErrDBListObjects = serializer.NewError(serializer.CodeDBError, "Failed to list object records", nil) ErrDBDeleteObjects = serializer.NewError(serializer.CodeDBError, "Failed to delete object records", nil) diff --git a/pkg/filesystem/manage_test.go b/pkg/filesystem/manage_test.go index 1f018bd..2ec0aec 100644 --- a/pkg/filesystem/manage_test.go +++ b/pkg/filesystem/manage_test.go @@ -472,6 +472,9 @@ func TestFileSystem_Delete(t *testing.T) { AddRow(4, "1.txt", "1.txt", 365, 1), ) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 365, 2)) + // 两次查询软连接 + mock.ExpectQuery("SELECT(.+)files(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) mock.ExpectQuery("SELECT(.+)files(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) // 查询上传策略 @@ -527,6 +530,9 @@ func TestFileSystem_Delete(t *testing.T) { AddRow(4, "1.txt", "1.txt", 602, 1), ) mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "name", "source_name", "policy_id", "size"}).AddRow(1, "2.txt", "2.txt", 602, 2)) + // 两次查询软连接 + mock.ExpectQuery("SELECT(.+)files(.+)"). + WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) mock.ExpectQuery("SELECT(.+)files(.+)"). WillReturnRows(sqlmock.NewRows([]string{"id", "policy_id", "source_name"})) // 查询上传策略 diff --git a/pkg/hashid/hash.go b/pkg/hashid/hash.go index 942c953..ffe5944 100644 --- a/pkg/hashid/hash.go +++ b/pkg/hashid/hash.go @@ -15,11 +15,12 @@ const ( FolderID // 目录ID TagID // 标签ID PolicyID // 存储策略ID + SourceLinkID ) var ( // ErrTypeNotMatch ID类型不匹配 - ErrTypeNotMatch = errors.New("ID类型不匹配") + ErrTypeNotMatch = errors.New("mismatched ID type.") ) // HashEncode 对给定数据计算HashID diff --git a/pkg/request/options.go b/pkg/request/options.go index dc0391e..63bc8dd 100644 --- a/pkg/request/options.go +++ b/pkg/request/options.go @@ -44,6 +44,12 @@ func newDefaultOption() *options { } } +func (o *options) clone() options { + newOptions := *o + newOptions.header = o.header.Clone() + return newOptions +} + // WithTimeout 设置请求超时 func WithTimeout(t time.Duration) Option { return optionFunc(func(o *options) { diff --git a/pkg/request/request.go b/pkg/request/request.go index 6ee78bc..2947085 100644 --- a/pkg/request/request.go +++ b/pkg/request/request.go @@ -56,7 +56,7 @@ func NewClient(opts ...Option) Client { func (c *HTTPClient) Request(method, target string, body io.Reader, opts ...Option) *Response { // 应用额外设置 c.mu.Lock() - options := *c.options + options := c.options.clone() c.mu.Unlock() for _, o := range opts { o.apply(&options) @@ -179,7 +179,7 @@ func (resp *Response) DecodeResponse() (*serializer.Response, error) { var res serializer.Response err = json.Unmarshal([]byte(respString), &res) if err != nil { - util.Log().Debug("无法解析回调服务端响应:%s", string(respString)) + util.Log().Debug("Failed to parse response: %s", string(respString)) return nil, err } return &res, nil @@ -251,7 +251,7 @@ func (instance NopRSCloser) Seek(offset int64, whence int) (int64, error) { return instance.status.Size, nil } } - return 0, errors.New("未实现") + return 0, errors.New("not implemented") } diff --git a/pkg/serializer/aria2.go b/pkg/serializer/aria2.go index 1d6d3c6..890b2b9 100644 --- a/pkg/serializer/aria2.go +++ b/pkg/serializer/aria2.go @@ -19,6 +19,7 @@ type DownloadListResponse struct { Downloaded uint64 `json:"downloaded"` Speed int `json:"speed"` Info rpc.StatusInfo `json:"info"` + NodeName string `json:"node"` } // FinishedListResponse 已完成任务条目 @@ -34,6 +35,7 @@ type FinishedListResponse struct { TaskError string `json:"task_error"` CreateTime time.Time `json:"create"` UpdateTime time.Time `json:"update"` + NodeName string `json:"node"` } // BuildFinishedListResponse 构建已完成任务条目 @@ -62,6 +64,7 @@ func BuildFinishedListResponse(tasks []model.Download) Response { TaskStatus: -1, UpdateTime: tasks[i].UpdatedAt, CreateTime: tasks[i].CreatedAt, + NodeName: tasks[i].NodeName, } if tasks[i].Task != nil { @@ -106,6 +109,7 @@ func BuildDownloadingResponse(tasks []model.Download, intervals map[uint]int) Re Downloaded: tasks[i].DownloadedSize, Speed: tasks[i].Speed, Info: tasks[i].StatusInfo, + NodeName: tasks[i].NodeName, }) } diff --git a/pkg/serializer/error.go b/pkg/serializer/error.go index b18d6de..326c0d8 100644 --- a/pkg/serializer/error.go +++ b/pkg/serializer/error.go @@ -221,7 +221,7 @@ const ( // DBErr 数据库操作失败 func DBErr(msg string, err error) Response { if msg == "" { - msg = "数据库操作失败" + msg = "Database operation failed." } return Err(CodeDBError, msg, err) } @@ -229,7 +229,7 @@ func DBErr(msg string, err error) Response { // ParamErr 各种参数错误 func ParamErr(msg string, err error) Response { if msg == "" { - msg = "参数错误" + msg = "Invalid parameters." } return Err(CodeParamErr, msg, err) } diff --git a/pkg/serializer/response.go b/pkg/serializer/response.go index 91aae47..ecfaec2 100644 --- a/pkg/serializer/response.go +++ b/pkg/serializer/response.go @@ -19,7 +19,7 @@ func NewResponseWithGobData(data interface{}) Response { var w bytes.Buffer encoder := gob.NewEncoder(&w) if err := encoder.Encode(data); err != nil { - return Err(CodeInternalSetting, "无法编码返回结果", err) + return Err(CodeInternalSetting, "Failed to encode response content", err) } return Response{Data: w.Bytes()} diff --git a/pkg/serializer/setting.go b/pkg/serializer/setting.go index d3b2395..2c1a345 100644 --- a/pkg/serializer/setting.go +++ b/pkg/serializer/setting.go @@ -22,6 +22,7 @@ type SiteConfig struct { CaptchaType string `json:"captcha_type"` TCaptchaCaptchaAppId string `json:"tcaptcha_captcha_app_id"` RegisterEnabled bool `json:"registerEnabled"` + AppPromotion bool `json:"app_promotion"` } type task struct { @@ -83,6 +84,7 @@ func BuildSiteConfig(settings map[string]string, user *model.User) Response { CaptchaType: checkSettingValue(settings, "captcha_type"), TCaptchaCaptchaAppId: checkSettingValue(settings, "captcha_TCaptcha_CaptchaAppId"), RegisterEnabled: model.IsTrueVal(checkSettingValue(settings, "register_enabled")), + AppPromotion: model.IsTrueVal(checkSettingValue(settings, "show_app_promotion")), }} return res } diff --git a/pkg/serializer/slave_test.go b/pkg/serializer/slave_test.go index add3a63..46b5d2d 100644 --- a/pkg/serializer/slave_test.go +++ b/pkg/serializer/slave_test.go @@ -19,14 +19,3 @@ func TestSlaveTransferReq_Hash(t *testing.T) { } a.NotEqual(s1.Hash("1"), s2.Hash("1")) } - -func TestSlaveRecycleReq_Hash(t *testing.T) { - a := assert.New(t) - s1 := &SlaveRecycleReq{ - Path: "1", - } - s2 := &SlaveRecycleReq{ - Path: "2", - } - a.NotEqual(s1.Hash("1"), s2.Hash("1")) -} diff --git a/pkg/serializer/user.go b/pkg/serializer/user.go index 2df80bd..68e9940 100644 --- a/pkg/serializer/user.go +++ b/pkg/serializer/user.go @@ -13,7 +13,7 @@ import ( func CheckLogin() Response { return Response{ Code: CodeCheckLogin, - Msg: "未登录", + Msg: "Login required", } } diff --git a/pkg/task/compress.go b/pkg/task/compress.go index 5986e26..5e20a36 100644 --- a/pkg/task/compress.go +++ b/pkg/task/compress.go @@ -69,7 +69,7 @@ func (job *CompressTask) SetError(err *JobError) { func (job *CompressTask) removeZipFile() { if job.zipPath != "" { if err := os.Remove(job.zipPath); err != nil { - util.Log().Warning("无法删除临时压缩文件 %s , %s", job.zipPath, err) + util.Log().Warning("Failed to delete temp zip file %q: %s", job.zipPath, err) } } } @@ -93,7 +93,7 @@ func (job *CompressTask) Do() { return } - util.Log().Debug("开始压缩文件") + util.Log().Debug("Starting compress file...") job.TaskModel.SetProgress(CompressingProgress) // 创建临时压缩文件 @@ -122,7 +122,7 @@ func (job *CompressTask) Do() { job.zipPath = zipFilePath zipFile.Close() - util.Log().Debug("压缩文件存放至%s,开始上传", zipFilePath) + util.Log().Debug("Compressed file saved to %q, start uploading it...", zipFilePath) job.TaskModel.SetProgress(TransferringProgress) // 上传文件 diff --git a/pkg/task/decompress.go b/pkg/task/decompress.go index 0db2ec5..9c6d88e 100644 --- a/pkg/task/decompress.go +++ b/pkg/task/decompress.go @@ -77,7 +77,7 @@ func (job *DecompressTask) Do() { // 创建文件系统 fs, err := filesystem.NewFileSystem(job.User) if err != nil { - job.SetErrorMsg("无法创建文件系统", err) + job.SetErrorMsg("Failed to create filesystem.", err) return } @@ -85,7 +85,7 @@ func (job *DecompressTask) Do() { err = fs.Decompress(context.Background(), job.TaskProps.Src, job.TaskProps.Dst, job.TaskProps.Encoding) if err != nil { - job.SetErrorMsg("解压缩失败", err) + job.SetErrorMsg("Failed to decompress file.", err) return } diff --git a/pkg/task/errors.go b/pkg/task/errors.go index ad9df0c..f1fca16 100644 --- a/pkg/task/errors.go +++ b/pkg/task/errors.go @@ -4,5 +4,5 @@ import "errors" var ( // ErrUnknownTaskType 未知任务类型 - ErrUnknownTaskType = errors.New("未知任务类型") + ErrUnknownTaskType = errors.New("unknown task type") ) diff --git a/pkg/task/import.go b/pkg/task/import.go index efc32e9..607b4d1 100644 --- a/pkg/task/import.go +++ b/pkg/task/import.go @@ -81,7 +81,7 @@ func (job *ImportTask) Do() { // 查找存储策略 policy, err := model.GetPolicyByID(job.TaskProps.PolicyID) if err != nil { - job.SetErrorMsg("找不到存储策略", err) + job.SetErrorMsg("Policy not exist.", err) return } @@ -96,7 +96,7 @@ func (job *ImportTask) Do() { fs.Policy = &policy if err := fs.DispatchHandler(); err != nil { - job.SetErrorMsg("无法分发存储策略", err) + job.SetErrorMsg("Failed to dispatch policy.", err) return } @@ -110,7 +110,7 @@ func (job *ImportTask) Do() { true) objects, err := fs.Handler.List(ctx, job.TaskProps.Src, job.TaskProps.Recursive) if err != nil { - job.SetErrorMsg("无法列取文件", err) + job.SetErrorMsg("Failed to list files.", err) return } @@ -126,7 +126,7 @@ func (job *ImportTask) Do() { virtualPath := path.Join(job.TaskProps.Dst, object.RelativePath) folder, err := fs.CreateDirectory(coxIgnoreConflict, virtualPath) if err != nil { - util.Log().Warning("导入任务无法创建用户目录[%s], %s", virtualPath, err) + util.Log().Warning("Importing task cannot create user directory %q: %s", virtualPath, err) } else if folder.ID > 0 { pathCache[virtualPath] = folder } @@ -152,7 +152,7 @@ func (job *ImportTask) Do() { } else { folder, err := fs.CreateDirectory(context.Background(), virtualPath) if err != nil { - util.Log().Warning("导入任务无法创建用户目录[%s], %s", + util.Log().Warning("Importing task cannot create user directory %q: %s", virtualPath, err) continue } @@ -163,10 +163,10 @@ func (job *ImportTask) Do() { // 插入文件记录 _, err := fs.AddFile(context.Background(), parentFolder, &fileHeader) if err != nil { - util.Log().Warning("导入任务无法创插入文件[%s], %s", + util.Log().Warning("Importing task cannot insert user file %q: %s", object.RelativePath, err) if err == filesystem.ErrInsufficientCapacity { - job.SetErrorMsg("容量不足", err) + job.SetErrorMsg("Insufficient storage capacity.", err) return } } diff --git a/pkg/task/job.go b/pkg/task/job.go index e9d54d8..d480492 100644 --- a/pkg/task/job.go +++ b/pkg/task/job.go @@ -89,12 +89,12 @@ func Resume(p Pool) { if len(tasks) == 0 { return } - util.Log().Info("从数据库中恢复 %d 个未完成任务", len(tasks)) + util.Log().Info("Resume %d unfinished task(s) from database.", len(tasks)) for i := 0; i < len(tasks); i++ { job, err := GetJobFromModel(&tasks[i]) if err != nil { - util.Log().Warning("无法恢复任务,%s", err) + util.Log().Warning("Failed to resume task: %s", err) continue } diff --git a/pkg/task/pool.go b/pkg/task/pool.go index 53e94a5..e37f179 100644 --- a/pkg/task/pool.go +++ b/pkg/task/pool.go @@ -44,11 +44,11 @@ func (pool *AsyncPool) freeWorker() { // Submit 开始提交任务 func (pool *AsyncPool) Submit(job Job) { go func() { - util.Log().Debug("等待获取Worker") + util.Log().Debug("Waiting for Worker.") worker := pool.obtainWorker() - util.Log().Debug("获取到Worker") + util.Log().Debug("Worker obtained.") worker.Do(job) - util.Log().Debug("释放Worker") + util.Log().Debug("Worker released.") pool.freeWorker() }() } @@ -60,7 +60,7 @@ func Init() { idleWorker: make(chan int, maxWorker), } TaskPoll.Add(maxWorker) - util.Log().Info("初始化任务队列,WorkerNum = %d", maxWorker) + util.Log().Info("Initialize task queue with WorkerNum = %d", maxWorker) if conf.SystemConfig.Mode == "master" { Resume(TaskPoll) diff --git a/pkg/task/recycle.go b/pkg/task/recycle.go index 17eaf3c..60cc97f 100644 --- a/pkg/task/recycle.go +++ b/pkg/task/recycle.go @@ -73,21 +73,21 @@ func (job *RecycleTask) GetError() *JobError { func (job *RecycleTask) Do() { download, err := model.GetDownloadByGid(job.TaskProps.DownloadGID, job.User.ID) if err != nil { - util.Log().Warning("回收任务 %d 找不到下载记录", job.TaskModel.ID) - job.SetErrorMsg("无法找到下载任务", err) + util.Log().Warning("Recycle task %d cannot found download record.", job.TaskModel.ID) + job.SetErrorMsg("Cannot found download task.", err) return } nodeID := download.GetNodeID() node := cluster.Default.GetNodeByID(nodeID) if node == nil { - util.Log().Warning("回收任务 %d 找不到节点", job.TaskModel.ID) - job.SetErrorMsg("从机节点不可用", nil) + util.Log().Warning("Recycle task %d cannot found node.", job.TaskModel.ID) + job.SetErrorMsg("Invalid slave node.", nil) return } err = node.GetAria2Instance().DeleteTempFile(download) if err != nil { - util.Log().Warning("无法删除中转临时目录[%s], %s", download.Parent, err) - job.SetErrorMsg("文件回收失败", err) + util.Log().Warning("Failed to delete transfer temp folder %q: %s", download.Parent, err) + job.SetErrorMsg("Failed to recycle files.", err) return } } diff --git a/pkg/task/slavetask/transfer.go b/pkg/task/slavetask/transfer.go index 818028e..bdc5926 100644 --- a/pkg/task/slavetask/transfer.go +++ b/pkg/task/slavetask/transfer.go @@ -69,7 +69,7 @@ func (job *TransferTask) SetErrorMsg(msg string, err error) { } if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), notifyMsg); err != nil { - util.Log().Warning("无法发送转存失败通知到从机, %s", err) + util.Log().Warning("Failed to send transfer failure notification to master node: %s", err) } } @@ -82,26 +82,26 @@ func (job *TransferTask) GetError() *task.JobError { func (job *TransferTask) Do() { fs, err := filesystem.NewAnonymousFileSystem() if err != nil { - job.SetErrorMsg("无法初始化匿名文件系统", err) + job.SetErrorMsg("Failed to initialize anonymous filesystem.", err) return } fs.Policy = job.Req.Policy if err := fs.DispatchHandler(); err != nil { - job.SetErrorMsg("无法分发存储策略", err) + job.SetErrorMsg("Failed to dispatch policy.", err) return } master, err := cluster.DefaultController.GetMasterInfo(job.MasterID) if err != nil { - job.SetErrorMsg("找不到主机节点", err) + job.SetErrorMsg("Cannot found master node ID.", err) return } fs.SwitchToShadowHandler(master.Instance, master.URL.String(), master.ID) file, err := os.Open(util.RelativePath(job.Req.Src)) if err != nil { - job.SetErrorMsg("无法读取源文件", err) + job.SetErrorMsg("Failed to read source file.", err) return } @@ -110,7 +110,7 @@ func (job *TransferTask) Do() { // 获取源文件大小 fi, err := file.Stat() if err != nil { - job.SetErrorMsg("无法获取源文件大小", err) + job.SetErrorMsg("Failed to get source file size.", err) return } @@ -122,7 +122,7 @@ func (job *TransferTask) Do() { Size: uint64(size), }) if err != nil { - job.SetErrorMsg("文件上传失败", err) + job.SetErrorMsg("Upload failed.", err) return } @@ -133,6 +133,6 @@ func (job *TransferTask) Do() { } if err := cluster.DefaultController.SendNotification(job.MasterID, job.Req.Hash(job.MasterID), msg); err != nil { - util.Log().Warning("无法发送转存成功通知到从机, %s", err) + util.Log().Warning("Failed to send transfer success notification to master node: %s", err) } } diff --git a/pkg/task/tranfer.go b/pkg/task/tranfer.go index f115e80..54bba47 100644 --- a/pkg/task/tranfer.go +++ b/pkg/task/tranfer.go @@ -3,6 +3,7 @@ package task import ( "context" "encoding/json" + "fmt" "path" "path/filepath" "strings" @@ -94,6 +95,7 @@ func (job *TransferTask) Do() { } successCount := 0 + errorList := make([]string, 0, len(job.TaskProps.Src)) for _, file := range job.TaskProps.Src { dst := path.Join(job.TaskProps.Dst, filepath.Base(file)) if job.TaskProps.TrimPath { @@ -109,7 +111,7 @@ func (job *TransferTask) Do() { // 获取从机节点 node := cluster.Default.GetNodeByID(job.TaskProps.NodeID) if node == nil { - job.SetErrorMsg("从机节点不可用", nil) + job.SetErrorMsg("Invalid slave node.", nil) } // 切换为从机节点处理上传 @@ -127,13 +129,17 @@ func (job *TransferTask) Do() { } if err != nil { - job.SetErrorMsg("文件转存失败", err) + errorList = append(errorList, err.Error()) } else { successCount++ job.TaskModel.SetProgress(successCount) } } + if len(errorList) > 0 { + job.SetErrorMsg("Failed to transfer one or more file(s).", fmt.Errorf(strings.Join(errorList, "\n"))) + } + } // NewTransferTask 新建中转任务 diff --git a/pkg/task/worker.go b/pkg/task/worker.go index 3e01f17..e40a3b5 100644 --- a/pkg/task/worker.go +++ b/pkg/task/worker.go @@ -16,14 +16,14 @@ type GeneralWorker struct { // Do 执行任务 func (worker *GeneralWorker) Do(job Job) { - util.Log().Debug("开始执行任务") + util.Log().Debug("Start executing task.") job.SetStatus(Processing) defer func() { // 致命错误捕获 if err := recover(); err != nil { - util.Log().Debug("任务执行出错,%s", err) - job.SetError(&JobError{Msg: "致命错误", Error: fmt.Sprintf("%s", err)}) + util.Log().Debug("Failed to execute task: %s", err) + job.SetError(&JobError{Msg: "Fatal error.", Error: fmt.Sprintf("%s", err)}) job.SetStatus(Error) } }() @@ -33,12 +33,12 @@ func (worker *GeneralWorker) Do(job Job) { // 任务执行失败 if err := job.GetError(); err != nil { - util.Log().Debug("任务执行出错") + util.Log().Debug("Failed to execute task.") job.SetStatus(Error) return } - util.Log().Debug("任务执行完成") + util.Log().Debug("Task finished.") // 执行完成 job.SetStatus(Complete) } diff --git a/pkg/thumb/image.go b/pkg/thumb/image.go index 69c73a3..cf851c3 100644 --- a/pkg/thumb/image.go +++ b/pkg/thumb/image.go @@ -45,7 +45,7 @@ func NewThumbFromFile(file io.Reader, name string) (*Thumb, error) { case "png": img, err = png.Decode(file) default: - return nil, errors.New("未知的图像类型") + return nil, errors.New("unknown image format") } if err != nil { return nil, err diff --git a/pkg/util/io.go b/pkg/util/io.go index 25b9dc9..fe3bd9a 100644 --- a/pkg/util/io.go +++ b/pkg/util/io.go @@ -22,7 +22,7 @@ func CreatNestedFile(path string) (*os.File, error) { if !Exists(basePath) { err := os.MkdirAll(basePath, 0700) if err != nil { - Log().Warning("无法创建目录,%s", err) + Log().Warning("Failed to create directory: %s", err) return nil, err } } diff --git a/routers/controllers/file.go b/routers/controllers/file.go index c660f88..8caadc2 100644 --- a/routers/controllers/file.go +++ b/routers/controllers/file.go @@ -79,8 +79,8 @@ func AnonymousGetContent(c *gin.Context) { } } -// AnonymousPermLink 文件签名后的永久链接 -func AnonymousPermLink(c *gin.Context) { +// AnonymousPermLink Deprecated 文件签名后的永久链接 +func AnonymousPermLinkDeprecated(c *gin.Context) { // 创建上下文 ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -102,6 +102,39 @@ func AnonymousPermLink(c *gin.Context) { } } +// AnonymousPermLink 文件中转后的永久直链接 +func AnonymousPermLink(c *gin.Context) { + // 创建上下文 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sourceLinkRaw, ok := c.Get("source_link") + if !ok { + c.JSON(200, serializer.Err(serializer.CodeFileNotFound, "", nil)) + return + } + + sourceLink := sourceLinkRaw.(*model.SourceLink) + + service := &explorer.FileAnonymousGetService{ + ID: sourceLink.FileID, + Name: sourceLink.File.Name, + } + + res := service.Source(ctx, c) + // 是否需要重定向 + if res.Code == -302 { + c.Redirect(302, res.Data.(string)) + return + } + + // 是否有错误发生 + if res.Code != 0 { + c.JSON(200, res) + } + +} + func GetSource(c *gin.Context) { // 创建上下文 ctx, cancel := context.WithCancel(context.Background()) diff --git a/routers/controllers/site.go b/routers/controllers/site.go index 414ebb5..d462a9d 100644 --- a/routers/controllers/site.go +++ b/routers/controllers/site.go @@ -27,6 +27,7 @@ func SiteConfig(c *gin.Context) { "captcha_type", "captcha_TCaptcha_CaptchaAppId", "register_enabled", + "show_app_promotion", ) // 如果已登录,则同时返回用户信息和标签 diff --git a/routers/router.go b/routers/router.go index 0727fe6..e6d9ba1 100644 --- a/routers/router.go +++ b/routers/router.go @@ -16,10 +16,10 @@ import ( // InitRouter 初始化路由 func InitRouter() *gin.Engine { if conf.SystemConfig.Mode == "master" { - util.Log().Info("当前运行模式:Master") + util.Log().Info("Current running mode: Master.") return InitMasterRouter() } - util.Log().Info("当前运行模式:Slave") + util.Log().Info("Current running mode: Slave.") return InitSlaveRouter() } @@ -108,7 +108,7 @@ func InitCORS(router *gin.Engine) { // slave模式下未启动跨域的警告 if conf.SystemConfig.Mode == "slave" { - util.Log().Warning("当前作为存储端(Slave)运行,但未启用跨域配置,可能会导致 Master 端无法正常上传文件") + util.Log().Warning("You are running Cloudreve as slave node, if you are using slave storage policy, please enable CORS feature in config file, otherwise file cannot be uploaded from Master site.") } } @@ -145,6 +145,15 @@ func InitMasterRouter() *gin.Engine { 路由 */ { + // Redirect file source link + source := r.Group("f") + { + source.GET(":id/:name", + middleware.HashID(hashid.SourceLinkID), + middleware.ValidateSourceLink(), + controllers.AnonymousPermLink) + } + // 全局设置相关 site := v3.Group("site") { @@ -197,6 +206,7 @@ func InitMasterRouter() *gin.Engine { // 获取用户头像 user.GET("avatar/:id/:size", middleware.HashID(hashid.UserID), + middleware.StaticResourceCache(), controllers.GetUserAvatar, ) } @@ -208,11 +218,18 @@ func InitMasterRouter() *gin.Engine { file := sign.Group("file") { // 文件外链(直接输出文件数据) - file.GET("get/:id/:name", controllers.AnonymousGetContent) + file.GET("get/:id/:name", + middleware.Sandbox(), + middleware.StaticResourceCache(), + controllers.AnonymousGetContent, + ) // 文件外链(301跳转) - file.GET("source/:id/:name", controllers.AnonymousPermLink) + file.GET("source/:id/:name", controllers.AnonymousPermLinkDeprecated) // 下载文件 - file.GET("download/:id", controllers.Download) + file.GET("download/:id", + middleware.StaticResourceCache(), + controllers.Download, + ) // 打包并下载文件 file.GET("archive/:sessionID/archive.zip", controllers.DownloadArchive) } @@ -445,7 +462,7 @@ func InitMasterRouter() *gin.Engine { // 列出文件 file.POST("list", controllers.AdminListFile) // 预览文件 - file.GET("preview/:id", controllers.AdminGetFile) + file.GET("preview/:id", middleware.Sandbox(), controllers.AdminGetFile) // 删除 file.POST("delete", controllers.AdminDeleteFile) // 列出用户或外部文件系统目录 @@ -555,9 +572,9 @@ func InitMasterRouter() *gin.Engine { // 创建文件下载会话 file.PUT("download/:id", controllers.CreateDownloadSession) // 预览文件 - file.GET("preview/:id", controllers.Preview) + file.GET("preview/:id", middleware.Sandbox(), controllers.Preview) // 获取文本文件内容 - file.GET("content/:id", controllers.PreviewText) + file.GET("content/:id", middleware.Sandbox(), controllers.PreviewText) // 取得Office文档预览地址 file.GET("doc/:id", controllers.GetDocPreview) // 获取缩略图 diff --git a/service/admin/policy.go b/service/admin/policy.go index 3207a7c..abfc9da 100644 --- a/service/admin/policy.go +++ b/service/admin/policy.go @@ -318,12 +318,20 @@ func (service *AdminListService) Policies() serializer.Response { // 统计每个策略的文件使用 statics := make(map[uint][2]int, len(res)) + policyIds := make([]uint, 0, len(res)) for i := 0; i < len(res); i++ { + policyIds = append(policyIds, res[i].ID) + } + + rows, _ := model.DB.Model(&model.File{}).Where("policy_id in (?)", policyIds). + Select("policy_id,count(id),sum(size)").Group("policy_id").Rows() + + for rows.Next() { + policyId := uint(0) total := [2]int{} - row := model.DB.Model(&model.File{}).Where("policy_id = ?", res[i].ID). - Select("count(id),sum(size)").Row() - row.Scan(&total[0], &total[1]) - statics[res[i].ID] = total + rows.Scan(&policyId, &total[0], &total[1]) + + statics[policyId] = total } return serializer.Response{Data: map[string]interface{}{ diff --git a/service/admin/user.go b/service/admin/user.go index 32af8b3..9ade2ed 100644 --- a/service/admin/user.go +++ b/service/admin/user.go @@ -109,6 +109,7 @@ func (service *AddUserService) Add() serializer.Response { user.Email = service.User.Email user.GroupID = service.User.GroupID user.Status = service.User.Status + user.TwoFactor = service.User.TwoFactor // 检查愚蠢操作 if user.ID == 1 && user.GroupID != 1 { diff --git a/service/aria2/manage.go b/service/aria2/manage.go index 115a440..35ccdff 100644 --- a/service/aria2/manage.go +++ b/service/aria2/manage.go @@ -27,6 +27,13 @@ type DownloadListService struct { func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response { // 查找下载记录 downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Error, common.Complete, common.Canceled, common.Unknown) + for key, download := range downloads { + node := cluster.Default.GetNodeByID(download.GetNodeID()) + if node != nil { + downloads[key].NodeName = node.DBModel().Name + } + } + return serializer.BuildFinishedListResponse(downloads) } @@ -35,12 +42,17 @@ func (service *DownloadListService) Downloading(c *gin.Context, user *model.User // 查找下载记录 downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, common.Downloading, common.Seeding, common.Paused, common.Ready) intervals := make(map[uint]int) - for _, download := range downloads { + for key, download := range downloads { if _, ok := intervals[download.ID]; !ok { if node := cluster.Default.GetNodeByID(download.GetNodeID()); node != nil { intervals[download.ID] = node.DBModel().Aria2OptionsSerialized.Interval } } + + node := cluster.Default.GetNodeByID(download.GetNodeID()) + if node != nil { + downloads[key].NodeName = node.DBModel().Name + } } return serializer.BuildDownloadingResponse(downloads, intervals) diff --git a/service/callback/upload.go b/service/callback/upload.go index 25390bc..0dd7924 100644 --- a/service/callback/upload.go +++ b/service/callback/upload.go @@ -175,7 +175,7 @@ func (service *OneDriveCallback) PreProcess(c *gin.Context) serializer.Response // SharePoint 会对 Office 文档增加 meta data 导致文件大小不一致,这里增加 1 MB 宽容 // See: https://github.com/OneDrive/onedrive-api-docs/issues/935 - if strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.com") && isSizeCheckFailed && (info.Size > uploadSession.Size) && (info.Size-uploadSession.Size <= 1048576) { + if (strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.com") || strings.Contains(fs.Policy.OptionsSerialized.OdDriver, "sharepoint.cn")) && isSizeCheckFailed && (info.Size > uploadSession.Size) && (info.Size-uploadSession.Size <= 1048576) { isSizeCheckFailed = false } @@ -239,7 +239,7 @@ func (service *S3Callback) PreProcess(c *gin.Context) serializer.Response { return ProcessCallback(service, c) } -// PreProcess 对OneDrive客户端回调进行预处理验证 +// PreProcess 对从机客户端回调进行预处理验证 func (service *UploadCallbackService) PreProcess(c *gin.Context) serializer.Response { // 创建文件系统 fs, err := filesystem.NewFileSystemFromCallback(c) diff --git a/service/explorer/file.go b/service/explorer/file.go index aea0fbf..246d4a7 100644 --- a/service/explorer/file.go +++ b/service/explorer/file.go @@ -178,12 +178,13 @@ func (service *FileAnonymousGetService) Source(ctx context.Context, c *gin.Conte } // 获取文件流 - res, err := fs.SignURL(ctx, &fs.FileTarget[0], - int64(model.GetIntSetting("preview_timeout", 60)), false) + ttl := int64(model.GetIntSetting("preview_timeout", 60)) + res, err := fs.SignURL(ctx, &fs.FileTarget[0], ttl, false) if err != nil { return serializer.Err(serializer.CodeNotSet, err.Error(), err) } + c.Header("Cache-Control", fmt.Sprintf("max-age=%d", ttl)) return serializer.Response{ Code: -302, Data: res, @@ -442,24 +443,48 @@ func (s *ItemIDService) Sources(ctx context.Context, c *gin.Context) serializer. } res := make([]serializer.Sources, 0, len(s.Raw().Items)) - for _, id := range s.Raw().Items { - fs.FileTarget = []model.File{} - sourceURL, err := fs.GetSource(ctx, id) - if len(fs.FileTarget) > 0 { - current := serializer.Sources{ - URL: sourceURL, - Name: fs.FileTarget[0].Name, - Parent: fs.FileTarget[0].FolderID, - } + files, err := model.GetFilesByIDs(s.Raw().Items, fs.User.ID) + if err != nil || len(files) == 0 { + return serializer.Err(serializer.CodeFileNotFound, "", err) + } + getSourceFunc := func(file model.File) (string, error) { + fs.FileTarget = []model.File{file} + return fs.GetSource(ctx, file.ID) + } + + // Create redirected source link if needed + if fs.User.Group.OptionsSerialized.RedirectedSource { + getSourceFunc = func(file model.File) (string, error) { + source, err := file.CreateOrGetSourceLink() if err != nil { - current.Error = err.Error() + return "", err } - res = append(res, current) + sourceLinkURL, err := source.Link() + if err != nil { + return "", err + } + + return sourceLinkURL, nil } } + for _, file := range files { + sourceURL, err := getSourceFunc(file) + current := serializer.Sources{ + URL: sourceURL, + Name: file.Name, + Parent: file.FolderID, + } + + if err != nil { + current.Error = err.Error() + } + + res = append(res, current) + } + return serializer.Response{ Code: 0, Data: res,