123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- package main
- import (
- "bytes"
- "encoding/json"
- "fmt"
- "net/url"
- "os"
- "os/exec"
- "regexp"
- "sort"
- "strconv"
- "strings"
- "text/template"
- "github.com/aws/aws-sdk-go/internal/fixtures/helpers"
- "github.com/aws/aws-sdk-go/internal/model/api"
- "github.com/aws/aws-sdk-go/internal/util"
- "github.com/aws/aws-sdk-go/internal/util/utilassert"
- )
- type testSuite struct {
- *api.API
- Description string
- Cases []testCase
- title string
- }
- type testCase struct {
- TestSuite *testSuite
- Given *api.Operation
- Params interface{} `json:",omitempty"`
- Data interface{} `json:"result,omitempty"`
- InputTest testExpectation `json:"serialized"`
- OutputTest testExpectation `json:"response"`
- }
- type testExpectation struct {
- Body string
- URI string
- Headers map[string]string
- StatusCode uint `json:"status_code"`
- }
- const preamble = `
- var _ bytes.Buffer // always import bytes
- var _ http.Request
- var _ json.Marshaler
- var _ time.Time
- var _ xmlutil.XMLNode
- var _ xml.Attr
- var _ = ioutil.Discard
- var _ = util.Trim("")
- var _ = url.Values{}
- var _ = io.EOF
- `
- var reStripSpace = regexp.MustCompile(`\s(\w)`)
- var reImportRemoval = regexp.MustCompile(`(?s:import \((.+?)\))`)
- func removeImports(code string) string {
- return reImportRemoval.ReplaceAllString(code, "")
- }
- var extraImports = []string{
- "bytes",
- "encoding/json",
- "encoding/xml",
- "io",
- "io/ioutil",
- "net/http",
- "testing",
- "time",
- "net/url",
- "",
- "github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil",
- "github.com/aws/aws-sdk-go/internal/util",
- "github.com/stretchr/testify/assert",
- }
- func addImports(code string) string {
- importNames := make([]string, len(extraImports))
- for i, n := range extraImports {
- if n != "" {
- importNames[i] = fmt.Sprintf("%q", n)
- }
- }
- str := reImportRemoval.ReplaceAllString(code, "import (\n"+strings.Join(importNames, "\n")+"$1\n)")
- return str
- }
- func (t *testSuite) TestSuite() string {
- var buf bytes.Buffer
- t.title = reStripSpace.ReplaceAllStringFunc(t.Description, func(x string) string {
- return strings.ToUpper(x[1:])
- })
- t.title = regexp.MustCompile(`\W`).ReplaceAllString(t.title, "")
- for idx, c := range t.Cases {
- c.TestSuite = t
- buf.WriteString(c.TestCase(idx) + "\n")
- }
- return buf.String()
- }
- var tplInputTestCase = template.Must(template.New("inputcase").Parse(`
- func Test{{ .OpName }}(t *testing.T) {
- svc := New{{ .TestCase.TestSuite.API.StructName }}(nil)
- svc.Endpoint = "https://test"
- input := {{ .ParamsString }}
- req, _ := svc.{{ .TestCase.Given.ExportedName }}Request(input)
- r := req.HTTPRequest
- // build request
- {{ .TestCase.TestSuite.API.ProtocolPackage }}.Build(req)
- assert.NoError(t, req.Error)
- {{ if ne .TestCase.InputTest.Body "" }}// assert body
- assert.NotNil(t, r.Body)
- {{ .BodyAssertions }}{{ end }}
- {{ if ne .TestCase.InputTest.URI "" }}// assert URL
- assert.Equal(t, "https://test{{ .TestCase.InputTest.URI }}", r.URL.String()){{ end }}
- // assert headers
- {{ range $k, $v := .TestCase.InputTest.Headers }}assert.Equal(t, "{{ $v }}", r.Header.Get("{{ $k }}"))
- {{ end }}
- }
- `))
- type tplInputTestCaseData struct {
- TestCase *testCase
- OpName, ParamsString string
- }
- func (t tplInputTestCaseData) BodyAssertions() string {
- protocol, code := t.TestCase.TestSuite.API.Metadata.Protocol, ""
- switch protocol {
- case "rest-xml":
- code += "body := util.SortXML(r.Body)\n"
- default:
- code += "body, _ := ioutil.ReadAll(r.Body)\n"
- }
- code += "assert.Equal(t, util.Trim(`" + t.TestCase.InputTest.Body + "`), util.Trim(string(body)))"
- return code
- }
- var tplOutputTestCase = template.Must(template.New("outputcase").Parse(`
- func Test{{ .OpName }}(t *testing.T) {
- svc := New{{ .TestCase.TestSuite.API.StructName }}(nil)
- buf := bytes.NewReader([]byte({{ .Body }}))
- req, out := svc.{{ .TestCase.Given.ExportedName }}Request(nil)
- req.HTTPResponse = &http.Response{StatusCode: 200, Body: ioutil.NopCloser(buf), Header: http.Header{}}
- // set headers
- {{ range $k, $v := .TestCase.OutputTest.Headers }}req.HTTPResponse.Header.Set("{{ $k }}", "{{ $v }}")
- {{ end }}
- // unmarshal response
- {{ .TestCase.TestSuite.API.ProtocolPackage }}.UnmarshalMeta(req)
- {{ .TestCase.TestSuite.API.ProtocolPackage }}.Unmarshal(req)
- assert.NoError(t, req.Error)
- // assert response
- assert.NotNil(t, out) // ensure out variable is used
- {{ .Assertions }}
- }
- `))
- type tplOutputTestCaseData struct {
- TestCase *testCase
- Body, OpName, Assertions string
- }
- func (i *testCase) TestCase(idx int) string {
- var buf bytes.Buffer
- opName := i.TestSuite.API.StructName() + i.TestSuite.title + "Case" + strconv.Itoa(idx+1)
- if i.Params != nil { // input test
- // query test should sort body as form encoded values
- switch i.TestSuite.API.Metadata.Protocol {
- case "query", "ec2":
- m, _ := url.ParseQuery(i.InputTest.Body)
- i.InputTest.Body = m.Encode()
- case "rest-xml":
- i.InputTest.Body = util.SortXML(bytes.NewReader([]byte(i.InputTest.Body)))
- case "json", "rest-json":
- i.InputTest.Body = strings.Replace(i.InputTest.Body, " ", "", -1)
- }
- input := tplInputTestCaseData{
- TestCase: i,
- OpName: strings.ToUpper(opName[0:1]) + opName[1:],
- ParamsString: helpers.ParamsStructFromJSON(i.Params, i.Given.InputRef.Shape, false),
- }
- if err := tplInputTestCase.Execute(&buf, input); err != nil {
- panic(err)
- }
- } else {
- output := tplOutputTestCaseData{
- TestCase: i,
- Body: fmt.Sprintf("%q", i.OutputTest.Body),
- OpName: strings.ToUpper(opName[0:1]) + opName[1:],
- Assertions: utilassert.GenerateAssertions(i.Data, i.Given.OutputRef.Shape, "out"),
- }
- if err := tplOutputTestCase.Execute(&buf, output); err != nil {
- panic(err)
- }
- }
- return buf.String()
- }
- // generateTestSuite generates a protocol test suite for a given configuration
- // JSON protocol test file.
- func generateTestSuite(filename string) string {
- inout := "Input"
- if strings.Contains(filename, "output/") {
- inout = "Output"
- }
- var suites []testSuite
- f, err := os.Open(filename)
- if err != nil {
- panic(err)
- }
- err = json.NewDecoder(f).Decode(&suites)
- if err != nil {
- panic(err)
- }
- var buf bytes.Buffer
- buf.WriteString("package " + suites[0].ProtocolPackage() + "_test\n\n")
- var innerBuf bytes.Buffer
- innerBuf.WriteString("//\n// Tests begin here\n//\n\n\n")
- for i, suite := range suites {
- svcPrefix := inout + "Service" + strconv.Itoa(i+1)
- suite.API.Metadata.ServiceAbbreviation = svcPrefix + "ProtocolTest"
- suite.API.Operations = map[string]*api.Operation{}
- for idx, c := range suite.Cases {
- c.Given.ExportedName = svcPrefix + "TestCaseOperation" + strconv.Itoa(idx+1)
- suite.API.Operations[c.Given.ExportedName] = c.Given
- }
- suite.API.NoInflections = true // don't require inflections
- suite.API.NoInitMethods = true // don't generate init methods
- suite.API.NoStringerMethods = true // don't generate stringer methods
- suite.API.Setup()
- suite.API.Metadata.EndpointPrefix = suite.API.PackageName()
- // Sort in order for deterministic test generation
- names := make([]string, 0, len(suite.API.Shapes))
- for n := range suite.API.Shapes {
- names = append(names, n)
- }
- sort.Strings(names)
- for _, name := range names {
- s := suite.API.Shapes[name]
- s.Rename(svcPrefix + "TestShape" + name)
- }
- svcCode := addImports(suite.API.ServiceGoCode())
- if i == 0 {
- importMatch := reImportRemoval.FindStringSubmatch(svcCode)
- buf.WriteString(importMatch[0] + "\n\n")
- buf.WriteString(preamble + "\n\n")
- }
- svcCode = removeImports(svcCode)
- svcCode = strings.Replace(svcCode, "func New(", "func New"+suite.API.StructName()+"(", -1)
- buf.WriteString(svcCode + "\n\n")
- apiCode := removeImports(suite.API.APIGoCode())
- apiCode = strings.Replace(apiCode, "var oprw sync.Mutex", "", -1)
- apiCode = strings.Replace(apiCode, "oprw.Lock()", "", -1)
- apiCode = strings.Replace(apiCode, "defer oprw.Unlock()", "", -1)
- buf.WriteString(apiCode + "\n\n")
- innerBuf.WriteString(suite.TestSuite() + "\n")
- }
- return buf.String() + innerBuf.String()
- }
- func main() {
- out := generateTestSuite(os.Args[1])
- if len(os.Args) == 3 {
- f, err := os.Create(os.Args[2])
- defer f.Close()
- if err != nil {
- panic(err)
- }
- f.WriteString(util.GoFmt(out))
- f.Close()
- c := exec.Command("gofmt", "-s", "-w", os.Args[2])
- if err := c.Run(); err != nil {
- panic(err)
- }
- } else {
- fmt.Println(out)
- }
- }
|