feat: support session persistence

This commit is contained in:
vvisionnn 2023-03-09 15:45:58 +08:00
parent f172220825
commit 3ded139e32
4 changed files with 246 additions and 8 deletions

5
go.mod
View file

@ -20,6 +20,8 @@ require (
github.com/gofrs/uuid v4.0.0+incompatible
github.com/gomodule/redigo v2.0.0+incompatible
github.com/google/go-querystring v1.0.0
github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.2.1
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/go-version v1.3.0
github.com/jinzhu/gorm v1.9.11
@ -83,8 +85,6 @@ require (
github.com/google/go-cmp v0.5.9 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
@ -115,7 +115,6 @@ require (
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.24.0 // indirect
github.com/prometheus/procfs v0.6.0 // indirect
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect

2
go.sum
View file

@ -750,8 +750,6 @@ github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdk
github.com/qiniu/go-sdk/v7 v7.11.1 h1:/LZ9rvFS4p6SnszhGv11FNB1+n4OZvBCwFg7opH5Ovs=
github.com/qiniu/go-sdk/v7 v7.11.1/go.mod h1:btsaOc8CA3hdVloULfFdDgDc+g4f3TDZEFsDY0BLE+w=
github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs=
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b h1:aUNXCGgukb4gtY99imuIeoh8Vr0GSwAlYxPAhqZrpFc=
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 h1:leEwA4MD1ew0lNgzz6Q4G76G3AEfeci+TMggN6WuFRs=
github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1/go.mod h1:JaY6n2sDr+z2WTsXkOmNRUfDy6FN0L6Nk7x06ndm4tY=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=

View file

@ -4,17 +4,18 @@ import (
"net/http"
"strings"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/session"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore"
"github.com/gin-contrib/sessions/redis"
"github.com/gin-gonic/gin"
)
// Store session存储
var Store memstore.Store
var Store sessions.Store
// Session 初始化session
func Session(secret string) gin.HandlerFunc {
@ -28,7 +29,7 @@ func Session(secret string) gin.HandlerFunc {
util.Log().Info("Connect to Redis server %q.", conf.RedisConfig.Server)
} else {
Store = memstore.NewStore([]byte(secret))
Store = session.NewStore(model.DB, []byte(secret))
}
sameSiteMode := http.SameSiteDefaultMode

240
pkg/session/session.go Normal file
View file

@ -0,0 +1,240 @@
package session
// TODO: unit test
import (
"encoding/base32"
"net/http"
"strings"
"time"
ginSession "github.com/gin-contrib/sessions"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/jinzhu/gorm"
)
type UserSession struct {
ID string `gorm:"unique_index"`
Data string `gorm:"text"`
CreatedAt time.Time
UpdatedAt time.Time
ExpiresAt time.Time `gorm:"index"`
}
type SessionCacher interface {
Setup()
Get(sessionID string) *UserSession
Create(val *UserSession) error
Update(val *UserSession) error
Delete(val *UserSession) error
}
type SqliteSessionCacher struct {
db *gorm.DB
}
func (c *SqliteSessionCacher) sessionTable() *gorm.DB {
return c.db.Table("user_sessions")
}
func (c *SqliteSessionCacher) Setup() {
c.sessionTable().AutoMigrate(&UserSession{})
}
func (c *SqliteSessionCacher) Get(sessionID string) *UserSession {
// after get session id, try get persisted session
persistedSession := &UserSession{}
record := c.sessionTable().
Where("id = ? AND expires_at > ?", sessionID, time.Now()).
Limit(1).
Find(persistedSession)
if record.Error != nil || record.RowsAffected == 0 {
return nil
}
return persistedSession
}
func (c *SqliteSessionCacher) Create(val *UserSession) error {
return c.sessionTable().Create(val).Error
}
func (c *SqliteSessionCacher) Update(val *UserSession) error {
return c.sessionTable().Save(val).Error
}
func (c *SqliteSessionCacher) Delete(val *UserSession) error {
return c.sessionTable().Delete(val).Error
}
type SessionStore struct {
SessionOptions *sessions.Options
Codecs []securecookie.Codec
cache SessionCacher
}
func NewSessionStore(db *gorm.DB, keyPires ...[]byte) *SessionStore {
store := SessionStore{
SessionOptions: &sessions.Options{
Path: "/",
MaxAge: 60 * 86400,
HttpOnly: true,
},
Codecs: securecookie.CodecsFromPairs(keyPires...),
cache: &SqliteSessionCacher{db: db},
}
store.setup()
store.MaxAge(store.SessionOptions.MaxAge)
return &store
}
/*
Session interface implementation
*/
func (s *SessionStore) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(s, name)
session.Options = s.SessionOptions
session.IsNew = true
s.MaxAge(s.SessionOptions.MaxAge)
if persistedSession := s.tryGetPersistedSessionFromCookie(r, name); persistedSession != nil {
// decode persisted session data to session.Values if found
err := securecookie.DecodeMulti(session.Name(), persistedSession.Data, &session.Values, s.Codecs...)
if err != nil {
return session, err
}
session.ID = persistedSession.ID
session.IsNew = false
return session, nil
}
return session, nil
}
func (s *SessionStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
// try get current session from cookie
persistedSession := s.tryGetPersistedSessionFromCookie(r, session.Name())
// delete session if MaxAge < 0
if session.Options.MaxAge < 0 {
if persistedSession != nil {
// delete persisted session
if err := s.cache.Delete(persistedSession); err != nil {
return err
}
}
// delete cookie
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
return nil
}
// or cretae new session / update current session
data, err := securecookie.EncodeMulti(session.Name(), session.Values, s.Codecs...)
if err != nil {
return err
}
if persistedSession == nil {
session.ID = strings.TrimRight(
base32.StdEncoding.EncodeToString(
securecookie.GenerateRandomKey(32)), "=")
// create new session
persistedSession = &UserSession{
ID: session.ID,
Data: data,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Duration(session.Options.MaxAge) * time.Second),
}
if err := s.cache.Create(persistedSession); err != nil {
return err
}
} else {
// update current session
persistedSession.Data = data
persistedSession.UpdatedAt = time.Now()
persistedSession.ExpiresAt = time.Now().Add(time.Duration(session.Options.MaxAge) * time.Second)
if err := s.cache.Update(persistedSession); err != nil {
return err
}
}
// set cookie
encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...)
if err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
return nil
}
func (s *SessionStore) Get(r *http.Request, name string) (*sessions.Session, error) {
return sessions.GetRegistry(r).Get(s, name)
}
/*
Handy tools
*/
func (st *SessionStore) setup() {
st.cache.Setup()
}
func (s *SessionStore) MaxAge(age int) {
s.SessionOptions.MaxAge = age
// Set the maxAge for each securecookie instance.
for _, codec := range s.Codecs {
if sc, ok := codec.(*securecookie.SecureCookie); ok {
sc.MaxAge(age)
}
}
}
func (s *SessionStore) tryGetPersistedSessionFromCookie(r *http.Request, name string) *UserSession {
// get cookie from request
cookie, err := r.Cookie(name)
if err != nil {
return nil
}
// decode cookie value to session id
var sessionID string
err = securecookie.DecodeMulti(name, cookie.Value, &sessionID, s.Codecs...)
if err != nil {
return nil
}
// after get session id, try get persisted session
return s.cache.Get(sessionID)
}
/*
Below is gin-session wrapper
*/
type Store interface {
ginSession.Store
}
type store struct {
*SessionStore
}
func (s *store) Options(options ginSession.Options) {
s.SessionStore.SessionOptions = options.ToGorillaOptions()
}
func NewStore(db *gorm.DB, keyPires ...[]byte) Store {
return &store{NewSessionStore(db, keyPires...)}
}