diff --git a/bootstrap/init.go b/bootstrap/init.go index 0a51835..6463164 100644 --- a/bootstrap/init.go +++ b/bootstrap/init.go @@ -2,6 +2,7 @@ package bootstrap import ( model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/models/scripts" "github.com/cloudreve/Cloudreve/v3/pkg/aria2" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/cache" @@ -27,6 +28,12 @@ func Init(path string) { mode string factory func() }{ + { + "both", + func() { + scripts.Init() + }, + }, { "both", func() { diff --git a/bootstrap/script.go b/bootstrap/script.go index 9168dfa..7db59e8 100644 --- a/bootstrap/script.go +++ b/bootstrap/script.go @@ -2,14 +2,14 @@ package bootstrap import ( "context" - "github.com/cloudreve/Cloudreve/v3/models/scripts" + "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" "github.com/cloudreve/Cloudreve/v3/pkg/util" ) func RunScript(name string) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := scripts.RunDBScript(name, ctx); err != nil { + if err := invoker.RunDBScript(name, ctx); err != nil { util.Log().Error("数据库脚本执行失败: %s", err) return } diff --git a/models/migration.go b/models/migration.go index 9d2af47..be794d8 100644 --- a/models/migration.go +++ b/models/migration.go @@ -1,12 +1,17 @@ package model import ( + "context" + "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" "github.com/cloudreve/Cloudreve/v3/pkg/cache" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/fatih/color" "github.com/gofrs/uuid" + "github.com/hashicorp/go-version" "github.com/jinzhu/gorm" + "sort" + "strings" ) // 是否需要迁移 @@ -54,6 +59,9 @@ func migration() { // 向设置数据表添加初始设置 addDefaultSettings() + // 执行数据库升级脚本 + execUpgradeScripts() + util.Log().Info("数据库初始化结束") } @@ -290,3 +298,17 @@ func addDefaultNode() { } } } + +func execUpgradeScripts() { + s := invoker.ListPrefix("UpgradeTo") + versions := make([]*version.Version, len(s)) + for i, raw := range s { + v, _ := version.NewVersion(strings.TrimPrefix(raw, "UpgradeTo")) + versions[i] = v + } + sort.Sort(version.Collection(versions)) + + for i := 0; i < len(versions); i++ { + invoker.RunDBScript("UpgradeTo"+versions[i].String(), context.Background()) + } +} diff --git a/models/scripts/init.go b/models/scripts/init.go new file mode 100644 index 0000000..7c375bf --- /dev/null +++ b/models/scripts/init.go @@ -0,0 +1,9 @@ +package scripts + +import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" + +func Init() { + invoker.Register("ResetAdminPassword", ResetAdminPassword(0)) + invoker.Register("CalibrateUserStorage", UserStorageCalibration(0)) + invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0)) +} diff --git a/models/scripts/invoker.go b/models/scripts/invoker/invoker.go similarity index 50% rename from models/scripts/invoker.go rename to models/scripts/invoker/invoker.go index af0155b..adb2f97 100644 --- a/models/scripts/invoker.go +++ b/models/scripts/invoker/invoker.go @@ -1,8 +1,9 @@ -package scripts +package invoker import ( "context" "fmt" + "github.com/cloudreve/Cloudreve/v3/pkg/util" ) type DBScript interface { @@ -13,6 +14,7 @@ var availableScripts = make(map[string]DBScript) func RunDBScript(name string, ctx context.Context) error { if script, ok := availableScripts[name]; ok { + util.Log().Info("开始执行数据库脚本 [%s]", name) script.Run(ctx) return nil } @@ -20,6 +22,16 @@ func RunDBScript(name string, ctx context.Context) error { return fmt.Errorf("数据库脚本 [%s] 不存在", name) } -func register(name string, script DBScript) { +func Register(name string, script DBScript) { availableScripts[name] = script } + +func ListPrefix(prefix string) []string { + var scripts []string + for name := range availableScripts { + if name[:len(prefix)] == prefix { + scripts = append(scripts, name) + } + } + return scripts +} diff --git a/models/scripts/invoker_test.go b/models/scripts/invoker/invoker_test.go similarity index 75% rename from models/scripts/invoker_test.go rename to models/scripts/invoker/invoker_test.go index 0ca324b..73728dd 100644 --- a/models/scripts/invoker_test.go +++ b/models/scripts/invoker/invoker_test.go @@ -1,4 +1,4 @@ -package scripts +package invoker import ( "context" @@ -35,7 +35,7 @@ func TestMain(m *testing.M) { func TestRunDBScript(t *testing.T) { asserts := assert.New(t) - register("test", TestScript(0)) + Register("test", TestScript(0)) // 不存在 { @@ -47,3 +47,14 @@ func TestRunDBScript(t *testing.T) { asserts.NoError(RunDBScript("test", context.Background())) } } + +func TestListPrefix(t *testing.T) { + asserts := assert.New(t) + Register("U1", TestScript(0)) + Register("U2", TestScript(0)) + Register("U3", TestScript(0)) + Register("P1", TestScript(0)) + + res := ListPrefix("U") + asserts.Len(res, 3) +} diff --git a/models/scripts/reset.go b/models/scripts/reset.go index d5747db..88ee25d 100644 --- a/models/scripts/reset.go +++ b/models/scripts/reset.go @@ -9,10 +9,6 @@ import ( type ResetAdminPassword int -func init() { - register("ResetAdminPassword", ResetAdminPassword(0)) -} - // Run 运行脚本从社区版升级至 Pro 版 func (script ResetAdminPassword) Run(ctx context.Context) { // 查找用户 diff --git a/models/scripts/storage.go b/models/scripts/storage.go index 9e152d5..6a15567 100644 --- a/models/scripts/storage.go +++ b/models/scripts/storage.go @@ -8,10 +8,6 @@ import ( type UserStorageCalibration int -func init() { - register("CalibrateUserStorage", UserStorageCalibration(0)) -} - type storageResult struct { Total uint64 } diff --git a/models/scripts/storage_test.go b/models/scripts/storage_test.go index 7287724..da50e5b 100644 --- a/models/scripts/storage_test.go +++ b/models/scripts/storage_test.go @@ -2,11 +2,31 @@ package scripts import ( "context" + "database/sql" "github.com/DATA-DOG/go-sqlmock" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" "testing" ) +var mock sqlmock.Sqlmock +var mockDB *gorm.DB + +// TestMain 初始化数据库Mock +func TestMain(m *testing.M) { + var db *sql.DB + var err error + db, mock, err = sqlmock.New() + if err != nil { + panic("An error was not expected when opening a stub database connection") + } + model.DB, _ = gorm.Open("mysql", db) + mockDB = model.DB + defer db.Close() + m.Run() +} + func TestUserStorageCalibration_Run(t *testing.T) { asserts := assert.New(t) script := UserStorageCalibration(0) diff --git a/models/scripts/upgrade.go b/models/scripts/upgrade.go new file mode 100644 index 0000000..717a72e --- /dev/null +++ b/models/scripts/upgrade.go @@ -0,0 +1,43 @@ +package scripts + +import ( + "context" + model "github.com/cloudreve/Cloudreve/v3/models" + "github.com/cloudreve/Cloudreve/v3/pkg/util" + "strconv" +) + +type UpgradeTo340 int + +// Run upgrade from older version to 3.4.0 +func (script UpgradeTo340) Run(ctx context.Context) { + // 取回老版本 aria2 设定 + old := model.GetSettingByType([]string{"aria2"}) + if len(old) == 0 { + return + } + + // 写入到新版本的节点设定 + n, err := model.GetNodeByID(1) + if err != nil { + util.Log().Error("找不到主机节点, %s", err) + } + + n.Aria2Enabled = old["aria2_rpcurl"] != "" + n.Aria2OptionsSerialized.Options = old["aria2_options"] + n.Aria2OptionsSerialized.Server = old["aria2_rpcurl"] + + interval, err := strconv.Atoi(old["aria2_interval"]) + if err != nil { + interval = 10 + } + n.Aria2OptionsSerialized.Interval = interval + n.Aria2OptionsSerialized.TempPath = old["aria2_temp_path"] + n.Aria2OptionsSerialized.Token = old["aria2_token"] + if err := model.DB.Save(&n).Error; err != nil { + util.Log().Error("无法保存主机节点 Aria2 配置信息, %s", err) + } else { + model.DB.Where("type = ?", "aria2").Delete(model.Setting{}) + util.Log().Info("Aria2 配置信息已成功迁移至 3.4.0+ 版本的模式") + } +} diff --git a/models/scripts/upgrade_test.go b/models/scripts/upgrade_test.go new file mode 100644 index 0000000..8f7adba --- /dev/null +++ b/models/scripts/upgrade_test.go @@ -0,0 +1,66 @@ +package scripts + +import ( + "context" + "errors" + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestUpgradeTo340_Run(t *testing.T) { + a := assert.New(t) + script := UpgradeTo340(0) + + // skip + { + mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"})) + script.Run(context.Background()) + a.NoError(mock.ExpectationsWereMet()) + } + + // node not found + { + mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1")) + mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"})) + script.Run(context.Background()) + a.NoError(mock.ExpectationsWereMet()) + } + + // success + { + mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}). + AddRow("aria2_rpcurl", "expected_aria2_rpcurl"). + AddRow("aria2_interval", "expected_aria2_interval"). + AddRow("aria2_temp_path", "expected_aria2_temp_path"). + AddRow("aria2_token", "expected_aria2_token"). + AddRow("aria2_options", "{}")) + + mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + script.Run(context.Background()) + a.NoError(mock.ExpectationsWereMet()) + } + + // failed + { + mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}). + AddRow("aria2_rpcurl", "expected_aria2_rpcurl"). + AddRow("aria2_interval", "expected_aria2_interval"). + AddRow("aria2_temp_path", "expected_aria2_temp_path"). + AddRow("aria2_token", "expected_aria2_token"). + AddRow("aria2_options", "{}")) + + mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + mock.ExpectBegin() + mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) + mock.ExpectRollback() + script.Run(context.Background()) + a.NoError(mock.ExpectationsWereMet()) + } +}