gormstore.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package sessions
  2. import (
  3. "net/http"
  4. "time"
  5. "github.com/gorilla/context"
  6. "github.com/gorilla/securecookie"
  7. gsessions "github.com/gorilla/sessions"
  8. "github.com/jinzhu/gorm"
  9. "github.com/teris-io/shortid"
  10. )
  11. // Options for gormstore
  12. type GormStoreOptions struct {
  13. TableName string
  14. SkipCreateTable bool
  15. }
  16. // Store represent a gormstore
  17. type GormStore struct {
  18. db *gorm.DB
  19. opts GormStoreOptions
  20. Codecs []securecookie.Codec
  21. SessionOpts *gsessions.Options
  22. }
  23. type gormSession struct {
  24. ID string `sql:"unique_index"`
  25. Data string `sql:"type:text"`
  26. CreatedAt time.Time
  27. UpdatedAt time.Time
  28. ExpiresAt time.Time `sql:"index"`
  29. tableName string `sql:"-"` // just for convenience instead of db.Table(...)
  30. }
  31. // Define a type for context keys so that they can't clash with anything else stored in context
  32. type contextKey string
  33. func (gs *gormSession) TableName() string {
  34. return gs.tableName
  35. }
  36. // New creates a new gormstore session
  37. func NewGormStore(db *gorm.DB, keyPairs ...[]byte) *GormStore {
  38. return NewGormStoreWithOptions(db, GormStoreOptions{}, keyPairs...)
  39. }
  40. // NewOptions creates a new gormstore session with options
  41. func NewGormStoreWithOptions(db *gorm.DB, opts GormStoreOptions, keyPairs ...[]byte) *GormStore {
  42. st := &GormStore{
  43. db: db,
  44. opts: opts,
  45. Codecs: securecookie.CodecsFromPairs(keyPairs...),
  46. SessionOpts: &gsessions.Options{
  47. Path: defaultPath,
  48. MaxAge: defaultMaxAge,
  49. },
  50. }
  51. if st.opts.TableName == "" {
  52. st.opts.TableName = "t_sessions"
  53. }
  54. if !st.opts.SkipCreateTable {
  55. st.db.AutoMigrate(&gormSession{tableName: st.opts.TableName})
  56. }
  57. st.Cleanup()
  58. return st
  59. }
  60. // Get returns a session for the given name after adding it to the registry.
  61. func (st *GormStore) Get(r *http.Request, name string) (*gsessions.Session, error) {
  62. return gsessions.GetRegistry(r).Get(st, name)
  63. }
  64. // New creates a session with name without adding it to the registry.
  65. func (st *GormStore) New(r *http.Request, name string) (*gsessions.Session, error) {
  66. session := gsessions.NewSession(st, name)
  67. opts := *st.SessionOpts
  68. session.Options = &opts
  69. session.IsNew = true
  70. st.MaxAge(st.SessionOpts.MaxAge)
  71. // try fetch from db if there is a cookie
  72. if cookie, err := r.Cookie(name); err == nil {
  73. session.ID = cookie.Value
  74. s := &gormSession{tableName: st.opts.TableName}
  75. if err := st.db.Where("id = ? AND expires_at > ?", session.ID, gorm.NowFunc()).First(s).Error; err != nil {
  76. return session, nil
  77. }
  78. if err := securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, st.Codecs...); err != nil {
  79. return session, nil
  80. }
  81. session.IsNew = false
  82. context.Set(r, contextKey(name), s)
  83. } else {
  84. session.ID = shortid.MustGenerate()
  85. }
  86. return session, nil
  87. }
  88. func (st *GormStore) RenewID(r *http.Request, w http.ResponseWriter, session *gsessions.Session) error {
  89. _id := session.ID
  90. session.ID = shortid.MustGenerate()
  91. st.db.Exec("UPDATE "+st.opts.TableName+" SET id=? WHERE id=?", session.ID, _id)
  92. http.SetCookie(w, gsessions.NewCookie(session.Name(), session.ID, session.Options))
  93. return nil
  94. }
  95. // Save session and set cookie header
  96. func (st *GormStore) Save(r *http.Request, w http.ResponseWriter, session *gsessions.Session) error {
  97. s, _ := context.Get(r, contextKey(session.Name())).(*gormSession)
  98. // delete if max age is < 0
  99. if session.Options.MaxAge < 0 || len(session.Values) == 0 {
  100. if s != nil {
  101. if err := st.db.Delete(s).Error; err != nil {
  102. return err
  103. }
  104. }
  105. return nil
  106. }
  107. data, err := securecookie.EncodeMulti(session.Name(), session.Values, st.Codecs...)
  108. if err != nil {
  109. return err
  110. }
  111. now := time.Now()
  112. expire := now.Add(time.Second * time.Duration(session.Options.MaxAge))
  113. if s == nil {
  114. // generate random session ID key suitable for storage in the db
  115. if session.ID == "" {
  116. session.ID = shortid.MustGenerate()
  117. }
  118. s = &gormSession{
  119. ID: session.ID,
  120. Data: data,
  121. CreatedAt: now,
  122. UpdatedAt: now,
  123. ExpiresAt: expire,
  124. tableName: st.opts.TableName,
  125. }
  126. if err := st.db.Create(s).Error; err != nil {
  127. return err
  128. }
  129. context.Set(r, contextKey(session.Name()), s)
  130. } else {
  131. s.Data = data
  132. s.UpdatedAt = now
  133. s.ExpiresAt = expire
  134. if err := st.db.Save(s).Error; err != nil {
  135. return err
  136. }
  137. }
  138. return nil
  139. }
  140. // MaxAge sets the maximum age for the store and the underlying cookie
  141. // implementation. Individual sessions can be deleted by setting
  142. // Options.MaxAge = -1 for that session.
  143. func (st *GormStore) MaxAge(age int) {
  144. st.SessionOpts.MaxAge = age
  145. for _, codec := range st.Codecs {
  146. if sc, ok := codec.(*securecookie.SecureCookie); ok {
  147. sc.MaxAge(age)
  148. }
  149. }
  150. }
  151. func (st *GormStore) Options(options Options) {
  152. st.SessionOpts = &gsessions.Options{
  153. Path: options.Path,
  154. Domain: options.Domain,
  155. MaxAge: options.MaxAge,
  156. Secure: options.Secure,
  157. HttpOnly: options.HttpOnly,
  158. }
  159. }
  160. // MaxLength restricts the maximum length of new sessions to l.
  161. // If l is 0 there is no limit to the size of a session, use with caution.
  162. // The default is 4096 (default for securecookie)
  163. func (st *GormStore) MaxLength(l int) {
  164. for _, c := range st.Codecs {
  165. if codec, ok := c.(*securecookie.SecureCookie); ok {
  166. codec.MaxLength(l)
  167. }
  168. }
  169. }
  170. // Cleanup deletes expired sessions
  171. func (st *GormStore) Cleanup() {
  172. st.db.Delete(&gormSession{tableName: st.opts.TableName}, "expires_at <= ?", gorm.NowFunc())
  173. time.AfterFunc(15*time.Second, func() {
  174. st.Cleanup()
  175. })
  176. }