Feat: HMAC auth and check
This commit is contained in:
parent
4649ddbae2
commit
e871f6e421
7 changed files with 166 additions and 5 deletions
7
main.go
7
main.go
|
@ -2,25 +2,22 @@ package main
|
|||
|
||||
import (
|
||||
"github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/auth"
|
||||
"github.com/HFO4/cloudreve/pkg/authn"
|
||||
"github.com/HFO4/cloudreve/pkg/conf"
|
||||
"github.com/HFO4/cloudreve/routers"
|
||||
"github.com/gin-gonic/gin"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
conf.Init("conf/conf.ini")
|
||||
model.Init()
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
// Debug 关闭时,切换为生产模式
|
||||
if !conf.SystemConfig.Debug {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
auth.Init()
|
||||
authn.Init()
|
||||
}
|
||||
|
||||
|
|
|
@ -146,6 +146,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
|||
{Name: "aria2_rpcurl", Value: `http://127.0.0.1:6800/`, Type: "aria2"},
|
||||
{Name: "aria2_options", Value: `{"max-tries":5}`, Type: "aria2"},
|
||||
{Name: "task_queue_token", Value: ``, Type: "task"},
|
||||
{Name: "secret_key", Value: util.RandStringRunes(256), Type: "auth"},
|
||||
}
|
||||
|
||||
for _, value := range defaultSettings {
|
||||
|
|
29
pkg/auth/auth.go
Normal file
29
pkg/auth/auth.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAuthFailed = serializer.NewError(serializer.CodeNoRightErr, "鉴权失败", nil)
|
||||
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "签名已过期", nil)
|
||||
)
|
||||
|
||||
// General 通用的认证接口
|
||||
var General Auth
|
||||
|
||||
// Auth 鉴权认证
|
||||
type Auth interface {
|
||||
// 对给定Body进行签名,expires为0表示永不过期
|
||||
Sign(body string, expires int64) string
|
||||
// 对给定Body和Sign进行检查
|
||||
Check(body string, sign string) error
|
||||
}
|
||||
|
||||
// Init 初始化通用鉴权器
|
||||
func Init() {
|
||||
General = HMACAuth{
|
||||
SecretKey: []byte(model.GetSettingByName("secret_key")),
|
||||
}
|
||||
}
|
53
pkg/auth/hmac.go
Normal file
53
pkg/auth/hmac.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HMACAuth HMAC算法鉴权
|
||||
type HMACAuth struct {
|
||||
SecretKey []byte
|
||||
}
|
||||
|
||||
// Sign 对给定Body生成expires后失效的签名
|
||||
func (auth HMACAuth) Sign(body string, expires int64) string {
|
||||
h := hmac.New(sha256.New, auth.SecretKey)
|
||||
expireTimeStamp := strconv.FormatInt(expires, 10)
|
||||
_, err := io.WriteString(h, body+":"+expireTimeStamp)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", h.Sum(nil)) + ":" + expireTimeStamp
|
||||
}
|
||||
|
||||
// Check 对给定Body和Sign进行鉴权,包括对expires的检查
|
||||
func (auth HMACAuth) Check(body string, sign string) error {
|
||||
signSlice := strings.Split(sign, ":")
|
||||
// 如果未携带expires字段
|
||||
if signSlice[len(signSlice)-1] == "" {
|
||||
return ErrAuthFailed
|
||||
}
|
||||
|
||||
// 验证是否过期
|
||||
expires, err := strconv.ParseInt(signSlice[len(signSlice)-1], 10, 64)
|
||||
if err != nil {
|
||||
return ErrAuthFailed.WithError(err)
|
||||
}
|
||||
// 如果签名过期
|
||||
if expires < time.Now().Unix() && expires != 0 {
|
||||
return ErrExpired
|
||||
}
|
||||
|
||||
// 验证签名
|
||||
if auth.Sign(body, expires) != sign {
|
||||
return ErrAuthFailed
|
||||
}
|
||||
return nil
|
||||
}
|
74
pkg/auth/hmac_test.go
Normal file
74
pkg/auth/hmac_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
model "github.com/HFO4/cloudreve/models"
|
||||
"github.com/HFO4/cloudreve/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var mock sqlmock.Sqlmock
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// 设置gin为测试模式
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 初始化sqlmock
|
||||
var db *sql.DB
|
||||
var err error
|
||||
db, mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
panic("An error was not expected when opening a stub database connection")
|
||||
}
|
||||
|
||||
mockDB, _ := gorm.Open("mysql", db)
|
||||
model.DB = mockDB
|
||||
defer db.Close()
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestHMACAuth_Sign(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
auth := HMACAuth{
|
||||
SecretKey: []byte(util.RandStringRunes(256)),
|
||||
}
|
||||
|
||||
asserts.NotEmpty(auth.Sign("content", 0))
|
||||
}
|
||||
|
||||
func TestHMACAuth_Check(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
auth := HMACAuth{
|
||||
SecretKey: []byte(util.RandStringRunes(256)),
|
||||
}
|
||||
|
||||
// 正常,永不过期
|
||||
{
|
||||
sign := auth.Sign("content", 0)
|
||||
asserts.NoError(auth.Check("content", sign))
|
||||
}
|
||||
|
||||
// 过期
|
||||
{
|
||||
sign := auth.Sign("content", 1)
|
||||
asserts.Error(auth.Check("content", sign))
|
||||
}
|
||||
|
||||
// 签名格式错误
|
||||
{
|
||||
sign := auth.Sign("content", 1)
|
||||
asserts.Error(auth.Check("content", sign+":"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312"))
|
||||
Init()
|
||||
asserts.NoError(mock.ExpectationsWereMet())
|
||||
}
|
|
@ -54,6 +54,8 @@ const (
|
|||
CodeCreateFolderFailed = 40003
|
||||
// CodeObjectExist 对象已存在
|
||||
CodeObjectExist = 40004
|
||||
// CodeSignExpired 签名过期
|
||||
CodeSignExpired = 40005
|
||||
// CodeDBError 数据库操作失败
|
||||
CodeDBError = 50001
|
||||
// CodeEncryptError 加密失败
|
||||
|
|
|
@ -4,8 +4,13 @@ import (
|
|||
"math/rand"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// RandStringRunes 返回随机字符串
|
||||
func RandStringRunes(n int) string {
|
||||
var letterRunes = []rune("1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
|
|
Loading…
Add table
Reference in a new issue