Skip to content

Instantly share code, notes, and snippets.

@adammw
Created June 28, 2022 07:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adammw/941d15c8b1730e3e89fc61138a6a4f24 to your computer and use it in GitHub Desktop.
Save adammw/941d15c8b1730e3e89fc61138a6a4f24 to your computer and use it in GitHub Desktop.
// Copyright 2022 Zendesk, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
// You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
package iampolicygen
import (
"fmt"
"github.com/go-logr/logr"
"github.com/thoas/go-funk"
"go/ast"
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/packages"
"sigs.k8s.io/controller-tools/pkg/markers"
"strings"
)
const (
awsServicePkg = "github.com/aws/aws-sdk-go/service"
awsClientPkg = "github.com/aws/aws-sdk-go/aws/client"
commentPrefix = "+iampolicygen:"
)
type iampolicygenMarker struct {
Action []string `marker:""`
// TODO: support other properties
}
func runAnalysisPass(log logr.Logger, runner func(*analysis.Pass) (interface{}, error), pkg *packages.Package) (interface{}, error) {
pass := &analysis.Pass{
Analyzer: &analysis.Analyzer{
Run: runner,
},
Fset: pkg.Fset,
Files: pkg.Syntax,
OtherFiles: pkg.OtherFiles,
IgnoredFiles: pkg.IgnoredFiles,
Pkg: pkg.Types,
TypesInfo: pkg.TypesInfo,
TypesSizes: pkg.TypesSizes,
Report: func(d analysis.Diagnostic) {
posn := pkg.Fset.Position(d.Pos)
log.V(1).Info(fmt.Sprintf("%s: %s\n", posn, d.Message))
},
// not supporting the fact or dependency interface since it's not required for us
}
return pass.Analyzer.Run(pass)
}
func DetectAWSCalls(log logr.Logger, patterns ...string) ([]string, error) {
pkgs, err := load(patterns...)
if err != nil {
return nil, err
}
detectedCalls := make(map[string]bool)
for _, pkg := range pkgs {
result, err := runAnalysisPass(log, runServiceCallAnalyzer, pkg)
if err != nil {
return nil, err
}
detectedCalls = funk.Union(detectedCalls, result.(map[string]bool)).(map[string]bool)
}
// extract any waiters
waiterServices := map[string]bool{}
waiters := map[string]bool{}
for detectedCall := range detectedCalls {
if !strings.Contains(detectedCall, "WaitUntil") {
continue
}
waiters[detectedCall] = true
// extract service
w := strings.SplitN(detectedCall, ":", 2)
service := w[0]
waiterServices[service] = true
// remove from detected calls
delete(detectedCalls, detectedCall)
}
// for each service, generate a map of waiters to their permissions
waiterCallMap := map[string]map[string]bool{}
for service := range waiterServices {
pkgs, err := load(fmt.Sprintf("%s/%s", awsServicePkg, service))
if err != nil {
return nil, err
}
for _, pkg := range pkgs {
result, err := runAnalysisPass(log, runWaiterCallAnalyzer, pkg)
if err != nil {
return nil, err
}
waiterCallMap = funk.Union(waiterCallMap, result).(map[string]map[string]bool)
}
}
// insert permissions for used waiters into detected calls
for waiter := range waiters {
perms, ok := waiterCallMap[waiter]
if !ok {
return nil, fmt.Errorf("could not find calls for waiter: %s", waiter)
}
detectedCalls = funk.Union(detectedCalls, perms).(map[string]bool)
}
return funk.Keys(detectedCalls).([]string), nil
}
// modified version of golang.org/x/tools/go/analysis/checker/internal/checker#load
func load(patterns ...string) ([]*packages.Package, error) {
if len(patterns) < 1 {
return nil, fmt.Errorf("need at least one package to search")
}
initial, err := packages.Load(&packages.Config{
Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles |
packages.NeedImports | packages.NeedTypes | packages.NeedTypesSizes |
packages.NeedSyntax | packages.NeedTypesInfo,
Tests: false,
}, patterns...)
if err != nil {
n := packages.PrintErrors(initial)
switch {
case n > 1:
err = fmt.Errorf("%d errors during loading", n)
case n == 1:
err = fmt.Errorf("error during loading")
case len(initial) == 0:
err = fmt.Errorf("%s matched no packages", strings.Join(patterns, " "))
}
}
return initial, err
}
func runServiceCallAnalyzer(pass *analysis.Pass) (interface{}, error) {
detectedCalls := map[string]bool{}
markerDef, err := markers.MakeDefinition("iampolicygen", markers.DescribesPackage, iampolicygenMarker{})
if err != nil {
return nil, err
}
for _, file := range pass.Files {
filename := pass.Fset.Position(file.Pos()).Filename
if file.Name.Name == "main" {
continue // skip testing package main
}
if strings.HasSuffix(filename, "_test.go") {
continue // skip test files
}
// extract from comments
for _, comment := range file.Comments {
if strings.HasPrefix(strings.TrimSpace(comment.Text()), commentPrefix) {
data, err := markerDef.Parse(comment.Text())
if err != nil {
return nil, fmt.Errorf("cannot parse comment marker: %w", err)
}
statement, ok := data.(iampolicygenMarker)
if !ok {
return nil, fmt.Errorf("cannot parse comment marker: expected iampolicygenMarker, got %T", data)
}
pass.Reportf(comment.Pos(), "comment - %+v", statement)
for _, action := range statement.Action {
detectedCalls[action] = true
}
}
}
// extract from method calls
detectedCalls = funk.Union(detectedCalls, extractServiceCalls(file, pass)).(map[string]bool)
}
return detectedCalls, nil
}
// returns result of map [waiter name] -> [permissions]:true
func runWaiterCallAnalyzer(pass *analysis.Pass) (interface{}, error) {
detectedCallMap := map[string]map[string]bool{}
for _, file := range pass.Files {
filename := pass.Fset.Position(file.Pos()).Filename
if !strings.HasSuffix(filename, "waiters.go") {
continue // skip all but waiters file
}
ast.Inspect(file, func(n ast.Node) bool {
funcDecl, ok := n.(*ast.FuncDecl)
if ok && strings.HasPrefix(funcDecl.Name.Name, "WaitUntil") {
actionName := funcDecl.Name.Name
actionName = strings.TrimSuffix(actionName, "WithContext")
actionName = strings.TrimSuffix(actionName, "Request")
actionName = strings.TrimSuffix(actionName, "Pages")
serviceCalls := extractServiceCalls(funcDecl.Body, pass)
key := fmt.Sprintf("%s:%s", pass.Pkg.Name(), actionName)
if _, ok := detectedCallMap[key]; !ok {
detectedCallMap[key] = map[string]bool{}
}
for serviceCall := range serviceCalls {
if !strings.Contains(serviceCall, "WaitUntil") {
detectedCallMap[key][serviceCall] = true
}
}
}
return n == file
})
}
return detectedCallMap, nil
}
func extractServiceCalls(n ast.Node, pass *analysis.Pass) map[string]bool {
detectedCalls := map[string]bool{}
ast.Inspect(n, func(n ast.Node) bool {
serviceName, actionName := extractServiceAction(n, pass)
if serviceName != "" && actionName != "" {
pass.Reportf(n.Pos(), "found call - %v:%v", serviceName, actionName)
detectedCalls[fmt.Sprintf("%s:%s", serviceName, actionName)] = true
privilege, ok := IAMDefinition[serviceName].Privileges[actionName]
if ok {
// add dependent actions
for _, resourceType := range privilege.ResourceTypes {
if len(resourceType.DependentActions) > 0 {
pass.Reportf(n.Pos(), "adding DependentActions %v from %v", resourceType.DependentActions, resourceType.ResourceType)
for _, action := range resourceType.DependentActions {
detectedCalls[action] = true
}
}
}
} else if !strings.HasPrefix(actionName, "WaitUntil") {
pass.Reportf(n.Pos(), "unknown privilege %s:%s ??", serviceName, actionName)
}
}
return true
})
return detectedCalls
}
func extractFunCallObject(n ast.Node, typesInfo *types.Info) types.Object {
call, ok := n.(*ast.CallExpr)
if !ok {
return nil
}
var fIdent *ast.Ident
switch fun := call.Fun.(type) {
case *ast.Ident:
fIdent = fun
case *ast.SelectorExpr:
fIdent = fun.Sel
default: // e.g. []byte() initializers
return nil // can safely ignore
}
return typesInfo.ObjectOf(fIdent)
}
func extractFunCallRecvType(fObj types.Object) types.Type {
sig, ok := fObj.Type().(*types.Signature)
if !ok {
return nil
}
return extractPointerElem(sig.Recv().Type())
}
func extractPointerElem(t types.Type) types.Type {
for {
if p, ok := t.(*types.Pointer); !ok {
break
} else {
t = p.Elem()
}
}
return t
}
func extractServiceAction(n ast.Node, pass *analysis.Pass) (serviceName, actionName string) {
fObj := extractFunCallObject(n, pass.TypesInfo)
if fObj == nil || !isAwsServiceCall(fObj, pass) {
return "", ""
}
serviceName = fObj.Pkg().Name()
serviceName = strings.TrimSuffix(serviceName, "iface")
if serviceName == "resourcegroupstaggingapi" {
// thank AWS for their consistency - https://docs.aws.amazon.com/resourcegroupstagging/latest/APIReference/overview.html
serviceName = "tag"
}
actionName = fObj.Name()
actionName = strings.TrimSuffix(actionName, "WithContext")
actionName = strings.TrimSuffix(actionName, "Request")
actionName = strings.TrimSuffix(actionName, "Pages")
return serviceName, actionName
}
func isAwsServiceCall(fObj types.Object, pass *analysis.Pass) bool {
// filter only to function calls to AWS service packages
if fObj == nil ||
fObj.Pkg() == nil ||
!strings.HasPrefix(fObj.Pkg().Path(), awsServicePkg) ||
fObj.Name() == "New" {
return false
}
// extract function call receiver - expecting a named type
fRecvT := extractFunCallRecvType(fObj)
namedFRecv, ok := fRecvT.(*types.Named)
if fRecvT == nil || !ok {
return false
}
// filter out interface callers that aren't the AWS service iface APIs
_, isInterface := namedFRecv.Underlying().(*types.Interface)
if isInterface &&
(!strings.HasSuffix(namedFRecv.Obj().Pkg().Name(), "iface") || !strings.HasSuffix(namedFRecv.Obj().Name(), "API")) {
return false
}
// filter out non-interface callers that aren't the AWS service structs
if !isInterface && !isAwsClientEmbedded(namedFRecv) {
pass.Reportf(fObj.Pos(), "detected %v call on non-AWS client receiver - %v - ignoring", fObj, namedFRecv)
return false
}
return true
}
func isAwsClientEmbedded(t types.Type) bool {
var structT *types.Struct
var ok bool
// search for underlying struct
for {
if structT, ok = t.(*types.Struct); ok {
break
}
t = t.Underlying()
}
// search all fields for a named type:
// github.com/aws/aws-sdk-go/aws/client.Client
for i := 0; i < structT.NumFields(); i++ {
field := structT.Field(i)
if !field.Embedded() {
continue
}
fType := extractPointerElem(field.Type())
namedFType, ok := fType.(*types.Named)
if ok &&
namedFType.Obj().Name() == "Client" &&
namedFType.Obj().Pkg().Path() == awsClientPkg {
return true
}
}
return false
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment