swagger_webservice.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. package swagger
  2. import (
  3. "fmt"
  4. "github.com/emicklei/go-restful"
  5. // "github.com/emicklei/hopwatch"
  6. "net/http"
  7. "reflect"
  8. "sort"
  9. "strings"
  10. "github.com/emicklei/go-restful/log"
  11. )
  12. type SwaggerService struct {
  13. config Config
  14. apiDeclarationMap *ApiDeclarationList
  15. }
  16. func newSwaggerService(config Config) *SwaggerService {
  17. sws := &SwaggerService{
  18. config: config,
  19. apiDeclarationMap: new(ApiDeclarationList)}
  20. // Build all ApiDeclarations
  21. for _, each := range config.WebServices {
  22. rootPath := each.RootPath()
  23. // skip the api service itself
  24. if rootPath != config.ApiPath {
  25. if rootPath == "" || rootPath == "/" {
  26. // use routes
  27. for _, route := range each.Routes() {
  28. entry := staticPathFromRoute(route)
  29. _, exists := sws.apiDeclarationMap.At(entry)
  30. if !exists {
  31. sws.apiDeclarationMap.Put(entry, sws.composeDeclaration(each, entry))
  32. }
  33. }
  34. } else { // use root path
  35. sws.apiDeclarationMap.Put(each.RootPath(), sws.composeDeclaration(each, each.RootPath()))
  36. }
  37. }
  38. }
  39. // if specified then call the PostBuilderHandler
  40. if config.PostBuildHandler != nil {
  41. config.PostBuildHandler(sws.apiDeclarationMap)
  42. }
  43. return sws
  44. }
  45. // LogInfo is the function that is called when this package needs to log. It defaults to log.Printf
  46. var LogInfo = func(format string, v ...interface{}) {
  47. // use the restful package-wide logger
  48. log.Printf(format, v...)
  49. }
  50. // InstallSwaggerService add the WebService that provides the API documentation of all services
  51. // conform the Swagger documentation specifcation. (https://github.com/wordnik/swagger-core/wiki).
  52. func InstallSwaggerService(aSwaggerConfig Config) {
  53. RegisterSwaggerService(aSwaggerConfig, restful.DefaultContainer)
  54. }
  55. // RegisterSwaggerService add the WebService that provides the API documentation of all services
  56. // conform the Swagger documentation specifcation. (https://github.com/wordnik/swagger-core/wiki).
  57. func RegisterSwaggerService(config Config, wsContainer *restful.Container) {
  58. sws := newSwaggerService(config)
  59. ws := new(restful.WebService)
  60. ws.Path(config.ApiPath)
  61. ws.Produces(restful.MIME_JSON)
  62. if config.DisableCORS {
  63. ws.Filter(enableCORS)
  64. }
  65. ws.Route(ws.GET("/").To(sws.getListing))
  66. ws.Route(ws.GET("/{a}").To(sws.getDeclarations))
  67. ws.Route(ws.GET("/{a}/{b}").To(sws.getDeclarations))
  68. ws.Route(ws.GET("/{a}/{b}/{c}").To(sws.getDeclarations))
  69. ws.Route(ws.GET("/{a}/{b}/{c}/{d}").To(sws.getDeclarations))
  70. ws.Route(ws.GET("/{a}/{b}/{c}/{d}/{e}").To(sws.getDeclarations))
  71. ws.Route(ws.GET("/{a}/{b}/{c}/{d}/{e}/{f}").To(sws.getDeclarations))
  72. ws.Route(ws.GET("/{a}/{b}/{c}/{d}/{e}/{f}/{g}").To(sws.getDeclarations))
  73. LogInfo("[restful/swagger] listing is available at %v%v", config.WebServicesUrl, config.ApiPath)
  74. wsContainer.Add(ws)
  75. // Check paths for UI serving
  76. if config.StaticHandler == nil && config.SwaggerFilePath != "" && config.SwaggerPath != "" {
  77. swaggerPathSlash := config.SwaggerPath
  78. // path must end with slash /
  79. if "/" != config.SwaggerPath[len(config.SwaggerPath)-1:] {
  80. LogInfo("[restful/swagger] use corrected SwaggerPath ; must end with slash (/)")
  81. swaggerPathSlash += "/"
  82. }
  83. LogInfo("[restful/swagger] %v%v is mapped to folder %v", config.WebServicesUrl, swaggerPathSlash, config.SwaggerFilePath)
  84. wsContainer.Handle(swaggerPathSlash, http.StripPrefix(swaggerPathSlash, http.FileServer(http.Dir(config.SwaggerFilePath))))
  85. //if we define a custom static handler use it
  86. } else if config.StaticHandler != nil && config.SwaggerPath != "" {
  87. swaggerPathSlash := config.SwaggerPath
  88. // path must end with slash /
  89. if "/" != config.SwaggerPath[len(config.SwaggerPath)-1:] {
  90. LogInfo("[restful/swagger] use corrected SwaggerFilePath ; must end with slash (/)")
  91. swaggerPathSlash += "/"
  92. }
  93. LogInfo("[restful/swagger] %v%v is mapped to custom Handler %T", config.WebServicesUrl, swaggerPathSlash, config.StaticHandler)
  94. wsContainer.Handle(swaggerPathSlash, config.StaticHandler)
  95. } else {
  96. LogInfo("[restful/swagger] Swagger(File)Path is empty ; no UI is served")
  97. }
  98. }
  99. func staticPathFromRoute(r restful.Route) string {
  100. static := r.Path
  101. bracket := strings.Index(static, "{")
  102. if bracket <= 1 { // result cannot be empty
  103. return static
  104. }
  105. if bracket != -1 {
  106. static = r.Path[:bracket]
  107. }
  108. if strings.HasSuffix(static, "/") {
  109. return static[:len(static)-1]
  110. } else {
  111. return static
  112. }
  113. }
  114. func enableCORS(req *restful.Request, resp *restful.Response, chain *restful.FilterChain) {
  115. if origin := req.HeaderParameter(restful.HEADER_Origin); origin != "" {
  116. // prevent duplicate header
  117. if len(resp.Header().Get(restful.HEADER_AccessControlAllowOrigin)) == 0 {
  118. resp.AddHeader(restful.HEADER_AccessControlAllowOrigin, origin)
  119. }
  120. }
  121. chain.ProcessFilter(req, resp)
  122. }
  123. func (sws SwaggerService) getListing(req *restful.Request, resp *restful.Response) {
  124. listing := sws.produceListing()
  125. resp.WriteAsJson(listing)
  126. }
  127. func (sws SwaggerService) produceListing() ResourceListing {
  128. listing := ResourceListing{SwaggerVersion: swaggerVersion, ApiVersion: sws.config.ApiVersion, Info: sws.config.Info}
  129. sws.apiDeclarationMap.Do(func(k string, v ApiDeclaration) {
  130. ref := Resource{Path: k}
  131. if len(v.Apis) > 0 { // use description of first (could still be empty)
  132. ref.Description = v.Apis[0].Description
  133. }
  134. listing.Apis = append(listing.Apis, ref)
  135. })
  136. return listing
  137. }
  138. func (sws SwaggerService) getDeclarations(req *restful.Request, resp *restful.Response) {
  139. decl, ok := sws.produceDeclarations(composeRootPath(req))
  140. if !ok {
  141. resp.WriteErrorString(http.StatusNotFound, "ApiDeclaration not found")
  142. return
  143. }
  144. // unless WebServicesUrl is given
  145. if len(sws.config.WebServicesUrl) == 0 {
  146. // update base path from the actual request
  147. // TODO how to detect https? assume http for now
  148. var host string
  149. // X-Forwarded-Host or Host or Request.Host
  150. hostvalues, ok := req.Request.Header["X-Forwarded-Host"] // apache specific?
  151. if !ok || len(hostvalues) == 0 {
  152. forwarded, ok := req.Request.Header["Host"] // without reverse-proxy
  153. if !ok || len(forwarded) == 0 {
  154. // fallback to Host field
  155. host = req.Request.Host
  156. } else {
  157. host = forwarded[0]
  158. }
  159. } else {
  160. host = hostvalues[0]
  161. }
  162. // inspect Referer for the scheme (http vs https)
  163. scheme := "http"
  164. if referer := req.Request.Header["Referer"]; len(referer) > 0 {
  165. if strings.HasPrefix(referer[0], "https") {
  166. scheme = "https"
  167. }
  168. }
  169. decl.BasePath = fmt.Sprintf("%s://%s", scheme, host)
  170. }
  171. resp.WriteAsJson(decl)
  172. }
  173. func (sws SwaggerService) produceAllDeclarations() map[string]ApiDeclaration {
  174. decls := map[string]ApiDeclaration{}
  175. sws.apiDeclarationMap.Do(func(k string, v ApiDeclaration) {
  176. decls[k] = v
  177. })
  178. return decls
  179. }
  180. func (sws SwaggerService) produceDeclarations(route string) (*ApiDeclaration, bool) {
  181. decl, ok := sws.apiDeclarationMap.At(route)
  182. if !ok {
  183. return nil, false
  184. }
  185. decl.BasePath = sws.config.WebServicesUrl
  186. return &decl, true
  187. }
  188. // composeDeclaration uses all routes and parameters to create a ApiDeclaration
  189. func (sws SwaggerService) composeDeclaration(ws *restful.WebService, pathPrefix string) ApiDeclaration {
  190. decl := ApiDeclaration{
  191. SwaggerVersion: swaggerVersion,
  192. BasePath: sws.config.WebServicesUrl,
  193. ResourcePath: pathPrefix,
  194. Models: ModelList{},
  195. ApiVersion: ws.Version()}
  196. // collect any path parameters
  197. rootParams := []Parameter{}
  198. for _, param := range ws.PathParameters() {
  199. rootParams = append(rootParams, asSwaggerParameter(param.Data()))
  200. }
  201. // aggregate by path
  202. pathToRoutes := newOrderedRouteMap()
  203. for _, other := range ws.Routes() {
  204. if strings.HasPrefix(other.Path, pathPrefix) {
  205. pathToRoutes.Add(other.Path, other)
  206. }
  207. }
  208. pathToRoutes.Do(func(path string, routes []restful.Route) {
  209. api := Api{Path: strings.TrimSuffix(withoutWildcard(path), "/"), Description: ws.Documentation()}
  210. voidString := "void"
  211. for _, route := range routes {
  212. operation := Operation{
  213. Method: route.Method,
  214. Summary: route.Doc,
  215. Notes: route.Notes,
  216. // Type gets overwritten if there is a write sample
  217. DataTypeFields: DataTypeFields{Type: &voidString},
  218. Parameters: []Parameter{},
  219. Nickname: route.Operation,
  220. ResponseMessages: composeResponseMessages(route, &decl, &sws.config)}
  221. operation.Consumes = route.Consumes
  222. operation.Produces = route.Produces
  223. // share root params if any
  224. for _, swparam := range rootParams {
  225. operation.Parameters = append(operation.Parameters, swparam)
  226. }
  227. // route specific params
  228. for _, param := range route.ParameterDocs {
  229. operation.Parameters = append(operation.Parameters, asSwaggerParameter(param.Data()))
  230. }
  231. sws.addModelsFromRouteTo(&operation, route, &decl)
  232. api.Operations = append(api.Operations, operation)
  233. }
  234. decl.Apis = append(decl.Apis, api)
  235. })
  236. return decl
  237. }
  238. func withoutWildcard(path string) string {
  239. if strings.HasSuffix(path, ":*}") {
  240. return path[0:len(path)-3] + "}"
  241. }
  242. return path
  243. }
  244. // composeResponseMessages takes the ResponseErrors (if any) and creates ResponseMessages from them.
  245. func composeResponseMessages(route restful.Route, decl *ApiDeclaration, config *Config) (messages []ResponseMessage) {
  246. if route.ResponseErrors == nil {
  247. return messages
  248. }
  249. // sort by code
  250. codes := sort.IntSlice{}
  251. for code, _ := range route.ResponseErrors {
  252. codes = append(codes, code)
  253. }
  254. codes.Sort()
  255. for _, code := range codes {
  256. each := route.ResponseErrors[code]
  257. message := ResponseMessage{
  258. Code: code,
  259. Message: each.Message,
  260. }
  261. if each.Model != nil {
  262. st := reflect.TypeOf(each.Model)
  263. isCollection, st := detectCollectionType(st)
  264. modelName := modelBuilder{}.keyFrom(st)
  265. if isCollection {
  266. modelName = "array[" + modelName + "]"
  267. }
  268. modelBuilder{Models: &decl.Models, Config: config}.addModel(st, "")
  269. // reference the model
  270. message.ResponseModel = modelName
  271. }
  272. messages = append(messages, message)
  273. }
  274. return
  275. }
  276. // addModelsFromRoute takes any read or write sample from the Route and creates a Swagger model from it.
  277. func (sws SwaggerService) addModelsFromRouteTo(operation *Operation, route restful.Route, decl *ApiDeclaration) {
  278. if route.ReadSample != nil {
  279. sws.addModelFromSampleTo(operation, false, route.ReadSample, &decl.Models)
  280. }
  281. if route.WriteSample != nil {
  282. sws.addModelFromSampleTo(operation, true, route.WriteSample, &decl.Models)
  283. }
  284. }
  285. func detectCollectionType(st reflect.Type) (bool, reflect.Type) {
  286. isCollection := false
  287. if st.Kind() == reflect.Slice || st.Kind() == reflect.Array {
  288. st = st.Elem()
  289. isCollection = true
  290. } else {
  291. if st.Kind() == reflect.Ptr {
  292. if st.Elem().Kind() == reflect.Slice || st.Elem().Kind() == reflect.Array {
  293. st = st.Elem().Elem()
  294. isCollection = true
  295. }
  296. }
  297. }
  298. return isCollection, st
  299. }
  300. // addModelFromSample creates and adds (or overwrites) a Model from a sample resource
  301. func (sws SwaggerService) addModelFromSampleTo(operation *Operation, isResponse bool, sample interface{}, models *ModelList) {
  302. if isResponse {
  303. type_, items := asDataType(sample, &sws.config)
  304. operation.Type = type_
  305. operation.Items = items
  306. }
  307. modelBuilder{Models: models, Config: &sws.config}.addModelFrom(sample)
  308. }
  309. func asSwaggerParameter(param restful.ParameterData) Parameter {
  310. return Parameter{
  311. DataTypeFields: DataTypeFields{
  312. Type: &param.DataType,
  313. Format: asFormat(param.DataType, param.DataFormat),
  314. DefaultValue: Special(param.DefaultValue),
  315. },
  316. Name: param.Name,
  317. Description: param.Description,
  318. ParamType: asParamType(param.Kind),
  319. Required: param.Required}
  320. }
  321. // Between 1..7 path parameters is supported
  322. func composeRootPath(req *restful.Request) string {
  323. path := "/" + req.PathParameter("a")
  324. b := req.PathParameter("b")
  325. if b == "" {
  326. return path
  327. }
  328. path = path + "/" + b
  329. c := req.PathParameter("c")
  330. if c == "" {
  331. return path
  332. }
  333. path = path + "/" + c
  334. d := req.PathParameter("d")
  335. if d == "" {
  336. return path
  337. }
  338. path = path + "/" + d
  339. e := req.PathParameter("e")
  340. if e == "" {
  341. return path
  342. }
  343. path = path + "/" + e
  344. f := req.PathParameter("f")
  345. if f == "" {
  346. return path
  347. }
  348. path = path + "/" + f
  349. g := req.PathParameter("g")
  350. if g == "" {
  351. return path
  352. }
  353. return path + "/" + g
  354. }
  355. func asFormat(dataType string, dataFormat string) string {
  356. if dataFormat != "" {
  357. return dataFormat
  358. }
  359. return "" // TODO
  360. }
  361. func asParamType(kind int) string {
  362. switch {
  363. case kind == restful.PathParameterKind:
  364. return "path"
  365. case kind == restful.QueryParameterKind:
  366. return "query"
  367. case kind == restful.BodyParameterKind:
  368. return "body"
  369. case kind == restful.HeaderParameterKind:
  370. return "header"
  371. case kind == restful.FormParameterKind:
  372. return "form"
  373. }
  374. return ""
  375. }
  376. func asDataType(any interface{}, config *Config) (*string, *Item) {
  377. // If it's not a collection, return the suggested model name
  378. st := reflect.TypeOf(any)
  379. isCollection, st := detectCollectionType(st)
  380. modelName := modelBuilder{}.keyFrom(st)
  381. // if it's not a collection we are done
  382. if !isCollection {
  383. return &modelName, nil
  384. }
  385. // XXX: This is not very elegant
  386. // We create an Item object referring to the given model
  387. models := ModelList{}
  388. mb := modelBuilder{Models: &models, Config: config}
  389. mb.addModelFrom(any)
  390. elemTypeName := mb.getElementTypeName(modelName, "", st)
  391. item := new(Item)
  392. if mb.isPrimitiveType(elemTypeName) {
  393. mapped := mb.jsonSchemaType(elemTypeName)
  394. item.Type = &mapped
  395. } else {
  396. item.Ref = &elemTypeName
  397. }
  398. tmp := "array"
  399. return &tmp, item
  400. }