diff --git a/go.mod b/go.mod index 08478fe..709c6d1 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,8 @@ require ( github.com/gofrs/uuid v4.0.0+incompatible github.com/gomodule/redigo v2.0.0+incompatible github.com/google/go-querystring v1.0.0 + github.com/gorilla/securecookie v1.1.1 + github.com/gorilla/sessions v1.2.1 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/go-version v1.3.0 github.com/jinzhu/gorm v1.9.11 @@ -83,8 +85,6 @@ require ( github.com/google/go-cmp v0.5.9 // indirect github.com/google/uuid v1.3.0 // indirect github.com/gorilla/context v1.1.1 // indirect - github.com/gorilla/securecookie v1.1.1 // indirect - github.com/gorilla/sessions v1.2.1 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect @@ -115,7 +115,6 @@ require ( github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.24.0 // indirect github.com/prometheus/procfs v0.6.0 // indirect - github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 4d35a94..5f69ef6 100644 --- a/go.sum +++ b/go.sum @@ -750,8 +750,6 @@ github.com/qiniu/dyn v1.3.0/go.mod h1:E8oERcm8TtwJiZvkQPbcAh0RL8jO1G0VXJMW3FAWdk github.com/qiniu/go-sdk/v7 v7.11.1 h1:/LZ9rvFS4p6SnszhGv11FNB1+n4OZvBCwFg7opH5Ovs= github.com/qiniu/go-sdk/v7 v7.11.1/go.mod h1:btsaOc8CA3hdVloULfFdDgDc+g4f3TDZEFsDY0BLE+w= github.com/qiniu/x v1.10.5/go.mod h1:03Ni9tj+N2h2aKnAz+6N0Xfl8FwMEDRC2PAlxekASDs= -github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b h1:aUNXCGgukb4gtY99imuIeoh8Vr0GSwAlYxPAhqZrpFc= -github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg= github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 h1:leEwA4MD1ew0lNgzz6Q4G76G3AEfeci+TMggN6WuFRs= github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1/go.mod h1:JaY6n2sDr+z2WTsXkOmNRUfDy6FN0L6Nk7x06ndm4tY= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= diff --git a/middleware/session.go b/middleware/session.go index 77825ae..28d5fa9 100644 --- a/middleware/session.go +++ b/middleware/session.go @@ -4,17 +4,18 @@ import ( "net/http" "strings" + model "github.com/cloudreve/Cloudreve/v3/models" "github.com/cloudreve/Cloudreve/v3/pkg/conf" "github.com/cloudreve/Cloudreve/v3/pkg/serializer" + "github.com/cloudreve/Cloudreve/v3/pkg/session" "github.com/cloudreve/Cloudreve/v3/pkg/util" "github.com/gin-contrib/sessions" - "github.com/gin-contrib/sessions/memstore" "github.com/gin-contrib/sessions/redis" "github.com/gin-gonic/gin" ) // Store session存储 -var Store memstore.Store +var Store sessions.Store // Session 初始化session func Session(secret string) gin.HandlerFunc { @@ -28,7 +29,7 @@ func Session(secret string) gin.HandlerFunc { util.Log().Info("Connect to Redis server %q.", conf.RedisConfig.Server) } else { - Store = memstore.NewStore([]byte(secret)) + Store = session.NewStore(model.DB, []byte(secret)) } sameSiteMode := http.SameSiteDefaultMode diff --git a/pkg/session/session.go b/pkg/session/session.go new file mode 100644 index 0000000..637f2d1 --- /dev/null +++ b/pkg/session/session.go @@ -0,0 +1,240 @@ +package session + +// TODO: unit test + +import ( + "encoding/base32" + "net/http" + "strings" + "time" + + ginSession "github.com/gin-contrib/sessions" + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + "github.com/jinzhu/gorm" +) + +type UserSession struct { + ID string `gorm:"unique_index"` + Data string `gorm:"text"` + CreatedAt time.Time + UpdatedAt time.Time + ExpiresAt time.Time `gorm:"index"` +} + +type SessionCacher interface { + Setup() + Get(sessionID string) *UserSession + Create(val *UserSession) error + Update(val *UserSession) error + Delete(val *UserSession) error +} + +type SqliteSessionCacher struct { + db *gorm.DB +} + +func (c *SqliteSessionCacher) sessionTable() *gorm.DB { + return c.db.Table("user_sessions") +} + +func (c *SqliteSessionCacher) Setup() { + c.sessionTable().AutoMigrate(&UserSession{}) +} + +func (c *SqliteSessionCacher) Get(sessionID string) *UserSession { + // after get session id, try get persisted session + persistedSession := &UserSession{} + record := c.sessionTable(). + Where("id = ? AND expires_at > ?", sessionID, time.Now()). + Limit(1). + Find(persistedSession) + + if record.Error != nil || record.RowsAffected == 0 { + return nil + } + + return persistedSession +} + +func (c *SqliteSessionCacher) Create(val *UserSession) error { + return c.sessionTable().Create(val).Error +} + +func (c *SqliteSessionCacher) Update(val *UserSession) error { + return c.sessionTable().Save(val).Error +} + +func (c *SqliteSessionCacher) Delete(val *UserSession) error { + return c.sessionTable().Delete(val).Error +} + +type SessionStore struct { + SessionOptions *sessions.Options + Codecs []securecookie.Codec + cache SessionCacher +} + +func NewSessionStore(db *gorm.DB, keyPires ...[]byte) *SessionStore { + store := SessionStore{ + SessionOptions: &sessions.Options{ + Path: "/", + MaxAge: 60 * 86400, + HttpOnly: true, + }, + Codecs: securecookie.CodecsFromPairs(keyPires...), + cache: &SqliteSessionCacher{db: db}, + } + + store.setup() + store.MaxAge(store.SessionOptions.MaxAge) + return &store +} + +/* + Session interface implementation +*/ + +func (s *SessionStore) New(r *http.Request, name string) (*sessions.Session, error) { + session := sessions.NewSession(s, name) + session.Options = s.SessionOptions + session.IsNew = true + + s.MaxAge(s.SessionOptions.MaxAge) + if persistedSession := s.tryGetPersistedSessionFromCookie(r, name); persistedSession != nil { + // decode persisted session data to session.Values if found + err := securecookie.DecodeMulti(session.Name(), persistedSession.Data, &session.Values, s.Codecs...) + if err != nil { + return session, err + } + + session.ID = persistedSession.ID + session.IsNew = false + return session, nil + } + + return session, nil +} + +func (s *SessionStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { + // try get current session from cookie + persistedSession := s.tryGetPersistedSessionFromCookie(r, session.Name()) + + // delete session if MaxAge < 0 + if session.Options.MaxAge < 0 { + if persistedSession != nil { + // delete persisted session + if err := s.cache.Delete(persistedSession); err != nil { + return err + } + } + + // delete cookie + http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) + return nil + } + + // or cretae new session / update current session + data, err := securecookie.EncodeMulti(session.Name(), session.Values, s.Codecs...) + if err != nil { + return err + } + + if persistedSession == nil { + session.ID = strings.TrimRight( + base32.StdEncoding.EncodeToString( + securecookie.GenerateRandomKey(32)), "=") + + // create new session + persistedSession = &UserSession{ + ID: session.ID, + Data: data, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Duration(session.Options.MaxAge) * time.Second), + } + + if err := s.cache.Create(persistedSession); err != nil { + return err + } + } else { + // update current session + persistedSession.Data = data + persistedSession.UpdatedAt = time.Now() + persistedSession.ExpiresAt = time.Now().Add(time.Duration(session.Options.MaxAge) * time.Second) + + if err := s.cache.Update(persistedSession); err != nil { + return err + } + } + + // set cookie + encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...) + if err != nil { + return err + } + + http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options)) + return nil +} + +func (s *SessionStore) Get(r *http.Request, name string) (*sessions.Session, error) { + return sessions.GetRegistry(r).Get(s, name) +} + +/* + Handy tools +*/ + +func (st *SessionStore) setup() { + st.cache.Setup() +} + +func (s *SessionStore) MaxAge(age int) { + s.SessionOptions.MaxAge = age + + // Set the maxAge for each securecookie instance. + for _, codec := range s.Codecs { + if sc, ok := codec.(*securecookie.SecureCookie); ok { + sc.MaxAge(age) + } + } +} + +func (s *SessionStore) tryGetPersistedSessionFromCookie(r *http.Request, name string) *UserSession { + // get cookie from request + cookie, err := r.Cookie(name) + if err != nil { + return nil + } + + // decode cookie value to session id + var sessionID string + err = securecookie.DecodeMulti(name, cookie.Value, &sessionID, s.Codecs...) + if err != nil { + return nil + } + + // after get session id, try get persisted session + return s.cache.Get(sessionID) +} + +/* + Below is gin-session wrapper +*/ + +type Store interface { + ginSession.Store +} + +type store struct { + *SessionStore +} + +func (s *store) Options(options ginSession.Options) { + s.SessionStore.SessionOptions = options.ToGorillaOptions() +} + +func NewStore(db *gorm.DB, keyPires ...[]byte) Store { + return &store{NewSessionStore(db, keyPires...)} +}