container.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. package restful
  2. // Copyright 2013 Ernest Micklei. All rights reserved.
  3. // Use of this source code is governed by a license
  4. // that can be found in the LICENSE file.
  5. import (
  6. "bytes"
  7. "errors"
  8. "fmt"
  9. "net/http"
  10. "os"
  11. "runtime"
  12. "strings"
  13. "sync"
  14. "github.com/emicklei/go-restful/log"
  15. )
  16. // Container holds a collection of WebServices and a http.ServeMux to dispatch http requests.
  17. // The requests are further dispatched to routes of WebServices using a RouteSelector
  18. type Container struct {
  19. webServicesLock sync.RWMutex
  20. webServices []*WebService
  21. ServeMux *http.ServeMux
  22. isRegisteredOnRoot bool
  23. containerFilters []FilterFunction
  24. doNotRecover bool // default is true
  25. recoverHandleFunc RecoverHandleFunction
  26. serviceErrorHandleFunc ServiceErrorHandleFunction
  27. router RouteSelector // default is a CurlyRouter (RouterJSR311 is a slower alternative)
  28. contentEncodingEnabled bool // default is false
  29. }
  30. // NewContainer creates a new Container using a new ServeMux and default router (RouterJSR311)
  31. func NewContainer() *Container {
  32. return &Container{
  33. webServices: []*WebService{},
  34. ServeMux: http.NewServeMux(),
  35. isRegisteredOnRoot: false,
  36. containerFilters: []FilterFunction{},
  37. doNotRecover: true,
  38. recoverHandleFunc: logStackOnRecover,
  39. serviceErrorHandleFunc: writeServiceError,
  40. router: CurlyRouter{},
  41. contentEncodingEnabled: false}
  42. }
  43. // RecoverHandleFunction declares functions that can be used to handle a panic situation.
  44. // The first argument is what recover() returns. The second must be used to communicate an error response.
  45. type RecoverHandleFunction func(interface{}, http.ResponseWriter)
  46. // RecoverHandler changes the default function (logStackOnRecover) to be called
  47. // when a panic is detected. DoNotRecover must be have its default value (=false).
  48. func (c *Container) RecoverHandler(handler RecoverHandleFunction) {
  49. c.recoverHandleFunc = handler
  50. }
  51. // ServiceErrorHandleFunction declares functions that can be used to handle a service error situation.
  52. // The first argument is the service error, the second is the request that resulted in the error and
  53. // the third must be used to communicate an error response.
  54. type ServiceErrorHandleFunction func(ServiceError, *Request, *Response)
  55. // ServiceErrorHandler changes the default function (writeServiceError) to be called
  56. // when a ServiceError is detected.
  57. func (c *Container) ServiceErrorHandler(handler ServiceErrorHandleFunction) {
  58. c.serviceErrorHandleFunc = handler
  59. }
  60. // DoNotRecover controls whether panics will be caught to return HTTP 500.
  61. // If set to true, Route functions are responsible for handling any error situation.
  62. // Default value is true.
  63. func (c *Container) DoNotRecover(doNot bool) {
  64. c.doNotRecover = doNot
  65. }
  66. // Router changes the default Router (currently RouterJSR311)
  67. func (c *Container) Router(aRouter RouteSelector) {
  68. c.router = aRouter
  69. }
  70. // EnableContentEncoding (default=false) allows for GZIP or DEFLATE encoding of responses.
  71. func (c *Container) EnableContentEncoding(enabled bool) {
  72. c.contentEncodingEnabled = enabled
  73. }
  74. // Add a WebService to the Container. It will detect duplicate root paths and exit in that case.
  75. func (c *Container) Add(service *WebService) *Container {
  76. c.webServicesLock.Lock()
  77. defer c.webServicesLock.Unlock()
  78. // if rootPath was not set then lazy initialize it
  79. if len(service.rootPath) == 0 {
  80. service.Path("/")
  81. }
  82. // cannot have duplicate root paths
  83. for _, each := range c.webServices {
  84. if each.RootPath() == service.RootPath() {
  85. log.Printf("[restful] WebService with duplicate root path detected:['%v']", each)
  86. os.Exit(1)
  87. }
  88. }
  89. // If not registered on root then add specific mapping
  90. if !c.isRegisteredOnRoot {
  91. c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
  92. }
  93. c.webServices = append(c.webServices, service)
  94. return c
  95. }
  96. // addHandler may set a new HandleFunc for the serveMux
  97. // this function must run inside the critical region protected by the webServicesLock.
  98. // returns true if the function was registered on root ("/")
  99. func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
  100. pattern := fixedPrefixPath(service.RootPath())
  101. // check if root path registration is needed
  102. if "/" == pattern || "" == pattern {
  103. serveMux.HandleFunc("/", c.dispatch)
  104. return true
  105. }
  106. // detect if registration already exists
  107. alreadyMapped := false
  108. for _, each := range c.webServices {
  109. if each.RootPath() == service.RootPath() {
  110. alreadyMapped = true
  111. break
  112. }
  113. }
  114. if !alreadyMapped {
  115. serveMux.HandleFunc(pattern, c.dispatch)
  116. if !strings.HasSuffix(pattern, "/") {
  117. serveMux.HandleFunc(pattern+"/", c.dispatch)
  118. }
  119. }
  120. return false
  121. }
  122. func (c *Container) Remove(ws *WebService) error {
  123. if c.ServeMux == http.DefaultServeMux {
  124. errMsg := fmt.Sprintf("[restful] cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
  125. log.Printf(errMsg)
  126. return errors.New(errMsg)
  127. }
  128. c.webServicesLock.Lock()
  129. defer c.webServicesLock.Unlock()
  130. // build a new ServeMux and re-register all WebServices
  131. newServeMux := http.NewServeMux()
  132. newServices := []*WebService{}
  133. newIsRegisteredOnRoot := false
  134. for _, each := range c.webServices {
  135. if each.rootPath != ws.rootPath {
  136. // If not registered on root then add specific mapping
  137. if !newIsRegisteredOnRoot {
  138. newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
  139. }
  140. newServices = append(newServices, each)
  141. }
  142. }
  143. c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
  144. return nil
  145. }
  146. // logStackOnRecover is the default RecoverHandleFunction and is called
  147. // when DoNotRecover is false and the recoverHandleFunc is not set for the container.
  148. // Default implementation logs the stacktrace and writes the stacktrace on the response.
  149. // This may be a security issue as it exposes sourcecode information.
  150. func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
  151. var buffer bytes.Buffer
  152. buffer.WriteString(fmt.Sprintf("[restful] recover from panic situation: - %v\r\n", panicReason))
  153. for i := 2; ; i += 1 {
  154. _, file, line, ok := runtime.Caller(i)
  155. if !ok {
  156. break
  157. }
  158. buffer.WriteString(fmt.Sprintf(" %s:%d\r\n", file, line))
  159. }
  160. log.Print(buffer.String())
  161. httpWriter.WriteHeader(http.StatusInternalServerError)
  162. httpWriter.Write(buffer.Bytes())
  163. }
  164. // writeServiceError is the default ServiceErrorHandleFunction and is called
  165. // when a ServiceError is returned during route selection. Default implementation
  166. // calls resp.WriteErrorString(err.Code, err.Message)
  167. func writeServiceError(err ServiceError, req *Request, resp *Response) {
  168. resp.WriteErrorString(err.Code, err.Message)
  169. }
  170. // Dispatch the incoming Http Request to a matching WebService.
  171. func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
  172. writer := httpWriter
  173. // CompressingResponseWriter should be closed after all operations are done
  174. defer func() {
  175. if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
  176. compressWriter.Close()
  177. }
  178. }()
  179. // Instal panic recovery unless told otherwise
  180. if !c.doNotRecover { // catch all for 500 response
  181. defer func() {
  182. if r := recover(); r != nil {
  183. c.recoverHandleFunc(r, writer)
  184. return
  185. }
  186. }()
  187. }
  188. // Install closing the request body (if any)
  189. defer func() {
  190. if nil != httpRequest.Body {
  191. httpRequest.Body.Close()
  192. }
  193. }()
  194. // Detect if compression is needed
  195. // assume without compression, test for override
  196. if c.contentEncodingEnabled {
  197. doCompress, encoding := wantsCompressedResponse(httpRequest)
  198. if doCompress {
  199. var err error
  200. writer, err = NewCompressingResponseWriter(httpWriter, encoding)
  201. if err != nil {
  202. log.Print("[restful] unable to install compressor: ", err)
  203. httpWriter.WriteHeader(http.StatusInternalServerError)
  204. return
  205. }
  206. }
  207. }
  208. // Find best match Route ; err is non nil if no match was found
  209. var webService *WebService
  210. var route *Route
  211. var err error
  212. func() {
  213. c.webServicesLock.RLock()
  214. defer c.webServicesLock.RUnlock()
  215. webService, route, err = c.router.SelectRoute(
  216. c.webServices,
  217. httpRequest)
  218. }()
  219. if err != nil {
  220. // a non-200 response has already been written
  221. // run container filters anyway ; they should not touch the response...
  222. chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
  223. switch err.(type) {
  224. case ServiceError:
  225. ser := err.(ServiceError)
  226. c.serviceErrorHandleFunc(ser, req, resp)
  227. }
  228. // TODO
  229. }}
  230. chain.ProcessFilter(NewRequest(httpRequest), NewResponse(writer))
  231. return
  232. }
  233. wrappedRequest, wrappedResponse := route.wrapRequestResponse(writer, httpRequest)
  234. // pass through filters (if any)
  235. if len(c.containerFilters)+len(webService.filters)+len(route.Filters) > 0 {
  236. // compose filter chain
  237. allFilters := []FilterFunction{}
  238. allFilters = append(allFilters, c.containerFilters...)
  239. allFilters = append(allFilters, webService.filters...)
  240. allFilters = append(allFilters, route.Filters...)
  241. chain := FilterChain{Filters: allFilters, Target: func(req *Request, resp *Response) {
  242. // handle request by route after passing all filters
  243. route.Function(wrappedRequest, wrappedResponse)
  244. }}
  245. chain.ProcessFilter(wrappedRequest, wrappedResponse)
  246. } else {
  247. // no filters, handle request by route
  248. route.Function(wrappedRequest, wrappedResponse)
  249. }
  250. }
  251. // fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {}
  252. func fixedPrefixPath(pathspec string) string {
  253. varBegin := strings.Index(pathspec, "{")
  254. if -1 == varBegin {
  255. return pathspec
  256. }
  257. return pathspec[:varBegin]
  258. }
  259. // ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server
  260. func (c *Container) ServeHTTP(httpwriter http.ResponseWriter, httpRequest *http.Request) {
  261. c.ServeMux.ServeHTTP(httpwriter, httpRequest)
  262. }
  263. // Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics.
  264. func (c *Container) Handle(pattern string, handler http.Handler) {
  265. c.ServeMux.Handle(pattern, handler)
  266. }
  267. // HandleWithFilter registers the handler for the given pattern.
  268. // Container's filter chain is applied for handler.
  269. // If a handler already exists for pattern, HandleWithFilter panics.
  270. func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
  271. f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
  272. if len(c.containerFilters) == 0 {
  273. handler.ServeHTTP(httpResponse, httpRequest)
  274. return
  275. }
  276. chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
  277. handler.ServeHTTP(httpResponse, httpRequest)
  278. }}
  279. chain.ProcessFilter(NewRequest(httpRequest), NewResponse(httpResponse))
  280. }
  281. c.Handle(pattern, http.HandlerFunc(f))
  282. }
  283. // Filter appends a container FilterFunction. These are called before dispatching
  284. // a http.Request to a WebService from the container
  285. func (c *Container) Filter(filter FilterFunction) {
  286. c.containerFilters = append(c.containerFilters, filter)
  287. }
  288. // RegisteredWebServices returns the collections of added WebServices
  289. func (c *Container) RegisteredWebServices() []*WebService {
  290. c.webServicesLock.RLock()
  291. defer c.webServicesLock.RUnlock()
  292. result := make([]*WebService, len(c.webServices))
  293. for ix := range c.webServices {
  294. result[ix] = c.webServices[ix]
  295. }
  296. return result
  297. }
  298. // computeAllowedMethods returns a list of HTTP methods that are valid for a Request
  299. func (c *Container) computeAllowedMethods(req *Request) []string {
  300. // Go through all RegisteredWebServices() and all its Routes to collect the options
  301. methods := []string{}
  302. requestPath := req.Request.URL.Path
  303. for _, ws := range c.RegisteredWebServices() {
  304. matches := ws.pathExpr.Matcher.FindStringSubmatch(requestPath)
  305. if matches != nil {
  306. finalMatch := matches[len(matches)-1]
  307. for _, rt := range ws.Routes() {
  308. matches := rt.pathExpr.Matcher.FindStringSubmatch(finalMatch)
  309. if matches != nil {
  310. lastMatch := matches[len(matches)-1]
  311. if lastMatch == "" || lastMatch == "/" { // do not include if value is neither empty nor ‘/’.
  312. methods = append(methods, rt.Method)
  313. }
  314. }
  315. }
  316. }
  317. }
  318. // methods = append(methods, "OPTIONS") not sure about this
  319. return methods
  320. }
  321. // newBasicRequestResponse creates a pair of Request,Response from its http versions.
  322. // It is basic because no parameter or (produces) content-type information is given.
  323. func newBasicRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request) (*Request, *Response) {
  324. resp := NewResponse(httpWriter)
  325. resp.requestAccept = httpRequest.Header.Get(HEADER_Accept)
  326. return NewRequest(httpRequest), resp
  327. }