Add tests
authorSofian Brabez <sbz@6dev.net>
Wed, 25 Jul 2018 16:50:21 +0000 (17:50 +0100)
committerSofian Brabez <sbz@6dev.net>
Wed, 25 Jul 2018 16:50:21 +0000 (17:50 +0100)
aws.go
aws_test.go [new file with mode: 0644]
vault.go

diff --git a/aws.go b/aws.go
index 6f4a2dd79f8924fca6f82fea2e07e74cd7fdf2c0..e3b6d2c1656d22baa30194183c597ef8fb1c5c85 100644 (file)
--- a/aws.go
+++ b/aws.go
@@ -6,41 +6,77 @@ import (
        "github.com/aws/aws-sdk-go/aws/session"
        "github.com/aws/aws-sdk-go/service/ec2"
 
+       "errors"
+
        log "github.com/sirupsen/logrus"
 )
 
-func GetInstancesPrivateIps(filter map[string]string) ([]string, error) {
+var (
+       errInvalidFilter         = errors.New("invalid filter map input")
+       errMissingKeyNameFilter  = errors.New("filter doesn't have 'Name' key")
+       errMissingKeyValueFilter = errors.New("filter doesn't have 'Value' key")
+       errNoExistingInstances   = errors.New("no exisiting instances returned")
+)
 
-       sess := session.Must(session.NewSession(&aws.Config{
-               Region: aws.String(endpoints.EuCentral1RegionID),
-       }))
+type Service interface {
+       DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error)
+}
+
+type AWSService struct {
+       service Service
+}
 
-       service := ec2.New(sess)
+func (a *AWSService) GetInstancesPrivateIps(filter map[string]string) ([]string, error) {
+       var name, value string
+       var ok bool
+
+       if filter == nil {
+               return nil, errInvalidFilter
+       }
+
+       if name, ok = filter["Name"]; !ok {
+               return nil, errMissingKeyNameFilter
+       }
+
+       if value, ok = filter["Value"]; !ok {
+               return nil, errMissingKeyValueFilter
+       }
 
        input := &ec2.DescribeInstancesInput{
                Filters: []*ec2.Filter{
                        {
-                               Name:   aws.String(filter["Name"]),
-                               Values: []*string{aws.String(filter["Value"])},
+                               Name:   aws.String(name),
+                               Values: []*string{aws.String(value)},
                        },
                },
        }
 
-       result, err := service.DescribeInstances(input)
+       result, err := a.service.DescribeInstances(input)
        if err != nil {
                log.Debugf("Unable to DescribeInstances using input: %+v", result)
                return nil, err
        }
 
        nodeIps := make([]string, 0, len(result.Reservations))
-
        for _, r := range result.Reservations {
                if len(r.Instances) > 0 {
                        nodeIps = append(nodeIps, *r.Instances[0].PrivateIpAddress)
                }
        }
 
+       if len(nodeIps) == 0 {
+               return nil, errNoExistingInstances
+       }
+
        log.Debugf("Found %d private IPs: %v", len(nodeIps), nodeIps)
 
        return nodeIps, nil
 }
+
+func newEC2Provider() *AWSService {
+       return &AWSService{
+               service: ec2.New(session.Must(session.NewSession(&aws.Config{
+                       Region: aws.String(endpoints.EuCentral1RegionID),
+               }))),
+       }
+}
diff --git a/aws_test.go b/aws_test.go
new file mode 100644 (file)
index 0000000..353d1e9
--- /dev/null
@@ -0,0 +1,67 @@
+package main
+
+import (
+       "os"
+       "testing"
+
+       "github.com/aws/aws-sdk-go/service/ec2"
+)
+
+func init() {
+       os.Unsetenv("AWS_ACCESS_KEY_ID")
+       os.Unsetenv("AWS_SECRET_ACCESS_KEY_ID")
+}
+
+type mockService struct{}
+
+func (s *mockService) DescribeInstances(*ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) {
+       privateIp := "172.16.1.2"
+       mockResult := ec2.DescribeInstancesOutput{
+               Reservations: []*ec2.Reservation{
+                       {
+                               Instances: []*ec2.Instance{
+                                       {
+                                               PrivateIpAddress: &privateIp,
+                                       },
+                               },
+                       },
+               },
+       }
+
+       return &mockResult, nil
+}
+
+func Test_GetInstancesPrivateIps(t *testing.T) {
+       var mockFilter map[string]string
+       mockAWSInstance := &AWSService{service: &mockService{}}
+       nodeIps, err := mockAWSInstance.GetInstancesPrivateIps(mockFilter)
+       if nodeIps != nil {
+               t.Errorf("Filter is empty want %s, got %s", nodeIps, err)
+       }
+
+       mockFilter = map[string]string{"Name": "Location", "Value": "Mordor"}
+       nodeIps, err = mockAWSInstance.GetInstancesPrivateIps(mockFilter)
+       if err != nil {
+               t.Errorf("Unexpected filter error, got %s", err)
+       }
+
+       if len(nodeIps) != 1 {
+               t.Errorf("Want %s, got %s", []string{"172.16.1.2"}, nodeIps)
+       }
+
+       mockFilter = map[string]string{"Name": "Location"}
+       nodeIps, err = mockAWSInstance.GetInstancesPrivateIps(mockFilter)
+       if err == nil {
+               t.Errorf("Expected filter error, got %s", err)
+       }
+
+       mockFilter = map[string]string{"Value": "Mordor"}
+       nodeIps, err = mockAWSInstance.GetInstancesPrivateIps(mockFilter)
+       if err == nil {
+               t.Errorf("Expected filter error, got %s", err)
+       }
+
+       if len(nodeIps) != 0 {
+               t.Errorf("Want nil, got %s", nodeIps)
+       }
+}
index 4efd63e5f9b2b0d8c7f493a33b7983bdecbf761f..4c212e7683d7f8f8134b53aae450bf7d3b3fb597 100644 (file)
--- a/vault.go
+++ b/vault.go
@@ -51,7 +51,8 @@ func (v *Vault) CheckCluster() (bool, error) {
        vaultFilter["Name"] = "tag:Cluster"
        vaultFilter["Value"] = "theknowledge"
 
-       vaultNodeIps, err := GetInstancesPrivateIps(vaultFilter)
+       provider := newEC2Provider()
+       vaultNodeIps, err := provider.GetInstancesPrivateIps(vaultFilter)
        if err != nil {
                return false, err
        }