Feat: migration DB support custom upgrade scripts

This commit is contained in:
HFO4 2021-11-22 19:53:42 +08:00
parent 96b84bb5e5
commit 9fc08292a0
11 changed files with 196 additions and 14 deletions

View file

@ -2,6 +2,7 @@ package bootstrap
import ( import (
model "github.com/cloudreve/Cloudreve/v3/models" model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/models/scripts"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/aria2"
"github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
@ -27,6 +28,12 @@ func Init(path string) {
mode string mode string
factory func() factory func()
}{ }{
{
"both",
func() {
scripts.Init()
},
},
{ {
"both", "both",
func() { func() {

View file

@ -2,14 +2,14 @@ package bootstrap
import ( import (
"context" "context"
"github.com/cloudreve/Cloudreve/v3/models/scripts" "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
) )
func RunScript(name string) { func RunScript(name string) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
if err := scripts.RunDBScript(name, ctx); err != nil { if err := invoker.RunDBScript(name, ctx); err != nil {
util.Log().Error("数据库脚本执行失败: %s", err) util.Log().Error("数据库脚本执行失败: %s", err)
return return
} }

View file

@ -1,12 +1,17 @@
package model package model
import ( import (
"context"
"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
"github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"github.com/hashicorp/go-version"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"sort"
"strings"
) )
// 是否需要迁移 // 是否需要迁移
@ -54,6 +59,9 @@ func migration() {
// 向设置数据表添加初始设置 // 向设置数据表添加初始设置
addDefaultSettings() addDefaultSettings()
// 执行数据库升级脚本
execUpgradeScripts()
util.Log().Info("数据库初始化结束") util.Log().Info("数据库初始化结束")
} }
@ -290,3 +298,17 @@ func addDefaultNode() {
} }
} }
} }
func execUpgradeScripts() {
s := invoker.ListPrefix("UpgradeTo")
versions := make([]*version.Version, len(s))
for i, raw := range s {
v, _ := version.NewVersion(strings.TrimPrefix(raw, "UpgradeTo"))
versions[i] = v
}
sort.Sort(version.Collection(versions))
for i := 0; i < len(versions); i++ {
invoker.RunDBScript("UpgradeTo"+versions[i].String(), context.Background())
}
}

9
models/scripts/init.go Normal file
View file

@ -0,0 +1,9 @@
package scripts
import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker"
func Init() {
invoker.Register("ResetAdminPassword", ResetAdminPassword(0))
invoker.Register("CalibrateUserStorage", UserStorageCalibration(0))
invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0))
}

View file

