generate.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. // +build codegen
  2. package main
  3. import (
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "net/url"
  8. "os"
  9. "os/exec"
  10. "reflect"
  11. "regexp"
  12. "sort"
  13. "strconv"
  14. "strings"
  15. "text/template"
  16. "github.com/aws/aws-sdk-go/private/model/api"
  17. "github.com/aws/aws-sdk-go/private/util"
  18. )
  19. // TestSuiteTypeInput input test
  20. // TestSuiteTypeInput output test
  21. const (
  22. TestSuiteTypeInput = iota
  23. TestSuiteTypeOutput
  24. )
  25. type testSuite struct {
  26. *api.API
  27. Description string
  28. Cases []testCase
  29. Type uint
  30. title string
  31. }
  32. type testCase struct {
  33. TestSuite *testSuite
  34. Given *api.Operation
  35. Params interface{} `json:",omitempty"`
  36. Data interface{} `json:"result,omitempty"`
  37. InputTest testExpectation `json:"serialized"`
  38. OutputTest testExpectation `json:"response"`
  39. }
  40. type testExpectation struct {
  41. Body string
  42. URI string
  43. Headers map[string]string
  44. StatusCode uint `json:"status_code"`
  45. }
  46. const preamble = `
  47. var _ bytes.Buffer // always import bytes
  48. var _ http.Request
  49. var _ json.Marshaler
  50. var _ time.Time
  51. var _ xmlutil.XMLNode
  52. var _ xml.Attr
  53. var _ = ioutil.Discard
  54. var _ = util.Trim("")
  55. var _ = url.Values{}
  56. var _ = io.EOF
  57. var _ = aws.String
  58. var _ = fmt.Println
  59. func init() {
  60. protocol.RandReader = &awstesting.ZeroReader{}
  61. }
  62. `
  63. var reStripSpace = regexp.MustCompile(`\s(\w)`)
  64. var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`)
  65. func removeImports(code string) string {
  66. return reImportRemoval.ReplaceAllString(code, "")
  67. }
  68. var extraImports = []string{
  69. "bytes",
  70. "encoding/json",
  71. "encoding/xml",
  72. "fmt",
  73. "io",
  74. "io/ioutil",
  75. "net/http",
  76. "testing",
  77. "time",
  78. "net/url",
  79. "",
  80. "github.com/aws/aws-sdk-go/awstesting",
  81. "github.com/aws/aws-sdk-go/awstesting/unit",
  82. "github.com/aws/aws-sdk-go/private/protocol",
  83. "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
  84. "github.com/aws/aws-sdk-go/private/util",
  85. "github.com/stretchr/testify/assert",
  86. }
  87. func addImports(code string) string {
  88. importNames := make([]string, len(extraImports))
  89. for i, n := range extraImports {
  90. if n != "" {
  91. importNames[i] = fmt.Sprintf("%q", n)
  92. }
  93. }
  94. str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)")
  95. return str
  96. }
  97. func (t *testSuite) TestSuite() string {
  98. var buf bytes.Buffer
  99. t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string {
  100. return strings.ToUpper(x[1:])
  101. })
  102. t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "")
  103. for idx, c := range t.Cases {
  104. c.TestSuite = t
  105. buf.WriteString(c.TestCase(idx) + "\n")
  106. }
  107. return buf.String()
  108. }
  109. var tplInputTestCase = template.Must(template.New("inputcase").Parse(`
  110. func Test{{ .OpName }}(t *testing.T) {
  111. svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
  112. {{ if ne .ParamsString "" }}input := {{ .ParamsString }}
  113. req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }}
  114. r := req.HTTPRequest
  115. // build request
  116. {{ .TestCase.TestSuite.API.ProtocolPackage }}.Build(req)
  117. assert.NoError(t, req.Error)
  118. {{ if ne .TestCase.InputTest.Body "" }}// assert body
  119. assert.NotNil(t, r.Body)
  120. {{ .BodyAssertions }}{{ end }}
  121. {{ if ne .TestCase.InputTest.URI "" }}// assert URL
  122. awstesting.AssertURL(t, "https://test{{ .TestCase.InputTest.URI }}", r.URL.String()){{ end }}
  123. // assert headers
  124. {{ range $k, $v := .TestCase.InputTest.Headers }}assert.Equal(t, "{{ $v }}", r.Header.Get("{{ $k }}"))
  125. {{ end }}
  126. }
  127. `))
  128. type tplInputTestCaseData struct {
  129. TestCase *testCase
  130. OpName, ParamsString string
  131. }
  132. func (t tplInputTestCaseData) BodyAssertions() string {
  133. code := &bytes.Buffer{}
  134. protocol := t.TestCase.TestSuite.API.Metadata.Protocol
  135. // Extract the body bytes
  136. switch protocol {
  137. case "rest-xml":
  138. fmt.Fprintln(code, "body := util.SortXML(r.Body)")
  139. default:
  140. fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)")
  141. }
  142. // Generate the body verification code
  143. expectedBody := util.Trim(t.TestCase.InputTest.Body)
  144. switch protocol {
  145. case "ec2", "query":
  146. fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))",
  147. expectedBody)
  148. case "rest-xml":
  149. if strings.HasPrefix(expectedBody, "<") {
  150. fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(string(body)), %s{})",
  151. expectedBody, t.TestCase.Given.InputRef.ShapeName)
  152. } else {
  153. fmt.Fprintf(code, "assert.Equal(t, `%s`, util.Trim(string(body)))",
  154. expectedBody)
  155. }
  156. case "json", "jsonrpc", "rest-json":
  157. if strings.HasPrefix(expectedBody, "{") {
  158. fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))",
  159. expectedBody)
  160. } else {
  161. fmt.Fprintf(code, "assert.Equal(t, `%s`, util.Trim(string(body)))",
  162. expectedBody)
  163. }
  164. default:
  165. fmt.Fprintf(code, "assert.Equal(t, `%s`, util.Trim(string(body)))",
  166. expectedBody)
  167. }
  168. return code.String()
  169. }
  170. var tplOutputTestCase = template.Must(template.New("outputcase").Parse(`
  171. func Test{{ .OpName }}(t *testing.T) {
  172. svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
  173. buf := bytes.NewReader([]byte({{ .Body }}))
  174. req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil)
  175. req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
  176. // set headers
  177. {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}")
  178. {{ end }}
  179. // unmarshal response
  180. {{ .TestCase.TestSuite.API.ProtocolPackage }}.UnmarshalMeta(req)
  181. {{ .TestCase.TestSuite.API.ProtocolPackage }}.Unmarshal(req)
  182. assert.NoError(t, req.Error)
  183. // assert response
  184. assert.NotNil(t, out) // ensure out variable is used
  185. {{ .Assertions }}
  186. }
  187. `))
  188. type tplOutputTestCaseData struct {
  189. TestCase *testCase
  190. Body, OpName, Assertions string
  191. }
  192. func (i *testCase) TestCase(idx int) string {
  193. var buf bytes.Buffer
  194. opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1)
  195. if i.TestSuite.Type == TestSuiteTypeInput { // input test
  196. // query test should sort body as form encoded values
  197. switch i.TestSuite.API.Metadata.Protocol {
  198. case "query", "ec2":
  199. m, _ := url.ParseQuery(i.InputTest.Body)
  200. i.InputTest.Body = m.Encode()
  201. case "rest-xml":
  202. i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body)))
  203. case "json", "rest-json":
  204. i.InputTest.Body = strings.Replace(i.InputTest.Body, " ", "", -1)
  205. }
  206. input := tplInputTestCaseData{
  207. TestCase: i,
  208. OpName: strings.ToUpper(opName[0:1]) + opName[1:],
  209. ParamsString: api.ParamsStructFromJSON(i.Params, i.Given.InputRef.Shape, false),
  210. }
  211. if err := tplInputTestCase.Execute(&buf, input); err != nil {
  212. panic(err)
  213. }
  214. } else if i.TestSuite.Type == TestSuiteTypeOutput {
  215. output := tplOutputTestCaseData{
  216. TestCase: i,
  217. Body: fmt.Sprintf("%q", i.OutputTest.Body),
  218. OpName: strings.ToUpper(opName[0:1]) + opName[1:],
  219. Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"),
  220. }
  221. if err := tplOutputTestCase.Execute(&buf, output); err != nil {
  222. panic(err)
  223. }
  224. }
  225. return buf.String()
  226. }
  227. // generateTestSuite generates a protocol test suite for a given configuration
  228. // JSON protocol test file.
  229. func generateTestSuite(filename string) string {
  230. inout := "Input"
  231. if strings.Contains(filename, "output/") {
  232. inout = "Output"
  233. }
  234. var suites []testSuite
  235. f, err := os.Open(filename)
  236. if err != nil {
  237. panic(err)
  238. }
  239. err = json.NewDecoder(f).Decode(&suites)
  240. if err != nil {
  241. panic(err)
  242. }
  243. var buf bytes.Buffer
  244. buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n")
  245. var innerBuf bytes.Buffer
  246. innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n")
  247. for i, suite := range suites {
  248. svcPrefix := inout + "Service" + strconv.Itoa(i+1)
  249. suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest"
  250. suite.API.Operations = map[string]*api.Operation{}
  251. for idx, c := range suite.Cases {
  252. c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1)
  253. suite.API.Operations[c.Given.ExportedName] = c.Given
  254. }
  255. suite.Type = getType(inout)
  256. suite.API.NoInitMethods = true // don't generate init methods
  257. suite.API.NoStringerMethods = true // don't generate stringer methods
  258. suite.API.NoConstServiceNames = true // don't generate service names
  259. suite.API.Setup()
  260. suite.API.Metadata.EndpointPrefix = suite.API.PackageName()
  261. // Sort in order for deterministic test generation
  262. names := make([]string, 0, len(suite.API.Shapes))
  263. for n := range suite.API.Shapes {
  264. names = append(names, n)
  265. }
  266. sort.Strings(names)
  267. for _, name := range names {
  268. s := suite.API.Shapes[name]
  269. s.Rename(svcPrefix + "TestShape" + name)
  270. }
  271. svcCode := addImports(suite.API.ServiceGoCode())
  272. if i == 0 {
  273. importMatch := reImportRemoval.FindStringSubmatch(svcCode)
  274. buf.WriteString(importMatch[0] + "\n\n")
  275. buf.WriteString(preamble + "\n\n")
  276. }
  277. svcCode = removeImports(svcCode)
  278. svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1)
  279. svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1)
  280. svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1)
  281. buf.WriteString(svcCode + "\n\n")
  282. apiCode := removeImports(suite.API.APIGoCode())
  283. apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1)
  284. apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1)
  285. apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1)
  286. buf.WriteString(apiCode + "\n\n")
  287. innerBuf.WriteString(suite.TestSuite() + "\n")
  288. }
  289. return buf.String() + innerBuf.String()
  290. }
  291. // findMember searches the shape for the member with the matching key name.
  292. func findMember(shape *api.Shape, key string) string {
  293. for actualKey := range shape.MemberRefs {
  294. if strings.ToLower(key) == strings.ToLower(actualKey) {
  295. return actualKey
  296. }
  297. }
  298. return ""
  299. }
  300. // GenerateAssertions builds assertions for a shape based on its type.
  301. //
  302. // The shape's recursive values also will have assertions generated for them.
  303. func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string {
  304. switch t := out.(type) {
  305. case map[string]interface{}:
  306. keys := util.SortedKeys(t)
  307. code := ""
  308. if shape.Type == "map" {
  309. for _, k := range keys {
  310. v := t[k]
  311. s := shape.ValueRef.Shape
  312. code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]")
  313. }
  314. } else {
  315. for _, k := range keys {
  316. v := t[k]
  317. m := findMember(shape, k)
  318. s := shape.MemberRefs[m].Shape
  319. code += GenerateAssertions(v, s, prefix+"."+m+"")
  320. }
  321. }
  322. return code
  323. case []interface{}:
  324. code := ""
  325. for i, v := range t {
  326. s := shape.MemberRef.Shape
  327. code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]")
  328. }
  329. return code
  330. default:
  331. switch shape.Type {
  332. case "timestamp":
  333. return fmt.Sprintf("assert.Equal(t, time.Unix(%#v, 0).UTC().String(), %s.String())\n", out, prefix)
  334. case "blob":
  335. return fmt.Sprintf("assert.Equal(t, %#v, string(%s))\n", out, prefix)
  336. case "integer", "long":
  337. return fmt.Sprintf("assert.Equal(t, int64(%#v), *%s)\n", out, prefix)
  338. default:
  339. if !reflect.ValueOf(out).IsValid() {
  340. return fmt.Sprintf("assert.Nil(t, %s)\n", prefix)
  341. }
  342. return fmt.Sprintf("assert.Equal(t, %#v, *%s)\n", out, prefix)
  343. }
  344. }
  345. }
  346. func getType(t string) uint {
  347. switch t {
  348. case "Input":
  349. return TestSuiteTypeInput
  350. case "Output":
  351. return TestSuiteTypeOutput
  352. default:
  353. panic("Invalid type for test suite")
  354. }
  355. }
  356. func main() {
  357. out := generateTestSuite(os.Args[1])
  358. if len(os.Args) == 3 {
  359. f, err := os.Create(os.Args[2])
  360. defer f.Close()
  361. if err != nil {
  362. panic(err)
  363. }
  364. f.WriteString(util.GoFmt(out))
  365. f.Close()
  366. c := exec.Command("gofmt", "-s", "-w", os.Args[2])
  367. if err := c.Run(); err != nil {
  368. panic(err)
  369. }
  370. } else {
  371. fmt.Println(out)
  372. }
  373. }