Quellcode durchsuchen

backend: fixes and cleanups in awsvpc backend

Vaidas Jablonskis vor 8 Jahren
Ursprung
Commit
921e7a9e22
1 geänderte Dateien mit 65 neuen und 60 gelöschten Zeilen
  1. 65 60
      backend/awsvpc/awsvpc.go

+ 65 - 60
backend/awsvpc/awsvpc.go

@@ -22,6 +22,7 @@ import (
 	"github.com/aws/aws-sdk-go/aws"
 	"github.com/aws/aws-sdk-go/aws/awserr"
 	"github.com/aws/aws-sdk-go/aws/ec2metadata"
+	"github.com/aws/aws-sdk-go/aws/session"
 	"github.com/aws/aws-sdk-go/service/ec2"
 	log "github.com/golang/glog"
 	"golang.org/x/net/context"
@@ -80,47 +81,50 @@ func (be *AwsVpcBackend) RegisterNetwork(ctx context.Context, network string, co
 		return nil, fmt.Errorf("failed to acquire lease: %v", err)
 	}
 
+	sess, _ := session.NewSession(aws.NewConfig().WithMaxRetries(5))
+
 	// Figure out this machine's EC2 instance ID and region
-	metadataClient := ec2metadata.New(nil)
+	metadataClient := ec2metadata.New(sess)
 	region, err := metadataClient.Region()
 	if err != nil {
 		return nil, fmt.Errorf("error getting EC2 region name: %v", err)
 	}
+	sess.Config.Region = aws.String(region)
 	instanceID, err := metadataClient.GetMetadata("instance-id")
 	if err != nil {
 		return nil, fmt.Errorf("error getting EC2 instance ID: %v", err)
 	}
 
-	ec2c := ec2.New(&aws.Config{Region: aws.String(region)})
+	ec2c := ec2.New(sess)
+
+	// Find ENI which contains the external network interface IP address
+	eni, err := be.findENI(instanceID, ec2c)
+	if err != nil || eni == nil {
+		return nil, fmt.Errorf("unable to find ENI that matches the %s IP address. %s\n", be.extIface.IfaceAddr, err)
+	}
 
-	if _, err = be.disableSrcDestCheck(instanceID, ec2c); err != nil {
-		log.Infof("Warning- disabling source destination check failed: %v", err)
+	// Try to disable SourceDestCheck on the main network interface
+	if err := be.disableSrcDestCheck(eni.NetworkInterfaceId, ec2c); err != nil {
+		log.Warningf("failed to disable SourceDestCheck on %s: %s.\n", *eni.NetworkInterfaceId, err)
 	}
 
 	if cfg.RouteTableID == "" {
-		log.Infof("RouteTableID not passed as config parameter, detecting ...")
-		if cfg.RouteTableID, err = be.detectRouteTableID(instanceID, ec2c); err != nil {
+		if cfg.RouteTableID, err = be.detectRouteTableID(eni, ec2c); err != nil {
 			return nil, err
 		}
+		log.Infof("Found route table %s.\n", cfg.RouteTableID)
 	}
 
-	log.Info("RouteRouteTableID: ", cfg.RouteTableID)
 	networkConfig, err := be.sm.GetNetworkConfig(ctx, network)
 
-	err = be.cleanupInvalidRoutes(cfg.RouteTableID, networkConfig.Network, ec2c)
+	err = be.cleanupBlackholeRoutes(cfg.RouteTableID, networkConfig.Network, ec2c)
 	if err != nil {
-		log.Errorf("Error cleaning up route table: %v", err)
+		log.Errorf("Error cleaning up blackhole routes: %v", err)
 	}
 
-	matchingRouteFound, err := be.checkMatchingRoutes(cfg.RouteTableID, instanceID, l.Subnet.String(), ec2c)
+	matchingRouteFound, err := be.checkMatchingRoutes(cfg.RouteTableID, l.Subnet.String(), eni.NetworkInterfaceId, ec2c)
 	if err != nil {
 		log.Errorf("Error describing route tables: %v", err)
-
-		if ec2Err, ok := err.(awserr.Error); ok {
-			if ec2Err.Code() == "UnauthorizedOperation" {
-				log.Errorf("Note: DescribeRouteTables permission cannot be bound to any resource")
-			}
-		}
 	}
 
 	if !matchingRouteFound {
@@ -134,7 +138,7 @@ func (be *AwsVpcBackend) RegisterNetwork(ctx context.Context, network string, co
 		}
 
 		// Add the route for this machine's subnet
-		if _, err := be.createRoute(cfg.RouteTableID, instanceID, l.Subnet.String(), ec2c); err != nil {
+		if err := be.createRoute(cfg.RouteTableID, l.Subnet.String(), eni.NetworkInterfaceId, ec2c); err != nil {
 			return nil, fmt.Errorf("unable to add route %s: %v", l.Subnet.String(), err)
 		}
 	}
@@ -145,7 +149,7 @@ func (be *AwsVpcBackend) RegisterNetwork(ctx context.Context, network string, co
 	}, nil
 }
 
-func (be *AwsVpcBackend) cleanupInvalidRoutes(routeTableID string, network ip.IP4Net, ec2c *ec2.EC2) error {
+func (be *AwsVpcBackend) cleanupBlackholeRoutes(routeTableID string, network ip.IP4Net, ec2c *ec2.EC2) error {
 	filter := newFilter()
 	filter.Add("route.state", "blackhole")
 
@@ -160,7 +164,7 @@ func (be *AwsVpcBackend) cleanupInvalidRoutes(routeTableID string, network ip.IP
 			if *route.State == "blackhole" && route.DestinationCidrBlock != nil {
 				_, subnet, err := net.ParseCIDR(*route.DestinationCidrBlock)
 				if err == nil && network.Contains(ip.FromIP(subnet.IP)) {
-					log.Info("Removing route: ", *route.DestinationCidrBlock)
+					log.Info("Removing blackhole route: ", *route.DestinationCidrBlock)
 					deleteRouteInput := &ec2.DeleteRouteInput{RouteTableId: &routeTableID, DestinationCidrBlock: route.DestinationCidrBlock}
 					if _, err := ec2c.DeleteRoute(deleteRouteInput); err != nil {
 						if ec2err, ok := err.(awserr.Error); !ok || ec2err.Code() != "InvalidRoute.NotFound" {
@@ -176,7 +180,7 @@ func (be *AwsVpcBackend) cleanupInvalidRoutes(routeTableID string, network ip.IP
 	return nil
 }
 
-func (be *AwsVpcBackend) checkMatchingRoutes(routeTableID, instanceID, subnet string, ec2c *ec2.EC2) (bool, error) {
+func (be *AwsVpcBackend) checkMatchingRoutes(routeTableID, subnet string, eniID *string, ec2c *ec2.EC2) (bool, error) {
 	matchingRouteFound := false
 
 	filter := newFilter()
@@ -192,14 +196,10 @@ func (be *AwsVpcBackend) checkMatchingRoutes(routeTableID, instanceID, subnet st
 
 	for _, routeTable := range resp.RouteTables {
 		for _, route := range routeTable.Routes {
-			if route.DestinationCidrBlock != nil && subnet == *route.DestinationCidrBlock && *route.State == "active" {
-
-				if *route.InstanceId == instanceID {
-					matchingRouteFound = true
-					break
-				}
-
-				log.Errorf("Deleting invalid *active* matching route: %s, %s \n", *route.DestinationCidrBlock, *route.InstanceId)
+			if route.DestinationCidrBlock != nil && subnet == *route.DestinationCidrBlock &&
+				*route.State == "active" && route.NetworkInterfaceId == eniID {
+				matchingRouteFound = true
+				break
 			}
 		}
 	}
@@ -207,48 +207,34 @@ func (be *AwsVpcBackend) checkMatchingRoutes(routeTableID, instanceID, subnet st
 	return matchingRouteFound, nil
 }
 
-func (be *AwsVpcBackend) createRoute(routeTableID, instanceID, subnet string, ec2c *ec2.EC2) (*ec2.CreateRouteOutput, error) {
+func (be *AwsVpcBackend) createRoute(routeTableID, subnet string, eniID *string, ec2c *ec2.EC2) error {
 	route := &ec2.CreateRouteInput{
 		RouteTableId:         &routeTableID,
-		InstanceId:           &instanceID,
+		NetworkInterfaceId:   eniID,
 		DestinationCidrBlock: &subnet,
 	}
 
-	return ec2c.CreateRoute(route)
-}
-
-func (be *AwsVpcBackend) disableSrcDestCheck(instanceID string, ec2c *ec2.EC2) (*ec2.ModifyInstanceAttributeOutput, error) {
-	modifyAttributes := &ec2.ModifyInstanceAttributeInput{
-		InstanceId:      aws.String(instanceID),
-		SourceDestCheck: &ec2.AttributeBooleanValue{Value: aws.Bool(false)},
+	if _, err := ec2c.CreateRoute(route); err != nil {
+		return err
 	}
-
-	return ec2c.ModifyInstanceAttribute(modifyAttributes)
+	log.Infof("Route added %s - %s.\n", subnet, *eniID)
+	return nil
 }
 
-func (be *AwsVpcBackend) detectRouteTableID(instanceID string, ec2c *ec2.EC2) (string, error) {
-	instancesInput := &ec2.DescribeInstancesInput{
-		InstanceIds: []*string{&instanceID},
-	}
-
-	resp, err := ec2c.DescribeInstances(instancesInput)
-	if err != nil {
-		return "", fmt.Errorf("error getting instance info: %v", err)
-	}
-
-	if len(resp.Reservations) == 0 {
-		return "", fmt.Errorf("no reservations found")
-	}
-
-	if len(resp.Reservations[0].Instances) == 0 {
-		return "", fmt.Errorf("no matching instance found with id: %v", instanceID)
+func (be *AwsVpcBackend) disableSrcDestCheck(eniID *string, ec2c *ec2.EC2) error {
+	attr := &ec2.ModifyNetworkInterfaceAttributeInput{
+		NetworkInterfaceId: eniID,
+		SourceDestCheck:    &ec2.AttributeBooleanValue{Value: aws.Bool(false)},
 	}
+	_, err := ec2c.ModifyNetworkInterfaceAttribute(attr)
+	return err
+}
 
-	subnetID := resp.Reservations[0].Instances[0].SubnetId
-	vpcID := resp.Reservations[0].Instances[0].VpcId
-
-	log.Info("Subnet-ID: ", *subnetID)
-	log.Info("VPC-ID: ", *vpcID)
+// detectRouteTableID detect the routing table that is associated with the ENI,
+// subnet can be implicitly associated with the main routing table
+func (be *AwsVpcBackend) detectRouteTableID(eni *ec2.InstanceNetworkInterface, ec2c *ec2.EC2) (string, error) {
+	subnetID := eni.SubnetId
+	vpcID := eni.VpcId
 
 	filter := newFilter()
 	filter.Add("association.subnet-id", *subnetID)
@@ -285,3 +271,22 @@ func (be *AwsVpcBackend) detectRouteTableID(instanceID string, ec2c *ec2.EC2) (s
 
 	return *res.RouteTables[0].RouteTableId, nil
 }
+
+func (be *AwsVpcBackend) findENI(instanceID string, ec2c *ec2.EC2) (*ec2.InstanceNetworkInterface, error) {
+	instance, err := ec2c.DescribeInstances(&ec2.DescribeInstancesInput{
+		InstanceIds: []*string{aws.String(instanceID)}},
+	)
+	if err != nil {
+		return nil, err
+	}
+
+	for _, n := range instance.Reservations[0].Instances[0].NetworkInterfaces {
+		for _, a := range n.PrivateIpAddresses {
+			if *a.PrivateIpAddress == be.extIface.IfaceAddr.String() {
+				log.Infof("Found %s that has %s IP address.\n", *n.NetworkInterfaceId, be.extIface.IfaceAddr)
+				return n, nil
+			}
+		}
+	}
+	return nil, err
+}