Commit 28968144 by Mitsuhiro Tanda

update aws-sdk-go v1.0.0

parent 4c5cfd51
...@@ -20,53 +20,53 @@ ...@@ -20,53 +20,53 @@
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/aws", "ImportPath": "github.com/aws/aws-sdk-go/aws",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/endpoints", "ImportPath": "github.com/aws/aws-sdk-go/private/endpoints",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/ec2query", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/ec2query",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/query",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/rest", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/rest",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil", "ImportPath": "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/signer/v4", "ImportPath": "github.com/aws/aws-sdk-go/private/signer/v4",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/private/waiter", "ImportPath": "github.com/aws/aws-sdk-go/private/waiter",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/service/cloudwatch", "ImportPath": "github.com/aws/aws-sdk-go/service/cloudwatch",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/aws/aws-sdk-go/service/ec2", "ImportPath": "github.com/aws/aws-sdk-go/service/ec2",
"Comment": "v0.10.4-18-gce51895", "Comment": "v1.0.0",
"Rev": "ce51895e994693d65ab997ae48032bf13a9290b7" "Rev": "abb928e07c4108683d6b4d0b6ca08fe6bc0eee5f"
}, },
{ {
"ImportPath": "github.com/davecgh/go-spew/spew", "ImportPath": "github.com/davecgh/go-spew/spew",
......
...@@ -13,11 +13,11 @@ var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`) ...@@ -13,11 +13,11 @@ var indexRe = regexp.MustCompile(`(.+)\[(-?\d+)?\]$`)
// rValuesAtPath returns a slice of values found in value v. The values // rValuesAtPath returns a slice of values found in value v. The values
// in v are explored recursively so all nested values are collected. // in v are explored recursively so all nested values are collected.
func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool) []reflect.Value { func rValuesAtPath(v interface{}, path string, createPath, caseSensitive, nilTerm bool) []reflect.Value {
pathparts := strings.Split(path, "||") pathparts := strings.Split(path, "||")
if len(pathparts) > 1 { if len(pathparts) > 1 {
for _, pathpart := range pathparts { for _, pathpart := range pathparts {
vals := rValuesAtPath(v, pathpart, create, caseSensitive) vals := rValuesAtPath(v, pathpart, createPath, caseSensitive, nilTerm)
if len(vals) > 0 { if len(vals) > 0 {
return vals return vals
} }
...@@ -76,7 +76,16 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool) ...@@ -76,7 +76,16 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
return false return false
}) })
if create && value.Kind() == reflect.Ptr && value.IsNil() { if nilTerm && value.Kind() == reflect.Ptr && len(components[1:]) == 0 {
if !value.IsNil() {
value.Set(reflect.Zero(value.Type()))
}
return []reflect.Value{value}
}
if createPath && value.Kind() == reflect.Ptr && value.IsNil() {
// TODO if the value is the terminus it should not be created
// if the value to be set to its position is nil.
value.Set(reflect.New(value.Type().Elem())) value.Set(reflect.New(value.Type().Elem()))
value = value.Elem() value = value.Elem()
} else { } else {
...@@ -84,7 +93,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool) ...@@ -84,7 +93,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
} }
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map { if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !create && value.IsNil() { if !createPath && value.IsNil() {
value = reflect.ValueOf(nil) value = reflect.ValueOf(nil)
} }
} }
...@@ -116,7 +125,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool) ...@@ -116,7 +125,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
// pull out index // pull out index
i := int(*index) i := int(*index)
if i >= value.Len() { // check out of bounds if i >= value.Len() { // check out of bounds
if create { if createPath {
// TODO resize slice // TODO resize slice
} else { } else {
continue continue
...@@ -127,7 +136,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool) ...@@ -127,7 +136,7 @@ func rValuesAtPath(v interface{}, path string, create bool, caseSensitive bool)
value = reflect.Indirect(value.Index(i)) value = reflect.Indirect(value.Index(i))
if value.Kind() == reflect.Slice || value.Kind() == reflect.Map { if value.Kind() == reflect.Slice || value.Kind() == reflect.Map {
if !create && value.IsNil() { if !createPath && value.IsNil() {
value = reflect.ValueOf(nil) value = reflect.ValueOf(nil)
} }
} }
...@@ -176,8 +185,11 @@ func ValuesAtPath(i interface{}, path string) ([]interface{}, error) { ...@@ -176,8 +185,11 @@ func ValuesAtPath(i interface{}, path string) ([]interface{}, error) {
// SetValueAtPath sets a value at the case insensitive lexical path inside // SetValueAtPath sets a value at the case insensitive lexical path inside
// of a structure. // of a structure.
func SetValueAtPath(i interface{}, path string, v interface{}) { func SetValueAtPath(i interface{}, path string, v interface{}) {
if rvals := rValuesAtPath(i, path, true, false); rvals != nil { if rvals := rValuesAtPath(i, path, true, false, v == nil); rvals != nil {
for _, rval := range rvals { for _, rval := range rvals {
if rval.Kind() == reflect.Ptr && rval.IsNil() {
continue
}
setValue(rval, v) setValue(rval, v)
} }
} }
......
...@@ -105,4 +105,38 @@ func TestSetValueAtPathSuccess(t *testing.T) { ...@@ -105,4 +105,38 @@ func TestSetValueAtPathSuccess(t *testing.T) {
assert.Equal(t, "test0", s2.B.B.C) assert.Equal(t, "test0", s2.B.B.C)
awsutil.SetValueAtPath(&s2, "A", []Struct{{}}) awsutil.SetValueAtPath(&s2, "A", []Struct{{}})
assert.Equal(t, []Struct{{}}, s2.A) assert.Equal(t, []Struct{{}}, s2.A)
str := "foo"
s3 := Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", str)
assert.Equal(t, "foo", s3.B.B.C)
s3 = Struct{B: &Struct{B: &Struct{C: str}}}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", &str)
assert.Equal(t, "foo", s3.B.B.C)
var s4 struct{ Name *string }
awsutil.SetValueAtPath(&s4, "Name", str)
assert.Equal(t, str, *s4.Name)
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
s4 = struct{ Name *string }{Name: &str}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", &str)
assert.Equal(t, str, *s4.Name)
} }
...@@ -41,11 +41,20 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op ...@@ -41,11 +41,20 @@ func New(cfg aws.Config, info metadata.ClientInfo, handlers request.Handlers, op
Handlers: handlers, Handlers: handlers,
} }
maxRetries := aws.IntValue(cfg.MaxRetries) switch retryer, ok := cfg.Retryer.(request.Retryer); {
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries { case ok:
maxRetries = 3 svc.Retryer = retryer
case cfg.Retryer != nil && cfg.Logger != nil:
s := fmt.Sprintf("WARNING: %T does not implement request.Retryer; using DefaultRetryer instead", cfg.Retryer)
cfg.Logger.Log(s)
fallthrough
default:
maxRetries := aws.IntValue(cfg.MaxRetries)
if cfg.MaxRetries == nil || maxRetries == aws.UseServiceDefaultRetries {
maxRetries = 3
}
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
} }
svc.Retryer = DefaultRetryer{NumMaxRetries: maxRetries}
svc.AddDebugHandlers() svc.AddDebugHandlers()
......
...@@ -12,6 +12,9 @@ import ( ...@@ -12,6 +12,9 @@ import (
// is nil also. // is nil also.
const UseServiceDefaultRetries = -1 const UseServiceDefaultRetries = -1
// RequestRetryer is an alias for a type that implements the request.Retryer interface.
type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default, // A Config provides service configuration for service clients. By default,
// all clients will use the {defaults.DefaultConfig} structure. // all clients will use the {defaults.DefaultConfig} structure.
type Config struct { type Config struct {
...@@ -59,6 +62,21 @@ type Config struct { ...@@ -59,6 +62,21 @@ type Config struct {
// configuration. // configuration.
MaxRetries *int MaxRetries *int
// Retryer guides how HTTP requests should be retried in case of recoverable failures.
//
// When nil or the value does not implement the request.Retryer interface,
// the request.DefaultRetryer will be used.
//
// When both Retryer and MaxRetries are non-nil, the former is used and
// the latter ignored.
//
// To set the Retryer field in a type-safe manner and with chaining, use
// the request.WithRetryer helper function:
//
// cfg := request.WithRetryer(aws.NewConfig(), myRetryer)
//
Retryer RequestRetryer
// Disables semantic parameter validation, which validates input for missing // Disables semantic parameter validation, which validates input for missing
// required fields and/or other semantic request input errors. // required fields and/or other semantic request input errors.
DisableParamValidation *bool DisableParamValidation *bool
...@@ -217,6 +235,10 @@ func mergeInConfig(dst *Config, other *Config) { ...@@ -217,6 +235,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.MaxRetries = other.MaxRetries dst.MaxRetries = other.MaxRetries
} }
if other.Retryer != nil {
dst.Retryer = other.Retryer
}
if other.DisableParamValidation != nil { if other.DisableParamValidation != nil {
dst.DisableParamValidation = other.DisableParamValidation dst.DisableParamValidation = other.DisableParamValidation
} }
......
...@@ -44,12 +44,19 @@ func (r *Request) nextPageTokens() []interface{} { ...@@ -44,12 +44,19 @@ func (r *Request) nextPageTokens() []interface{} {
} }
tokens := []interface{}{} tokens := []interface{}{}
tokenAdded := false
for _, outToken := range r.Operation.OutputTokens { for _, outToken := range r.Operation.OutputTokens {
v, _ := awsutil.ValuesAtPath(r.Data, outToken) v, _ := awsutil.ValuesAtPath(r.Data, outToken)
if len(v) > 0 { if len(v) > 0 {
tokens = append(tokens, v[0]) tokens = append(tokens, v[0])
tokenAdded = true
} else {
tokens = append(tokens, nil)
} }
} }
if !tokenAdded {
return nil
}
return tokens return tokens
} }
...@@ -85,9 +92,10 @@ func (r *Request) NextPage() *Request { ...@@ -85,9 +92,10 @@ func (r *Request) NextPage() *Request {
// return true to keep iterating or false to stop. // return true to keep iterating or false to stop.
func (r *Request) EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error { func (r *Request) EachPage(fn func(data interface{}, isLastPage bool) (shouldContinue bool)) error {
for page := r; page != nil; page = page.NextPage() { for page := r; page != nil; page = page.NextPage() {
page.Send() if err := page.Send(); err != nil {
shouldContinue := fn(page.Data, !page.HasNextPage()) return err
if page.Error != nil || !shouldContinue { }
if getNextPage := fn(page.Data, !page.HasNextPage()); !getNextPage {
return page.Error return page.Error
} }
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit" "github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/route53"
"github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3"
) )
...@@ -314,7 +315,69 @@ func TestPaginationTruncation(t *testing.T) { ...@@ -314,7 +315,69 @@ func TestPaginationTruncation(t *testing.T) {
assert.Equal(t, []string{"Key1", "Key2"}, results) assert.Equal(t, []string{"Key1", "Key2"}, results)
assert.Nil(t, err) assert.Nil(t, err)
}
func TestPaginationNilToken(t *testing.T) {
client := route53.New(unit.Session)
reqNum := 0
resps := []*route53.ListResourceRecordSetsOutput{
{
ResourceRecordSets: []*route53.ResourceRecordSet{
{Name: aws.String("first.example.com.")},
},
IsTruncated: aws.Bool(true),
NextRecordName: aws.String("second.example.com."),
NextRecordType: aws.String("MX"),
NextRecordIdentifier: aws.String("second"),
MaxItems: aws.String("1"),
},
{
ResourceRecordSets: []*route53.ResourceRecordSet{
{Name: aws.String("second.example.com.")},
},
IsTruncated: aws.Bool(true),
NextRecordName: aws.String("third.example.com."),
NextRecordType: aws.String("MX"),
MaxItems: aws.String("1"),
},
{
ResourceRecordSets: []*route53.ResourceRecordSet{
{Name: aws.String("third.example.com.")},
},
IsTruncated: aws.Bool(false),
MaxItems: aws.String("1"),
},
}
client.Handlers.Send.Clear() // mock sending
client.Handlers.Unmarshal.Clear()
client.Handlers.UnmarshalMeta.Clear()
client.Handlers.ValidateResponse.Clear()
idents := []string{}
client.Handlers.Build.PushBack(func(r *request.Request) {
p := r.Params.(*route53.ListResourceRecordSetsInput)
idents = append(idents, aws.StringValue(p.StartRecordIdentifier))
})
client.Handlers.Unmarshal.PushBack(func(r *request.Request) {
r.Data = resps[reqNum]
reqNum++
})
params := &route53.ListResourceRecordSetsInput{
HostedZoneId: aws.String("id-zone"),
}
results := []string{}
err := client.ListResourceRecordSetsPages(params, func(p *route53.ListResourceRecordSetsOutput, last bool) bool {
results = append(results, *p.ResourceRecordSets[0].Name)
return true
})
assert.NoError(t, err)
assert.Equal(t, []string{"", "second", ""}, idents)
assert.Equal(t, []string{"first.example.com.", "second.example.com.", "third.example.com."}, results)
} }
// Benchmarks // Benchmarks
......
...@@ -3,6 +3,7 @@ package request ...@@ -3,6 +3,7 @@ package request
import ( import (
"time" "time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
) )
...@@ -15,6 +16,13 @@ type Retryer interface { ...@@ -15,6 +16,13 @@ type Retryer interface {
MaxRetries() int MaxRetries() int
} }
// WithRetryer sets a config Retryer value to the given Config returning it
// for chaining.
func WithRetryer(cfg *aws.Config, retryer Retryer) *aws.Config {
cfg.Retryer = retryer
return cfg
}
// retryableCodes is a collection of service response codes which are retry-able // retryableCodes is a collection of service response codes which are retry-able
// without any further action. // without any further action.
var retryableCodes = map[string]struct{}{ var retryableCodes = map[string]struct{}{
......
...@@ -5,4 +5,4 @@ package aws ...@@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go" const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK // SDKVersion is the version of this SDK
const SDKVersion = "0.10.4" const SDKVersion = "1.0.0"
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil" "github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
...@@ -47,52 +48,74 @@ func (w *Waiter) Wait() error { ...@@ -47,52 +48,74 @@ func (w *Waiter) Wait() error {
res := method.Call([]reflect.Value{in}) res := method.Call([]reflect.Value{in})
req := res[0].Interface().(*request.Request) req := res[0].Interface().(*request.Request)
req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("Waiter")) req.Handlers.Build.PushBack(request.MakeAddToUserAgentFreeFormHandler("Waiter"))
if err := req.Send(); err != nil {
return err
}
err := req.Send()
for _, a := range w.Acceptors { for _, a := range w.Acceptors {
if err != nil && a.Matcher != "error" {
// Only matcher error is valid if there is a request error
continue
}
result := false result := false
var vals []interface{}
switch a.Matcher { switch a.Matcher {
case "pathAll": case "pathAll", "path":
if vals, _ := awsutil.ValuesAtPath(req.Data, a.Argument); req.Error == nil && vals != nil { // Require all matches to be equal for result to match
result = true vals, _ = awsutil.ValuesAtPath(req.Data, a.Argument)
for _, val := range vals { result = true
if !awsutil.DeepEqual(val, a.Expected) { for _, val := range vals {
result = false if !awsutil.DeepEqual(val, a.Expected) {
break result = false
} break
} }
} }
case "pathAny": case "pathAny":
if vals, _ := awsutil.ValuesAtPath(req.Data, a.Argument); req.Error == nil && vals != nil { // Only a single match needs to equal for the result to match
for _, val := range vals { vals, _ = awsutil.ValuesAtPath(req.Data, a.Argument)
if awsutil.DeepEqual(val, a.Expected) { for _, val := range vals {
result = true if awsutil.DeepEqual(val, a.Expected) {
break result = true
} break
} }
} }
case "status": case "status":
s := a.Expected.(int) s := a.Expected.(int)
result = s == req.HTTPResponse.StatusCode result = s == req.HTTPResponse.StatusCode
case "error":
if aerr, ok := err.(awserr.Error); ok {
result = aerr.Code() == a.Expected.(string)
}
case "pathList":
// ignored matcher
default:
logf(client, "WARNING: Waiter for %s encountered unexpected matcher: %s",
w.Config.Operation, a.Matcher)
} }
if result { if !result {
switch a.State { // If there was no matching result found there is nothing more to do
case "success": // for this response, retry the request.
return nil // waiter completed continue
case "failure":
if req.Error == nil {
return awserr.New("ResourceNotReady",
fmt.Sprintf("failed waiting for successful resource state"), nil)
}
return req.Error // waiter failed
case "retry":
// do nothing, just retry
}
break
} }
switch a.State {
case "success":
// waiter completed
return nil
case "failure":
// Waiter failure state triggered
return awserr.New("ResourceNotReady",
fmt.Sprintf("failed waiting for successful resource state"), err)
case "retry":
// clear the error and retry the operation
err = nil
default:
logf(client, "WARNING: Waiter for %s encountered unexpected state: %s",
w.Config.Operation, a.State)
}
}
if err != nil {
return err
} }
time.Sleep(time.Second * time.Duration(w.Delay)) time.Sleep(time.Second * time.Duration(w.Delay))
...@@ -101,3 +124,13 @@ func (w *Waiter) Wait() error { ...@@ -101,3 +124,13 @@ func (w *Waiter) Wait() error {
return awserr.New("ResourceNotReady", return awserr.New("ResourceNotReady",
fmt.Sprintf("exceeded %d wait attempts", w.MaxAttempts), nil) fmt.Sprintf("exceeded %d wait attempts", w.MaxAttempts), nil)
} }
func logf(client reflect.Value, msg string, args ...interface{}) {
cfgVal := client.FieldByName("Config")
if !cfgVal.IsValid() {
return
}
if cfg, ok := cfgVal.Interface().(*aws.Config); ok && cfg.Logger != nil {
cfg.Logger.Log(fmt.Sprintf(msg, args...))
}
}
package waiter_test package waiter_test
import ( import (
"bytes"
"io/ioutil"
"net/http"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
...@@ -41,22 +44,76 @@ func (c *mockClient) MockRequest(input *MockInput) (*request.Request, *MockOutpu ...@@ -41,22 +44,76 @@ func (c *mockClient) MockRequest(input *MockInput) (*request.Request, *MockOutpu
return req, output return req, output
} }
var mockAcceptors = []waiter.WaitAcceptor{ func TestWaiterPathAll(t *testing.T) {
{ svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
State: "success", Region: aws.String("mock-region"),
Matcher: "pathAll", })}
Argument: "States[].State", svc.Handlers.Send.Clear() // mock sending
Expected: "running", svc.Handlers.Unmarshal.Clear()
}, svc.Handlers.UnmarshalMeta.Clear()
{ svc.Handlers.ValidateResponse.Clear()
State: "failure",
Matcher: "pathAny", reqNum := 0
Argument: "States[].State", resps := []*MockOutput{
Expected: "stopping", { // Request 1
}, States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
waiterCfg := waiter.Config{
Operation: "Mock",
Delay: 0,
MaxAttempts: 10,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "States[].State",
Expected: "running",
},
},
}
w := waiter.Waiter{
Client: svc,
Input: &MockInput{},
Config: waiterCfg,
}
err := w.Wait()
assert.NoError(t, err)
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
} }
func TestWaiter(t *testing.T) { func TestWaiterPath(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{ svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"), Region: aws.String("mock-region"),
})} })}
...@@ -73,13 +130,13 @@ func TestWaiter(t *testing.T) { ...@@ -73,13 +130,13 @@ func TestWaiter(t *testing.T) {
{State: aws.String("pending")}, {State: aws.String("pending")},
}, },
}, },
{ // Request 1 { // Request 2
States: []*MockState{ States: []*MockState{
{State: aws.String("running")}, {State: aws.String("running")},
{State: aws.String("pending")}, {State: aws.String("pending")},
}, },
}, },
{ // Request 1 { // Request 3
States: []*MockState{ States: []*MockState{
{State: aws.String("running")}, {State: aws.String("running")},
{State: aws.String("running")}, {State: aws.String("running")},
...@@ -104,7 +161,14 @@ func TestWaiter(t *testing.T) { ...@@ -104,7 +161,14 @@ func TestWaiter(t *testing.T) {
Operation: "Mock", Operation: "Mock",
Delay: 0, Delay: 0,
MaxAttempts: 10, MaxAttempts: 10,
Acceptors: mockAcceptors, Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "path",
Argument: "States[].State",
Expected: "running",
},
},
} }
w := waiter.Waiter{ w := waiter.Waiter{
Client: svc, Client: svc,
...@@ -135,13 +199,13 @@ func TestWaiterFailure(t *testing.T) { ...@@ -135,13 +199,13 @@ func TestWaiterFailure(t *testing.T) {
{State: aws.String("pending")}, {State: aws.String("pending")},
}, },
}, },
{ // Request 1 { // Request 2
States: []*MockState{ States: []*MockState{
{State: aws.String("running")}, {State: aws.String("running")},
{State: aws.String("pending")}, {State: aws.String("pending")},
}, },
}, },
{ // Request 1 { // Request 3
States: []*MockState{ States: []*MockState{
{State: aws.String("running")}, {State: aws.String("running")},
{State: aws.String("stopping")}, {State: aws.String("stopping")},
...@@ -166,7 +230,20 @@ func TestWaiterFailure(t *testing.T) { ...@@ -166,7 +230,20 @@ func TestWaiterFailure(t *testing.T) {
Operation: "Mock", Operation: "Mock",
Delay: 0, Delay: 0,
MaxAttempts: 10, MaxAttempts: 10,
Acceptors: mockAcceptors, Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "States[].State",
Expected: "running",
},
{
State: "failure",
Matcher: "pathAny",
Argument: "States[].State",
Expected: "stopping",
},
},
} }
w := waiter.Waiter{ w := waiter.Waiter{
Client: svc, Client: svc,
...@@ -181,3 +258,134 @@ func TestWaiterFailure(t *testing.T) { ...@@ -181,3 +258,134 @@ func TestWaiterFailure(t *testing.T) {
assert.Equal(t, 3, numBuiltReq) assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum) assert.Equal(t, 3, reqNum)
} }
func TestWaiterError(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
},
},
{ // Request 2, error case
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
},
},
}
numBuiltReq := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
numBuiltReq++
})
svc.Handlers.Send.PushBack(func(r *request.Request) {
if reqNum == 1 {
r.Error = awserr.New("MockException", "mock exception message", nil)
r.HTTPResponse = &http.Response{
StatusCode: 400,
Status: http.StatusText(400),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
reqNum++
}
})
svc.Handlers.Unmarshal.PushBack(func(r *request.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
return
}
r.Data = resps[reqNum]
reqNum++
})
waiterCfg := waiter.Config{
Operation: "Mock",
Delay: 0,
MaxAttempts: 10,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "pathAll",
Argument: "States[].State",
Expected: "running",
},
{
State: "retry",
Matcher: "error",
Argument: "",
Expected: "MockException",
},
},
}
w := waiter.Waiter{
Client: svc,
Input: &MockInput{},
Config: waiterCfg,
}
err := w.Wait()
assert.NoError(t, err)
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
}
func TestWaiterStatus(t *testing.T) {
svc := &mockClient{Client: awstesting.NewClient(&aws.Config{
Region: aws.String("mock-region"),
})}
svc.Handlers.Send.Clear() // mock sending
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.ValidateResponse.Clear()
reqNum := 0
svc.Handlers.Build.PushBack(func(r *request.Request) {
reqNum++
})
svc.Handlers.Send.PushBack(func(r *request.Request) {
code := 200
if reqNum == 3 {
code = 404
}
r.HTTPResponse = &http.Response{
StatusCode: code,
Status: http.StatusText(code),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
})
waiterCfg := waiter.Config{
Operation: "Mock",
Delay: 0,
MaxAttempts: 10,
Acceptors: []waiter.WaitAcceptor{
{
State: "success",
Matcher: "status",
Argument: "",
Expected: 404,
},
},
}
w := waiter.Waiter{
Client: svc,
Input: &MockInput{},
Config: waiterCfg,
}
err := w.Wait()
assert.NoError(t, err)
assert.Equal(t, 3, reqNum)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment