Skip to content

Instantly share code, notes, and snippets.

@jespercockx
Created February 28, 2020 14:27
Show Gist options
  • Save jespercockx/aaebd78c95e01ce0017992c768a8a971 to your computer and use it in GitHub Desktop.
Save jespercockx/aaebd78c95e01ce0017992c768a8a971 to your computer and use it in GitHub Desktop.
open import Agda.Primitive
open import Data.Bool.Base
open import Data.Nat.Base
open import Data.List.Base
open import Data.Product using (_×_; _,_; proj₁; proj₂)
import Data.String.Base as String
open import Data.Unit.Base
open import Function using (id; _∘_)
open import Relation.Nullary
open import Relation.Binary.PropositionalEquality.Core
open import Reflection
open import Reflection.Argument using (unArg)
open import Reflection.Term
import Reflection.Name as Name
variable
ℓ : Level
A B C : Set ℓ
record Calls (A : Set ℓ) : Set ℓ where
constructor recurse-on
inductive
field
calls : List (A × Calls A)
open Calls public
pattern call-on x = (x , recurse-on [])
{-# TERMINATING #-}
extract-calls : Name → Term → List Term
extract-calls f (var x args) = concatMap (extract-calls f ∘ unArg) args
extract-calls f (con c args) = concatMap (extract-calls f ∘ unArg) args
extract-calls f u@(def g args) =
(if does (f Name.≟ g) then [ u ] else []) ++
concatMap (extract-calls f ∘ unArg) args
extract-calls f (lam v t) = [] -- recursive calls under binders not supported
extract-calls f (pat-lam cs args) = [] -- recursive calls in pattern lambda not supported
extract-calls f (pi a b) = [] -- recursive calls in pi type not supported
extract-calls f (sort (set t)) = extract-calls f t
extract-calls f (sort (lit n)) = []
extract-calls f (sort unknown) = []
extract-calls f (lit l) = []
extract-calls f (meta x args) = concatMap (extract-calls f ∘ unArg) args
extract-calls f unknown = []
trace-calls : Name → Name → Clause → Clause
trace-calls f cg-f (absurd-clause ps) = absurd-clause ps
trace-calls f cg-f (clause ps t) = clause ps (con (quote recurse-on) [ vArg (foldr _`∷_ `[] (map rec (extract-calls f t))) ])
where
rec : Term → Term
rec (def f args) = con (quote _,_) (vArg (def f args) ∷ vArg (def cg-f args) ∷ [])
rec _ = unknown
`[] : Term
`[] = con (quote List.[]) []
_`∷_ : Term → Term → Term
x `∷ xs = con (quote List._∷_) (vArg x ∷ vArg xs ∷ [])
{-# TERMINATING #-}
make-traced-type : Type → TC Type
make-traced-type t = reduce t >>= λ where
(Π[ s ∶ a ] b) → do
b' ← extendContext a (make-traced-type b)
return (Π[ s ∶ a ] b')
_ → return (def (quote Calls) [ vArg t ])
make-traced : Name → Name → TC ⊤
make-traced f cg-f = do
function cls ← getDefinition f
where _ → typeError (strErr "not a defined function: " ∷ nameErr f ∷ [])
getType f >>= make-traced-type >>= declareDef (vArg cg-f)
defineFun cg-f (map (trace-calls f cg-f) cls)
-- Example call graph: naive Fibonacci function
fib : ℕ → ℕ
fib zero = zero
fib (suc zero) = 1
fib (suc (suc n)) = fib (suc n) + fib n
unquoteDecl cg-fib = make-traced (quote fib) cg-fib
-- cg-fib : ℕ → CallGraph ℕ
-- cg-fib zero .calls = []
-- cg-fib (suc zero) .calls = []
-- cg-fib (suc (suc n)) .calls = (suc n , cg-fib (suc n)) ∷ (n , cg-fib n) ∷ []
_ : cg-fib 5 ≡ recurse-on
((3 , recurse-on
((2 ,
recurse-on
((1 , recurse-on (call-on 1 ∷ call-on 0 ∷ [])) ∷ call-on 1 ∷ []))
∷ (1 , recurse-on (call-on 1 ∷ call-on 0 ∷ [])) ∷ []))
(2 , recurse-on
((1 , recurse-on (call-on 1 ∷ call-on 0 ∷ [])) ∷ call-on 1 ∷ []))
∷ [])
_ = refl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment