Skip to content

Instantly share code, notes, and snippets.

@titpetric
Created November 29, 2022 16:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save titpetric/3c3aa2d9badea464e3074a86b2cafa34 to your computer and use it in GitHub Desktop.
Save titpetric/3c3aa2d9badea464e3074a86b2cafa34 to your computer and use it in GitHub Desktop.
Count imports (detect import pollution)
package main
import (
"errors"
"fmt"
"go/parser"
"go/token"
"log"
"os"
"path/filepath"
"runtime/debug"
"sort"
"strings"
)
func listFiles() ([]string, error) {
files := []string{}
addFile := func(name string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// skip test files
if strings.HasSuffix(name, "_test.go") {
return nil
}
if strings.HasSuffix(name, ".go") {
files = append(files, name)
}
return nil
}
err := filepath.Walk(".", addFile)
if err != nil {
return nil, err
}
return files, nil
}
type importGrouping struct {
imports map[string]int
}
func newImportGrouping() (*importGrouping, error) {
return &importGrouping{
imports: make(map[string]int, 0),
}, nil
}
func (group *importGrouping) add(name string) {
group.imports[name]++
}
func (group *importGrouping) print() {
keys := make([]string, 0, len(group.imports))
for key := range group.imports {
keys = append(keys, key)
}
sort.SliceStable(keys, func(i, j int) bool {
return group.imports[keys[i]] > group.imports[keys[j]]
})
for idx, key := range keys {
fmt.Printf("#%d\t%s (%d)\n", idx+1, key, group.imports[key])
}
}
func (group *importGrouping) collectImports(name string, packagePath string) error {
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, name, nil, parser.ParseComments)
if err != nil {
return err
}
for _, i := range node.Imports {
importName := i.Path.Value
// skip self-referencing
if strings.Contains(importName, packagePath) {
continue
}
// skip stdlib (no dot in name)
if !strings.Contains(importName, ".") {
continue
}
group.add(importName)
}
return nil
}
func start() error {
// read package name from build info
info, ok := debug.ReadBuildInfo()
if !ok {
return errors.New("Failed to read build info")
}
packagePath := info.Main.Path
fmt.Println("Package path:", packagePath)
// create import grouper obj
group, err := newImportGrouping()
if err != nil {
return err
}
files, err := listFiles()
if err != nil {
return err
}
for _, file := range files {
if err := group.collectImports(file, packagePath); err != nil {
return err
}
}
group.print()
return nil
}
func main() {
if err := start(); err != nil {
log.Fatal(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment