Skip to content

Instantly share code, notes, and snippets.

@xeoncross
Forked from kwilczynski/dynamodb-setup.sh
Created November 20, 2019 17:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xeoncross/0ff98068359ac1242b621d3813334ce7 to your computer and use it in GitHub Desktop.
Save xeoncross/0ff98068359ac1242b621d3813334ce7 to your computer and use it in GitHub Desktop.
Something put together quickly to automatically associate Elastic IP address with the current EC2 instance.
aws dynamodb create-table --region <REGION> --profile <PROFILE> --table-name "<TABLE>" --attribute-definitions "AttributeName=key,AttributeType=S" --key-schema "AttributeName=key,KeyType=HASH" --billing-mode "PAY_PER_REQUEST"
aws dynamodb create-table --region <REGION> --profile <PROFILE> --table-name "<TABLE>" --attribute-definitions "AttributeName=key,AttributeType=S" --key-schema "AttributeName=key,KeyType=HASH" --billing-mode "PROVISIONED" --provisioned-throughput "ReadCapacityUnits=1,WriteCapacityUnits=1"
aws dynamodb update-time-to-live --region <REGION> --profile <PROFILE> --table-name "<TABLE>" --time-to-live-specification "Enabled=true,AttributeName=ttl"
aws dynamodb scan --region <REGION> --profile <PROFILE> --table-name "<TABLE>"
// ec2-address-association
// $ go get -v github.com/aws/aws-sdk-go/aws
// $ go get -v github.com/google/uuid
// $ go get -v github.com/hashicorp/logutils
// $ go get -v github.com/pkg/errors
// $ go get -v gopkg.in/alecthomas/kingpin.v2
// $ GOOS=linux go build -o ec2-address-association .
package main
import (
"bytes"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
"github.com/aws/aws-sdk-go/service/dynamodb/expression"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/google/uuid"
"github.com/hashicorp/logutils"
"github.com/pkg/errors"
kingpin "gopkg.in/alecthomas/kingpin.v2"
)
const (
version = "0.1.0"
defaultLogLevel = "INFO"
defaultTagPrefix = "ec2:address"
defaultInstanceTagName = "allocation-id"
defaultAddressTagName = "range"
defaultAddressTagValue = "true"
defaultAddressRetry = 5
defaultLockTableName = "ec2-address-range"
defaultLockName = "allocation-lock"
defaultLockAttempts = 30
)
var defaultLockTTL = 24 * time.Hour
type filterAddressFunc func(*ec2.Address) bool
type Address struct {
*ec2.Address
*ec2.EC2
}
func (a *Address) Available() bool {
return a.AssociationId == nil
}
type AssociateAddressInput struct {
InstanceID string
DryRun bool
}
func (a *Address) Associate(input *AssociateAddressInput) (*Address, error) {
associateParams := &ec2.AssociateAddressInput{
AllocationId: a.AllocationId,
InstanceId: aws.String(input.InstanceID),
DryRun: aws.Bool(input.DryRun),
}
log.Printf("[TRACE] Associate address request: %v\n", associateParams)
const (
awsErrorCode = "InvalidInstanceID"
awsDryRunCode = "DryRunOperation"
)
var associationID string
err := retryFunction(5, 250*time.Millisecond, func() error {
resp, err := a.AssociateAddress(associateParams)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
log.Printf(
"[DEBUG] EC2 instance %s not ready. Retrying...\n",
input.InstanceID,
)
return err
}
if ok && awsErr.Code() == awsDryRunCode {
log.Println("[DEBUG] Dry run completed successfully.")
return nil
}
return retryStop{err}
}
associationID = aws.StringValue(resp.AssociationId)
return nil
})
if err != nil {
return nil, errors.Wrap(err, "unable to associate address")
}
describeParams := &ec2.DescribeAddressesInput{
Filters: []*ec2.Filter{
ec2Filter("association-id", associationID),
},
}
if input.DryRun {
describeParams.Filters = []*ec2.Filter{
ec2Filter("allocation-id", aws.StringValue(a.AllocationId)),
}
}
log.Printf("[TRACE] Describe address request: %v\n", describeParams)
resp, err := a.DescribeAddresses(describeParams)
if err != nil {
return nil, errors.Wrap(err, "unable to describe address")
}
return &Address{
resp.Addresses[0],
a.EC2,
}, nil
}
type Instance struct {
*ec2.Instance
*ec2.EC2
}
type TagsInput struct{}
func (i *Instance) Tags(_ *TagsInput) map[string]string {
return buildTags(i.Instance.Tags)
}
type CreateTagsInput struct {
Tags map[string]string
}
func (i *Instance) CreateTags(input *CreateTagsInput) (*Instance, error) {
params := &ec2.CreateTagsInput{
Resources: []*string{
i.InstanceId,
},
Tags: buildEC2Tags(input.Tags),
}
log.Printf("[TRACE] Create tags request: %v\n", params)
const awsErrorCode = ".NotFound"
err := retryFunction(5, 250*time.Millisecond, func() error {
_, err := i.EC2.CreateTags(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && strings.Contains(awsErr.Code(), awsErrorCode) {
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return nil, errors.Wrap(err, "unable to create tags")
}
return i, nil
}
type Bucket struct {
*s3.Bucket
*s3.S3
}
type ListCurrentObjectsInput struct {
Prefix string
}
func (b *Bucket) ListCurrentObjects(input *ListCurrentObjectsInput) ([]*s3.Object, error) {
params := &s3.ListObjectsInput{
Bucket: b.Bucket.Name,
}
if input.Prefix != "" {
params.Prefix = aws.String(input.Prefix)
}
log.Printf("[TRACE] List objects request: %v\n", params)
const awsInternalError = "InternalError"
var objects []*s3.Object
err := retryFunction(5, 250*time.Millisecond, func() error {
resp, err := b.ListObjects(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsInternalError {
log.Println("[DEBUG] S3 internal error. Retrying...")
return err
}
return retryStop{err}
}
objects = resp.Contents
return nil
})
if err != nil {
return nil, errors.Wrap(err, "unable to list objects")
}
log.Printf("[DEBUG] Found object: %v\n", objects)
return objects, nil
}
type PutObjectInput struct {
Bucket string
Key string
Body io.ReadSeeker
}
func (b *Bucket) PutObject(input *PutObjectInput) (*Bucket, error) {
params := &s3.PutObjectInput{
Bucket: aws.String(input.Bucket),
Key: aws.String(input.Key),
ACL: aws.String("private"),
Body: input.Body,
}
log.Printf("[TRACE] Put object request: %v\n", params)
const awsInternalError = "InternalError"
err := retryFunction(5, 250*time.Millisecond, func() error {
_, err := b.S3.PutObject(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsInternalError {
log.Println("[DEBUG] S3 internal error. Retrying...")
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return nil, errors.Wrap(err, "unable to to put object")
}
return b, nil
}
type lock struct {
name string
identifier string
table string
renew bool
leaseDuration time.Duration
renewInterval time.Duration
lastError error
}
type Lock struct {
*lock
*dynamodb.DynamoDB
}
func (l *Lock) Lock() (*Lock, error) {
err := l.createLock()
if err != nil {
return nil, err
}
go l.renewLock()
return l, nil
}
func (l *Lock) Unlock() (*Lock, error) {
l.renew = false
err := l.removeLock()
if err != nil {
return nil, err
}
return l, nil
}
func (l *Lock) Locked() (bool, error) {
return l.checkLock()
}
func (l *Lock) Error() error {
return l.lock.lastError
}
type TryLockInput struct {
Attempts int
}
func (l *Lock) TryLock(input *TryLockInput) (*Lock, error) {
log.Printf("[DEBUG] Lock acquire attempts: %d\n", input.Attempts)
const awsErrorCode = "ConditionalCheckFailedException"
attempt := input.Attempts
err := retryFunction(input.Attempts, 250*time.Millisecond, func() error {
_, err := l.Lock()
if err != nil {
awsErr, ok := errors.Cause(err).(awserr.RequestFailure)
if ok && awsErr.Code() == awsErrorCode {
log.Printf(
"[DEBUG] Unable to acquire lock. Retrying... (%d/%d)\n",
attempt,
input.Attempts,
)
attempt--
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return nil, err
}
log.Printf("[DEBUG] Acquired lock: %s\n", l.lock.name)
return l, nil
}
func (l *Lock) createLock() error {
builder, err := expression.NewBuilder().WithCondition(
expression.Or(
expression.Name("key").NotEqual(expression.Value(l.lock.name)),
expression.Name("identifier").Equal(expression.Value(l.lock.identifier)),
expression.Name("expiry").LessThan(expression.Value(time.Now().UnixNano())),
),
).Build()
if err != nil {
return errors.Wrap(err, "unable to build expression")
}
t := time.Now()
item, err := dynamodbattribute.MarshalMap(map[string]interface{}{
"key": l.lock.name,
"identifier": l.lock.identifier,
"expiry": t.Add(l.lock.leaseDuration / time.Nanosecond).UnixNano(),
"ttl": t.Add(defaultLockTTL).Unix(),
})
if err != nil {
return errors.Wrap(err, "unable to marshal item")
}
params := &dynamodb.PutItemInput{
ExpressionAttributeNames: builder.Names(),
ExpressionAttributeValues: builder.Values(),
ConditionExpression: builder.Condition(),
TableName: aws.String(l.lock.table),
Item: item,
}
log.Printf("[TRACE] Put item request: %v\n", params)
const awsErrorCode = "ProvisionedThroughputExceededException"
err = retryFunction(5, 250*time.Millisecond, func() error {
_, err = l.DynamoDB.PutItem(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
log.Println("[DEBUG] Provisioned throughput exceeded. Retrying...")
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return errors.Wrap(err, "unable to put item")
}
return nil
}
func (l *Lock) removeLock() error {
builder, err := expression.NewBuilder().WithCondition(
expression.Or(
expression.Name("key").Equal(expression.Value(l.lock.name)),
expression.Name("identifier").Equal(expression.Value(l.lock.identifier)),
),
).Build()
if err != nil {
return errors.Wrap(err, "unable to build expression")
}
item, err := dynamodbattribute.MarshalMap(map[string]interface{}{
"key": l.lock.name,
})
if err != nil {
return errors.Wrap(err, "unable to marshal item")
}
params := &dynamodb.DeleteItemInput{
ExpressionAttributeNames: builder.Names(),
ExpressionAttributeValues: builder.Values(),
ConditionExpression: builder.Condition(),
TableName: aws.String(l.lock.table),
Key: item,
}
log.Printf("[TRACE] Delete item request: %v\n", params)
const awsErrorCode = "ProvisionedThroughputExceededException"
err = retryFunction(5, 250*time.Millisecond, func() error {
_, err = l.DynamoDB.DeleteItem(params)
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
log.Println("[DEBUG] Provisioned throughput exceeded. Retrying...")
return err
}
return retryStop{err}
}
return nil
})
if err != nil {
return errors.Wrap(err, "unable to delete item")
}
return nil
}
func (l *Lock) checkLock() (bool, error) {
builder, err := expression.NewBuilder().WithFilter(
expression.Or(
expression.Name("key").Equal(expression.Value(l.lock.name)),
expression.Name("identifier").Equal(expression.Value(l.lock.identifier)),
expression.Name("expiry").GreaterThan(expression.Value(time.Now().UnixNano())),
),
).WithProjection(
expression.NamesList(
expression.Name("key"),
expression.Name("identifier"),
expression.Name("expiry"),
),
).Build()
if err != nil {
return false, errors.Wrap(err, "unable to build expression")
}
params := &dynamodb.ScanInput{
ExpressionAttributeNames: builder.Names(),
ExpressionAttributeValues: builder.Values(),
FilterExpression: builder.Filter(),
ProjectionExpression: builder.Projection(),
TableName: aws.String(l.lock.table),
}
log.Printf("[TRACE] Scan items request: %v\n", params)
const awsErrorCode = "ProvisionedThroughputExceededException"
var count int64
err = retryFunction(5, 250*time.Millisecond, func() error {
resp, scanErr := l.DynamoDB.Scan(params)
if err != nil {
awsErr, ok := scanErr.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
log.Println("[DEBUG] Provisioned throughput exceeded. Retrying...")
return scanErr
}
return retryStop{scanErr}
}
count = *resp.Count
return nil
})
if err != nil {
return false, errors.Wrap(err, "unable to scan items")
}
return count > 0, l.lock.lastError
}
func (l *Lock) renewLock() {
l.lock.renew = true
for {
log.Printf("[DEBUG] Sleeping before renewing lease on lock: %s\n", l.lock.renewInterval)
time.Sleep(l.lock.renewInterval)
if !l.lock.renew {
break
}
log.Printf("[DEBUG] Renewing lease on lock: %s\n", l.lock.name)
err := l.createLock()
if err != nil {
log.Printf("[DEBUG] Unable to renew lease on lock: %s: %s\n", l.lock.name, err)
l.lock.lastError = err
break
}
}
log.Printf("[DEBUG] Stopping to renew lease on lock: %s\n", l.lock.name)
}
type ClientOpt func(*Client)
type Client struct {
debug bool
trace bool
profile string
region string
s3 *s3.S3
dynamoDB *dynamodb.DynamoDB
ec2 *ec2.EC2
ec2Metadata *ec2metadata.EC2Metadata
session *session.Session
}
func WithDebug(debug bool) ClientOpt {
return func(c *Client) {
c.debug = debug
}
}
func WithTrace(trace bool) ClientOpt {
return func(c *Client) {
c.trace = trace
}
}
func WithProfile(profile string) ClientOpt {
return func(c *Client) {
c.profile = profile
}
}
func WithRegion(region string) ClientOpt {
return func(c *Client) {
c.region = region
}
}
func NewClient(options ...ClientOpt) (*Client, error) {
c := &Client{}
for _, option := range options {
option(c)
}
config := aws.NewConfig()
if c.trace {
config.WithLogLevel(aws.LogDebugWithHTTPBody)
}
c.ec2Metadata = ec2metadata.New(session.Must(session.NewSession(
config.WithHTTPClient(&http.Client{
Timeout: 2 * time.Second,
}).WithMaxRetries(2),
)))
if !c.ec2Metadata.Available() {
return nil, errors.New("metadata service not available")
}
c.session = session.Must(session.NewSessionWithOptions(session.Options{
Profile: c.profile,
}))
c.session.Handlers.Retry.PushFront(func(r *request.Request) {
if r.IsErrorExpired() {
log.Println("[DEBUG] Credentials expired. Stop retrying...")
r.Retryable = aws.Bool(false)
}
})
if c.region == "" {
if c.region = aws.StringValue(c.session.Config.Region); c.region == "" {
metadataRegion, err := c.Region()
if err != nil {
return nil, errors.Wrap(err, "unable to retrieve region")
}
c.region = metadataRegion
}
}
config = config.WithRegion(c.region)
c.s3 = s3.New(c.session, config)
c.dynamoDB = dynamodb.New(c.session, config)
c.ec2 = ec2.New(c.session, config)
return c, nil
}
func (c *Client) Available() bool {
return c.ec2Metadata.Available()
}
func (c *Client) Region() (string, error) {
return c.ec2Metadata.Region()
}
func (c *Client) GetMetadata(path string) (string, error) {
return c.ec2Metadata.GetMetadata(path)
}
type DescribeCurrentInstanceInput struct{}
func (c *Client) DescribeCurrentInstance(_ *DescribeCurrentInstanceInput) (*Instance, error) {
instanceID, err := c.GetMetadata("instance-id")
if err != nil {
return nil, errors.Wrap(err, "unable to retrieve metadata")
}
params := &ec2.DescribeInstancesInput{
Filters: []*ec2.Filter{ec2Filter("instance-id", instanceID)},
}
log.Printf("[TRACE] Describe instance request: %v\n", params)
resp, err := c.ec2.DescribeInstances(params)
if err != nil {
return nil, errors.Wrap(err, "unable to describe instances")
}
instance := resp.Reservations[0].Instances[0]
log.Printf("[TRACE] Current EC2 instance: %v\n", instance)
return &Instance{
instance,
c.ec2,
}, nil
}
type DescribeAddressInput struct {
Filters []*ec2.Filter
Include []string
Exclude []string
OnlyAvailable bool
}
func (c *Client) DescribeAddress(input *DescribeAddressInput) (*Address, error) {
params := &ec2.DescribeAddressesInput{
Filters: input.Filters,
}
log.Printf("[TRACE] Describe address request: %v\n", params)
resp, err := c.ec2.DescribeAddresses(params)
if err != nil {
return nil, errors.Wrap(err, "unable to describe addresses")
}
addresses := resp.Addresses
if len(addresses) < 1 {
return nil, errors.New("no addresses found")
}
if input.OnlyAvailable {
addresses = filterAddresses(addresses, func(v *ec2.Address) bool {
return (&Address{v, c.ec2}).Available()
})
}
if len(input.Include) > 0 {
addresses = filterAddresses(addresses, func(v *ec2.Address) bool {
return addressInsideNetwork(aws.StringValue(v.PublicIp), input.Include)
})
}
if len(input.Exclude) > 0 {
addresses = filterAddresses(addresses, func(v *ec2.Address) bool {
return addressOutsideNetwork(aws.StringValue(v.PublicIp), input.Exclude)
})
}
log.Printf("[TRACE] Found addresses: %v\n", addresses)
if len(addresses) < 1 {
return nil, errors.New("no addresses found")
}
type tuple struct {
ip net.IP
address *ec2.Address
}
sortedAddresses := make([]*tuple, len(addresses))
for i, address := range addresses {
sortedAddresses[i] = &tuple{
ip: net.ParseIP(aws.StringValue(address.PublicIp)),
address: address,
}
}
sort.Slice(sortedAddresses, func(i, j int) bool {
return bytes.Compare(sortedAddresses[i].ip, sortedAddresses[j].ip) < 0
})
address := sortedAddresses[0].address
log.Printf("[DEBUG] Found address: %v\n", address)
return &Address{
address,
c.ec2,
}, nil
}
type DescribeBucketInput struct {
Bucket string
}
func (c *Client) DescribeBucket(input *DescribeBucketInput) (*Bucket, error) {
locationParams := &s3.GetBucketLocationInput{
Bucket: aws.String(input.Bucket),
}
log.Printf("[TRACE] Bucket location request: %v\n", locationParams)
resp, err := c.s3.GetBucketLocation(locationParams)
const awsErrorCode = "NoSuchBucket"
if err != nil {
awsErr, ok := err.(awserr.Error)
if ok && awsErr.Code() == awsErrorCode {
return nil, errors.New("bucket does not exist")
}
return nil, errors.Wrap(err, "unable to determine bucket location")
}
region := s3.NormalizeBucketLocation(aws.StringValue(resp.LocationConstraint))
log.Printf("[DEBUG] Found bucket region: %v\n", region)
c.s3 = s3.New(c.session, aws.NewConfig().WithRegion(region))
buckets, err := c.s3.ListBuckets(&s3.ListBucketsInput{})
if err != nil {
return nil, errors.Wrap(err, "unable to describe bucket")
}
var bucket *s3.Bucket
for _, b := range buckets.Buckets {
if aws.StringValue(b.Name) == input.Bucket {
bucket = b
break
}
}
return &Bucket{
&s3.Bucket{
Name: bucket.Name,
CreationDate: bucket.CreationDate,
},
c.s3,
}, nil
}
type CreateLockInput struct {
Name string
Table string
LeaseDuration time.Duration
RenewInterval time.Duration
}
func (c *Client) CreateLock(input *CreateLockInput) (*Lock, error) {
l := &lock{
name: input.Name,
table: input.Table,
leaseDuration: input.LeaseDuration,
renewInterval: input.RenewInterval,
}
uuid, err := uuid.NewRandom()
if err != nil {
return nil, errors.Wrap(err, "unable to create unique identifier")
}
l.identifier = uuid.String()
log.Printf("[DEBUG] Unique identifier: %s\n", l.identifier)
return &Lock{
l,
c.dynamoDB,
}, nil
}
func filterAddresses(addresses []*ec2.Address, f filterAddressFunc) (results []*ec2.Address) {
for _, address := range addresses {
if f(address) {
results = append(results, address)
}
}
return results
}
func filterNetwork(address net.IP, addresses []string, inverse bool) (found bool) {
for _, item := range addresses {
_, network, err := net.ParseCIDR(item)
if err == nil && network.Contains(address) {
found = true
break
}
if net.ParseIP(item).Equal(address) {
found = true
break
}
}
if inverse {
found = !found
}
return found
}
func addressInsideNetwork(address string, addresses []string) bool {
return filterNetwork(net.ParseIP(address), addresses, false)
}
func addressOutsideNetwork(address string, addresses []string) bool {
return filterNetwork(net.ParseIP(address), addresses, true)
}
func ec2Filter(name string, values ...string) *ec2.Filter {
return &ec2.Filter{
Name: aws.String(name),
Values: aws.StringSlice(values),
}
}
func buildEC2Filters(filters []string) []*ec2.Filter {
var pairs [][]string
for _, filter := range filters {
if strings.HasPrefix(filter, "Name=") {
var parts []string
for _, s := range strings.SplitN(filter, ",", 2) {
parts = append(parts, strings.Split(s, "=")...)
}
pairs = append(pairs, []string{parts[1], parts[3]})
} else {
for _, s := range strings.Split(filter, ",") {
pairs = append(pairs, strings.Split(s, "="))
}
}
}
var ec2Filters []*ec2.Filter
for _, pair := range pairs {
if len(pair) < 2 || pair[0] == "" || pair[1] == "" {
continue
}
ec2Filters = append(ec2Filters, ec2Filter(pair[0], strings.Split(pair[1], ",")...))
}
return ec2Filters
}
func ec2Tag(name, value string) *ec2.Tag {
return &ec2.Tag{
Key: aws.String(name),
Value: aws.String(value),
}
}
func buildEC2Tags(tags map[string]string) []*ec2.Tag {
ec2Tags := make([]*ec2.Tag, 0, len(tags))
const awsTagPrefix = "aws:"
for key, value := range tags {
if !strings.HasPrefix(key, awsTagPrefix) {
ec2Tags = append(ec2Tags, ec2Tag(key, value))
}
}
return ec2Tags
}
func buildTags(ec2Tags []*ec2.Tag) (tags map[string]string) {
tags = make(map[string]string)
const awsTagPrefix = "aws:"
if len(ec2Tags) > 0 {
var key string
for _, tag := range ec2Tags {
if key = aws.StringValue(tag.Key); !strings.HasPrefix(key, awsTagPrefix) {
tags[key] = aws.StringValue(tag.Value)
}
}
}
return tags
}
type retryStop struct {
error
}
func retryFunction(attempts int, sleep time.Duration, callbackFunc func() error) error {
err := callbackFunc()
if err != nil {
log.Printf("[DEBUG] Retrying attempt: %d\n", attempts)
s, ok := err.(retryStop)
if ok {
return s.error
}
if attempts--; attempts > 0 {
sleep = sleep + jitter(sleep)/2
log.Printf("[DEBUG] Sleeping before retrying: %s\n", sleep)
time.Sleep(sleep)
return retryFunction(attempts, 2*sleep, callbackFunc)
}
return err
}
return nil
}
func jitter(t time.Duration) time.Duration {
return time.Duration(rand.Int63n(int64(t)))
}
func randomDuration() time.Duration {
intervals := []int{1, 2, 3, 5, 7, 11}
index := rand.Int() % len(intervals)
sleep := time.Duration(intervals[index]) * time.Second
sleep = sleep + jitter(sleep)/2
log.Printf("[DEBUG] Duration: %v\n", sleep)
return sleep
}
type config struct {
debug *bool
trace *bool
logLevel *string
dryRun *bool
force *bool
randomSleep *bool
exclusiveLock *bool
disableRecovery *bool
allowReassociation *bool
profile *string
region *string
allocationID *string
publicIP *string
include *string
exclude *string
addressTagName *string
addressTagValue *string
ec2Filters *[]string
instanceTag *string
s3Bucket *string
lockName *string
lockTable *string
}
type filters []string
func (f *filters) Set(value string) error {
*f = append(*f, value)
return nil
}
func (f *filters) IsCumulative() bool {
return true
}
func (f *filters) String() string {
return fmt.Sprint([]string(*f))
}
func (f *filters) Type() string {
return fmt.Sprintf("%T", *f)
}
func Filters(s kingpin.Settings) (values *[]string) {
values = &[]string{}
s.SetValue((*filters)(values))
return values
}
func init() {
rand.Seed(time.Now().UnixNano())
}
func main() {
cliArgs := os.Args[1:]
envArgs := os.Getenv("EC2_ADDRESS_ARGUMENTS")
if len(cliArgs) < 1 && envArgs != "" {
ss := strings.Split(envArgs, ",")
for _, s := range ss {
cliArgs = append(cliArgs, strings.Split(s, " ")...)
}
}
flags := kingpin.New("ec2-address-association", "Automatically associate Elastic IP address with the current EC2 instance.")
config := &config{
debug: flags.Flag("debug", "Print diagnostic information.").Bool(),
logLevel: flags.Flag("log-level", `Set the log level (defaults to "INFO").`).PlaceHolder("LEVEL").String(),
dryRun: flags.Flag("dry-run", "A dry-run only. No Elastic IP association will be made.").Bool(),
force: flags.Flag("force", "Force Elastic IP address association.").Bool(),
exclusiveLock: flags.Flag("exclusive-lock", "Use DynamoDB to set an exclusive lock when searching for Elastic IP address.").Bool(),
randomSleep: flags.Flag("random-sleep", "Sleep for a random interval before searching for Elastic IP address.").Bool(),
disableRecovery: flags.Flag("disable-recovery", "Disable Elastic IP address recovery.").Bool(),
allowReassociation: flags.Flag("allow-reassociation", "Allow to reassociate currently associated Elastic IP address.").Bool(),
profile: flags.Flag("profile", "Specify name of the AWS profile to use.").Short('p').String(),
region: flags.Flag("region", "Specify name of the AWS region to use.").Short('r').String(),
allocationID: flags.Flag("allocation-id", "Specify the allocation ID to use.").PlaceHolder("ID").String(),
publicIP: flags.Flag("public-ip", "Specify the public IP address to use.").PlaceHolder("ADDRESS").String(),
include: flags.Flag("include", "A comma-separated list of IP addresses and networks to include.").PlaceHolder("IP,NETWORK").String(),
exclude: flags.Flag("exclude", "A comma-separated list of IP addresses and networks to exclude.").PlaceHolder("IP,NETWORK").String(),
addressTagName: flags.Flag("tag-name", "Name of the tag to use when searching for an Elastic IP address.").PlaceHolder("NAME").String(),
addressTagValue: flags.Flag("tag-value", "Value of the tag to use when searching for an Elastic IP address.").PlaceHolder("VALUE").String(),
ec2Filters: Filters(flags.Flag("filters", "A valid one or more filters (see AWS CLI for details).").PlaceHolder("FILTER")),
instanceTag: flags.Flag("instance-tag", "Name of the EC2 instance tag to use for Elastic IP recovery.").PlaceHolder("NAME").String(),
s3Bucket: flags.Flag("s3-bucket", "S3 bucket name to use for Elastic IP recovery.").PlaceHolder("NAME").String(),
lockName: flags.Flag("lock-name", "Name of the global lock to use.").PlaceHolder("NAME").String(),
lockTable: flags.Flag("lock-table", "Name of the DynamoDB table to use for the global lock.").PlaceHolder("NAME").String(),
}
config.trace = &[]bool{false}[0]
flags.Version(version)
flags.UsageTemplate(customUsageTemplate)
flags.HelpFlag.Short('h')
kingpin.MustParse(flags.Parse(cliArgs))
filter := &logutils.LevelFilter{
Levels: []logutils.LogLevel{"TRACE", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", "NONE"},
MinLevel: logutils.LogLevel(defaultLogLevel),
Writer: os.Stdout,
}
if *config.logLevel == "" {
*config.logLevel = os.Getenv("EC2_ADDRESS_LOG_LEVEL")
}
if *config.logLevel == "" {
*config.logLevel = defaultLogLevel
}
*config.logLevel = strings.ToUpper(*config.logLevel)
if *config.logLevel == "TRACE" {
*config.trace = true
}
if *config.trace || *config.logLevel == "DEBUG" || os.Getenv("EC2_ADDRESS_DEBUG") == "1" {
*config.debug = true
}
if *config.debug {
*config.logLevel = "DEBUG"
}
if *config.trace {
*config.logLevel = "TRACE"
}
validLogLevel := false
for _, level := range filter.Levels {
if string(level) == *config.logLevel {
validLogLevel = true
break
}
}
if !validLogLevel {
log.SetFlags(0)
log.Fatalf(
"An invalid log level %q given. See %s --help.\n",
*config.logLevel,
os.Args[0],
)
}
filter.SetMinLevel(logutils.LogLevel(*config.logLevel))
log.SetOutput(filter)
log.Printf("[DEBUG] Command line arguments: %v\n", cliArgs)
if *config.allocationID != "" && *config.publicIP != "" {
log.SetFlags(0)
log.Fatalf(
"Flags `--allocation-id' and `--public-ip' cannot be used together. See %s --help.\n",
os.Args[0],
)
}
if *config.exclusiveLock && *config.randomSleep {
log.SetFlags(0)
log.Fatalf(
"Flags `--exclusive-lock' and `--random-sleep' cannot be used together. See %s --help.\n",
os.Args[0],
)
}
if *config.dryRun {
log.Println("[INFO] Dry run mode has been enabled.")
}
if *config.force {
log.Println("[WARN] Force mode has been enabled.")
}
if *config.exclusiveLock {
log.Println("[INFO] Exclusive lock mode has been enabled.")
}
if *config.randomSleep {
log.Println("[INFO] Random sleep mode has been enabled.")
}
if *config.disableRecovery {
log.Println("[WARN] Recovery mode has been disabled.")
}
if *config.allowReassociation {
log.Println("[WARN] Elastic IP reassociation has been enabled.")
}
log.Println("[INFO] Attempting to associate an Elastic IP address...")
startTime := time.Now().UTC()
var filters []*ec2.Filter
recoverAssociation := false
switch {
case *config.allocationID != "":
log.Printf("[INFO] Using allocation ID: %s\n", *config.allocationID)
filters = append(filters, ec2Filter("allocation-id", *config.allocationID))
case *config.publicIP != "":
log.Printf("[INFO] Using public IP address: %s\n", *config.publicIP)
filters = append(filters, ec2Filter("public-ip", *config.publicIP))
default:
if *config.instanceTag != "" {
log.Printf("[INFO] Using EC2 instance tag: %s\n", *config.instanceTag)
}
if *config.s3Bucket != "" {
log.Printf("[INFO] Using S3 bucket: %s\n", *config.s3Bucket)
}
recoverAssociation = true
}
log.Printf("[DEBUG] Filters: %v\n", filters)
client, err := NewClient(
WithDebug(*config.debug),
WithTrace(*config.trace),
WithProfile(*config.profile),
WithRegion(*config.region),
)
if err != nil {
log.Fatalf("[FATAL] Unable to create an EC2 client: %s\n", err)
}
region, err := client.Region()
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve current AWS region: %s\n", err)
}
log.Printf("[INFO] Using region: %s\n", region)
instance, err := client.DescribeCurrentInstance(&DescribeCurrentInstanceInput{})
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve EC2 instance: %s\n", err)
}
log.Printf("[INFO] Current EC2 instance: %s\n", aws.StringValue(instance.InstanceId))
var bucket *Bucket
var recovered bool
var source string
sources := []string{"EC2"}
if *config.s3Bucket != "" {
sources = append(sources, "S3")
}
if !*config.disableRecovery && recoverAssociation {
var allocationID string
log.Printf(
"[INFO] Attempting to recover Elastic IP address association... (sources: %s)\n",
strings.Join(sources, ", "),
)
source = "EC2"
instanceTag := fmt.Sprintf("%s:%s", defaultTagPrefix, defaultInstanceTagName)
if *config.instanceTag != "" {
instanceTag = *config.instanceTag
}
for key, value := range instance.Tags(&TagsInput{}) {
log.Printf("[TRACE] Tag: key = %s, value = %s\n", key, value)
if key == instanceTag && value != "" {
log.Printf("[DEBUG] Found allocation ID using EC2 instance tag: %v\n", value)
allocationID = value
break
}
}
if allocationID == "" && *config.s3Bucket != "" {
source = "S3"
objectPrefix := fmt.Sprintf("%s/%s", region, aws.StringValue(instance.InstanceId))
bucketInput := &DescribeBucketInput{
Bucket: *config.s3Bucket,
}
bucket, err = client.DescribeBucket(bucketInput)
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve S3 bucket: %s\n", err)
}
objectsListInput := &ListCurrentObjectsInput{
Prefix: objectPrefix,
}
objects, objectsErr := bucket.ListCurrentObjects(objectsListInput)
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve objects from S3 bucket: %s\n", objectsErr)
}
if len(objects) > 0 {
var parts []string
for _, object := range objects {
parts = strings.Split(aws.StringValue(object.Key), "/")
if len(parts) > 2 && parts[len(parts)-1] != "" {
value := parts[len(parts)-1]
log.Printf("[DEBUG] Found allocation ID using S3 bucket: %v\n", value)
allocationID = value
break
}
}
}
}
if allocationID != "" {
log.Printf(
"[INFO] Recovered Elastic IP association (allocation ID: %s). (source: %s)\n",
allocationID,
source,
)
recovered = true
filters = append(filters, ec2Filter("allocation-id", allocationID))
} else {
log.Printf(
"[WARN] Unable to recover Elastic IP association. (sources: %s)\n",
strings.Join(sources, ", "),
)
recovered = false
}
}
if len(filters) < 1 && (*config.publicIP == "" || *config.allocationID == "") {
if len(*config.ec2Filters) > 0 {
log.Printf("[INFO] Using filters: %s\n", strings.Join(*config.ec2Filters, " "))
filters = buildEC2Filters(*config.ec2Filters)
} else {
var tagName string
tagValue := defaultAddressTagValue
if *config.addressTagValue != "" {
tagValue = *config.addressTagValue
}
if *config.addressTagName != "" {
tagName = fmt.Sprintf("tag:%s", *config.addressTagName)
log.Printf("[INFO] Using filter: %s:%s\n", tagName, tagValue)
} else {
tagName = fmt.Sprintf("tag:%s:%s", defaultTagPrefix, defaultAddressTagName)
log.Printf("[INFO] Using default filter: %s:%s\n", tagName, tagValue)
}
filters = append(filters, ec2Filter(tagName, tagValue))
}
}
onlyAvailable := !(*config.allowReassociation || recovered)
if *config.publicIP != "" || *config.allocationID != "" {
onlyAvailable = false
}
log.Printf("[DEBUG] Using only available addresses: %t\n", onlyAvailable)
addressInput := &DescribeAddressInput{
Filters: filters,
OnlyAvailable: onlyAvailable,
}
if *config.include != "" {
includeList := strings.Split(*config.include, ",")
log.Printf("[INFO] Using include list: %s\n", strings.Join(includeList, " "))
addressInput.Include = includeList
}
if *config.exclude != "" {
excludeList := strings.Split(*config.exclude, ",")
log.Printf("[INFO] Using exclude list: %s\n", strings.Join(excludeList, " "))
addressInput.Exclude = excludeList
}
var (
sleep time.Duration
lock *Lock
)
if !recovered {
if *config.randomSleep {
sleep = randomDuration()
log.Printf("[INFO] Sleeping for %s...\n", sleep)
time.Sleep(sleep)
}
if *config.exclusiveLock {
lockInput := &CreateLockInput{
Name: defaultLockName,
Table: defaultLockTableName,
LeaseDuration: 5 * time.Second,
RenewInterval: 1 * time.Second,
}
lock, err = client.CreateLock(lockInput)
if err != nil {
log.Fatalf("[FATAL] Unable to create lock: %s\n", err)
}
defer func() {
ok, errLock := lock.Locked()
if errLock != nil {
log.Printf("[ERROR] Unable to retrieve lock status: %s\n", errLock)
}
if ok {
_, errLock = lock.Unlock()
if errLock != nil {
log.Printf("[ERROR] Unable to release lock: %s\n", errLock)
}
}
}()
log.Println("[INFO] Attempting to acquire lock...")
tryLockInput := &TryLockInput{
Attempts: defaultLockAttempts,
}
_, err = lock.TryLock(tryLockInput)
if err != nil {
log.Fatalf("[FATAL] Unable to acquire lock: %s\n", err)
}
log.Println("[INFO] Lock acquired successfully.")
}
}
var address *Address
availableAddressRetry := defaultAddressRetry
err = retryFunction(5, 250*time.Millisecond, func() error {
availableAddress, availableErr := client.DescribeAddress(addressInput)
if availableErr == nil && addressInput.OnlyAvailable && !availableAddress.Available() {
if availableAddressRetry--; availableAddressRetry > 0 {
return errors.New("address already in use")
}
}
if availableErr != nil {
return availableErr
}
address = availableAddress
return nil
})
if err != nil {
log.Printf("[ERROR] Searching for an Elastic IP address failed: %s\n", err)
return
}
log.Printf("[INFO] Found Elastic IP address: %s\n", aws.StringValue(address.PublicIp))
isAlreadyAllocated := aws.StringValue(address.InstanceId) == aws.StringValue(instance.InstanceId)
if !address.Available() && isAlreadyAllocated {
log.Printf(
"[WARN] Elastic IP address %s already associated with current EC2 instance.\n",
aws.StringValue(address.PublicIp),
)
} else if !address.Available() && !*config.allowReassociation {
log.Printf(
"[ERROR] Elastic IP address %s already has association (%s) with EC2 instance: %s\n",
aws.StringValue(address.PublicIp),
aws.StringValue(address.AssociationId),
aws.StringValue(address.InstanceId),
)
return
}
if !isAlreadyAllocated || *config.force {
associationInput := &AssociateAddressInput{
InstanceID: aws.StringValue(instance.InstanceId),
DryRun: *config.dryRun,
}
address, err = address.Associate(associationInput)
if err != nil {
log.Printf(
"[ERROR] Unable to associate Elastic IP address %s with EC2 instance %s: %s\n",
aws.StringValue(address.PublicIp),
aws.StringValue(instance.InstanceId),
err,
)
return
}
successMessage := fmt.Sprintf(
"Successfully associated Elastic IP address %s (%s) with EC2 instance %s.",
aws.StringValue(address.PublicIp),
aws.StringValue(address.AllocationId),
aws.StringValue(instance.InstanceId),
)
if !*config.dryRun {
successMessage = fmt.Sprintf("%s\b: %s", successMessage, aws.StringValue(address.AssociationId))
}
log.Printf("[INFO] %s\n", successMessage)
if !recovered && *config.exclusiveLock {
_, err = lock.Unlock()
if err != nil {
log.Printf("[ERROR] Unable to release lock: %s\n", err)
} else {
log.Println("[INFO] Lock released successfully.")
}
}
if *config.disableRecovery {
log.Printf(
"[INFO] Updating recovery has been disabled. (sources %s)\n",
strings.Join(sources, ", "),
)
} else {
log.Printf(
"[INFO] Updating recovery with Elastic IP address %s (%s). (sources %s)\n",
aws.StringValue(address.PublicIp),
aws.StringValue(address.AllocationId),
strings.Join(sources, ", "),
)
for _, source := range sources {
switch source {
case "EC2":
instanceTag := fmt.Sprintf("%s:%s", defaultTagPrefix, defaultInstanceTagName)
if *config.instanceTag != "" {
instanceTag = *config.instanceTag
}
createTagsInput := &CreateTagsInput{
Tags: map[string]string{
instanceTag: aws.StringValue(address.AllocationId),
},
}
_, err = instance.CreateTags(createTagsInput)
if err != nil {
log.Fatalf("[FATAL] Unable to create EC2 instance tags: %s\n", err)
}
case "S3":
if bucket == nil {
bucketInput := &DescribeBucketInput{
Bucket: *config.s3Bucket,
}
bucket, err = client.DescribeBucket(bucketInput)
if err != nil {
log.Fatalf("[FATAL] Unable to retrieve S3 bucket: %s\n", err)
}
}
objectKet := fmt.Sprintf(
"%s/%s/%s",
region,
aws.StringValue(instance.InstanceId),
aws.StringValue(address.AllocationId),
)
putObjectInput := &PutObjectInput{
Bucket: *config.s3Bucket,
Key: objectKet,
Body: bytes.NewReader([]byte("")),
}
_, err := bucket.PutObject(putObjectInput)
if err != nil {
log.Fatalf("[FATAL] Unable to add object to S3 bucket: %s\n", err)
}
}
}
}
}
var sleepMessage string
if *config.randomSleep && sleep > 0 {
sleepMessage = fmt.Sprintf(" (sleep time: %s)", sleep)
}
stopTime := time.Since(startTime)
log.Printf("[INFO] Elapsed time: %s%s\n", stopTime-sleep, sleepMessage)
}
var customUsageTemplate = `{{define "FormatUsage"}}
{{ if .Help }}
{{ .Help | Wrap 0 }}\
{{ end }}\
{{ end }}\
Usage: {{.App.Name }} [<FLAGS> ...]{{ template "FormatUsage" .App }}
{{ if .Context.Flags }}\
Flags:
{{ .Context.Flags | FlagsToTwoColumns | FormatTwoColumns }}
{{ end }}\
`
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"ec2:DescribeAddresses",
"ec2:DescribeInstances",
"ec2:AssociateAddress",
"ec2:CreateTags"
],
"Resource": "*"
},
{
"Effect": "Allow",
"Action": [
"s3:GetBucketLocation",
"s3:ListAllMyBuckets"
],
"Resource": "*"
},
{
"Effect": "Allow",
"Action": [
"s3:PutObject",
"s3:GetObject",
"s3:ListBucket"
],
"Resource": [
"arn:aws:s3:::<BUCKET NAME>",
"arn:aws:s3:::<BUCKET NAME>/*"
]
},
{
"Effect": "Allow",
"Action": [
"dynamodb:PutItem",
"dynamodb:DeleteItem",
"dynamodb:Scan"
],
"Resource": [
"arn:aws:dynamodb:<REGION>:<ACCOUNT ID>:table/<TABLE NAME>"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment