sessions.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package sessions
  2. import (
  3. "log"
  4. "net/http"
  5. "regexp"
  6. "strings"
  7. "github.com/gin-gonic/gin"
  8. "github.com/gorilla/context"
  9. "github.com/gorilla/sessions"
  10. )
  11. const (
  12. DefaultKey = "github.com/penggy/sessions"
  13. errorFormat = "[sessions] ERROR! %s\n"
  14. defaultMaxAge = 60 * 60 * 24 * 30 // 30 days
  15. defaultPath = "/"
  16. )
  17. type Store interface {
  18. sessions.Store
  19. RenewID(r *http.Request, w http.ResponseWriter, gsession *sessions.Session) error
  20. Options(Options)
  21. }
  22. // Options stores configuration for a session or session store.
  23. // Fields are a subset of http.Cookie fields.
  24. type Options struct {
  25. Path string
  26. Domain string
  27. // MaxAge=0 means no 'Max-Age' attribute specified.
  28. // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'.
  29. // MaxAge>0 means Max-Age attribute present and given in seconds.
  30. MaxAge int
  31. Secure bool
  32. HttpOnly bool
  33. }
  34. // Wraps thinly gorilla-session methods.
  35. // Session stores the values and optional configuration for a session.
  36. type Session interface {
  37. // Get returns the session value associated to the given key.
  38. Get(key interface{}) interface{}
  39. // Set sets the session value associated to the given key.
  40. Set(key interface{}, val interface{})
  41. // Delete removes the session value associated to the given key.
  42. Delete(key interface{})
  43. // Clear deletes all values in the session.
  44. Clear()
  45. // AddFlash adds a flash message to the session.
  46. // A single variadic argument is accepted, and it is optional: it defines the flash key.
  47. // If not defined "_flash" is used by default.
  48. AddFlash(value interface{}, vars ...string)
  49. // Flashes returns a slice of flash messages from the session.
  50. // A single variadic argument is accepted, and it is optional: it defines the flash key.
  51. // If not defined "_flash" is used by default.
  52. Flashes(vars ...string) []interface{}
  53. // Options sets confuguration for a session.
  54. Options(Options)
  55. // Save saves all sessions used during the current request.
  56. Save() error
  57. RenewID() (string, error)
  58. ID() string
  59. SetMaxAge(maxAge int)
  60. Destroy()
  61. }
  62. func Sessions(name string, store Store) gin.HandlerFunc {
  63. return func(c *gin.Context) {
  64. s := &session{name, c.Request, store, nil, false, c.Writer, false}
  65. c.Set(DefaultKey, s)
  66. defer context.Clear(c.Request)
  67. defer s.Save()
  68. http.SetCookie(s.writer, sessions.NewCookie(s.name, s.ID(), s.Session().Options))
  69. c.Next()
  70. }
  71. }
  72. func GorillaSessions(name string, store Store) gin.HandlerFunc {
  73. return func(c *gin.Context) {
  74. s := &session{name, c.Request, store, nil, false, c.Writer, true}
  75. c.Set(DefaultKey, s)
  76. defer context.Clear(c.Request)
  77. c.Next()
  78. }
  79. }
  80. type session struct {
  81. name string
  82. request *http.Request
  83. store Store
  84. session *sessions.Session
  85. written bool
  86. writer http.ResponseWriter
  87. gorilla bool
  88. }
  89. func (s *session) Get(key interface{}) interface{} {
  90. return s.Session().Values[key]
  91. }
  92. func (s *session) Set(key interface{}, val interface{}) {
  93. s.Session().Values[key] = val
  94. s.written = true
  95. }
  96. func (s *session) Delete(key interface{}) {
  97. delete(s.Session().Values, key)
  98. s.written = true
  99. }
  100. func (s *session) Clear() {
  101. for key := range s.Session().Values {
  102. delete(s.Session().Values, key)
  103. }
  104. s.written = true
  105. }
  106. func (s *session) AddFlash(value interface{}, vars ...string) {
  107. s.Session().AddFlash(value, vars...)
  108. }
  109. func (s *session) Flashes(vars ...string) []interface{} {
  110. return s.Session().Flashes(vars...)
  111. }
  112. func (s *session) Options(options Options) {
  113. s.Session().Options = &sessions.Options{
  114. Path: options.Path,
  115. Domain: options.Domain,
  116. MaxAge: options.MaxAge,
  117. Secure: options.Secure,
  118. HttpOnly: options.HttpOnly,
  119. }
  120. }
  121. func (s *session) Save() error {
  122. if s.Written() {
  123. e := s.Session().Save(s.request, s.writer)
  124. if e == nil {
  125. s.written = false
  126. }
  127. return e
  128. }
  129. return nil
  130. }
  131. func (s *session) RenewID() (string, error) {
  132. e := s.store.RenewID(s.request, s.writer, s.Session())
  133. return s.ID(), e
  134. }
  135. func (s *session) ID() string {
  136. return s.Session().ID
  137. }
  138. func (s *session) SetMaxAge(maxAge int) {
  139. s.Session().Options.MaxAge = maxAge
  140. if s.gorilla {
  141. s.written = true
  142. } else {
  143. http.SetCookie(s.writer, sessions.NewCookie(s.name, s.Session().ID, s.Session().Options))
  144. }
  145. }
  146. func (s *session) Destroy() {
  147. s.SetMaxAge(-1)
  148. s.Clear()
  149. }
  150. func (s *session) Written() bool {
  151. return s.written
  152. }
  153. func (s *session) Session() *sessions.Session {
  154. if s.session == nil {
  155. var err error
  156. s.session, err = s.store.Get(s.request, s.name)
  157. if err != nil {
  158. log.Printf(errorFormat, err)
  159. }
  160. }
  161. return s.session
  162. }
  163. func (s *session) XHR() bool {
  164. if strings.EqualFold(s.request.Header.Get("x-requested-with"), "XMLHttpRequest") {
  165. return true
  166. }
  167. return regexp.MustCompile("\\/json$").MatchString(s.request.Header.Get("accept"))
  168. }
  169. // shortcut to get session
  170. func Default(c *gin.Context) Session {
  171. return c.MustGet(DefaultKey).(Session)
  172. }