From c7e47293db28a75d193378d8ab6a328dd3201e4c Mon Sep 17 00:00:00 2001 From: HFO4 <912394456@qq.com> Date: Wed, 13 Nov 2019 16:28:14 +0800 Subject: [PATCH] Add: model - group / Test: group --- models/group.go | 37 +++++++++++++++++++++++++ models/group_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++ models/migration.go | 46 +++++++++++++++++++++++++++++-- models/user.go | 38 ++++++++++++++----------- models/user_test.go | 18 ++++++++++-- service/user/login.go | 3 ++ 6 files changed, 185 insertions(+), 21 deletions(-) create mode 100644 models/group.go create mode 100644 models/group_test.go diff --git a/models/group.go b/models/group.go new file mode 100644 index 0000000..2338c16 --- /dev/null +++ b/models/group.go @@ -0,0 +1,37 @@ +package model + +import ( + "encoding/json" + "github.com/jinzhu/gorm" +) + +// Group 用户组模型 +type Group struct { + gorm.Model + Name string + Policies string + MaxStorage uint64 + SpeedLimit int + ShareEnabled bool + RangeTransferEnabled bool + WebDAVEnabled bool + Aria2Option string + Color string + + // 数据库忽略字段 + PolicyList []int `gorm:"-"` +} + +// GetGroupByID 用ID获取用户组 +func GetGroupByID(ID interface{}) (Group, error) { + var group Group + result := DB.First(&group, ID) + return group, result.Error +} + +// AfterFind 找到用户组后的钩子,处理Policy列表 +func (group *Group) AfterFind() (err error) { + // 解析用户设置到OptionsSerialized + err = json.Unmarshal([]byte(group.Policies), &group.PolicyList) + return err +} diff --git a/models/group_test.go b/models/group_test.go new file mode 100644 index 0000000..9e28c19 --- /dev/null +++ b/models/group_test.go @@ -0,0 +1,64 @@ +package model + +import ( + "github.com/DATA-DOG/go-sqlmock" + "github.com/jinzhu/gorm" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGetGroupByID(t *testing.T) { + asserts := assert.New(t) + + //找到用户组时 + groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). + AddRow(1, "管理员", "[1]") + mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) + + group, err := GetGroupByID(1) + asserts.NoError(err) + asserts.Equal(Group{ + Model: gorm.Model{ + ID: 1, + }, + Name: "管理员", + Policies: "[1]", + PolicyList: []int{1}, + }, group) + + //未找到用户时 + mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found")) + group, err = GetGroupByID(1) + asserts.Error(err) + asserts.Equal(Group{}, group) +} + +func TestGroup_AfterFind(t *testing.T) { + asserts := assert.New(t) + + testCase := Group{ + Model: gorm.Model{ + ID: 1, + }, + Name: "管理员", + Policies: "[1]", + } + err := testCase.AfterFind() + asserts.NoError(err) + asserts.Equal(testCase.PolicyList, []int{1}) + + testCase.Policies = "[1,2,3,4,5]" + err = testCase.AfterFind() + asserts.NoError(err) + asserts.Equal(testCase.PolicyList, []int{1, 2, 3, 4, 5}) + + testCase.Policies = "[1,2,3,4,5" + err = testCase.AfterFind() + asserts.Error(err) + + testCase.Policies = "[]" + err = testCase.AfterFind() + asserts.NoError(err) + asserts.Equal(testCase.PolicyList, []int{}) +} diff --git a/models/migration.go b/models/migration.go index fad5022..401e231 100644 --- a/models/migration.go +++ b/models/migration.go @@ -25,7 +25,10 @@ func migration() { util.Log().Info("开始进行数据库自动迁移...") // 自动迁移模式 - DB.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}, &Setting{}) + DB.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}, &Setting{}, &Group{}) + + // 创建初始用户组 + addDefaultGroups() // 创建初始管理员账户 addDefaultUser() @@ -121,6 +124,45 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti } } +func addDefaultGroups() { + _, err := GetGroupByID(1) + // 未找到初始管理组时,则创建 + if gorm.IsRecordNotFoundError(err) { + defaultAdminGroup := Group{ + Name: "管理员", + Policies: "[1]", + MaxStorage: 1 * 1024 * 1024 * 1024, + ShareEnabled: true, + Color: "danger", + RangeTransferEnabled: true, + WebDAVEnabled: true, + Aria2Option: "0,0,0", + } + if err := DB.Create(&defaultAdminGroup).Error; err != nil { + util.Log().Panic("无法创建管理用户组, ", err) + } + } + + err = nil + _, err = GetGroupByID(2) + // 未找到初始注册会员时,则创建 + if gorm.IsRecordNotFoundError(err) { + defaultAdminGroup := Group{ + Name: "注册会员", + Policies: "[1]", + MaxStorage: 1 * 1024 * 1024 * 1024, + ShareEnabled: true, + Color: "danger", + RangeTransferEnabled: true, + WebDAVEnabled: true, + Aria2Option: "0,0,0", + } + if err := DB.Create(&defaultAdminGroup).Error; err != nil { + util.Log().Panic("无法创建初始注册会员用户组, ", err) + } + } +} + func addDefaultUser() { _, err := GetUserByID(1) @@ -131,7 +173,7 @@ func addDefaultUser() { defaultUser.Email = "admin@cloudreve.org" defaultUser.Nick = "admin" defaultUser.Status = Active - defaultUser.Group = 1 + defaultUser.GroupID = 1 defaultUser.PrimaryGroup = 1 err := defaultUser.SetPassword("admin") if err != nil { diff --git a/models/user.go b/models/user.go index 5d8a1bf..22e47cd 100644 --- a/models/user.go +++ b/models/user.go @@ -22,21 +22,27 @@ const ( // User 用户模型 type User struct { + // 表字段 gorm.Model - Email string `gorm:"type:varchar(100);unique_index"` - Nick string `gorm:"size:50"` - Password string `json:"-"` - Status int - Group int - PrimaryGroup int - ActivationKey string `json:"-"` - Storage int64 - LastNotify *time.Time - OpenID string `json:"-"` - TwoFactor string `json:"-"` - Delay int - Avatar string - Options string `json:"-",gorm:"size:4096"` + Email string `gorm:"type:varchar(100);unique_index"` + Nick string `gorm:"size:50"` + Password string `json:"-"` + Status int + GroupID uint + PrimaryGroup int + ActivationKey string `json:"-"` + Storage uint64 + LastNotify *time.Time + OpenID string `json:"-"` + TwoFactor string `json:"-"` + Delay int + Avatar string + Options string `json:"-",gorm:"size:4096"` + + // 关联模型 + Group Group + + // 数据库忽略字段 OptionsSerialized UserOption `gorm:"-"` } @@ -49,14 +55,14 @@ type UserOption struct { // GetUserByID 用ID获取用户 func GetUserByID(ID interface{}) (User, error) { var user User - result := DB.First(&user, ID) + result := DB.Set("gorm:auto_preload", true).First(&user, ID) return user, result.Error } // GetUserByEmail 用Email获取用户 func GetUserByEmail(email string) (User, error) { var user User - result := DB.Where("email = ?", email).First(&user) + result := DB.Set("gorm:auto_preload", true).Where("email = ?", email).First(&user) return user, result.Error } diff --git a/models/user_test.go b/models/user_test.go index 5d7f295..afc7b63 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -13,10 +13,13 @@ func TestGetUserByID(t *testing.T) { asserts := assert.New(t) //找到用户时 - rows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options"}). - AddRow(1, nil, "admin@cloudreve.org", "{}") + userRows := sqlmock.NewRows([]string{"id", "deleted_at", "email", "options", "group_id"}). + AddRow(1, nil, "admin@cloudreve.org", "{}", 1) + mock.ExpectQuery("^SELECT (.+)").WillReturnRows(userRows) - mock.ExpectQuery("^SELECT (.+)").WillReturnRows(rows) + groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). + AddRow(1, "管理员", "[1]") + mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) user, err := GetUserByID(1) asserts.NoError(err) @@ -27,6 +30,15 @@ func TestGetUserByID(t *testing.T) { }, Email: "admin@cloudreve.org", Options: "{}", + GroupID: 1, + Group: Group{ + Model: gorm.Model{ + ID: 1, + }, + Name: "管理员", + Policies: "[1]", + PolicyList: []int{1}, + }, }, user) //未找到用户时 diff --git a/service/user/login.go b/service/user/login.go index e1fe73e..1c6f8b9 100644 --- a/service/user/login.go +++ b/service/user/login.go @@ -4,6 +4,7 @@ import ( "cloudreve/models" "cloudreve/pkg/serializer" "cloudreve/pkg/util" + "fmt" "github.com/gin-gonic/gin" ) @@ -47,6 +48,8 @@ func (service *UserLoginService) Login(c *gin.Context) serializer.Response { "user_id": expectedUser.ID, }) + fmt.Println(expectedUser) + return serializer.BuildUserResponse(expectedUser) }