@ -1,8 +1,9 @@
package scripts package invoker
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
) )
type DBScript interface { type DBScript interface {
@ -13,6 +14,7 @@ var availableScripts = make(map[string]DBScript)
func RunDBScript(name string, ctx context.Context) error { func RunDBScript(name string, ctx context.Context) error {
if script, ok := availableScripts[name]; ok { if script, ok := availableScripts[name]; ok {
util.Log().Info("开始执行数据库脚本 [%s]", name)
script.Run(ctx) script.Run(ctx)
return nil return nil
} }
@ -20,6 +22,16 @@ func RunDBScript(name string, ctx context.Context) error {
return fmt.Errorf("数据库脚本 [%s] 不存在", name) return fmt.Errorf("数据库脚本 [%s] 不存在", name)
} }
func register(name string, script DBScript) { func Register(name string, script DBScript) {
availableScripts[name] = script availableScripts[name] = script
} }
func ListPrefix(prefix string) []string {
var scripts []string
for name := range availableScripts {
if name[:len(prefix)] == prefix {
scripts = append(scripts, name)
}
}
return scripts
}

View file

@ -1,4 +1,4 @@
package scripts package invoker
import ( import (
"context" "context"
@ -35,7 +35,7 @@ func TestMain(m *testing.M) {
func TestRunDBScript(t *testing.T) { func TestRunDBScript(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)
register("test", TestScript(0)) Register("test", TestScript(0))
// 不存在 // 不存在
{ {
@ -47,3 +47,14 @@ func TestRunDBScript(t *testing.T) {
asserts.NoError(RunDBScript("test", context.Background())) asserts.NoError(RunDBScript("test", context.Background()))
} }
} }
func TestListPrefix(t *testing.T) {
asserts := assert.New(t)
Register("U1", TestScript(0))
Register("U2", TestScript(0))
Register("U3", TestScript(0))
Register("P1", TestScript(0))
res := ListPrefix("U")
asserts.Len(res, 3)
}

View file

@ -9,10 +9,6 @@ import (
type ResetAdminPassword int type ResetAdminPassword int
func init() {
register("ResetAdminPassword", ResetAdminPassword(0))
}
// Run 运行脚本从社区版升级至 Pro 版 // Run 运行脚本从社区版升级至 Pro 版
func (script ResetAdminPassword) Run(ctx context.Context) { func (script ResetAdminPassword) Run(ctx context.Context) {
// 查找用户 // 查找用户

View file

@ -8,10 +8,6 @@ import (
type UserStorageCalibration int type UserStorageCalibration int
func init() {
register("CalibrateUserStorage", UserStorageCalibration(0))
}
type storageResult struct { type storageResult struct {
Total uint64 Total uint64
} }

View file

@ -2,11 +2,31 @@ package scripts
import ( import (
"context" "context"
"database/sql"
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
var mock sqlmock.Sqlmock
var mockDB *gorm.DB
// TestMain 初始化数据库Mock
func TestMain(m *testing.M) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()
if err != nil {
panic("An error was not expected when opening a stub database connection")
}
model.DB, _ = gorm.Open("mysql", db)
mockDB = model.DB
defer db.Close()
m.Run()
}
func TestUserStorageCalibration_Run(t *testing.T) { func TestUserStorageCalibration_Run(t *testing.T) {
asserts := assert.New(t) asserts := assert.New(t)
script := UserStorageCalibration(0) script := UserStorageCalibration(0)

43
models/scripts/upgrade.go Normal file
View file

@ -0,0 +1,43 @@
package scripts
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"strconv"
)
type UpgradeTo340 int
// Run upgrade from older version to 3.4.0
func (script UpgradeTo340) Run(ctx context.Context) {
// 取回老版本 aria2 设定
old := model.GetSettingByType([]string{"aria2"})
if len(old) == 0 {
return
}
// 写入到新版本的节点设定
n, err := model.GetNodeByID(1)
if err != nil {
util.Log().Error("找不到主机节点, %s", err)
}
n.Aria2Enabled = old["aria2_rpcurl"] != ""
n.Aria2OptionsSerialized.Options = old["aria2_options"]
n.Aria2OptionsSerialized.Server = old["aria2_rpcurl"]
interval, err := strconv.Atoi(old["aria2_interval"])
if err != nil {
interval = 10
}
n.Aria2OptionsSerialized.Interval = interval
n.Aria2OptionsSerialized.TempPath = old["aria2_temp_path"]
n.Aria2OptionsSerialized.Token = old["aria2_token"]
if err := model.DB.Save(&n).Error; err != nil {
util.Log().Error("无法保存主机节点 Aria2 配置信息, %s", err)
} else {
model.DB.Where("type = ?", "aria2").Delete(model.Setting{})
util.Log().Info("Aria2 配置信息已成功迁移至 3.4.0+ 版本的模式")
}
}

View file

@ -0,0 +1,66 @@
package scripts
import (
"context"
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"testing"
)
func TestUpgradeTo340_Run(t *testing.T) {
a := assert.New(t)
script := UpgradeTo340(0)
// skip
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}))
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
// node not found
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1"))
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}))
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
// success
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
AddRow("aria2_interval", "expected_aria2_interval").
AddRow("aria2_temp_path", "expected_aria2_temp_path").
AddRow("aria2_token", "expected_aria2_token").
AddRow("aria2_options", "{}"))
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
// failed
{
mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}).
AddRow("aria2_rpcurl", "expected_aria2_rpcurl").
AddRow("aria2_interval", "expected_aria2_interval").
AddRow("aria2_temp_path", "expected_aria2_temp_path").
AddRow("aria2_token", "expected_aria2_token").
AddRow("aria2_options", "{}"))
mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
mock.ExpectBegin()
mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error"))
mock.ExpectRollback()
script.Run(context.Background())
a.NoError(mock.ExpectationsWereMet())
}
}