Feat: migration DB support custom upgrade scripts
This commit is contained in:
parent
96b84bb5e5
commit
9fc08292a0
11 changed files with 196 additions and 14 deletions
|
@ -2,6 +2,7 @@ package bootstrap
|
|||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/models/scripts"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
|
@ -27,6 +28,12 @@ func Init(path string) {
|
|||
mode string
|
||||
factory func()
|
||||
}{
|
||||
{
|
||||
"both",
|
||||
func() {
|
||||
scripts.Init()
|
||||
},
|
||||
},
|
||||
{
|
||||
"both",
|
||||
func() {
|
||||
|
|
|
@ -2,14 +2,14 @@ package bootstrap
|
|||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudreve/Cloudreve/v3/models/scripts"
|
||||
"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
func RunScript(name string) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
if err := scripts.RunDBScript(name, ctx); err != nil {
|
||||
if err := invoker.RunDBScript(name, ctx); err != nil {
|
||||
util.Log().Error("数据库脚本执行失败: %s", err)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/fatih/color"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/jinzhu/gorm"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 是否需要迁移
|
||||
|
@ -54,6 +59,9 @@ func migration() {
|
|||
// 向设置数据表添加初始设置
|
||||
addDefaultSettings()
|
||||
|
||||
// 执行数据库升级脚本
|
||||
execUpgradeScripts()
|
||||
|
||||
util.Log().Info("数据库初始化结束")
|
||||
|
||||
}
|
||||
|
@ -290,3 +298,17 @@ func addDefaultNode() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func execUpgradeScripts() {
|
||||
s := invoker.ListPrefix("UpgradeTo")
|
||||
versions := make([]*version.Version, len(s))
|
||||
for i, raw := range s {
|
||||
v, _ := version.NewVersion(strings.TrimPrefix(raw, "UpgradeTo"))
|
||||
versions[i] = v
|
||||
}
|
||||
sort.Sort(version.Collection(versions))
|
||||
|
||||
for i := 0; i < len(versions); i++ {
|
||||
invoker.RunDBScript("UpgradeTo"+versions[i].String(), context.Background())
|
||||
}
|
||||
}
|
||||
|
|
9
models/scripts/init.go
Normal file
9
models/scripts/init.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package scripts
|
||||
|
||||
import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
|
||||
|
||||
func Init() {
|
||||
invoker.Register("ResetAdminPassword", ResetAdminPassword(0))
|
||||
invoker.Register("CalibrateUserStorage", UserStorageCalibration(0))
|
||||
invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0))
|
||||
}
|
|
@ -1,8 +1,9 @@
|
|||
package scripts
|
||||
package invoker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
type DBScript interface {
|
||||
|
@ -13,6 +14,7 @@ var availableScripts = make(map[string]DBScript)
|
|||
|
||||
func RunDBScript(name string, ctx context.Context) error {
|
||||
if script, ok := availableScripts[name]; ok {
|
||||
util.Log().Info("开始执行数据库脚本 [%s]", name)
|
||||
script.Run(ctx)
|
||||
return nil
|
||||
}
|
||||
|
@ -20,6 +22,16 @@ func RunDBScript(name string, ctx context.Context) error {
|
|||
return fmt.Errorf("数据库脚本 [%s] 不存在", name)
|
||||
}
|
||||
|
||||
func register(name string, script DBScript) {
|
||||
func Register(name string, script DBScript) {
|
||||
availableScripts[name] = script
|
||||
}
|
||||
|
||||
func ListPrefix(prefix string) []string {
|
||||
var scripts []string
|
||||
for name := range availableScripts {
|
||||
if name[:len(prefix)] == prefix {
|
||||
scripts = append(scripts, name)
|
||||
}
|
||||
}
|
||||
return scripts
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package scripts
|
||||
package invoker
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -35,7 +35,7 @@ func TestMain(m *testing.M) {
|
|||
|
||||
func TestRunDBScript(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
register("test", TestScript(0))
|
||||
Register("test", TestScript(0))
|
||||
|
||||
// 不存在
|
||||
{
|
||||
|
@ -47,3 +47,14 @@ func TestRunDBScript(t *testing.T) {
|
|||
asserts.NoError(RunDBScript("test", context.Background()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListPrefix(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
Register("U1", TestScript(0))
|
||||
Register("U2", TestScript(0))
|
||||
Register("U3", TestScript(0))
|
||||
Register("P1", TestScript(0))
|
||||
|
||||
res := ListPrefix("U")
|
||||
asserts.Len(res, 3)
|
||||
}
|
|
@ -9,10 +9,6 @@ import (
|
|||
|
||||
type ResetAdminPassword int
|
||||
|
||||
func init() {
|
||||
register("ResetAdminPassword", ResetAdminPassword(0))
|
||||
}
|
||||
|
||||
// Run 运行脚本从社区版升级至 Pro 版
|
||||
func (script ResetAdminPassword) Run(ctx context.Context) {
|
||||
// 查找用户
|
||||
|
|
|
@ -8,10 +8,6 @@ import (
|
|||
|
||||
type UserStorageCalibration int
|
||||
|
||||
func init() {
|
||||
register("CalibrateUserStorage", UserStorageCalibration(0))
|
||||
}
|
||||
|
||||
type storageResult struct {
|
||||
Total uint64
|
||||
}
|
||||
|
|
|
@ -2,11 +2,31 @@ package scripts
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
var mockDB *gorm.DB
|
||||
|
||||
// TestMain 初始化数据库Mock
|
||||
func TestMain(m *testing.M) {
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
model.DB, _ = gorm.Open("mysql", db)
|
||||
mockDB = model.DB
|
||||
defer db.Close()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestUserStorageCalibration_Run(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
script := UserStorageCalibration(0)
|
||||
|
|
43
models/scripts/upgrade.go
Normal file
43
models/scripts/upgrade.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type UpgradeTo340 int
|
||||
|
||||
// Run upgrade from older version to 3.4.0
|
||||
func (script UpgradeTo340) Run(ctx context.Context) {
|
||||
// 取回老版本 aria2 设定
|
||||
old := model.GetSettingByType([]string{"aria2"})
|
||||
if len(old) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 写入到新版本的节点设定
|
||||
n, err := model.GetNodeByID(1)
|
||||
if err != nil {
|
||||
util.Log().Error("找不到主机节点, %s", err)
|
||||
}
|
||||
|
||||
n.Aria2Enabled = old["aria2_rpcurl"] != ""
|
||||
n.Aria2OptionsSerialized.Options = old["aria2_options"]
|
||||
n.Aria2OptionsSerialized.Server = old["aria2_rpcurl"]
|
||||
|
||||
interval, err := strconv.Atoi(old["aria2_interval"])
|
||||
if err != nil {
|
||||
interval = 10
|
||||
}
|
||||
n.Aria2OptionsSerialized.Interval = interval
|
||||
n.Aria2OptionsSerialized.TempPath = old["aria2_temp_path"]
|
||||
n.Aria2OptionsSerialized.Token = old["aria2_token"]
|
||||
if err := model.DB.Save(&n).Error; err != nil {
|
||||
util.Log().Error("无法保存主机节点 Aria2 配置信息, %s", err)
|
||||
} else {
|
||||
model.DB.Where("type = ?", "aria2").Delete(model.Setting{})
|
||||
util.Log().Info("Aria2 配置信息已成功迁移至 3.4.0+ 版本的模式")
|
||||
}
|
||||
}
|
66
models/scripts/upgrade_test.go
Normal file
66
models/scripts/upgrade_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package scripts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUpgradeTo340_Run(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
script := UpgradeTo340(0)
|
||||
|
||||
// skip
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}))
|
||||
script.Run(context.Background())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// node not found
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1"))
|
||||
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}))
|
||||
script.Run(context.Background())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// success
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
|
||||
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
|
||||
AddRow("aria2_interval", "expected_aria2_interval").
|
||||
AddRow("aria2_temp_path", "expected_aria2_temp_path").
|
||||
AddRow("aria2_token", "expected_aria2_token").
|
||||
AddRow("aria2_options", "{}"))
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
script.Run(context.Background())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// failed
|
||||
{
|
||||
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
|
||||
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
|
||||
AddRow("aria2_interval", "expected_aria2_interval").
|
||||
AddRow("aria2_temp_path", "expected_aria2_temp_path").
|
||||
AddRow("aria2_token", "expected_aria2_token").
|
||||
AddRow("aria2_options", "{}"))
|
||||
|
||||
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
|
||||
mock.ExpectRollback()
|
||||
script.Run(context.Background())
|
||||
a.NoError(mock.ExpectationsWereMet())
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue