lxg 4 년 전
커밋
93cd3940f8
5개의 변경된 파일383개의 추가작업 그리고 0개의 파일을 삭제
  1. 21 0
      LICENSE
  2. 91 0
      README.md
  3. 86 0
      config.go
  4. 100 0
      cors.go
  5. 85 0
      utils.go

+ 21 - 0
LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2016 Gin-Gonic
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 91 - 0
README.md

@@ -0,0 +1,91 @@
+# CORS gin's middleware
+
+[![Build Status](https://travis-ci.org/gin-contrib/cors.svg)](https://travis-ci.org/gin-contrib/cors)
+[![codecov](https://codecov.io/gh/gin-contrib/cors/branch/master/graph/badge.svg)](https://codecov.io/gh/gin-contrib/cors)
+[![Go Report Card](https://goreportcard.com/badge/github.com/gin-contrib/cors)](https://goreportcard.com/report/github.com/gin-contrib/cors)
+[![GoDoc](https://godoc.org/github.com/gin-contrib/cors?status.svg)](https://godoc.org/github.com/gin-contrib/cors)
+[![Join the chat at https://gitter.im/gin-gonic/gin](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/gin-gonic/gin)
+
+Gin middleware/handler to enable CORS support.
+
+## Usage
+
+### Start using it
+
+Download and install it:
+
+```sh
+$ go get github.com/gin-contrib/cors
+```
+
+Import it in your code:
+
+```go
+import "github.com/gin-contrib/cors"
+```
+
+### Canonical example:
+
+```go
+package main
+
+import (
+	"time"
+
+	"github.com/gin-contrib/cors"
+	"github.com/gin-gonic/gin"
+)
+
+func main() {
+	router := gin.Default()
+	// CORS for https://foo.com and https://github.com origins, allowing:
+	// - PUT and PATCH methods
+	// - Origin header
+	// - Credentials share
+	// - Preflight requests cached for 12 hours
+	router.Use(cors.New(cors.Config{
+		AllowOrigins:     []string{"https://foo.com"},
+		AllowMethods:     []string{"PUT", "PATCH"},
+		AllowHeaders:     []string{"Origin"},
+		ExposeHeaders:    []string{"Content-Length"},
+		AllowCredentials: true,
+		AllowOriginFunc: func(origin string) bool {
+			return origin == "https://github.com"
+		},
+		MaxAge: 12 * time.Hour,
+	}))
+	router.Run()
+}
+```
+
+### Using DefaultConfig as start point
+
+```go
+func main() {
+	router := gin.Default()
+	// - No origin allowed by default
+	// - GET,POST, PUT, HEAD methods
+	// - Credentials share disabled
+	// - Preflight requests cached for 12 hours
+	config := cors.DefaultConfig()
+	config.AllowOrigins = []string{"http://google.com"}
+	// config.AllowOrigins == []string{"http://google.com", "http://facebook.com"}
+
+	router.Use(cors.New(config))
+	router.Run()
+}
+```
+
+### Default() allows all origins
+
+```go
+func main() {
+	router := gin.Default()
+	// same as
+	// config := cors.DefaultConfig()
+	// config.AllowAllOrigins = true
+	// router.Use(cors.New(config))
+	router.Use(cors.Default())
+	router.Run()
+}
+```

+ 86 - 0
config.go

@@ -0,0 +1,86 @@
+package cors
+
+import (
+	"net/http"
+
+	"github.com/gin-gonic/gin"
+)
+
+type cors struct {
+	allowAllOrigins  bool
+	allowCredentials bool
+	allowOriginFunc  func(string) bool
+	allowOrigins     []string
+	exposeHeaders    []string
+	normalHeaders    http.Header
+	preflightHeaders http.Header
+}
+
+func newCors(config Config) *cors {
+	if err := config.Validate(); err != nil {
+		panic(err.Error())
+	}
+	return &cors{
+		allowOriginFunc:  config.AllowOriginFunc,
+		allowAllOrigins:  config.AllowAllOrigins,
+		allowCredentials: config.AllowCredentials,
+		allowOrigins:     normalize(config.AllowOrigins),
+		normalHeaders:    generateNormalHeaders(config),
+		preflightHeaders: generatePreflightHeaders(config),
+	}
+}
+
+func (cors *cors) applyCors(c *gin.Context) {
+	origin := c.Request.Header.Get("Origin")
+	if origin == "" {
+		origin = "*"
+	}
+	// if len(origin) == 0 {
+	// 	// request is not a CORS request
+	// 	return
+	// }
+	if !cors.validateOrigin(origin) {
+		c.AbortWithStatus(http.StatusForbidden)
+		return
+	}
+
+	if c.Request.Method == "OPTIONS" {
+		cors.handlePreflight(c)
+		defer c.AbortWithStatus(200)
+	} else {
+		cors.handleNormal(c)
+	}
+
+	if !cors.allowAllOrigins {
+		c.Header("Access-Control-Allow-Origin", origin)
+	}
+}
+
+func (cors *cors) validateOrigin(origin string) bool {
+	if cors.allowAllOrigins {
+		return true
+	}
+	for _, value := range cors.allowOrigins {
+		if value == origin {
+			return true
+		}
+	}
+	if cors.allowOriginFunc != nil {
+		return cors.allowOriginFunc(origin)
+	}
+	return false
+}
+
+func (cors *cors) handlePreflight(c *gin.Context) {
+	header := c.Writer.Header()
+	for key, value := range cors.preflightHeaders {
+		header[key] = value
+	}
+}
+
+func (cors *cors) handleNormal(c *gin.Context) {
+	header := c.Writer.Header()
+	for key, value := range cors.normalHeaders {
+		header[key] = value
+	}
+}

+ 100 - 0
cors.go

@@ -0,0 +1,100 @@
+package cors
+
+import (
+	"errors"
+	"strings"
+	"time"
+
+	"github.com/gin-gonic/gin"
+)
+
+// Config represents all available options for the middleware.
+type Config struct {
+	AllowAllOrigins bool
+
+	// AllowedOrigins is a list of origins a cross-domain request can be executed from.
+	// If the special "*" value is present in the list, all origins will be allowed.
+	// Default value is []
+	AllowOrigins []string
+
+	// AllowOriginFunc is a custom function to validate the origin. It take the origin
+	// as argument and returns true if allowed or false otherwise. If this option is
+	// set, the content of AllowedOrigins is ignored.
+	AllowOriginFunc func(origin string) bool
+
+	// AllowedMethods is a list of methods the client is allowed to use with
+	// cross-domain requests. Default value is simple methods (GET and POST)
+	AllowMethods []string
+
+	// AllowedHeaders is list of non simple headers the client is allowed to use with
+	// cross-domain requests.
+	AllowHeaders []string
+
+	// AllowCredentials indicates whether the request can include user credentials like
+	// cookies, HTTP authentication or client side SSL certificates.
+	AllowCredentials bool
+
+	// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
+	// API specification
+	ExposeHeaders []string
+
+	// MaxAge indicates how long (in seconds) the results of a preflight request
+	// can be cached
+	MaxAge time.Duration
+}
+
+// AddAllowMethods is allowed to add custom methods
+func (c *Config) AddAllowMethods(methods ...string) {
+	c.AllowMethods = append(c.AllowMethods, methods...)
+}
+
+// AddAllowHeaders is allowed to add custom headers
+func (c *Config) AddAllowHeaders(headers ...string) {
+	c.AllowHeaders = append(c.AllowHeaders, headers...)
+}
+
+// AddExposeHeaders is allowed to add custom expose headers
+func (c *Config) AddExposeHeaders(headers ...string) {
+	c.ExposeHeaders = append(c.ExposeHeaders, headers...)
+}
+
+// Validate is check configuration of user defined.
+func (c Config) Validate() error {
+	if c.AllowAllOrigins && (c.AllowOriginFunc != nil || len(c.AllowOrigins) > 0) {
+		return errors.New("conflict settings: all origins are allowed. AllowOriginFunc or AllowedOrigins is not needed")
+	}
+	if !c.AllowAllOrigins && c.AllowOriginFunc == nil && len(c.AllowOrigins) == 0 {
+		return errors.New("conflict settings: all origins disabled")
+	}
+	for _, origin := range c.AllowOrigins {
+		if origin != "*" && !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
+			return errors.New("bad origin: origins must either be '*' or include http:// or https://")
+		}
+	}
+	return nil
+}
+
+// DefaultConfig returns a generic default configuration mapped to localhost.
+func DefaultConfig() Config {
+	return Config{
+		AllowMethods:     []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"},
+		AllowHeaders:     []string{"Origin", "Content-Length", "Content-Type"},
+		AllowCredentials: true,
+		AllowOriginFunc:  func(origin string) bool { return true },
+		MaxAge:           12 * time.Hour,
+	}
+}
+
+// Default returns the location middleware with default configuration.
+func Default() gin.HandlerFunc {
+	config := DefaultConfig()
+	return New(config)
+}
+
+// New returns the location middleware with user-defined custom configuration.
+func New(config Config) gin.HandlerFunc {
+	cors := newCors(config)
+	return func(c *gin.Context) {
+		cors.applyCors(c)
+	}
+}

+ 85 - 0
utils.go

@@ -0,0 +1,85 @@
+package cors
+
+import (
+	"net/http"
+	"strconv"
+	"strings"
+	"time"
+)
+
+type converter func(string) string
+
+func generateNormalHeaders(c Config) http.Header {
+	headers := make(http.Header)
+	if c.AllowCredentials {
+		headers.Set("Access-Control-Allow-Credentials", "true")
+	}
+	if len(c.ExposeHeaders) > 0 {
+		exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey)
+		headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ","))
+	}
+	if c.AllowAllOrigins {
+		headers.Set("Access-Control-Allow-Origin", "*")
+	} else {
+		headers.Set("Vary", "Origin")
+	}
+	return headers
+}
+
+func generatePreflightHeaders(c Config) http.Header {
+	headers := make(http.Header)
+	if c.AllowCredentials {
+		headers.Set("Access-Control-Allow-Credentials", "true")
+	}
+	if len(c.AllowMethods) > 0 {
+		allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper)
+		value := strings.Join(allowMethods, ",")
+		headers.Set("Access-Control-Allow-Methods", value)
+	}
+	if len(c.AllowHeaders) > 0 {
+		allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey)
+		value := strings.Join(allowHeaders, ",")
+		headers.Set("Access-Control-Allow-Headers", value)
+	}
+	if c.MaxAge > time.Duration(0) {
+		value := strconv.FormatInt(int64(c.MaxAge/time.Second), 10)
+		headers.Set("Access-Control-Max-Age", value)
+	}
+	if c.AllowAllOrigins {
+		headers.Set("Access-Control-Allow-Origin", "*")
+	} else {
+		// Always set Vary headers
+		// see https://github.com/rs/cors/issues/10,
+		// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
+
+		headers.Add("Vary", "Origin")
+		headers.Add("Vary", "Access-Control-Request-Method")
+		headers.Add("Vary", "Access-Control-Request-Headers")
+	}
+	return headers
+}
+
+func normalize(values []string) []string {
+	if values == nil {
+		return nil
+	}
+	distinctMap := make(map[string]bool, len(values))
+	normalized := make([]string, 0, len(values))
+	for _, value := range values {
+		value = strings.TrimSpace(value)
+		value = strings.ToLower(value)
+		if _, seen := distinctMap[value]; !seen {
+			normalized = append(normalized, value)
+			distinctMap[value] = true
+		}
+	}
+	return normalized
+}
+
+func convert(s []string, c converter) []string {
+	var out []string
+	for _, i := range s {
+		out = append(out, c(i))
+	}
+	return out
+}