Skip to content

Instantly share code, notes, and snippets.

@Merovius
Created September 29, 2022 07:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Merovius/ca9e9199f8f46fea63f99744b61f7ea7 to your computer and use it in GitHub Desktop.
Save Merovius/ca9e9199f8f46fea63f99744b61f7ea7 to your computer and use it in GitHub Desktop.
package callsrc
// Package callsrc helps enforcing that a function is only called in certain contexts.
package callsrc
import (
"io"
"os"
"runtime"
"strings"
)
const (
// Init allows a call from init()
Init = (1 << iota)
// PkgScope allows a call from package-scope variable initializers
PkgScope
// MainPkg allows a call originating from package main
MainPkg
// MainFn allows a call from main.main. Implies MainPkg
MainFn
// TestMain allows a call form TestMain
TestMain
// TestFunc allows a call from a Test*, Bench* or Fuzz* function. Implies
// TestMain
TestFunc
)
func has(flags, f int) bool {
return flags&f != 0
}
// Allow asserts that a call happens in a particular context. skip is the
// number of frames to skip, with 0 identifying the caller of Allow. If the
// call is not allowed in the actual context, a message is printed and the
// program exits.
func Allow(skip, flags int) {
if skip < 0 {
io.WriteString(os.Stderr, "skip must not be negative\n")
os.Exit(1)
}
f := functions(skip, 3)
// f[0] is the restricted function
// f[1] is the caller of the restricted function
// f[2] is its caller, i.e. the testing or runtime package in common contexts
if f[1].full == "main.main" && has(flags, MainFn) {
return
}
if f[1].pkg == "main" && has(flags, MainPkg) {
return
}
if f[1].name == "init" && has(flags, PkgScope) {
return
}
if strings.HasPrefix(f[1].name, "init.") && has(flags, Init) {
return
}
if len(f) < 3 {
io.WriteString(os.Stderr, f[0].full+" must not be called from "+f[1].full+"\n")
os.Exit(1)
}
if f[1].name == "TestMain" && f[2].pkg == "testing" && has(flags, TestMain) {
return
}
if f[2].pkg == "testing" && has(flags, TestFunc) {
return
}
if f[1].name == "TestMain" && f[2].pkg == "testing" && has(flags, TestMain) {
// TODO: Technically, this isn't enough. e.g. testing.AllocsPerRun
// takes a callback and this callback would be matched here but isn't a
// Test*, Bench* or Fuzz* function.
return
}
io.WriteString(os.Stderr, f[0].full+" must not be called from "+f[1].full+"\n")
os.Exit(1)
}
type function struct {
full string
pkg string
name string
}
func functions(skip, n int) []function {
pc := make([]uintptr, 1024)
pc = pc[:runtime.Callers(3+skip, pc)]
frames := runtime.CallersFrames(pc)
var out []function
for n > 0 {
f, ok := frames.Next()
if !ok {
break
}
p, n, ok := strings.Cut(f.Function, ".")
if ok {
out = append(out, function{f.Function, p, n})
} else {
out = append(out, function{f.Function, "", f.Function})
}
}
return out
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment