-
-
Save adammw/941d15c8b1730e3e89fc61138a6a4f24 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright 2022 Zendesk, Inc. | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. | |
package iampolicygen | |
import ( | |
"fmt" | |
"github.com/go-logr/logr" | |
"github.com/thoas/go-funk" | |
"go/ast" | |
"go/types" | |
"golang.org/x/tools/go/analysis" | |
"golang.org/x/tools/go/packages" | |
"sigs.k8s.io/controller-tools/pkg/markers" | |
"strings" | |
) | |
const ( | |
awsServicePkg = "github.com/aws/aws-sdk-go/service" | |
awsClientPkg = "github.com/aws/aws-sdk-go/aws/client" | |
commentPrefix = "+iampolicygen:" | |
) | |
type iampolicygenMarker struct { | |
Action []string `marker:""` | |
// TODO: support other properties | |
} | |
func runAnalysisPass(log logr.Logger, runner func(*analysis.Pass) (interface{}, error), pkg *packages.Package) (interface{}, error) { | |
pass := &analysis.Pass{ | |
Analyzer: &analysis.Analyzer{ | |
Run: runner, | |
}, | |
Fset: pkg.Fset, | |
Files: pkg.Syntax, | |
OtherFiles: pkg.OtherFiles, | |
IgnoredFiles: pkg.IgnoredFiles, | |
Pkg: pkg.Types, | |
TypesInfo: pkg.TypesInfo, | |
TypesSizes: pkg.TypesSizes, | |
Report: func(d analysis.Diagnostic) { | |
posn := pkg.Fset.Position(d.Pos) | |
log.V(1).Info(fmt.Sprintf("%s: %s\n", posn, d.Message)) | |
}, | |
// not supporting the fact or dependency interface since it's not required for us | |
} | |
return pass.Analyzer.Run(pass) | |
} | |
func DetectAWSCalls(log logr.Logger, patterns ...string) ([]string, error) { | |
pkgs, err := load(patterns...) | |
if err != nil { | |
return nil, err | |
} | |
detectedCalls := make(map[string]bool) | |
for _, pkg := range pkgs { | |
result, err := runAnalysisPass(log, runServiceCallAnalyzer, pkg) | |
if err != nil { | |
return nil, err | |
} | |
detectedCalls = funk.Union(detectedCalls, result.(map[string]bool)).(map[string]bool) | |
} | |
// extract any waiters | |
waiterServices := map[string]bool{} | |
waiters := map[string]bool{} | |
for detectedCall := range detectedCalls { | |
if !strings.Contains(detectedCall, "WaitUntil") { | |
continue | |
} | |
waiters[detectedCall] = true | |
// extract service | |
w := strings.SplitN(detectedCall, ":", 2) | |
service := w[0] | |
waiterServices[service] = true | |
// remove from detected calls | |
delete(detectedCalls, detectedCall) | |
} | |
// for each service, generate a map of waiters to their permissions | |
waiterCallMap := map[string]map[string]bool{} | |
for service := range waiterServices { | |
pkgs, err := load(fmt.Sprintf("%s/%s", awsServicePkg, service)) | |
if err != nil { | |
return nil, err | |
} | |
for _, pkg := range pkgs { | |
result, err := runAnalysisPass(log, runWaiterCallAnalyzer, pkg) | |
if err != nil { | |
return nil, err | |
} | |
waiterCallMap = funk.Union(waiterCallMap, result).(map[string]map[string]bool) | |
} | |
} | |
// insert permissions for used waiters into detected calls | |
for waiter := range waiters { | |
perms, ok := waiterCallMap[waiter] | |
if !ok { | |
return nil, fmt.Errorf("could not find calls for waiter: %s", waiter) | |
} | |
detectedCalls = funk.Union(detectedCalls, perms).(map[string]bool) | |
} | |
return funk.Keys(detectedCalls).([]string), nil | |
} | |
// modified version of golang.org/x/tools/go/analysis/checker/internal/checker#load | |
func load(patterns ...string) ([]*packages.Package, error) { | |
if len(patterns) < 1 { | |
return nil, fmt.Errorf("need at least one package to search") | |
} | |
initial, err := packages.Load(&packages.Config{ | |
Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | | |
packages.NeedImports | packages.NeedTypes | packages.NeedTypesSizes | | |
packages.NeedSyntax | packages.NeedTypesInfo, | |
Tests: false, | |
}, patterns...) | |
if err != nil { | |
n := packages.PrintErrors(initial) | |
switch { | |
case n > 1: | |
err = fmt.Errorf("%d errors during loading", n) | |
case n == 1: | |
err = fmt.Errorf("error during loading") | |
case len(initial) == 0: | |
err = fmt.Errorf("%s matched no packages", strings.Join(patterns, " ")) | |
} | |
} | |
return initial, err | |
} | |
func runServiceCallAnalyzer(pass *analysis.Pass) (interface{}, error) { | |
detectedCalls := map[string]bool{} | |
markerDef, err := markers.MakeDefinition("iampolicygen", markers.DescribesPackage, iampolicygenMarker{}) | |
if err != nil { | |
return nil, err | |
} | |
for _, file := range pass.Files { | |
filename := pass.Fset.Position(file.Pos()).Filename | |
if file.Name.Name == "main" { | |
continue // skip testing package main | |
} | |
if strings.HasSuffix(filename, "_test.go") { | |
continue // skip test files | |
} | |
// extract from comments | |
for _, comment := range file.Comments { | |
if strings.HasPrefix(strings.TrimSpace(comment.Text()), commentPrefix) { | |
data, err := markerDef.Parse(comment.Text()) | |
if err != nil { | |
return nil, fmt.Errorf("cannot parse comment marker: %w", err) | |
} | |
statement, ok := data.(iampolicygenMarker) | |
if !ok { | |
return nil, fmt.Errorf("cannot parse comment marker: expected iampolicygenMarker, got %T", data) | |
} | |
pass.Reportf(comment.Pos(), "comment - %+v", statement) | |
for _, action := range statement.Action { | |
detectedCalls[action] = true | |
} | |
} | |
} | |
// extract from method calls | |
detectedCalls = funk.Union(detectedCalls, extractServiceCalls(file, pass)).(map[string]bool) | |
} | |
return detectedCalls, nil | |
} | |
// returns result of map [waiter name] -> [permissions]:true | |
func runWaiterCallAnalyzer(pass *analysis.Pass) (interface{}, error) { | |
detectedCallMap := map[string]map[string]bool{} | |
for _, file := range pass.Files { | |
filename := pass.Fset.Position(file.Pos()).Filename | |
if !strings.HasSuffix(filename, "waiters.go") { | |
continue // skip all but waiters file | |
} | |
ast.Inspect(file, func(n ast.Node) bool { | |
funcDecl, ok := n.(*ast.FuncDecl) | |
if ok && strings.HasPrefix(funcDecl.Name.Name, "WaitUntil") { | |
actionName := funcDecl.Name.Name | |
actionName = strings.TrimSuffix(actionName, "WithContext") | |
actionName = strings.TrimSuffix(actionName, "Request") | |
actionName = strings.TrimSuffix(actionName, "Pages") | |
serviceCalls := extractServiceCalls(funcDecl.Body, pass) | |
key := fmt.Sprintf("%s:%s", pass.Pkg.Name(), actionName) | |
if _, ok := detectedCallMap[key]; !ok { | |
detectedCallMap[key] = map[string]bool{} | |
} | |
for serviceCall := range serviceCalls { | |
if !strings.Contains(serviceCall, "WaitUntil") { | |
detectedCallMap[key][serviceCall] = true | |
} | |
} | |
} | |
return n == file | |
}) | |
} | |
return detectedCallMap, nil | |
} | |
func extractServiceCalls(n ast.Node, pass *analysis.Pass) map[string]bool { | |
detectedCalls := map[string]bool{} | |
ast.Inspect(n, func(n ast.Node) bool { | |
serviceName, actionName := extractServiceAction(n, pass) | |
if serviceName != "" && actionName != "" { | |
pass.Reportf(n.Pos(), "found call - %v:%v", serviceName, actionName) | |
detectedCalls[fmt.Sprintf("%s:%s", serviceName, actionName)] = true | |
privilege, ok := IAMDefinition[serviceName].Privileges[actionName] | |
if ok { | |
// add dependent actions | |
for _, resourceType := range privilege.ResourceTypes { | |
if len(resourceType.DependentActions) > 0 { | |
pass.Reportf(n.Pos(), "adding DependentActions %v from %v", resourceType.DependentActions, resourceType.ResourceType) | |
for _, action := range resourceType.DependentActions { | |
detectedCalls[action] = true | |
} | |
} | |
} | |
} else if !strings.HasPrefix(actionName, "WaitUntil") { | |
pass.Reportf(n.Pos(), "unknown privilege %s:%s ??", serviceName, actionName) | |
} | |
} | |
return true | |
}) | |
return detectedCalls | |
} | |
func extractFunCallObject(n ast.Node, typesInfo *types.Info) types.Object { | |
call, ok := n.(*ast.CallExpr) | |
if !ok { | |
return nil | |
} | |
var fIdent *ast.Ident | |
switch fun := call.Fun.(type) { | |
case *ast.Ident: | |
fIdent = fun | |
case *ast.SelectorExpr: | |
fIdent = fun.Sel | |
default: // e.g. []byte() initializers | |
return nil // can safely ignore | |
} | |
return typesInfo.ObjectOf(fIdent) | |
} | |
func extractFunCallRecvType(fObj types.Object) types.Type { | |
sig, ok := fObj.Type().(*types.Signature) | |
if !ok { | |
return nil | |
} | |
return extractPointerElem(sig.Recv().Type()) | |
} | |
func extractPointerElem(t types.Type) types.Type { | |
for { | |
if p, ok := t.(*types.Pointer); !ok { | |
break | |
} else { | |
t = p.Elem() | |
} | |
} | |
return t | |
} | |
func extractServiceAction(n ast.Node, pass *analysis.Pass) (serviceName, actionName string) { | |
fObj := extractFunCallObject(n, pass.TypesInfo) | |
if fObj == nil || !isAwsServiceCall(fObj, pass) { | |
return "", "" | |
} | |
serviceName = fObj.Pkg().Name() | |
serviceName = strings.TrimSuffix(serviceName, "iface") | |
if serviceName == "resourcegroupstaggingapi" { | |
// thank AWS for their consistency - https://docs.aws.amazon.com/resourcegroupstagging/latest/APIReference/overview.html | |
serviceName = "tag" | |
} | |
actionName = fObj.Name() | |
actionName = strings.TrimSuffix(actionName, "WithContext") | |
actionName = strings.TrimSuffix(actionName, "Request") | |
actionName = strings.TrimSuffix(actionName, "Pages") | |
return serviceName, actionName | |
} | |
func isAwsServiceCall(fObj types.Object, pass *analysis.Pass) bool { | |
// filter only to function calls to AWS service packages | |
if fObj == nil || | |
fObj.Pkg() == nil || | |
!strings.HasPrefix(fObj.Pkg().Path(), awsServicePkg) || | |
fObj.Name() == "New" { | |
return false | |
} | |
// extract function call receiver - expecting a named type | |
fRecvT := extractFunCallRecvType(fObj) | |
namedFRecv, ok := fRecvT.(*types.Named) | |
if fRecvT == nil || !ok { | |
return false | |
} | |
// filter out interface callers that aren't the AWS service iface APIs | |
_, isInterface := namedFRecv.Underlying().(*types.Interface) | |
if isInterface && | |
(!strings.HasSuffix(namedFRecv.Obj().Pkg().Name(), "iface") || !strings.HasSuffix(namedFRecv.Obj().Name(), "API")) { | |
return false | |
} | |
// filter out non-interface callers that aren't the AWS service structs | |
if !isInterface && !isAwsClientEmbedded(namedFRecv) { | |
pass.Reportf(fObj.Pos(), "detected %v call on non-AWS client receiver - %v - ignoring", fObj, namedFRecv) | |
return false | |
} | |
return true | |
} | |
func isAwsClientEmbedded(t types.Type) bool { | |
var structT *types.Struct | |
var ok bool | |
// search for underlying struct | |
for { | |
if structT, ok = t.(*types.Struct); ok { | |
break | |
} | |
t = t.Underlying() | |
} | |
// search all fields for a named type: | |
// github.com/aws/aws-sdk-go/aws/client.Client | |
for i := 0; i < structT.NumFields(); i++ { | |
field := structT.Field(i) | |
if !field.Embedded() { | |
continue | |
} | |
fType := extractPointerElem(field.Type()) | |
namedFType, ok := fType.(*types.Named) | |
if ok && | |
namedFType.Obj().Name() == "Client" && | |
namedFType.Obj().Pkg().Path() == awsClientPkg { | |
return true | |
} | |
} | |
return false | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment