Skip to content

Instantly share code, notes, and snippets.

@ilopmar
Last active December 3, 2022 19:41
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ilopmar/6037796 to your computer and use it in GitHub Desktop.
Save ilopmar/6037796 to your computer and use it in GitHub Desktop.
Sample Groovy AST Transformation developed during the 4th Kaleidos Piweek for log the execution time of a method
package net.kaleidos.piweek
import org.codehaus.groovy.transform.GroovyASTTransformationClass
import java.lang.annotation.ElementType
import java.lang.annotation.Retention
import java.lang.annotation.RetentionPolicy
import java.lang.annotation.Target
@Retention(RetentionPolicy.SOURCE)
@Target([ElementType.METHOD])
@GroovyASTTransformationClass(["net.kaleidos.piweek.LogTimeASTTransformation"])
public @interface LogTime {
}
package net.kaleidos.piweek
import org.codehaus.groovy.ast.ASTNode
import org.codehaus.groovy.ast.ClassNode
import org.codehaus.groovy.ast.MethodNode
import org.codehaus.groovy.ast.builder.AstBuilder
import org.codehaus.groovy.ast.expr.ArgumentListExpression
import org.codehaus.groovy.ast.expr.BinaryExpression
import org.codehaus.groovy.ast.expr.ClassExpression
import org.codehaus.groovy.ast.expr.ConstantExpression
import org.codehaus.groovy.ast.expr.DeclarationExpression
import org.codehaus.groovy.ast.expr.MethodCallExpression
import org.codehaus.groovy.ast.expr.VariableExpression
import org.codehaus.groovy.ast.stmt.BlockStatement
import org.codehaus.groovy.ast.stmt.ExpressionStatement
import org.codehaus.groovy.ast.stmt.Statement
import org.codehaus.groovy.control.CompilePhase
import org.codehaus.groovy.control.SourceUnit
import org.codehaus.groovy.syntax.Token
import org.codehaus.groovy.syntax.Types
import org.codehaus.groovy.transform.AbstractASTTransformation
import org.codehaus.groovy.transform.GroovyASTTransformation
@GroovyASTTransformation(phase=CompilePhase.CANONICALIZATION)
public class LogTimeASTTransformation extends AbstractASTTransformation {
//http://groovy.codehaus.org/api/org/codehaus/groovy/ast/expr/Expression.html
public void visit(ASTNode[] nodes, SourceUnit sourceUnit) {
List methods = sourceUnit.getAST()?.getMethods()
// find all methods annotated with @WithLogging
methods.findAll { MethodNode method ->
method.getAnnotations(new ClassNode(LogTime))
}.each { MethodNode method ->
String startTimeVarName = randomVarName()
String endTimeVarName = randomVarName()
String totalTimeVarName = randomVarName()
//Statement initialTime = createSystemNanoTimeAst(startTimeVarName)
Statement initialTime = getInitialTimeFromString(startTimeVarName)
Statement endTime = createSystemNanoTimeAst(endTimeVarName)
Statement executionTime = createExecutionTimeAst(totalTimeVarName, startTimeVarName, endTimeVarName)
Statement printTime = createPrintlnAst(totalTimeVarName)
List existingStatements = method.getCode().getStatements()
existingStatements.add(0, initialTime)
existingStatements.add(endTime)
existingStatements.add(executionTime)
existingStatements.add(printTime)
}
}
private Statement createSystemNanoTimeAst(String var) {
return new ExpressionStatement(
new DeclarationExpression(
new VariableExpression(var),
Token.newSymbol(Types.EQUALS, 0, 0),
new MethodCallExpression(
new ClassExpression(new ClassNode(System)),
"nanoTime",
MethodCallExpression.NO_ARGUMENTS
)
)
)
}
private ExpressionStatement getInitialTimeFromString(String varName) {
List<ASTNode> nodes = new AstBuilder().buildFromString("def ${varName} = System.nanoTime()")
BlockStatement block = (BlockStatement) nodes.get(0)
List<Statement> statements = block.getStatements()
if (statements != null && statements.size() > 0) {
return new ExpressionStatement(statements.get(0).expression)
}
return null
}
private Statement createExecutionTimeAst(String totalTimeVarName, String startTimeVarName, String endTimeVarName) {
return new ExpressionStatement(
new DeclarationExpression(
new VariableExpression(totalTimeVarName),
Token.newSymbol(Types.EQUALS, 0, 0),
new BinaryExpression(
new BinaryExpression(
new VariableExpression(endTimeVarName),
Token.newSymbol(Types.MINUS, 0, 0),
new VariableExpression(startTimeVarName),
),
Token.newSymbol(Types.DIVIDE, 0, 0),
new ConstantExpression(new Integer(1000000))
)
)
)
}
private Statement createPrintlnAst(String message) {
return new ExpressionStatement(
new MethodCallExpression(
new VariableExpression("this"),
new ConstantExpression("println"),
new ArgumentListExpression(
new VariableExpression(message)
)
)
)
}
private String randomVarName() {
def pool = ['a'..'z','A'..'Z'].flatten()
Random rand = new Random(System.currentTimeMillis())
def chars = (0..10).collect { pool[rand.nextInt(pool.size())] }
def rndString = chars.join()
return "${rndString}_${new Date().time}"
}
}
import net.kaleidos.piweek.*
@LogTime
def hello() {
println "Hello world - start"
// Do some stuff
Object.sleep 1000
println "Hello world - end"
}
hello()
// This code prints
Hello world - start
Hello world - end
1009.788456
@jagedn
Copy link

jagedn commented May 2, 2020

Thanks for this example. I've adapted and translated to Asteroid.
Hope you like it

import asteroid.AbstractLocalTransformation
import asteroid.Phase
import com.trysplit.annotations.TimeTrace
import groovy.transform.CompileStatic
import org.codehaus.groovy.ast.AnnotationNode
import org.codehaus.groovy.ast.MethodNode
import org.codehaus.groovy.ast.stmt.Statement
import org.codehaus.groovy.control.CompilePhase
import org.codehaus.groovy.syntax.Types

import static asteroid.Expressions.*
import static asteroid.Statements.*

@CompileStatic
@Phase(CompilePhase.SEMANTIC_ANALYSIS)
class TimeTraceTransformationImpl extends AbstractLocalTransformation<TimeTrace, MethodNode> {

    @Override
    void doVisit(final AnnotationNode annotation, final MethodNode methodNode) {
        def oldCode = methodNode.code
        methodNode.code = blockS(
                declareStart(),
                traceStart(),
                tryCatchSBuilder().tryStmt(oldCode).finallyStmt(blockS(
                        declareEnd(),
                        calculateDiff(),
                        traceEnd(),
                        traceDiff()
                )).build()
        )
    }

    Statement declareStart() {
        stmt(varDeclarationX('nanoTimeTraceStart', Long, staticCallX(System, 'nanoTime')))
    }

    Statement traceStart() {
        stmt(callThisX("println", varX('nanoTimeTraceStart')))
    }

    Statement declareEnd() {
        stmt(varDeclarationX('nanoTimeTraceEnd', Long, staticCallX(System, 'nanoTime')))
    }

    Statement calculateDiff() {
        stmt(
                varDeclarationX('nanoTimeTraceDiff', Long,
                        binX(
                                binX('nanoTimeTraceEnd', Types.MINUS, 'nanoTimeTraceStart'),
                                Types.DIVIDE,
                                constX(1000000)
                        )
                )
        )
    }

    Statement traceEnd() {
        stmt(callThisX("println", varX('nanoTimeTraceEnd')))
    }

    Statement traceDiff() {
        stmt(callThisX("println", varX('nanoTimeTraceDiff')))
    }


}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment