config.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package cors
  2. import (
  3. "net/http"
  4. "github.com/gin-gonic/gin"
  5. )
  6. type cors struct {
  7. allowAllOrigins bool
  8. allowCredentials bool
  9. allowOriginFunc func(string) bool
  10. allowOrigins []string
  11. exposeHeaders []string
  12. normalHeaders http.Header
  13. preflightHeaders http.Header
  14. }
  15. func newCors(config Config) *cors {
  16. if err := config.Validate(); err != nil {
  17. panic(err.Error())
  18. }
  19. return &cors{
  20. allowOriginFunc: config.AllowOriginFunc,
  21. allowAllOrigins: config.AllowAllOrigins,
  22. allowCredentials: config.AllowCredentials,
  23. allowOrigins: normalize(config.AllowOrigins),
  24. normalHeaders: generateNormalHeaders(config),
  25. preflightHeaders: generatePreflightHeaders(config),
  26. }
  27. }
  28. func (cors *cors) applyCors(c *gin.Context) {
  29. origin := c.Request.Header.Get("Origin")
  30. if origin == "" {
  31. origin = "*"
  32. }
  33. // if len(origin) == 0 {
  34. // // request is not a CORS request
  35. // return
  36. // }
  37. if !cors.validateOrigin(origin) {
  38. c.AbortWithStatus(http.StatusForbidden)
  39. return
  40. }
  41. if c.Request.Method == "OPTIONS" {
  42. cors.handlePreflight(c)
  43. defer c.AbortWithStatus(200)
  44. } else {
  45. cors.handleNormal(c)
  46. }
  47. if !cors.allowAllOrigins {
  48. c.Header("Access-Control-Allow-Origin", origin)
  49. }
  50. }
  51. func (cors *cors) validateOrigin(origin string) bool {
  52. if cors.allowAllOrigins {
  53. return true
  54. }
  55. for _, value := range cors.allowOrigins {
  56. if value == origin {
  57. return true
  58. }
  59. }
  60. if cors.allowOriginFunc != nil {
  61. return cors.allowOriginFunc(origin)
  62. }
  63. return false
  64. }
  65. func (cors *cors) handlePreflight(c *gin.Context) {
  66. header := c.Writer.Header()
  67. for key, value := range cors.preflightHeaders {
  68. header[key] = value
  69. }
  70. }
  71. func (cors *cors) handleNormal(c *gin.Context) {
  72. header := c.Writer.Header()
  73. for key, value := range cors.normalHeaders {
  74. header[key] = value
  75. }
  76. }