123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432 |
- // +build codegen
- package main
- import (
- "bytes"
- "encoding/json"
- "fmt"
- "net/url"
- "os"
- "os/exec"
- "reflect"
- "regexp"
- "sort"
- "strconv"
- "strings"
- "text/template"
- "github.com/aws/aws-sdk-go/private/model/api"
- "github.com/aws/aws-sdk-go/private/util"
- )
- // TestSuiteTypeInput input test
- // TestSuiteTypeInput output test
- const (
- TestSuiteTypeInput = iota
- TestSuiteTypeOutput
- )
- type testSuite struct {
- *api.API
- Description string
- Cases []testCase
- Type uint
- title string
- }
- type testCase struct {
- TestSuite *testSuite
- Given *api.Operation
- Params interface{} `json:",omitempty"`
- Data interface{} `json:"result,omitempty"`
- InputTest testExpectation `json:"serialized"`
- OutputTest testExpectation `json:"response"`
- }
- type testExpectation struct {
- Body string
- URI string
- Headers map[string]string
- StatusCode uint `json:"status_code"`
- }
- const preamble = `
- var _ bytes.Buffer // always import bytes
- var _ http.Request
- var _ json.Marshaler
- var _ time.Time
- var _ xmlutil.XMLNode
- var _ xml.Attr
- var _ = ioutil.Discard
- var _ = util.Trim("")
- var _ = url.Values{}
- var _ = io.EOF
- var _ = aws.String
- var _ = fmt.Println
- func init() {
- protocol.RandReader = &awstesting.ZeroReader{}
- }
- `
- var reStripSpace = regexp.MustCompile(`\s(\w)`)
- var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`)
- func removeImports(code string) string {
- return reImportRemoval.ReplaceAllString(code, "")
- }
- var extraImports = []string{
- "bytes",
- "encoding/json",
- "encoding/xml",
- "fmt",
- "io",
- "io/ioutil",
- "net/http",
- "testing",
- "time",
- "net/url",
- "",
- "github.com/aws/aws-sdk-go/awstesting",
- "github.com/aws/aws-sdk-go/awstesting/unit",
- "github.com/aws/aws-sdk-go/private/protocol",
- "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
- "github.com/aws/aws-sdk-go/private/util",
- "github.com/stretchr/testify/assert",
- }
- func addImports(code string) string {
- importNames := make([]string, len(extraImports))
- for i, n := range extraImports {
- if n != "" {
- importNames[i] = fmt.Sprintf("%q", n)
- }
- }
- str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)")
- return str
- }
- func (t *testSuite) TestSuite() string {
- var buf bytes.Buffer
- t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string {
- return strings.ToUpper(x[1:])
- })
- t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "")
- for idx, c := range t.Cases {
- c.TestSuite = t
- buf.WriteString(c.TestCase(idx) + "\n")
- }
- return buf.String()
- }
- var tplInputTestCase = template.Must(template.New("inputcase").Parse(`
- func Test{{ .OpName }}(t *testing.T) {
- svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
- {{ if ne .ParamsString "" }}input := {{ .ParamsString }}
- req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input){{ else }}req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(nil){{ end }}
- r := req.HTTPRequest
- // build request
- {{ .TestCase.TestSuite.API.ProtocolPackage }}.Build(req)
- assert.NoError(t, req.Error)
- {{ if ne .TestCase.InputTest.Body "" }}// assert body
- assert.NotNil(t, r.Body)
- {{ .BodyAssertions }}{{ end }}
- {{ if ne .TestCase.InputTest.URI "" }}// assert URL
- awstesting.AssertURL(t, "https://test{{ .TestCase.InputTest.URI }}", r.URL.String()){{ end }}
- // assert headers
- {{ range $k, $v := .TestCase.InputTest.Headers }}assert.Equal(t, "{{ $v }}", r.Header.Get("{{ $k }}"))
- {{ end }}
- }
- `))
- type tplInputTestCaseData struct {
- TestCase *testCase
- OpName, ParamsString string
- }
- func (t tplInputTestCaseData) BodyAssertions() string {
- code := &bytes.Buffer{}
- protocol := t.TestCase.TestSuite.API.Metadata.Protocol
- // Extract the body bytes
- switch protocol {
- case "rest-xml":
- fmt.Fprintln(code, "body := util.SortXML(r.Body)")
- default:
- fmt.Fprintln(code, "body, _ := ioutil.ReadAll(r.Body)")
- }
- // Generate the body verification code
- expectedBody := util.Trim(t.TestCase.InputTest.Body)
- switch protocol {
- case "ec2", "query":
- fmt.Fprintf(code, "awstesting.AssertQuery(t, `%s`, util.Trim(string(body)))",
- expectedBody)
- case "rest-xml":
- if strings.HasPrefix(expectedBody, "<") {
- fmt.Fprintf(code, "awstesting.AssertXML(t, `%s`, util.Trim(string(body)), %s{})",
- expectedBody, t.TestCase.Given.InputRef.ShapeName)
- } else {
- fmt.Fprintf(code, "assert.Equal(t, `%s`, util.Trim(string(body)))",
- expectedBody)
- }
- case "json", "jsonrpc", "rest-json":
- if strings.HasPrefix(expectedBody, "{") {
- fmt.Fprintf(code, "awstesting.AssertJSON(t, `%s`, util.Trim(string(body)))",
- expectedBody)
- } else {
- fmt.Fprintf(code, "assert.Equal(t, `%s`, util.Trim(string(body)))",
- expectedBody)
- }
- default:
- fmt.Fprintf(code, "assert.Equal(t, `%s`, util.Trim(string(body)))",
- expectedBody)
- }
- return code.String()
- }
- var tplOutputTestCase = template.Must(template.New("outputcase").Parse(`
- func Test{{ .OpName }}(t *testing.T) {
- svc := New{{ .TestCase.TestSuite.API.StructName }}(unit.Session, &aws.Config{Endpoint: aws.String("https://test")})
- buf := bytes.NewReader([]byte({{ .Body }}))
- req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil)
- req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
- // set headers
- {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}")
- {{ end }}
- // unmarshal response
- {{ .TestCase.TestSuite.API.ProtocolPackage }}.UnmarshalMeta(req)
- {{ .TestCase.TestSuite.API.ProtocolPackage }}.Unmarshal(req)
- assert.NoError(t, req.Error)
- // assert response
- assert.NotNil(t, out) // ensure out variable is used
- {{ .Assertions }}
- }
- `))
- type tplOutputTestCaseData struct {
- TestCase *testCase
- Body, OpName, Assertions string
- }
- func (i *testCase) TestCase(idx int) string {
- var buf bytes.Buffer
- opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1)
- if i.TestSuite.Type == TestSuiteTypeInput { // input test
- // query test should sort body as form encoded values
- switch i.TestSuite.API.Metadata.Protocol {
- case "query", "ec2":
- m, _ := url.ParseQuery(i.InputTest.Body)
- i.InputTest.Body = m.Encode()
- case "rest-xml":
- i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body)))
- case "json", "rest-json":
- i.InputTest.Body = strings.Replace(i.InputTest.Body, " ", "", -1)
- }
- input := tplInputTestCaseData{
- TestCase: i,
- OpName: strings.ToUpper(opName[0:1]) + opName[1:],
- ParamsString: api.ParamsStructFromJSON(i.Params, i.Given.InputRef.Shape, false),
- }
- if err := tplInputTestCase.Execute(&buf, input); err != nil {
- panic(err)
- }
- } else if i.TestSuite.Type == TestSuiteTypeOutput {
- output := tplOutputTestCaseData{
- TestCase: i,
- Body: fmt.Sprintf("%q", i.OutputTest.Body),
- OpName: strings.ToUpper(opName[0:1]) + opName[1:],
- Assertions: GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"),
- }
- if err := tplOutputTestCase.Execute(&buf, output); err != nil {
- panic(err)
- }
- }
- return buf.String()
- }
- // generateTestSuite generates a protocol test suite for a given configuration
- // JSON protocol test file.
- func generateTestSuite(filename string) string {
- inout := "Input"
- if strings.Contains(filename, "output/") {
- inout = "Output"
- }
- var suites []testSuite
- f, err := os.Open(filename)
- if err != nil {
- panic(err)
- }
- err = json.NewDecoder(f).Decode(&suites)
- if err != nil {
- panic(err)
- }
- var buf bytes.Buffer
- buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n")
- var innerBuf bytes.Buffer
- innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n")
- for i, suite := range suites {
- svcPrefix := inout + "Service" + strconv.Itoa(i+1)
- suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest"
- suite.API.Operations = map[string]*api.Operation{}
- for idx, c := range suite.Cases {
- c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1)
- suite.API.Operations[c.Given.ExportedName] = c.Given
- }
- suite.Type = getType(inout)
- suite.API.NoInitMethods = true // don't generate init methods
- suite.API.NoStringerMethods = true // don't generate stringer methods
- suite.API.NoConstServiceNames = true // don't generate service names
- suite.API.Setup()
- suite.API.Metadata.EndpointPrefix = suite.API.PackageName()
- // Sort in order for deterministic test generation
- names := make([]string, 0, len(suite.API.Shapes))
- for n := range suite.API.Shapes {
- names = append(names, n)
- }
- sort.Strings(names)
- for _, name := range names {
- s := suite.API.Shapes[name]
- s.Rename(svcPrefix + "TestShape" + name)
- }
- svcCode := addImports(suite.API.ServiceGoCode())
- if i == 0 {
- importMatch := reImportRemoval.FindStringSubmatch(svcCode)
- buf.WriteString(importMatch[0] + "\n\n")
- buf.WriteString(preamble + "\n\n")
- }
- svcCode = removeImports(svcCode)
- svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1)
- svcCode = strings.Replace(svcCode, "func newClient(", "func new"+suite.API.StructName()+"Client(", -1)
- svcCode = strings.Replace(svcCode, "return newClient(", "return new"+suite.API.StructName()+"Client(", -1)
- buf.WriteString(svcCode + "\n\n")
- apiCode := removeImports(suite.API.APIGoCode())
- apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1)
- apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1)
- apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1)
- buf.WriteString(apiCode + "\n\n")
- innerBuf.WriteString(suite.TestSuite() + "\n")
- }
- return buf.String() + innerBuf.String()
- }
- // findMember searches the shape for the member with the matching key name.
- func findMember(shape *api.Shape, key string) string {
- for actualKey := range shape.MemberRefs {
- if strings.ToLower(key) == strings.ToLower(actualKey) {
- return actualKey
- }
- }
- return ""
- }
- // GenerateAssertions builds assertions for a shape based on its type.
- //
- // The shape's recursive values also will have assertions generated for them.
- func GenerateAssertions(out interface{}, shape *api.Shape, prefix string) string {
- switch t := out.(type) {
- case map[string]interface{}:
- keys := util.SortedKeys(t)
- code := ""
- if shape.Type == "map" {
- for _, k := range keys {
- v := t[k]
- s := shape.ValueRef.Shape
- code += GenerateAssertions(v, s, prefix+"[\""+k+"\"]")
- }
- } else {
- for _, k := range keys {
- v := t[k]
- m := findMember(shape, k)
- s := shape.MemberRefs[m].Shape
- code += GenerateAssertions(v, s, prefix+"."+m+"")
- }
- }
- return code
- case []interface{}:
- code := ""
- for i, v := range t {
- s := shape.MemberRef.Shape
- code += GenerateAssertions(v, s, prefix+"["+strconv.Itoa(i)+"]")
- }
- return code
- default:
- switch shape.Type {
- case "timestamp":
- return fmt.Sprintf("assert.Equal(t, time.Unix(%#v, 0).UTC().String(), %s.String())\n", out, prefix)
- case "blob":
- return fmt.Sprintf("assert.Equal(t, %#v, string(%s))\n", out, prefix)
- case "integer", "long":
- return fmt.Sprintf("assert.Equal(t, int64(%#v), *%s)\n", out, prefix)
- default:
- if !reflect.ValueOf(out).IsValid() {
- return fmt.Sprintf("assert.Nil(t, %s)\n", prefix)
- }
- return fmt.Sprintf("assert.Equal(t, %#v, *%s)\n", out, prefix)
- }
- }
- }
- func getType(t string) uint {
- switch t {
- case "Input":
- return TestSuiteTypeInput
- case "Output":
- return TestSuiteTypeOutput
- default:
- panic("Invalid type for test suite")
- }
- }
- func main() {
- out := generateTestSuite(os.Args[1])
- if len(os.Args) == 3 {
- f, err := os.Create(os.Args[2])
- defer f.Close()
- if err != nil {
- panic(err)
- }
- f.WriteString(util.GoFmt(out))
- f.Close()
- c := exec.Command("gofmt", "-s", "-w", os.Args[2])
- if err := c.Run(); err != nil {
- panic(err)
- }
- } else {
- fmt.Println(out)
- }
- }
|