Skip to content

Instantly share code, notes, and snippets.

@davidkelliott
Created July 13, 2022 13:26
Show Gist options
  • Save davidkelliott/50eb619dceba5e6ab3613062002831eb to your computer and use it in GitHub Desktop.
Save davidkelliott/50eb619dceba5e6ab3613062002831eb to your computer and use it in GitHub Desktop.
package main
import (
"context"
"encoding/json"
"fmt"
// "strings"
"log"
"strconv"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"
"github.com/aws/aws-sdk-go-v2/service/sts"
)
var dryRun bool = true
func getDefaultVpcId(client *ec2.Client) string {
log.Printf("Retrieving default VPC ID")
output, err := client.DescribeVpcs(context.TODO(), &ec2.DescribeVpcsInput{
Filters: []types.Filter{
{
Name: aws.String("is-default"),
Values: []string{"true"},
},
},
})
if err != nil {
log.Print(err)
}
var vpcId string
for _, object := range output.Vpcs {
vpcId = aws.ToString(object.VpcId)
log.Printf("VPC ID: %s", vpcId)
log.Printf("Default: %s", strconv.FormatBool(*object.IsDefault))
}
return vpcId
}
func getDefaultInternetGatewayId(client *ec2.Client, vpcId string) string {
output, err := client.DescribeInternetGateways(context.TODO(), &ec2.DescribeInternetGatewaysInput{
Filters: []types.Filter{
{
Name: aws.String("attachment.vpc-id"),
Values: []string{vpcId},
},
},
})
if err != nil {
log.Print(err)
}
var gatewayId string
for _, object := range output.InternetGateways {
gatewayId = aws.ToString(object.InternetGatewayId)
}
log.Printf("Retrieved Gateway ID: %s", gatewayId)
return gatewayId
}
func detachInternetGateway(client *ec2.Client, gatewayId, vpcId *string) {
log.Printf("Detaching Internet Gateway: %s", *gatewayId)
_, err := client.DetachInternetGateway(context.TODO(), &ec2.DetachInternetGatewayInput{
InternetGatewayId: gatewayId,
VpcId: vpcId,
DryRun: &dryRun,
})
if err != nil {
log.Print(err)
} else {
log.Printf("Gateway ID: %s detached", *gatewayId)
}
}
func deleteInternetGateway(client *ec2.Client, gatewayId *string) {
log.Printf("Deleting Internet Gateway: %s", *gatewayId)
_, err := client.DeleteInternetGateway(context.TODO(), &ec2.DeleteInternetGatewayInput{
InternetGatewayId: gatewayId,
DryRun: &dryRun,
})
if err != nil {
log.Print(err)
} else {
log.Printf("Gateway ID: %s deleted", *gatewayId)
}
}
func deleteVpc(client *ec2.Client, vpcId *string) {
log.Printf("Deleting VPC: %s", *vpcId)
_, err := client.DeleteVpc(context.TODO(), &ec2.DeleteVpcInput{
VpcId: vpcId,
DryRun: &dryRun,
})
if err != nil {
log.Print(err)
} else {
log.Printf("VPC ID: %s deleted", *vpcId)
}
}
func deleteInternetGatewayAndVpc(client *ec2.Client) {
// Get the default VPC
defaultVpcId := getDefaultVpcId(client)
if defaultVpcId != "" {
// Get the internet gateway
defaultInternetGateway := getDefaultInternetGatewayId(client, defaultVpcId)
if defaultInternetGateway != "" {
// Detach the internet gateway
detachInternetGateway(client, &defaultInternetGateway, &defaultVpcId)
// Delete the internet gateway
deleteInternetGateway(client, &defaultInternetGateway)
}
// Delete the default vpc
deleteVpc(client, &defaultVpcId)
}
}
func getSecretsManagerSecret(cfg aws.Config, secretName string) string {
client := secretsmanager.NewFromConfig(cfg)
input := &secretsmanager.GetSecretValueInput{
SecretId: aws.String(secretName),
VersionStage: aws.String("AWSCURRENT"),
}
result, err := client.GetSecretValue(context.TODO(), input)
if err != nil {
log.Print(err)
}
return *result.SecretString
}
func getMPAccounts(cfg aws.Config) map[string]string {
accounts := make(map[string]string)
// Get accounts secret
environments := getSecretsManagerSecret(cfg, "environment_management")
var allAccounts map[string]interface{}
json.Unmarshal([]byte(environments), &allAccounts)
for _, record := range allAccounts {
if rec, ok := record.(map[string]interface{}); ok {
for key, val := range rec {
//if strings.Contains(key, "core-vpc-test") {
accounts[key] = val.(string)
//}
}
}
}
return accounts
}
func getAssumeRoleCfg(cfg aws.Config, roleARN string) aws.Config {
newCfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
log.Print(err)
}
stsClient := sts.NewFromConfig(cfg)
provider := stscreds.NewAssumeRoleProvider(stsClient, roleARN)
newCfg.Credentials = aws.NewCredentialsCache(provider)
return newCfg
}
func main() {
// Load the Shared AWS Configuration (~/.aws/config)
cfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
log.Print(err)
}
accounts := getMPAccounts(cfg)
for accountName, accountId := range accounts {
log.Printf("Account: %s: %s", accountName, accountId)
accountCfg := getAssumeRoleCfg(cfg, fmt.Sprintf("arn:aws:iam::%s:role/ModernisationPlatformAccess", accountId))
client := ec2.NewFromConfig(accountCfg)
deleteInternetGatewayAndVpc(client)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment