diff --git a/middleware/auth.go b/middleware/auth.go index fd0f143..69233ee 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -90,7 +90,7 @@ func WebDAVAuth() gin.HandlerFunc { return } - expectedUser, err := model.GetUserByEmail(username) + expectedUser, err := model.GetActiveUserByEmail(username) if err != nil { c.Status(http.StatusUnauthorized) c.Abort() diff --git a/models/user.go b/models/user.go index c4226f0..ecd091b 100644 --- a/models/user.go +++ b/models/user.go @@ -139,6 +139,13 @@ func GetActiveUserByOpenID(openid string) (User, error) { // GetUserByEmail 用Email获取用户 func GetUserByEmail(email string) (User, error) { + var user User + result := DB.Set("gorm:auto_preload", true).Where("email = ?", email).First(&user) + return user, result.Error +} + +// GetActiveUserByEmail 用Email获取可登录用户 +func GetActiveUserByEmail(email string) (User, error) { var user User result := DB.Set("gorm:auto_preload", true).Where("status = ? and email = ?", Active, email).First(&user) return user, result.Error diff --git a/models/user_test.go b/models/user_test.go index ea346a2..5b4d375 100644 --- a/models/user_test.go +++ b/models/user_test.go @@ -352,10 +352,20 @@ func TestUser_IncreaseStorageWithoutCheck(t *testing.T) { } } -func TestGetUserByEmail(t *testing.T) { +func TestGetActiveUserByEmail(t *testing.T) { asserts := assert.New(t) mock.ExpectQuery("SELECT(.+)").WithArgs(Active, "abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"})) + _, err := GetActiveUserByEmail("abslant@foxmail.com") + + asserts.Error(err) + asserts.NoError(mock.ExpectationsWereMet()) +} + +func TestGetUserByEmail(t *testing.T) { + asserts := assert.New(t) + + mock.ExpectQuery("SELECT(.+)").WithArgs("abslant@foxmail.com").WillReturnRows(sqlmock.NewRows([]string{"id", "email"})) _, err := GetUserByEmail("abslant@foxmail.com") asserts.Error(err) diff --git a/pkg/filesystem/driver/local/handler_test.go b/pkg/filesystem/driver/local/handler_test.go index ed0bc2b..6d6f98d 100644 --- a/pkg/filesystem/driver/local/handler_test.go +++ b/pkg/filesystem/driver/local/handler_test.go @@ -2,13 +2,6 @@ package local import ( "context" - "io" - "io/ioutil" - "net/url" - "os" - "strings" - "testing" - model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/auth" "github.com/cloudreve/Cloudreve/v3/pkg/conf" @@ -16,6 +9,12 @@ import ( "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" + "io" + "io/ioutil" + "net/url" + "os" + "strings" + "testing" ) func TestHandler_Put(t *testing.T) { @@ -61,24 +60,34 @@ func TestHandler_Delete(t *testing.T) { asserts := assert.New(t) handler := Driver{} ctx := context.Background() + filePath := util.RelativePath("test.file") - file, err := os.Create(util.RelativePath("test.file")) + file, err := os.Create(filePath) asserts.NoError(err) _ = file.Close() list, err := handler.Delete(ctx, []string{"test.file"}) asserts.Equal([]string{}, list) asserts.NoError(err) - file, err = os.Create(util.RelativePath("test.file")) - asserts.NoError(err) + file, err = os.Create(filePath) _ = file.Close() + file, _ = os.OpenFile(filePath, os.O_RDWR, os.FileMode(0)) + asserts.NoError(err) list, err = handler.Delete(ctx, []string{"test.file", "test.notexist"}) - asserts.Equal([]string{"test.notexist"}, list) - asserts.Error(err) + file.Close() + asserts.Equal([]string{}, list) + asserts.NoError(err) list, err = handler.Delete(ctx, []string{"test.notexist"}) - asserts.Equal([]string{"test.notexist"}, list) - asserts.Error(err) + asserts.Equal([]string{}, list) + asserts.NoError(err) + + file, err = os.Create(filePath) + asserts.NoError(err) + list, err = handler.Delete(ctx, []string{"test.file"}) + _ = file.Close() + asserts.Equal([]string{}, list) + asserts.NoError(err) } func TestHandler_Get(t *testing.T) { diff --git a/routers/controllers/user.go b/routers/controllers/user.go index b710942..77a5426 100644 --- a/routers/controllers/user.go +++ b/routers/controllers/user.go @@ -18,7 +18,7 @@ import ( // StartLoginAuthn 开始注册WebAuthn登录 func StartLoginAuthn(c *gin.Context) { userName := c.Param("username") - expectedUser, err := model.GetUserByEmail(userName) + expectedUser, err := model.GetActiveUserByEmail(userName) if err != nil { c.JSON(200, serializer.Err(serializer.CodeNotFound, "用户不存在", err)) return @@ -52,7 +52,7 @@ func StartLoginAuthn(c *gin.Context) { // FinishLoginAuthn 完成注册WebAuthn登录 func FinishLoginAuthn(c *gin.Context) { userName := c.Param("username") - expectedUser, err := model.GetUserByEmail(userName) + expectedUser, err := model.GetActiveUserByEmail(userName) if err != nil { c.JSON(200, serializer.Err(serializer.CodeCredentialInvalid, "用户邮箱或密码错误", err)) return diff --git a/service/user/login.go b/service/user/login.go index acea163..4689bc4 100644 --- a/service/user/login.go +++ b/service/user/login.go @@ -94,6 +94,12 @@ func (service *UserResetEmailService) Reset(c *gin.Context) serializer.Response // 查找用户 if user, err := model.GetUserByEmail(service.UserName); err == nil { + if user.Status == model.Baned || user.Status == model.OveruseBaned { + return serializer.Err(403, "该账号已被封禁", nil) + } + if user.Status == model.NotActivicated { + return serializer.Err(403, "该账号未激活", nil) + } // 创建密码重设会话 secret := util.RandStringRunes(32) cache.Set(fmt.Sprintf("user_reset_%d", user.ID), secret, 3600) diff --git a/service/user/register.go b/service/user/register.go index 04083ad..94c5eda 100644 --- a/service/user/register.go +++ b/service/user/register.go @@ -64,10 +64,17 @@ func (service *UserRegisterService) Register(c *gin.Context) serializer.Response user.Status = model.NotActivicated } user.GroupID = uint(defaultGroup) - + userNotActivated := false // 创建用户 if err := model.DB.Create(&user).Error; err != nil { - return serializer.DBErr("此邮箱已被使用", err) + //检查已存在使用者是否尚未激活 + expectedUser, err := model.GetUserByEmail(service.UserName) + if expectedUser.Status == model.NotActivicated { + userNotActivated = true + user = expectedUser + } else { + return serializer.DBErr("此邮箱已被使用", err) + } } // 发送激活邮件 @@ -100,8 +107,12 @@ func (service *UserRegisterService) Register(c *gin.Context) serializer.Response if err := email.Send(user.Email, title, body); err != nil { return serializer.Err(serializer.CodeInternalSetting, "无法发送激活邮件", err) } - - return serializer.Response{Code: 203} + if userNotActivated == true { + //原本在上面要抛出的DBErr,放来这边抛出 + return serializer.DBErr("用户未激活,已重新发送激活邮件", nil) + } else { + return serializer.Response{Code: 203} + } } return serializer.Response{}