Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Something put together quickly to automatically associate Elastic IP address with the current EC2 instance.
aws dynamodb create-table --region <REGION> --profile <PROFILE> --table-name "<TABLE>" --attribute-definitions "AttributeName=key,AttributeType=S" --key-schema "AttributeName=key,KeyType=HASH" --billing-mode "PAY_PER_REQUEST"
aws dynamodb create-table --region <REGION> --profile <PROFILE> --table-name "<TABLE>" --attribute-definitions "AttributeName=key,AttributeType=S" --key-schema "AttributeName=key,KeyType=HASH" --billing-mode "PROVISIONED" --provisioned-throughput "ReadCapacityUnits=1,WriteCapacityUnits=1"
aws dynamodb update-time-to-live --region <REGION> --profile <PROFILE> --table-name "<TABLE>" --time-to-live-specification "Enabled=true,AttributeName=ttl"
aws dynamodb scan --region <REGION> --profile <PROFILE> --table-name "<TABLE>"
// ec2-address-association
// $ go get -v github.com/aws/aws-sdk-go/aws
// $ go get -v github.com/google/uuid
// $ go get -v github.com/hashicorp/logutils
// $ go get -v github.com/pkg/errors
// $ go get -v gopkg.in/alecthomas/kingpin.v2
// $ GOOS=linux go build -o ec2-address-association .
package main
import (
"bytes"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
"github.com/aws/aws-sdk-go/service/dynamodb/expression"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/google/uuid"
"github.com/hashicorp/logutils"
"github.com/pkg/errors"
kingpin "gopkg.in/alecthomas/kingpin.v2"
)
const (
version = "0.1.0"
defaultLogLevel = "INFO"
defaultTagPrefix = "ec2:address"
defaultInstanceTagName = "allocation-id"
defaultAddressTagName = "range"
defaultAddressTagValue = "true"
defaultAddressRetry = 5
defaultLockTableName = "ec2-address-range"
defaultLockName = "allocation-lock"
defaultLockAttempts = 30
)
var defaultLockTTL = 24 * time.Hour
type filterAddressFunc func(*ec2.Address) bool
type Address struct {
*ec2.Address
*ec2.EC2
}
func (a *Address) Available() bool {
return a.AssociationId == nil
}
type AssociateAddressInput struct {
InstanceID string
DryRun bool
}
func (a *Address) Associate(input *AssociateAddressInput) (*Address, error) {
associateParams := &ec2.AssociateAddressInput{
AllocationId: a.AllocationId,
InstanceId: aws.String(input.InstanceID),
DryRun: aws.Bool(input.DryRun),
}
log.Printf("[TRACE] Associate address request: %v\n", associateParams)
const (
awsErrorCode = "InvalidInstanceID"
awsDryRunCode = "DryRunOperation"
)
var associationID string
err := retryFunction(5, 250*time.Millisecond, func() error {
resp, err := a.AssociateAddress(associateParams)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
log.Printf(
"[DEBUG] EC2 instance %s not ready. Retrying...\n",
input.InstanceID,
)
return err
}
if ok && awsErr.Code() == awsDryRunCode {
log.Println("[DEBUG] Dry run completed successfully.")
return nil
}
return retryStop{err}
}
associationID = aws.StringValue(resp.AssociationId)
return nil
})
if err != nil {
return nil, errors.Wrap(err, "unable to associate address")
}
describeParams := &ec2.DescribeAddressesInput{
Filters: []*ec2.Filter{
ec2Filter("association-id", associationID),
},
}
if input.DryRun {
describeParams.Filters = []*ec2.Filter{
ec2Filter("allocation-id", aws.StringValue(a.AllocationId)),
}
}
log.Printf("[TRACE] Describe address request: %v\n", describeParams)
resp, err := a.DescribeAddresses(describeParams)
if err != nil {
return nil, errors.Wrap(err, "unable to describe address")
}
return &Address{
resp.Addresses[0],
a.EC2,
}, nil
}
type Instance struct {
*ec2.Instance
*ec2.EC2
}
type TagsInput struct{}
func (i *Instance) Tags(_ *TagsInput) map[string]string {
return buildTags(i.Instance.Tags)
}
type CreateTagsInput struct {
Tags map[string]string
}
func (i *Instance) CreateTags(input *CreateTagsInput) (*Instance, error) {
params := &ec2.CreateTagsInput{
Resources: []*string{
i.InstanceId,
},
Tags: buildEC2Tags(input.Tags),
}
log.Printf("[TRACE] Create tags request: %v\n", params)
const awsErrorCode = ".NotFound"
err := retryFunction(5, 250*time.Millisecond, func() error {
_, err := i.EC2.CreateTags(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && strings.Contains(awsErr.Code(), awsErrorCode) {
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return nil, errors.Wrap(err, "unable to create tags")
}
return i, nil
}
type Bucket struct {
*s3.Bucket
*s3.S3
}
type ListCurrentObjectsInput struct {
Prefix string
}
func (b *Bucket) ListCurrentObjects(input *ListCurrentObjectsInput) ([]*s3.Object, error) {
params := &s3.ListObjectsInput{
Bucket: b.Bucket.Name,
}
if input.Prefix != "" {
params.Prefix = aws.String(input.Prefix)
}
log.Printf("[TRACE] List objects request: %v\n", params)
const awsInternalError = "InternalError"
var objects []*s3.Object
err := retryFunction(5, 250*time.Millisecond, func() error {
resp, err := b.ListObjects(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsInternalError {
log.Println("[DEBUG] S3 internal error. Retrying...")
return err
}
return retryStop{err}
}
objects = resp.Contents
return nil
})
if err != nil {
return nil, errors.Wrap(err, "unable to list objects")
}
log.Printf("[DEBUG] Found object: %v\n", objects)
return objects, nil
}
type PutObjectInput struct {
Bucket string
Key string
Body io.ReadSeeker
}
func (b *Bucket) PutObject(input *PutObjectInput) (*Bucket, error) {
params := &s3.PutObjectInput{
Bucket: aws.String(input.Bucket),
Key: aws.String(input.Key),
ACL: aws.String("private"),
Body: input.Body,
}
log.Printf("[TRACE] Put object request: %v\n", params)
const awsInternalError = "InternalError"
err := retryFunction(5, 250*time.Millisecond, func() error {
_, err := b.S3.PutObject(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsInternalError {
log.Println("[DEBUG] S3 internal error. Retrying...")
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return nil, errors.Wrap(err, "unable to to put object")
}
return b, nil
}
type lock struct {
name string
identifier string
table string
renew bool
leaseDuration time.Duration
renewInterval time.Duration
lastError error
}
type Lock struct {
*lock
*dynamodb.DynamoDB
}
func (l *Lock) Lock() (*Lock, error) {
err := l.createLock()
if err != nil {
return nil, err
}
go l.renewLock()
return l, nil
}
func (l *Lock) Unlock() (*Lock, error) {
l.renew = false
err := l.removeLock()
if err != nil {
return nil, err
}
return l, nil
}
func (l *Lock) Locked() (bool, error) {
return l.checkLock()
}
func (l *Lock) Error() error {
return l.lock.lastError
}
type TryLockInput struct {
Attempts int
}
func (l *Lock) TryLock(input *TryLockInput) (*Lock, error) {
log.Printf("[DEBUG] Lock acquire attempts: %d\n", input.Attempts)
const awsErrorCode = "ConditionalCheckFailedException"
attempt := input.Attempts
err := retryFunction(input.Attempts, 250*time.Millisecond, func() error {
_, err := l.Lock()
if err != nil {
awsErr, ok := errors.Cause(err).(awserr.RequestFailure)
if ok && awsErr.Code() == awsErrorCode {
log.Printf(
"[DEBUG] Unable to acquire lock. Retrying... (%d/%d)\n",
attempt,
input.Attempts,
)
attempt--
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return nil, err
}
log.Printf("[DEBUG] Acquired lock: %s\n", l.lock.name)
return l, nil
}
func (l *Lock) createLock() error {
builder, err := expression.NewBuilder().WithCondition(
expression.Or(
expression.Name("key").NotEqual(expression.Value(l.lock.name)),
expression.Name("identifier").Equal(expression.Value(l.lock.identifier)),
expression.Name("expiry").LessThan(expression.Value(time.Now().UnixNano())),
),
).Build()
if err != nil {
return errors.Wrap(err, "unable to build expression")
}
t := time.Now()
item, err := dynamodbattribute.MarshalMap(map[string]interface{}{
"key": l.lock.name,
"identifier": l.lock.identifier,
"expiry": t.Add(l.lock.leaseDuration / time.Nanosecond).UnixNano(),
"ttl": t.Add(defaultLockTTL).Unix(),
})
if err != nil {
return errors.Wrap(err, "unable to marshal item")
}
params := &dynamodb.PutItemInput{
ExpressionAttributeNames: builder.Names(),
ExpressionAttributeValues: builder.Values(),
ConditionExpression: builder.Condition(),
TableName: aws.String(l.lock.table),
Item: item,
}
log.Printf("[TRACE] Put item request: %v\n", params)
const awsErrorCode = "ProvisionedThroughputExceededException"
err = retryFunction(5, 250*time.Millisecond, func() error {
_, err = l.DynamoDB.PutItem(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
log.Println("[DEBUG] Provisioned throughput exceeded. Retrying...")
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return errors.Wrap(err, "unable to put item")
}
return nil
}
func (l *Lock) removeLock() error {
builder, err := expression.NewBuilder().WithCondition(
expression.Or(
expression.Name("key").Equal(expression.Value(l.lock.name)),
expression.Name("identifier").Equal(expression.Value(l.lock.identifier)),
),
).Build()
if err != nil {
return errors.Wrap(err, "unable to build expression")
}
item, err := dynamodbattribute.MarshalMap(map[string]interface{}{
"key": l.lock.name,
})
if err != nil {
return errors.Wrap(err, "unable to marshal item")
}
params := &dynamodb.DeleteItemInput{
ExpressionAttributeNames: builder.Names(),
ExpressionAttributeValues: builder.Values(),
ConditionExpression: builder.Condition(),
TableName: aws.String(l.lock.table),
Key: item,
}
log.Printf("[TRACE] Delete item request: %v\n", params)
const awsErrorCode = "ProvisionedThroughputExceededException"
err = retryFunction(5, 250*time.Millisecond, func() error {
_, err = l.DynamoDB.DeleteItem(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
log.Println("[DEBUG] Provisioned throughput exceeded. Retrying...")
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return errors.Wrap(err, "unable to delete item")
}
return nil
}
func (l *Lock) checkLock() (bool, error) {
builder, err := expression.NewBuilder().WithFilter(
expression.Or(
expression.Name("key").Equal(expression.Value(l.lock.name)),
expression.Name("identifier").Equal(expression.Value(l.lock.identifier)),
expression.Name("expiry").GreaterThan(expression.Value(time.Now().UnixNano())),
),
).WithProjection(
expression.NamesList(
expression.Name("key"),
expression.Name("identifier"),
expression.Name("expiry"),
),
).Build()
if err != nil {
return false, errors.Wrap(err, "unable to build expression")
}
params := &dynamodb.ScanInput{
ExpressionAttributeNames: builder.Names(),
ExpressionAttributeValues: builder.Values(),
FilterExpression: builder.Filter(),
ProjectionExpression: builder.Projection(),
TableName: aws.String(l.lock.table),
}
log.Printf("[TRACE] Scan items request: %v\n", params)
const awsErrorCode = "ProvisionedThroughputExceededException"
var count int64
err = retryFunction(5, 250*time.Millisecond, func() error {
resp, scanErr := l.DynamoDB.Scan(params)
if err != nil {
awsErr, ok := scanErr.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
log.Println("[DEBUG] Provisioned throughput exceeded. Retrying...")
return scanErr
}
return retryStop{scanErr}
}
count = *resp.Count
return nil
})
if err != nil {
return false, errors.Wrap(err, "unable to scan items")
}
return count > 0, l.lock.lastError
}
func (l *Lock) renewLock() {
l.lock.renew = true
for {
log.Printf("[DEBUG] Sleeping before renewing lease on lock: %s\n", l.lock.renewInterval)
time.Sleep(l.lock.renewInterval)
if !l.lock.renew {
break
}
log.Printf("[DEBUG] Renewing lease on lock: %s\n", l.lock.name)
err := l.createLock()
if err != nil {
log.Printf("[DEBUG] Unable to renew lease on lock: %s: %s\n", l.lock.name, err)
l.lock.lastError = err
break
}
}
log.Printf("[DEBUG] Stopping to renew lease on lock: %s\n", l.lock.name)
}
type ClientOpt func(*Client)
type Client struct {
debug bool
trace bool
profile string
region string
s3 *s3.S3
dynamoDB *dynamodb.DynamoDB
ec2 *ec2.EC2
ec2Metadata *ec2metadata.EC2Metadata
session *session.Session
}
func WithDebug(debug bool) ClientOpt {
return func(c *Client) {
c.debug = debug
}
}
func WithTrace(trace bool) ClientOpt {
return func(c *Client) {
c.trace = trace
}
}
func WithProfile(profile string) ClientOpt {
return func(c *Client) {
c.profile = profile
}
}
func WithRegion(region string) ClientOpt {
return func(c *Client) {
c.region = region
}
}
func NewClient(options ...ClientOpt) (*Client, error) {
c := &Client{}
for _, option := range options {
option(c)
}
config := aws.NewConfig()
if c.trace {
config.WithLogLevel(aws.LogDebugWithHTTPBody)
}
c.ec2Metadata = ec2metadata.New(session.Must(session.NewSession(
config.WithHTTPClient(&http.Client{
Timeout: 2 * time.Second,
}).WithMaxRetries(2),
)))
if !c.ec2Metadata.Available() {
return nil, errors.New("metadata service not available")
}
c.session = session.Must(session.NewSessionWithOptions(session.Options{
Profile: c.profile,
}))
c.session.Handlers.Retry.PushFront(func(r *request.Request) {
if r.IsErrorExpired() {
log.Println("[DEBUG] Credentials expired. Stop retrying...")
r.Retryable = aws.Bool(false)
}
})
if c.region == "" {
if c.region = aws.StringValue(c.session.Config.Region); c.region == "" {
metadataRegion, err := c.Region()
if err != nil {
return nil, errors.Wrap(err, "unable to retrieve region")
}
c.region = metadataRegion
}
}
config = config.WithRegion(c.region)
c.s3 = s3.New(c.session, config)
c.dynamoDB = dynamodb.New(c.session, config)
c.ec2 = ec2.New(c.session, config)
return c, nil
}
func (c *Client) Available() bool {
return c.ec2Metadata.Available()
}
func (c *Client) Region() (string, error) {
return c.ec2Metadata.Region()
}
func (c *Client) GetMetadata(path string) (string, error) {
return c.ec2Metadata.GetMetadata(path)
}
type DescribeCurrentInstanceInput struct{}
func (c *Client) DescribeCurrentInstance(_ *DescribeCurrentInstanceInput) (*Instance, error) {
instanceID, err := c.GetMetadata("instance-id")
if err != nil {
return nil, errors.Wrap(err, "unable to retrieve metadata")
}
params := &ec2.DescribeInstancesInput{
Filters: []*ec2.Filter{ec2Filter("instance-id", instanceID)},
}
log.Printf("[TRACE] Describe instance request: %v\n", params)
resp, err := c.ec2.DescribeInstances(params)
if err != nil {
return nil, errors.Wrap(err, "unable to describe instances")
}
instance := resp.Reservations[0].Instances[0]
log.Printf("[TRACE] Current EC2 instance: %v\n", instance)
return &Instance{
instance,
c.ec2,
}, nil
}
type DescribeAddressInput struct {
Filters []*ec2.Filter
Include []string
Exclude []string
OnlyAvailable bool
}
func (c *Client) DescribeAddress(input *DescribeAddressInput) (*Address, error) {
params := &ec2.DescribeAddressesInput{
Filters: input.Filters,
}
log.Printf("[TRACE] Describe address request: %v\n", params)
resp, err := c.ec2.DescribeAddresses(params)
if err != nil {
return nil, errors.Wrap(err, "unable to describe addresses")
}
addresses := resp.Addresses
if len(addresses) < 1 {
return nil, errors.New("no addresses found")
}
if input.OnlyAvailable {
addresses = filterAddresses(addresses, func(v *ec2.Address) bool {
return (&Address{v, c.ec2}).Available()
})
}
if len(input.Include) > 0 {
addresses = filterAddresses(addresses, func(v *ec2.Address) bool {
return addressInsideNetwork(aws.StringValue(v.PublicIp), input.Include)
})
}
if len(input.Exclude) > 0 {
addresses = filterAddresses(addresses, func(v *ec2.Address) bool {
return addressOutsideNetwork(aws.StringValue(v.PublicIp), input.Exclude)
})
}
log.Printf("[TRACE] Found addresses: %v\n", addresses)
if len(addresses) < 1 {
return nil, errors.New("no addresses found")
}
type tuple struct {
ip net.IP
address *ec2.Address
}
sortedAddresses := make([]*tuple, len(addresses))
for i, address := range addresses {
sortedAddresses[i] = &tuple{
ip: net.ParseIP(aws.StringValue(address.PublicIp)),
address: address,
}
}
sort.Slice(sortedAddresses, func(i, j int) bool {
return bytes.Compare(sortedAddresses[i].ip, sortedAddresses[j].ip) < 0
})
address := sortedAddresses[0].address
log.Printf("[DEBUG] Found address: %v\n", address)
return &Address{
address,
c.ec2,
}, nil
}
type DescribeBucketInput struct {
Bucket string
}
func (c *Client) DescribeBucket(input *DescribeBucketInput) (*Bucket, error) {
locationParams := &s3.GetBucketLocationInput{
Bucket: aws.String(input.Bucket),
}
log.Printf("[TRACE] Bucket location request: %v\n", locationParams)
resp, err := c.s3.GetBucketLocation(locationParams)
const awsErrorCode = "NoSuchBucket"
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
return nil, errors.New("bucket does not exist")
}
return nil, errors.Wrap(err, "unable to determine bucket location")
}
region := s3.NormalizeBucketLocation(aws.StringValue(resp.LocationConstraint))
log.Printf("[DEBUG] Found bucket region: %v\n", region)
c.s3 = s3.New(c.session, aws.NewConfig().WithRegion(region))
buckets, err := c.s3.ListBuckets(&s3.ListBucketsInput{})
if err != nil {
return nil, errors.Wrap(err, "unable to describe bucket")
}
var bucket *s3.Bucket
for _, b := range buckets.Buckets {
if aws.StringValue(b.Name) == input.Bucket {
bucket = b
break
}
}
return &Bucket{
&s3.Bucket{
Name: bucket.Name,
CreationDate: bucket.CreationDate,
},
c.s3,
}, nil
}
type CreateLockInput struct {
Name string
Table string
LeaseDuration time.Duration
RenewInterval time.Duration
}
func (c *Client) CreateLock(input *CreateLockInput) (*Lock, error) {
l := &lock{
name: input.Name,
table: input.Table,
leaseDuration: input.LeaseDuration,
renewInterval: input.RenewInterval,
}
uuid, err := uuid.NewRandom()
if err != nil {
return nil, errors.Wrap(err, "unable to create unique identifier")
}
l.identifier = uuid.String()
log.Printf("[DEBUG] Unique identifier: %s\n", l.identifier)
return &Lock{
l,
c.dynamoDB,
}, nil
}
func filterAddresses(addresses []*ec2.Address, f filterAddressFunc) (results []*ec2.Address) {
for _, address := range addresses {
if f(address) {
results = append(results, address)
}
}
return results
}
func filterNetwork(address net.IP, addresses []string, inverse bool) (found bool) {
for _, item := range addresses {
_, network, err := net.ParseCIDR(item)
if err == nil && network.Contains(address) {
found = true
break
}
if net.ParseIP(item).Equal(address) {
found = true
break
}
}
if inverse {
found = !found
}
return found
}
func addressInsideNetwork(address string, addresses []string) bool {
return filterNetwork(net.ParseIP(address), addresses, false)
}
func addressOutsideNetwork(address string, addresses []string) bool {
return filterNetwork(net.ParseIP(address), addresses, true)
}
func ec2Filter(name string, values ...string) *ec2.Filter {
return &ec2.Filter{
Name: aws.String(name),
Values: aws.StringSlice(values),
}
}
func buildEC2Filters(filters []string) []*ec2.Filter {
var pairs [][]string
for _, filter := range filters {
if strings.HasPrefix(filter, "Name=") {
var parts []string
for _, s := range strings.SplitN(filter, ",", 2) {
parts = append(parts, strings.Split(s, "=")...)
}
pairs = append(pairs, []string{parts[1], parts[3]})
} else {
for _, s := range strings.Split(filter, ",") {
pairs = append(pairs, strings.Split(s, "="))
}
}
}
var ec2Filters []*ec2.Filter
for _, pair := range pairs {
if len(pair) < 2 || pair[0] == "" || pair[1] == "" {
continue
}
ec2Filters = append(ec2Filters, ec2Filter(pair[0], strings.Split(pair[1], ",")...))
}
return ec2Filters
}
func ec2Tag(name, value string) *ec2.Tag {
return &ec2.Tag{
Key: aws.String(name),
Value: aws.String(value),
}
}
func buildEC2Tags(tags map[string]string) []*ec2.Tag {
ec2Tags := make([]*ec2.Tag, 0, len(tags))
const awsTagPrefix = "aws:"
for key, value := range tags {
if !strings.HasPrefix(key, awsTagPrefix) {
ec2Tags = append(ec2Tags, ec2Tag(key, value))
}
}
return ec2Tags
}
func buildTags(ec2Tags []*ec2.Tag) (tags map[string]string) {
tags = make(map[string]string)
const awsTagPrefix = "aws:"
if len(ec2Tags) > 0 {
var key string
for _, tag := range ec2Tags {
if key = aws.StringValue(tag.Key); !strings.HasPrefix(key, awsTagPrefix) {
tags[key] = aws.StringValue(tag.Value)
}
}
}
return tags
}
type retryStop struct {
error
}
func retryFunction(attempts int, sleep time.Duration, callbackFunc func() error) error {
err := callbackFunc()
if err != nil {
log.Printf("[DEBUG] Retrying attempt: %d\n", attempts)
s, ok := err.(retryStop)
if ok {
return s.error
}
if attempts--; attempts > 0 {
sleep = sleep + jitter(sleep)/2
log.Printf("[DEBUG] Sleeping before retrying: %s\n", sleep)
time.Sleep(sleep)
return retryFunction(attempts, 2*sleep, callbackFunc)
}
return err
}
return nil
}
func jitter(t time.Duration) time.Duration {
return time.Duration(rand.Int63n(int64(t)))
}
func randomDuration() time.Duration {
intervals := []int{1, 2, 3, 5, 7, 11}
index := rand.Int() % len(intervals)
sleep := time.Duration(intervals[index]) * time.Second
sleep = sleep + jitter(sleep)/2
log.Printf("[DEBUG] Duration: %v\n", sleep)
return sleep
}
type config struct {
debug *bool
trace *bool
logLevel *string
dryRun *bool
force *bool
randomSleep *bool
exclusiveLock *bool
disableRecovery *bool
allowReassociation *bool
profile *string
region *string
allocationID *string
publicIP *string
include *string
exclude *string
addressTagName *string
addressTagValue *string
ec2Filters *[]string
instanceTag *string
s3Bucket *string
lockName *string
lockTable *string
}
type filters []string
func (f *filters) Set(value string) error {
*f = append(*f, value)
return nil
}
func (f *filters) IsCumulative() bool {
return true
}
func (f *filters) String() string {
return fmt.Sprint([]string(*f))
}
func (f *filters) Type() string {
return fmt.Sprintf("%T", *f)
}
func Filters(s kingpin.Settings) (values *[]string) {
values = &[]string{}
s.SetValue((*filters)(values))
return values
}
func init() {
rand.Seed(time.Now().UnixNano())
}
func main() {
cliArgs := os.Args[1:]
envArgs := os.Getenv("EC2_ADDRESS_ARGUMENTS")
if len(cliArgs) < 1 && envArgs != "" {
ss := strings.Split(envArgs, ",")
for _, s := range ss {
cliArgs = append(cliArgs, strings.Split(s, " ")...)
}
}
flags := kingpin.New("ec2-address-association", "Automatically associate Elastic IP address with the current EC2 instance.")
config := &config{
debug: flags.Flag("debug", "Print diagnostic information.").Bool(),
logLevel: flags.Flag("log-level", `Set the log level (defaults to "INFO").`).PlaceHolder("LEVEL").String(),
dryRun: flags.Flag("dry-run", "A dry-run only. No Elastic IP association will be made.").Bool(),
force: flags.Flag("force", "Force Elastic IP address association.").Bool(),
exclusiveLock: flags.Flag("exclusive-lock", "Use DynamoDB to set an exclusive lock when searching for Elastic IP address.").Bool(),
randomSleep: flags.Flag("random-sleep", "Sleep for a random interval before searching for Elastic IP address.").Bool(),
disableRecovery: flags.Flag("disable-recovery", "Disable Elastic IP address recovery.").Bool(),
allowReassociation: flags.Flag("allow-reassociation", "Allow to reassociate currently associated Elastic IP address.").Bool(),
profile: flags.Flag("profile", "Specify name of the AWS profile to use.").Short('p').String(),
region: flags.Flag("region", "Specify name of the AWS region to use.").Short('r').String(),
allocationID: flags.Flag("allocation-id", "Specify the allocation ID to use.").PlaceHolder("ID").String(),
publicIP: flags.Flag("public-ip", "Specify the public IP address to use.").PlaceHolder("ADDRESS").String(),
include: flags.Flag("include", "A comma-separated list of IP addresses and networks to include.").PlaceHolder("IP,NETWORK").String(),
exclude: flags.Flag("exclude", "A comma-separated list of IP addresses and networks to exclude.").PlaceHolder("IP,NETWORK").String(),
addressTagName: flags.Flag("tag-name", "Name of the tag to use when searching for an Elastic IP address.").PlaceHolder("NAME").String(),
addressTagValue: flags.Flag("tag-value", "Value of the tag to use when searching for an Elastic IP address.").PlaceHolder("VALUE").String(),
ec2Filters: Filters(flags.Flag("filters", "A valid one or more filters (see AWS CLI for details).").PlaceHolder("FILTER")),
instanceTag: flags.Flag("instance-tag", "Name of the EC2 instance tag to use for Elastic IP recovery.").PlaceHolder("NAME").String(),
s3Bucket: flags.Flag("s3-bucket", "S3 bucket name to use for Elastic IP recovery.").PlaceHolder("NAME").String(),
lockName: flags.Flag("lock-name", "Name of the global lock to use.").PlaceHolder("NAME").String(),
lockTable: flags.Flag("lock-table", "Name of the DynamoDB table to use for the global lock.").PlaceHolder("NAME").String(),
}
config.trace = &[]bool{false}[0]
flags.Version(version)
flags.UsageTemplate(customUsageTemplate)
flags.HelpFlag.Short('h')
kingpin.MustParse(flags.Parse(cliArgs))
filter := &logutils.LevelFilter{
Levels: []logutils.LogLevel{"TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", "NONE"},
MinLevel: logutils.LogLevel(defaultLogLevel),
Writer: os.Stdout,
}
if *config.logLevel == "" {
*config.logLevel = os.Getenv("EC2_ADDRESS_LOG_LEVEL")
}
if *config.logLevel == "" {
*config.logLevel = defaultLogLevel
}
*config.logLevel = strings.ToUpper(*config.logLevel)
if *config.logLevel == "TRACE" {
*config.trace = true
}
if *config.trace || *config.logLevel == "DEBUG" || os.Getenv("EC2_ADDRESS_DEBUG") == "1" {
*config.debug = true
}
if *config.debug {
*config.logLevel = "DEBUG"
}
if *config.trace {
*config.logLevel = "TRACE"
}
validLogLevel := false
for _, level := range filter.Levels {
if string(level) == *config.logLevel {
validLogLevel = true
break
}
}
if !validLogLevel {
log.SetFlags(0)
log.Fatalf(
"An invalid log level %q given. See %s --help.\n",
*config.logLevel,
os.Args[0],
)
}
filter.SetMinLevel(logutils.LogLevel(*config.logLevel))
log.SetOutput(filter)
log.Printf("[DEBUG] Command line arguments: %v\n", cliArgs)
if *config.allocationID != "" && *config.publicIP != "" {
log.SetFlags(0)
log.Fatalf(
"Flags `--allocation-id' and `--public-ip' cannot be used together. See %s --help.\n",
os.Args[0],
)
}
if *config.exclusiveLock && *config.randomSleep {
log.SetFlags(0)
log.Fatalf(
"Flags `--exclusive-lock' and `--random-sleep' cannot be used together. See %s --help.\n",
os.Args[0],
)
}
if *config.dryRun {
log.Println("[INFO] Dry run mode has been enabled.")
}
if *config.force {
log.Println("[WARN] Force mode has been enabled.")
}
if *config.exclusiveLock {
log.Println("[INFO] Exclusive lock mode has been enabled.")
}
if *config.randomSleep {
log.Println("[INFO] Random sleep mode has been enabled.")
}
if *config.disableRecovery {
log.Println("[WARN] Recovery mode has been disabled.")
}
if *config.allowReassociation {
log.Println("[WARN] Elastic IP reassociation has been enabled.")
}
log.Println("[INFO] Attempting to associate an Elastic IP address...")
startTime := time.Now().UTC()
var filters []*ec2.Filter
recoverAssociation := false
switch {
case *config.allocationID != "":
log.Printf("[INFO] Using allocation ID: %s\n", *config.allocationID)
filters = append(filters, ec2Filter("allocation-id", *config.allocationID))
case *config.publicIP != "":
log.Printf("[INFO] Using public IP address: %s\n", *config.publicIP)
filters = append(filters, ec2Filter("public-ip", *config.publicIP))
default:
if *config.instanceTag != "" {
log.Printf("[INFO] Using EC2 instance tag: %s\n", *config.instanceTag)
}
if *config.s3Bucket != "" {
log.Printf("[INFO] Using S3 bucket: %s\n", *config.s3Bucket)
}
recoverAssociation = true
}
log.Printf("[DEBUG] Filters: %v\n", filters)
client, err := NewClient(
WithDebug(*config.debug),
WithTrace(*config.trace),
WithProfile(*config.profile),
WithRegion(*config.region),
)
if err != nil {
log.Fatalf("[FATAL] Unable to create an EC2 client: %s\n", err)
}
region, err := client.Region()
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve current AWS region: %s\n", err)
}
log.Printf("[INFO] Using region: %s\n", region)
instance, err := client.DescribeCurrentInstance(&DescribeCurrentInstanceInput{})
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve EC2 instance: %s\n", err)
}
log.Printf("[INFO] Current EC2 instance: %s\n", aws.StringValue(instance.InstanceId))
var bucket *Bucket
var recovered bool
var source string
sources := []string{"EC2"}
if *config.s3Bucket != "" {
sources = append(sources, "S3")
}
if !*config.disableRecovery && recoverAssociation {
var allocationID string
log.Printf(
"[INFO] Attempting to recover Elastic IP address association... (sources: %s)\n",
strings.Join(sources, ", "),
)
source = "EC2"
instanceTag := fmt.Sprintf("%s:%s", defaultTagPrefix, defaultInstanceTagName)
if *config.instanceTag != "" {
instanceTag = *config.instanceTag
}
for key, value := range instance.Tags(&TagsInput{}) {
log.Printf("[TRACE] Tag: key = %s, value = %s\n", key, value)
if key == instanceTag && value != "" {
log.Printf("[DEBUG] Found allocation ID using EC2 instance tag: %v\n", value)
allocationID = value
break
}
}
if allocationID == "" && *config.s3Bucket != "" {
source = "S3"
objectPrefix := fmt.Sprintf("%s/%s", region, aws.StringValue(instance.InstanceId))
bucketInput := &DescribeBucketInput{
Bucket: *config.s3Bucket,
}
bucket, err = client.DescribeBucket(bucketInput)
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve S3 bucket: %s\n", err)
}
objectsListInput := &ListCurrentObjectsInput{
Prefix: objectPrefix,
}
objects, objectsErr := bucket.ListCurrentObjects(objectsListInput)
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve objects from S3 bucket: %s\n", objectsErr)
}
if len(objects) > 0 {
var parts []string
for _, object := range objects {
parts = strings.Split(aws.StringValue(object.Key), "/")
if len(parts) > 2 && parts[len(parts)-1] != "" {
value := parts[len(parts)-1]
log.Printf("[DEBUG] Found allocation ID using S3 bucket: %v\n", value)
allocationID = value
break
}
}
}
}
if allocationID != "" {
log.Printf(
"[INFO] Recovered Elastic IP association (allocation ID: %s). (source: %s)\n",
allocationID,
source,
)
recovered = true
filters = append(filters, ec2Filter("allocation-id", allocationID))
} else {
log.Printf(
"[WARN] Unable to recover Elastic IP association. (sources: %s)\n",
strings.Join(sources, ", "),
)
recovered = false
}
}
if len(filters) < 1 && (*config.publicIP == "" || *config.allocationID == "") {
if len(*config.ec2Filters) > 0 {
log.Printf("[INFO] Using filters: %s\n", strings.Join(*config.ec2Filters, " "))
filters = buildEC2Filters(*config.ec2Filters)
} else {
var tagName string
tagValue := defaultAddressTagValue
if *config.addressTagValue != "" {
tagValue = *config.addressTagValue
}
if *config.addressTagName != "" {
tagName = fmt.Sprintf("tag:%s", *config.addressTagName)
log.Printf("[INFO] Using filter: %s:%s\n", tagName, tagValue)
} else {
tagName = fmt.Sprintf("tag:%s:%s", defaultTagPrefix, defaultAddressTagName)
log.Printf("[INFO] Using default filter: %s:%s\n", tagName, tagValue)
}
filters = append(filters, ec2Filter(tagName, tagValue))
}
}
onlyAvailable := !(*config.allowReassociation || recovered)
if *config.publicIP != "" || *config.allocationID != "" {
onlyAvailable = false
}
log.Printf("[DEBUG] Using only available addresses: %t\n", onlyAvailable)
addressInput := &DescribeAddressInput{
Filters: filters,
OnlyAvailable: onlyAvailable,
}
if *config.include != "" {
includeList := strings.Split(*config.include, ",")
log.Printf("[INFO] Using include list: %s\n", strings.Join(includeList, " "))
addressInput.Include = includeList
}
if *config.exclude != "" {
excludeList := strings.Split(*config.exclude, ",")
log.Printf("[INFO] Using exclude list: %s\n", strings.Join(excludeList, " "))
addressInput.Exclude = excludeList
}
var (
sleep time.Duration
lock *Lock
)
if !recovered {
if *config.randomSleep {
sleep = randomDuration()
log.Printf("[INFO] Sleeping for %s...\n", sleep)
time.Sleep(sleep)
}
if *config.exclusiveLock {
lockInput := &CreateLockInput{
Name: defaultLockName,
Table: defaultLockTableName,
LeaseDuration: 5 * time.Second,
RenewInterval: 1 * time.Second,
}
lock, err = client.CreateLock(lockInput)
if err != nil {
log.Fatalf("[FATAL] Unable to create lock: %s\n", err)
}
defer func() {
ok, errLock := lock.Locked()
if errLock != nil {
log.Printf("[ERROR] Unable to retrieve lock status: %s\n", errLock)
}
if ok {
_, errLock = lock.Unlock()
if errLock != nil {
log.Printf("[ERROR] Unable to release lock: %s\n", errLock)
}
}
}()
log.Println("[INFO] Attempting to acquire lock...")
tryLockInput := &TryLockInput{
Attempts: defaultLockAttempts,
}
_, err = lock.TryLock(tryLockInput)
if err != nil {
log.Fatalf("[FATAL] Unable to acquire lock: %s\n", err)
}
log.Println("[INFO] Lock acquired successfully.")
}
}
var address *Address
availableAddressRetry := defaultAddressRetry
err = retryFunction(5, 250*time.Millisecond, func() error {
availableAddress, availableErr := client.DescribeAddress(addressInput)
if availableErr == nil && addressInput.OnlyAvailable && !availableAddress.Available() {
if availableAddressRetry--; availableAddressRetry > 0 {
return errors.New("address already in use")
}
}
if availableErr != nil {
return availableErr
}
address = availableAddress
return nil
})
if err != nil {
log.Printf("[ERROR] Searching for an Elastic IP address failed: %s\n", err)
return
}
log.Printf("[INFO] Found Elastic IP address: %s\n", aws.StringValue(address.PublicIp))
isAlreadyAllocated := aws.StringValue(address.InstanceId) == aws.StringValue(instance.InstanceId)
if !address.Available() && isAlreadyAllocated {
log.Printf(
"[WARN] Elastic IP address %s already associated with current EC2 instance.\n",
aws.StringValue(address.PublicIp),
)
} else if !address.Available() && !*config.allowReassociation {
log.Printf(
"[ERROR] Elastic IP address %s already has association (%s) with EC2 instance: %s\n",
aws.StringValue(address.PublicIp),
aws.StringValue(address.AssociationId),
aws.StringValue(address.InstanceId),
)
return
}
if !isAlreadyAllocated || *config.force {
associationInput := &AssociateAddressInput{
InstanceID: aws.StringValue(instance.InstanceId),
DryRun: *config.dryRun,
}
address, err = address.Associate(associationInput)
if err != nil {
log.Printf(
"[ERROR] Unable to associate Elastic IP address %s with EC2 instance %s: %s\n",
aws.StringValue(address.PublicIp),
aws.StringValue(instance.InstanceId),
err,
)
return
}
successMessage := fmt.Sprintf(
"Successfully associated Elastic IP address %s (%s) with EC2 instance %s.",
aws.StringValue(address.PublicIp),
aws.StringValue(address.AllocationId),
aws.StringValue(instance.InstanceId),
)
if !*config.dryRun {
successMessage = fmt.Sprintf("%s\b: %s", successMessage, aws.StringValue(address.AssociationId))
}
log.Printf("[INFO] %s\n", successMessage)
if !recovered && *config.exclusiveLock {
_, err = lock.Unlock()
if err != nil {
log.Printf("[ERROR] Unable to release lock: %s\n", err)
} else {
log.Println("[INFO] Lock released successfully.")
}
}
if *config.disableRecovery {
log.Printf(
"[INFO] Updating recovery has been disabled. (sources %s)\n",
strings.Join(sources, ", "),
)
} else {
log.Printf(
"[INFO] Updating recovery with Elastic IP address %s (%s). (sources %s)\n",
aws.StringValue(address.PublicIp),
aws.StringValue(address.AllocationId),
strings.Join(sources, ", "),
)
for _, source := range sources {
switch source {
case "EC2":
instanceTag := fmt.Sprintf("%s:%s", defaultTagPrefix, defaultInstanceTagName)
if *config.instanceTag != "" {
instanceTag = *config.instanceTag
}
createTagsInput := &CreateTagsInput{
Tags: map[string]string{
instanceTag: aws.StringValue(address.AllocationId),
},
}
_, err = instance.CreateTags(createTagsInput)
if err != nil {
log.Fatalf("[FATAL] Unable to create EC2 instance tags: %s\n", err)
}
case "S3":
if bucket == nil {
bucketInput := &DescribeBucketInput{
Bucket: *config.s3Bucket,
}
bucket, err = client.DescribeBucket(bucketInput)
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve S3 bucket: %s\n", err)
}
}
objectKet := fmt.Sprintf(
"%s/%s/%s",
region,
aws.StringValue(instance.InstanceId),
aws.StringValue(address.AllocationId),
)
putObjectInput := &PutObjectInput{
Bucket: *config.s3Bucket,
Key: objectKet,
Body: bytes.NewReader([]byte("")),
}
_, err := bucket.PutObject(putObjectInput)
if err != nil {
log.Fatalf("[FATAL] Unable to add object to S3 bucket: %s\n", err)
}
}
}
}
}
var sleepMessage string
if *config.randomSleep && sleep > 0 {
sleepMessage = fmt.Sprintf(" (sleep time: %s)", sleep)
}
stopTime := time.Since(startTime)
log.Printf("[INFO] Elapsed time: %s%s\n", stopTime-sleep, sleepMessage)
}
var customUsageTemplate = `{{define "FormatUsage"}}
{{ if .Help }}
{{ .Help | Wrap 0 }}\
{{ end }}\
{{ end }}\
Usage: {{.App.Name }} [<FLAGS> ...]{{ template "FormatUsage" .App }}
{{ if .Context.Flags }}\
Flags:
{{ .Context.Flags | FlagsToTwoColumns | FormatTwoColumns }}
{{ end }}\
`
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"ec2:DescribeAddresses",
"ec2:DescribeInstances",
"ec2:AssociateAddress",
"ec2:CreateTags"
],
"Resource": "*"
},
{
"Effect": "Allow",
"Action": [
"s3:GetBucketLocation",
"s3:ListAllMyBuckets"
],
"Resource": "*"
},
{
"Effect": "Allow",
"Action": [
"s3:PutObject",
"s3:GetObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::<BUCKET NAME>",
"arn:aws:s3:::<BUCKET NAME>/*"
]
},
{
"Effect": "Allow",
"Action": [
"dynamodb:PutItem",
"dynamodb:DeleteItem",
"dynamodb:Scan"
],
"Resource": [
"arn:aws:dynamodb:<REGION>:<ACCOUNT ID>:table/<TABLE NAME>"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.