generate.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. package main
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "net/url"
  7. "os"
  8. "os/exec"
  9. "regexp"
  10. "sort"
  11. "strconv"
  12. "strings"
  13. "text/template"
  14. "github.com/aws/aws-sdk-go/internal/fixtures/helpers"
  15. "github.com/aws/aws-sdk-go/internal/model/api"
  16. "github.com/aws/aws-sdk-go/internal/util"
  17. "github.com/aws/aws-sdk-go/internal/util/utilassert"
  18. )
  19. type testSuite struct {
  20. *api.API
  21. Description string
  22. Cases []testCase
  23. title string
  24. }
  25. type testCase struct {
  26. TestSuite *testSuite
  27. Given *api.Operation
  28. Params interface{} `json:",omitempty"`
  29. Data interface{} `json:"result,omitempty"`
  30. InputTest testExpectation `json:"serialized"`
  31. OutputTest testExpectation `json:"response"`
  32. }
  33. type testExpectation struct {
  34. Body string
  35. URI string
  36. Headers map[string]string
  37. StatusCode uint `json:"status_code"`
  38. }
  39. const preamble = `
  40. var _ bytes.Buffer // always import bytes
  41. var _ http.Request
  42. var _ json.Marshaler
  43. var _ time.Time
  44. var _ xmlutil.XMLNode
  45. var _ xml.Attr
  46. var _ = ioutil.Discard
  47. var _ = util.Trim("")
  48. var _ = url.Values{}
  49. var _ = io.EOF
  50. `
  51. var reStripSpace = regexp.MustCompile(`\s(\w)`)
  52. var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`)
  53. func removeImports(code string) string {
  54. return reImportRemoval.ReplaceAllString(code, "")
  55. }
  56. var extraImports = []string{
  57. "bytes",
  58. "encoding/json",
  59. "encoding/xml",
  60. "io",
  61. "io/ioutil",
  62. "net/http",
  63. "testing",
  64. "time",
  65. "net/url",
  66. "",
  67. "github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil",
  68. "github.com/aws/aws-sdk-go/internal/util",
  69. "github.com/stretchr/testify/assert",
  70. }
  71. func addImports(code string) string {
  72. importNames := make([]string, len(extraImports))
  73. for i, n := range extraImports {
  74. if n != "" {
  75. importNames[i] = fmt.Sprintf("%q", n)
  76. }
  77. }
  78. str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)")
  79. return str
  80. }
  81. func (t *testSuite) TestSuite() string {
  82. var buf bytes.Buffer
  83. t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string {
  84. return strings.ToUpper(x[1:])
  85. })
  86. t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "")
  87. for idx, c := range t.Cases {
  88. c.TestSuite = t
  89. buf.WriteString(c.TestCase(idx) + "\n")
  90. }
  91. return buf.String()
  92. }
  93. var tplInputTestCase = template.Must(template.New("inputcase").Parse(`
  94. func Test{{ .OpName }}(t *testing.T) {
  95. svc := New{{ .TestCase.TestSuite.API.StructName }}(nil)
  96. svc.Endpoint = "https://test"
  97. input := {{ .ParamsString }}
  98. req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input)
  99. r := req.HTTPRequest
  100. // build request
  101. {{ .TestCase.TestSuite.API.ProtocolPackage }}.Build(req)
  102. assert.NoError(t, req.Error)
  103. {{ if ne .TestCase.InputTest.Body "" }}// assert body
  104. assert.NotNil(t, r.Body)
  105. {{ .BodyAssertions }}{{ end }}
  106. {{ if ne .TestCase.InputTest.URI "" }}// assert URL
  107. assert.Equal(t, "https://test{{ .TestCase.InputTest.URI }}", r.URL.String()){{ end }}
  108. // assert headers
  109. {{ range $k, $v := .TestCase.InputTest.Headers }}assert.Equal(t, "{{ $v }}", r.Header.Get("{{ $k }}"))
  110. {{ end }}
  111. }
  112. `))
  113. type tplInputTestCaseData struct {
  114. TestCase *testCase
  115. OpName, ParamsString string
  116. }
  117. func (t tplInputTestCaseData) BodyAssertions() string {
  118. protocol, code := t.TestCase.TestSuite.API.Metadata.Protocol, ""
  119. switch protocol {
  120. case "rest-xml":
  121. code += "body := util.SortXML(r.Body)\n"
  122. default:
  123. code += "body, _ := ioutil.ReadAll(r.Body)\n"
  124. }
  125. code += "assert.Equal(t, util.Trim(`" + t.TestCase.InputTest.Body + "`), util.Trim(string(body)))"
  126. return code
  127. }
  128. var tplOutputTestCase = template.Must(template.New("outputcase").Parse(`
  129. func Test{{ .OpName }}(t *testing.T) {
  130. svc := New{{ .TestCase.TestSuite.API.StructName }}(nil)
  131. buf := bytes.NewReader([]byte({{ .Body }}))
  132. req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil)
  133. req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
  134. // set headers
  135. {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}")
  136. {{ end }}
  137. // unmarshal response
  138. {{ .TestCase.TestSuite.API.ProtocolPackage }}.UnmarshalMeta(req)
  139. {{ .TestCase.TestSuite.API.ProtocolPackage }}.Unmarshal(req)
  140. assert.NoError(t, req.Error)
  141. // assert response
  142. assert.NotNil(t, out) // ensure out variable is used
  143. {{ .Assertions }}
  144. }
  145. `))
  146. type tplOutputTestCaseData struct {
  147. TestCase *testCase
  148. Body, OpName, Assertions string
  149. }
  150. func (i *testCase) TestCase(idx int) string {
  151. var buf bytes.Buffer
  152. opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1)
  153. if i.Params != nil { // input test
  154. // query test should sort body as form encoded values
  155. switch i.TestSuite.API.Metadata.Protocol {
  156. case "query", "ec2":
  157. m, _ := url.ParseQuery(i.InputTest.Body)
  158. i.InputTest.Body = m.Encode()
  159. case "rest-xml":
  160. i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body)))
  161. case "json", "rest-json":
  162. i.InputTest.Body = strings.Replace(i.InputTest.Body, " ", "", -1)
  163. }
  164. input := tplInputTestCaseData{
  165. TestCase: i,
  166. OpName: strings.ToUpper(opName[0:1]) + opName[1:],
  167. ParamsString: helpers.ParamsStructFromJSON(i.Params, i.Given.InputRef.Shape, false),
  168. }
  169. if err := tplInputTestCase.Execute(&buf, input); err != nil {
  170. panic(err)
  171. }
  172. } else {
  173. output := tplOutputTestCaseData{
  174. TestCase: i,
  175. Body: fmt.Sprintf("%q", i.OutputTest.Body),
  176. OpName: strings.ToUpper(opName[0:1]) + opName[1:],
  177. Assertions: utilassert.GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"),
  178. }
  179. if err := tplOutputTestCase.Execute(&buf, output); err != nil {
  180. panic(err)
  181. }
  182. }
  183. return buf.String()
  184. }
  185. // generateTestSuite generates a protocol test suite for a given configuration
  186. // JSON protocol test file.
  187. func generateTestSuite(filename string) string {
  188. inout := "Input"
  189. if strings.Contains(filename, "output/") {
  190. inout = "Output"
  191. }
  192. var suites []testSuite
  193. f, err := os.Open(filename)
  194. if err != nil {
  195. panic(err)
  196. }
  197. err = json.NewDecoder(f).Decode(&suites)
  198. if err != nil {
  199. panic(err)
  200. }
  201. var buf bytes.Buffer
  202. buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n")
  203. var innerBuf bytes.Buffer
  204. innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n")
  205. for i, suite := range suites {
  206. svcPrefix := inout + "Service" + strconv.Itoa(i+1)
  207. suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest"
  208. suite.API.Operations = map[string]*api.Operation{}
  209. for idx, c := range suite.Cases {
  210. c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1)
  211. suite.API.Operations[c.Given.ExportedName] = c.Given
  212. }
  213. suite.API.NoInflections = true // don't require inflections
  214. suite.API.NoInitMethods = true // don't generate init methods
  215. suite.API.NoStringerMethods = true // don't generate stringer methods
  216. suite.API.Setup()
  217. suite.API.Metadata.EndpointPrefix = suite.API.PackageName()
  218. // Sort in order for deterministic test generation
  219. names := make([]string, 0, len(suite.API.Shapes))
  220. for n := range suite.API.Shapes {
  221. names = append(names, n)
  222. }
  223. sort.Strings(names)
  224. for _, name := range names {
  225. s := suite.API.Shapes[name]
  226. s.Rename(svcPrefix + "TestShape" + name)
  227. }
  228. svcCode := addImports(suite.API.ServiceGoCode())
  229. if i == 0 {
  230. importMatch := reImportRemoval.FindStringSubmatch(svcCode)
  231. buf.WriteString(importMatch[0] + "\n\n")
  232. buf.WriteString(preamble + "\n\n")
  233. }
  234. svcCode = removeImports(svcCode)
  235. svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1)
  236. buf.WriteString(svcCode + "\n\n")
  237. apiCode := removeImports(suite.API.APIGoCode())
  238. apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1)
  239. apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1)
  240. apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1)
  241. buf.WriteString(apiCode + "\n\n")
  242. innerBuf.WriteString(suite.TestSuite() + "\n")
  243. }
  244. return buf.String() + innerBuf.String()
  245. }
  246. func main() {
  247. out := generateTestSuite(os.Args[1])
  248. if len(os.Args) == 3 {
  249. f, err := os.Create(os.Args[2])
  250. defer f.Close()
  251. if err != nil {
  252. panic(err)
  253. }
  254. f.WriteString(util.GoFmt(out))
  255. f.Close()
  256. c := exec.Command("gofmt", "-s", "-w", os.Args[2])
  257. if err := c.Run(); err != nil {
  258. panic(err)
  259. }
  260. } else {
  261. fmt.Println(out)
  262. }
  263. }