Feat: migration DB support custom upgrade scripts
This commit is contained in:
parent
96b84bb5e5
commit
9fc08292a0
11 changed files with 196 additions and 14 deletions
|
@ -2,6 +2,7 @@ package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/models/scripts"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||||
|
@ -27,6 +28,12 @@ func Init(path string) {
|
||||||
mode string
|
mode string
|
||||||
factory func()
|
factory func()
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
"both",
|
||||||
|
func() {
|
||||||
|
scripts.Init()
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"both",
|
"both",
|
||||||
func() {
|
func() {
|
||||||
|
|
|
@ -2,14 +2,14 @@ package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/cloudreve/Cloudreve/v3/models/scripts"
|
"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RunScript(name string) {
|
func RunScript(name string) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := scripts.RunDBScript(name, ctx); err != nil {
|
if err := invoker.RunDBScript(name, ctx); err != nil {
|
||||||
util.Log().Error("数据库脚本执行失败: %s", err)
|
util.Log().Error("数据库脚本执行失败: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,17 @@
|
||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/gofrs/uuid"
|
"github.com/gofrs/uuid"
|
||||||
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 是否需要迁移
|
// 是否需要迁移
|
||||||
|
@ -54,6 +59,9 @@ func migration() {
|
||||||
// 向设置数据表添加初始设置
|
// 向设置数据表添加初始设置
|
||||||
addDefaultSettings()
|
addDefaultSettings()
|
||||||
|
|
||||||
|
// 执行数据库升级脚本
|
||||||
|
execUpgradeScripts()
|
||||||
|
|
||||||
util.Log().Info("数据库初始化结束")
|
util.Log().Info("数据库初始化结束")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -290,3 +298,17 @@ func addDefaultNode() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func execUpgradeScripts() {
|
||||||
|
s := invoker.ListPrefix("UpgradeTo")
|
||||||
|
versions := make([]*version.Version, len(s))
|
||||||
|
for i, raw := range s {
|
||||||
|
v, _ := version.NewVersion(strings.TrimPrefix(raw, "UpgradeTo"))
|
||||||
|
versions[i] = v
|
||||||
|
}
|
||||||
|
sort.Sort(version.Collection(versions))
|
||||||
|
|
||||||
|
for i := 0; i < len(versions); i++ {
|
||||||
|
invoker.RunDBScript("UpgradeTo"+versions[i].String(), context.Background())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
9
models/scripts/init.go
Normal file
9
models/scripts/init.go
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
package scripts
|
||||||
|
|
||||||
|
import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
|
||||||
|
|
||||||
|
func Init() {
|
||||||
|
invoker.Register("ResetAdminPassword", ResetAdminPassword(0))
|
||||||
|
invoker.Register("CalibrateUserStorage", UserStorageCalibration(0))
|
||||||
|
invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0))
|
||||||
|
}
|
|
@ -1,8 +1,9 @@
|
||||||
package scripts
|
package invoker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DBScript interface {
|
type DBScript interface {
|
||||||
|
@ -13,6 +14,7 @@ var availableScripts = make(map[string]DBScript)
|
||||||
|
|
||||||
func RunDBScript(name string, ctx context.Context) error {
|
func RunDBScript(name string, ctx context.Context) error {
|
||||||
if script, ok := availableScripts[name]; ok {
|
if script, ok := availableScripts[name]; ok {
|
||||||
|
util.Log().Info("开始执行数据库脚本 [%s]", name)
|
||||||
script.Run(ctx)
|
script.Run(ctx)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -20,6 +22,16 @@ func RunDBScript(name string, ctx context.Context) error {
|
||||||
return fmt.Errorf("数据库脚本 [%s] 不存在", name)
|
return fmt.Errorf("数据库脚本 [%s] 不存在", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func register(name string, script DBScript) {
|
func Register(name string, script DBScript) {
|
||||||
availableScripts[name] = script
|
availableScripts[name] = script
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ListPrefix(prefix string) []string {
|
||||||
|
var scripts []string
|
||||||
|
for name := range availableScripts {
|
||||||
|
if name[:len(prefix)] == prefix {
|
||||||
|
scripts = append(scripts, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scripts
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package scripts
|
package invoker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -35,7 +35,7 @@ func TestMain(m *testing.M) {
|
||||||
|
|
||||||
func TestRunDBScript(t *testing.T) {
|
func TestRunDBScript(t *testing.T) {
|
||||||
asserts := assert.New(t)
|
asserts := assert.New(t)
|
||||||
register("test", TestScript(0))
|
Register("test", TestScript(0))
|
||||||
|
|
||||||
// 不存在
|
// 不存在
|
||||||
{
|
{
|
||||||
|
@ -47,3 +47,14 @@ func TestRunDBScript(t *testing.T) {
|
||||||
asserts.NoError(RunDBScript("test", context.Background()))
|
asserts.NoError(RunDBScript("test", context.Background()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestListPrefix(t *testing.T) {
|
||||||
|
asserts := assert.New(t)
|
||||||
|
Register("U1", TestScript(0))
|
||||||
|
Register("U2", TestScript(0))
|
||||||
|
Register("U3", TestScript(0))
|
||||||
|
Register("P1", TestScript(0))
|
||||||
|
|
||||||
|
res := ListPrefix("U")
|
||||||
|
asserts.Len(res, 3)
|
||||||
|
}
|
|
@ -9,10 +9,6 @@ import (
|
||||||
|
|
||||||
type ResetAdminPassword int
|
type ResetAdminPassword int
|
||||||
|
|
||||||
func init() {
|
|
||||||
register("ResetAdminPassword", ResetAdminPassword(0))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run 运行脚本从社区版升级至 Pro 版
|
// Run 运行脚本从社区版升级至 Pro 版
|
||||||
func (script ResetAdminPassword) Run(ctx context.Context) {
|
func (script ResetAdminPassword) Run(ctx context.Context) {
|
||||||
// 查找用户
|
// 查找用户
|
||||||
|
|
|
@ -8,10 +8,6 @@ import (
|
||||||
|
|
||||||
type UserStorageCalibration int
|
type UserStorageCalibration int
|
||||||
|
|
||||||
func init() {
|
|
||||||
register("CalibrateUserStorage", UserStorageCalibration(0))
|
|
||||||
}
|
|
||||||
|
|
||||||
type storageResult struct {
|
type storageResult struct {
|
||||||
Total uint64
|
Total uint64
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,11 +2,31 @@ package scripts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var mock sqlmock.Sqlmock
|
||||||
|
var mockDB *gorm.DB
|
||||||
|
|
||||||
|
// TestMain 初始化数据库Mock
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
var db *sql.DB
|
||||||
|
var err error
|
||||||
|
db, mock, err = sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
panic("An error was not expected when opening a stub database connection")
|
||||||
|
}
|
||||||
|
model.DB, _ = gorm.Open("mysql", db)
|
||||||
|
mockDB = model.DB
|
||||||
|
defer db.Close()
|
||||||
|
m.Run()
|
||||||
|
}
|
||||||
|
|
||||||
func TestUserStorageCalibration_Run(t *testing.T) {
|
func TestUserStorageCalibration_Run(t *testing.T) {
|
||||||
asserts := assert.New(t)
|
asserts := assert.New(t)
|
||||||
script := UserStorageCalibration(0)
|
script := UserStorageCalibration(0)
|
||||||
|
|
43
models/scripts/upgrade.go
Normal file
43
models/scripts/upgrade.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package scripts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||||
|
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UpgradeTo340 int
|
||||||
|
|
||||||
|
// Run upgrade from older version to 3.4.0
|
||||||
|
func (script UpgradeTo340) Run(ctx context.Context) {
|
||||||
|
// 取回老版本 aria2 设定
|
||||||
|
old := model.GetSettingByType([]string{"aria2"})
|
||||||
|
if len(old) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入到新版本的节点设定
|
||||||
|
n, err := model.GetNodeByID(1)
|
||||||
|
if err != nil {
|
||||||
|
util.Log().Error("找不到主机节点, %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n.Aria2Enabled = old["aria2_rpcurl"] != ""
|
||||||
|
n.Aria2OptionsSerialized.Options = old["aria2_options"]
|
||||||
|
n.Aria2OptionsSerialized.Server = old["aria2_rpcurl"]
|
||||||
|
|
||||||
|
interval, err := strconv.Atoi(old["aria2_interval"])
|
||||||
|
if err != nil {
|
||||||
|
interval = 10
|
||||||
|
}
|
||||||
|
n.Aria2OptionsSerialized.Interval = interval
|
||||||
|
n.Aria2OptionsSerialized.TempPath = old["aria2_temp_path"]
|
||||||
|
n.Aria2OptionsSerialized.Token = old["aria2_token"]
|
||||||
|
if err := model.DB.Save(&n).Error; err != nil {
|
||||||
|
util.Log().Error("无法保存主机节点 Aria2 配置信息, %s", err)
|
||||||
|
} else {
|
||||||
|
model.DB.Where("type = ?", "aria2").Delete(model.Setting{})
|
||||||
|
util.Log().Info("Aria2 配置信息已成功迁移至 3.4.0+ 版本的模式")
|
||||||
|
}
|
||||||
|
}
|
66
models/scripts/upgrade_test.go
Normal file
66
models/scripts/upgrade_test.go
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
package scripts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpgradeTo340_Run(t *testing.T) {
|
||||||
|
a := assert.New(t)
|
||||||
|
script := UpgradeTo340(0)
|
||||||
|
|
||||||
|
// skip
|
||||||
|
{
|
||||||
|
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||||
|
script.Run(context.Background())
|
||||||
|
a.NoError(mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// node not found
|
||||||
|
{
|
||||||
|
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1"))
|
||||||
|
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||||
|
script.Run(context.Background())
|
||||||
|
a.NoError(mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// success
|
||||||
|
{
|
||||||
|
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
|
||||||
|
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
|
||||||
|
AddRow("aria2_interval", "expected_aria2_interval").
|
||||||
|
AddRow("aria2_temp_path", "expected_aria2_temp_path").
|
||||||
|
AddRow("aria2_token", "expected_aria2_token").
|
||||||
|
AddRow("aria2_options", "{}"))
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
script.Run(context.Background())
|
||||||
|
a.NoError(mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// failed
|
||||||
|
{
|
||||||
|
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
|
||||||
|
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
|
||||||
|
AddRow("aria2_interval", "expected_aria2_interval").
|
||||||
|
AddRow("aria2_temp_path", "expected_aria2_temp_path").
|
||||||
|
AddRow("aria2_token", "expected_aria2_token").
|
||||||
|
AddRow("aria2_options", "{}"))
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||||
|
mock.ExpectRollback()
|
||||||
|
script.Run(context.Background())
|
||||||
|
a.NoError(mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue