Skip to content

Instantly share code, notes, and snippets.

@riyadparvez
Last active August 29, 2015 14:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save riyadparvez/a2c157b24579c6552466 to your computer and use it in GitHub Desktop.
Save riyadparvez/a2c157b24579c6552466 to your computer and use it in GitHub Desktop.
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/APSInt.h"
#include "clang/Driver/Options.h"
#include "clang/AST/AST.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/Expr.h"
#include "clang/AST/OperationKinds.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Frontend/ASTConsumers.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Lex/Lexer.h"
#include "clang/Rewrite/Core/Rewriter.h"
#include "clang/Rewrite/Frontend/FrontendActions.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Refactoring.h"
#include "clang/Tooling/Tooling.h"
#include <iostream>
#include <set>
using namespace std;
using namespace clang;
using namespace clang::driver;
using namespace clang::tooling;
using namespace llvm;
typedef struct {
int line_no;
} s;
//
Rewriter rewriter;
LangOptions languageOptions;
class S2EInstrumentationVisitor : public RecursiveASTVisitor<S2EInstrumentationVisitor> {
private:
std::set<Expr *> expressions;
std::set<VarDecl *> declarations;
FunctionDecl *currentFunctionDecl;
ASTContext *astContext; // used for getting additional AST info
public:
explicit S2EInstrumentationVisitor(CompilerInstance *CI)
: astContext(&(CI->getASTContext())) // initialize private members
{
rewriter.setSourceMgr(astContext->getSourceManager(), astContext->getLangOpts());
}
bool VisitDecl(Decl *Declaration)
{
if(FunctionDecl *funcDecl = dyn_cast<FunctionDecl>(Declaration)) {
// Function is defined in external translation unit
if (funcDecl->isExternC()) {
currentFunctionDecl = funcDecl;
} else {
currentFunctionDecl = NULL;
}
expressions.clear();
declarations.clear();
#if 0
llvm::outs() << "function name: " << funcDecl->getNameAsString() << " (return type = " << funcDecl->getResultType().getAsString() << ")\n";
unsigned paramCount = funcDecl->getNumParams();
llvm::outs() << "function param count: " << paramCount << "\n";
for(unsigned i = 0; i < paramCount; ++i) {
llvm::outs() << "-param #" << i << "\n";
const ParmVarDecl *currentParam = funcDecl->getParamDecl(i);
QualType userType = currentParam->getType();
while(userType->isPointerType()) {
llvm::outs() << "\tpointer to" << "\n";
userType = userType->getPointeeType();
}
if(userType.isConstQualified()) {
llvm::outs() << "\tconst" << "\n";
}
if(userType->isReferenceType()) {
llvm::outs() << "\treference to" << "\n";
}
userType = userType.getNonReferenceType().getUnqualifiedType();
llvm::outs() << "\t(type = " << userType.getAsString() << ", name = " << currentParam->getNameAsString() << ")\n";
}
llvm::outs() << "\n";
#endif
}
if(VarDecl *varDecl = dyn_cast<VarDecl>(Declaration)) {
if(!dyn_cast<ParmVarDecl>(Declaration)) {
//llvm::outs() << "variable type: " << varDecl->getType().getAsString() << ", variable name: " << varDecl->getNameAsString();
std::string name = varDecl->getNameAsString();
if(varDecl->hasInit()) {
Expr* varInit = varDecl->getInit();
if(varInit->isRValue()) {
// Works
#if 0
SourceRange varSourceRange = varInit->getSourceRange();
if(!varSourceRange.isValid())
return true;
CharSourceRange charSourceRange(varSourceRange, true);
StringRef sourceText = Lexer::getSourceText(charSourceRange, astContext->getSourceManager(), astContext->getLangOpts(), 0);
//llvm::outs() << ", initialization value: " << sourceText.str();
#endif
if (isa<CallExpr>(varInit)) {
// Works
CallExpr *Call = dyn_cast<CallExpr>(varInit);
Decl *D = Call->getCalleeDecl();
FunctionDecl *FD = Call->getDirectCallee();
std::string fname = FD->getNameInfo().getAsString();
if (FD->isExternC() && (fname == "malloc" || fname == "calloc")) {
declarations.insert(varDecl);
std::string str = "\ns2e_concretize_fork(" + name + ", " + "sizeof(" + name + "), " + "0" + ");\n";
//llvm::outs() << str;
InstrumentStmtAfter(varDecl, str);
}
}
}
}
}
}
return true;
}
// Get assigned variable
bool GetAssignedVar(Stmt *s, std::string& name) {
BinaryOperator *BinOp = dyn_cast<BinaryOperator>(s);
if (BinOp && BinOp->isAssignmentOp()) {
Expr *Lhs = BinOp->getLHS();
if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Lhs)) {
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
name = VD->getQualifiedNameAsString();
return true;
}
}
}
return false;
}
// Override Statements which includes expressions and more
bool VisitStmt(Stmt *s) {
#if 0
Stmt *TH = If->getThen();
// Add braces if needed to then clause
InstrumentStmt(TH);
Stmt *EL = If->getElse();
if (EL) {
// Add braces if needed to else clause
InstrumentStmt(EL);
}
} else if (isa<ForStmt>(s)) {
ForStmt *For = cast<ForStmt>(s);
Stmt *BODY = For->getBody();
//InstrumentStmt(BODY);
}
#endif
return true; // returning false aborts the traversal
}
virtual bool VisitCallExpr(CallExpr *CallE) {
Decl *D = CallE->getCalleeDecl();
FunctionDecl *FD = CallE->getDirectCallee();
std::string fname = FD->getNameInfo().getAsString();
if(fname == "func") {
//SourceLocation START = s->getLocStart();
/** Replace function **/
SourceRange range = CallE->getSourceRange();
SourceLocation source = range.getBegin();
rewriter.ReplaceText(source, "s2e");
llvm::outs() << "Begin: " << range.getBegin().printToString(rewriter.getSourceMgr())
<< " End: " << range.getEnd().printToString(rewriter.getSourceMgr()) << "\n";
/** Replace function argument **/
//#if 0
//for (CallExpr::const_arg_iterator it = CallE->arg_begin(), ite = CallE->arg_end(); it != ite; ++it) {
for (CallExpr::arg_iterator it = CallE->arg_begin(), ite = CallE->arg_end(); it != ite; ++it) {
Expr *arg = *it;
//SourceLocation source = arg->getExprLoc();
SourceRange r = arg->getSourceRange();
//SourceLocation begin = r.getBegin();
//SourceLocation end = r.getEnd();
SourceLocation begin(arg->getLocStart()), _e(arg->getLocEnd());
SourceLocation end(clang::Lexer::getLocForEndOfToken(_e, 0, rewriter.getSourceMgr(), languageOptions));
llvm::outs() << std::string(rewriter.getSourceMgr().getCharacterData(begin),
rewriter.getSourceMgr().getCharacterData(end) - rewriter.getSourceMgr().getCharacterData(begin))
<< "\n";
//rewriter.ReplaceText(source, "val");
//llvm::outs() << source.printToString(rewriter.getSourceMgr()) << "\n";
return true;
}
//#endif
}
return false;
}
virtual bool VisitBinaryOperator(BinaryOperator* BinaryOp) {
if (BinaryOp->isAssignmentOp() && isa<CallExpr>(BinaryOp->getRHS())) {
Expr *Lhs = BinaryOp->getLHS();
std::string name;
if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Lhs)) {
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
declarations.insert(VD);
name = VD->getQualifiedNameAsString();
}
}
CallExpr *CallE = cast<CallExpr>(BinaryOp->getRHS());
Decl *D = CallE->getCalleeDecl();
FunctionDecl *FD = CallE->getDirectCallee();
std::string fname = FD->getNameInfo().getAsString();
if (fname == "malloc" || fname == "calloc") {
//expressions.insert(BinaryOp);
//CallE->dumpPretty(*astContext);
CallE->dumpColor();
llvm::outs() << "\n";
std::string str = "\ns2e_concretize_fork(" + name + ", " + "sizeof(" + name + "), " + "0" + ");\n";
//llvm::outs() << str;
InstrumentStmtAfter(BinaryOp, str);
return true;
}
}
return false;
}
// Returns true if the condition was simple boolean
virtual bool VisitBooleanCondition(Expr *Cond) {
Expr *Var = Cond;
bool negation = false;
UnaryOperator *UnaryOp = dyn_cast<UnaryOperator>(Cond);
// Handles if (p) or if (!p) cases
if (UnaryOp && UnaryOp->getOpcode() == UO_Not) {
Var = UnaryOp->getSubExpr();
negation = true;
}
if (ImplicitCastExpr *Cast = dyn_cast<ImplicitCastExpr>(Var)) {
VarDecl *VarD;
std::string name;
Expr *OriginalCast = Cast->getSubExpr();
if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(OriginalCast)) {
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
VarD = VD;
name = VD->getQualifiedNameAsString();
}
}
llvm::outs() << name << "\n";
} else {
return true;
}
return true;
}
virtual bool VisitIfStmt(IfStmt* If) {
Expr *Cond = If->getCond();
if (VisitBooleanCondition(Cond)) {
return true;
}
#if 0
BinaryOperator *BinaryOp = dyn_cast<BinaryOperator>(Cond);
if (!BinaryOp->isEqualityOp() &&
!BinaryOp->isRelationalOp() &&
!BinaryOp->isComparisonOp()) {
return true;
}
Expr *Lhs = BinaryOp->getLHS();
std::string name;
VarDecl *VarD;
// a == 5, p == NULL, s.x == 0
if (ImplicitCastExpr *Cast = dyn_cast<ImplicitCastExpr>(Lhs)) {
Expr *OriginalCast = Cast->getSubExpr();
if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(OriginalCast)) {
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
VarD = VD;
name = VD->getQualifiedNameAsString();
}
} else if (MemberExpr *Member = dyn_cast<MemberExpr>(OriginalCast)) {
// Doesn't work
ValueDecl *ValueD = Member->getMemberDecl();
DeclarationNameInfo Name = Member->getMemberNameInfo();
name = Name.getAsString();
//llvm::outs() << "LHS is " << name << "\n";
if (VarDecl *VD = dyn_cast<VarDecl>(Member->getMemberDecl())) {
VarD = VD;
name = VD->getQualifiedNameAsString();
}
}
}
//
else if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Lhs)) {
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
VarD = VD;
name = VD->getQualifiedNameAsString();
}
}
Expr *Rhs = BinaryOp->getRHS();
if (isa<IntegerLiteral>(Rhs)) {
IntegerLiteral *IntLit = cast<IntegerLiteral>(Rhs);
llvm::APSInt Result;
Rhs->EvaluateAsInt(Result, *astContext);
//int64_t result = Result.getExtValue();
llvm::APSInt concreteValue;
BinaryOperatorKind opc = BinaryOp->getOpcode();
switch (opc) {
case (BO_LE) :
case (BO_GE) :
case (BO_EQ) : {
concreteValue = Result;
break;
}
case (BO_NE) : {
concreteValue = Result++;
break;
}
case (BO_LT) : {
concreteValue = Result--;
break;
}
case (BO_GT) : {
concreteValue = Result++;
break;
}
default : {
break;
}
}
llvm::outs() << "Set value " << concreteValue.toString(10) << "\n";
#if 0
if (BinaryOp->isEqualityOp()) {
} else if (BinaryOp->isComparisonOp()) {
} else if (BinaryOp->isRelationalOp()) {
}
#endif
//llvm::outs() << "Integer Literal: " << result.toString(10) << "\n";
//InstrumentStmtBefore(If);
} else if (isa<CharacterLiteral>(Rhs)) {
CharacterLiteral *CharLit = cast<CharacterLiteral>(Rhs);
} else if (Rhs->isNullPointerConstant(*astContext, Expr::NPC_ValueDependentIsNotNull)) {
// Works
//InstrumentStmtBefore(If);
}
#endif
return true;
}
void InstrumentStmtBefore(Stmt *s, const std::string& str) {
if (!isa<CompoundStmt>(s)) {
SourceLocation START = s->getLocStart();
rewriter.InsertText(START, str, true, true);
} else {
SourceLocation START = s->getSourceRange().getBegin();
rewriter.InsertText(START, str, true, true);
}
}
// InstrumentStmt - Add braces to line of code
void InstrumentStmtAfter(Stmt *s, const std::string& str) {
// Only perform if statement is not compound
if (!isa<CompoundStmt>(s)) {
#if 0
SourceLocation ST = s->getLocStart();
// Insert opening brace. Note the second true parameter to InsertText()
// says to indent. Sadly, it will indent to the line after the if, giving:
// if (expr)
// {
// stmt;
// }
rewriter.InsertText(ST, "{\n", true, true);
// Note Stmt::getLocEnd() returns the source location prior to the
// token at the end of the line. For instance, for:
// var = 123;
// ^---- getLocEnd() points here.
#endif
SourceLocation END = s->getLocEnd();
// MeasureTokenLength gets us past the last token, and adding 1 gets
// us past the ';'.
int offset = Lexer::MeasureTokenLength(END, rewriter.getSourceMgr(), rewriter.getLangOpts()) + 1;
SourceLocation END1 = END.getLocWithOffset(offset);
rewriter.InsertText(END1, str, true, true);
}
}
void InstrumentStmtAfter(Decl *d, const std::string& str) {
#if 0
SourceLocation ST = s->getLocStart();
// Insert opening brace. Note the second true parameter to InsertText()
// says to indent. Sadly, it will indent to the line after the if, giving:
// if (expr)
// {
// stmt;
// }
rewriter.InsertText(ST, "{\n", true, true);
// Note Stmt::getLocEnd() returns the source location prior to the
// token at the end of the line. For instance, for:
// var = 123;
// ^---- getLocEnd() points here.
#endif
SourceLocation END = d->getLocEnd();
// MeasureTokenLength gets us past the last token, and adding 1 gets
// us past the ';'.
int offset = Lexer::MeasureTokenLength(END, rewriter.getSourceMgr(), rewriter.getLangOpts()) + 1;
SourceLocation END1 = END.getLocWithOffset(offset);
rewriter.InsertText(END1, str, true, true);
}
#if 0
virtual bool VisitFunctionDecl(FunctionDecl *func) {
numFunctions++;
string funcName = func->getNameInfo().getName().getAsString();
if (funcName == "do_math") {
rewriter.ReplaceText(func->getLocation(), funcName.length(), "add5");
errs() << "** Rewrote function def: " << funcName << "\n";
}
return true;
}
virtual bool VisitStmt(Stmt *st) {
if (ReturnStmt *ret = dyn_cast<ReturnStmt>(st)) {
rewriter.ReplaceText(ret->getRetValue()->getLocStart(), 6, "val");
errs() << "** Rewrote ReturnStmt\n";
}
if (CallExpr *call = dyn_cast<CallExpr>(st)) {
rewriter.ReplaceText(call->getLocStart(), 7, "add5");
errs() << "** Rewrote function call\n";
}
return true;
}
// Override Binary Operator expressions
virtual Expr *VisitBinaryOperator(BinaryOperator *E) {
// Determine type of binary operator
if (E->isLogicalOp()) {
// Insert function call at start of first expression.
// Note getLocStart() should work as well as getExprLoc()
rewriter.InsertText(E->getLHS()->getExprLoc(),
E->getOpcode() == BO_LAnd ? "L_AND(" : "L_OR(", true);
// Replace operator ("||" or "&&") with ","
rewriter.ReplaceText(E->getOperatorLoc(), E->getOpcodeStr().size(), ",");
// Insert closing paren at end of right-hand expression
rewriter.InsertTextAfterToken(E->getRHS()->getLocEnd(), ")");
} else
// Note isComparisonOp() is like isRelationalOp() but includes == and !=
if (E->isRelationalOp()) {
llvm::errs() << "Relational Op " << E->getOpcodeStr() << "\n";
} else
// Handles == and != comparisons
if (E->isEqualityOp()) {
llvm::errs() << "Equality Op " << E->getOpcodeStr() << "\n";
}
return E;
}
/*
virtual bool VisitReturnStmt(ReturnStmt *ret) {
rewriter.ReplaceText(ret->getRetValue()->getLocStart(), 6, "val");
errs() << "** Rewrote ReturnStmt\n";
return true;
}
virtual bool VisitCallExpr(CallExpr *call) {
rewriter.ReplaceText(call->getLocStart(), 7, "add5");
errs() << "** Rewrote function call\n";
return true;
}
*/
#endif
};
class S2EInstrumentationASTConsumer : public ASTConsumer {
private:
S2EInstrumentationVisitor *visitor; // doesn't have to be private
public:
// override the constructor in order to pass CI
explicit S2EInstrumentationASTConsumer(CompilerInstance *CI)
: visitor(new S2EInstrumentationVisitor(CI)) // initialize the visitor
{ }
#if 0
// override this to call our ExampleVisitor on the entire source file
virtual void HandleTranslationUnit(ASTContext &Context) {
/* we can use ASTContext to get the TranslationUnitDecl, which is
a single Decl that collectively represents the entire source file */
visitor->TraverseDecl(Context.getTranslationUnitDecl());
//visitor->TraverseStmt();
}
#endif
// override this to call our ExampleVisitor on each top-level Decl
virtual bool HandleTopLevelDecl(DeclGroupRef DG) {
// a DeclGroupRef may have multiple Decls, so we iterate through each one
for (DeclGroupRef::iterator i = DG.begin(), e = DG.end(); i != e; i++) {
Decl *D = *i;
visitor->TraverseDecl(D); // recursively visit each AST node in Decl "D"
//D->dump();
}
return true;
}
};
class S2EInstrumentationFrontendAction : public ASTFrontendAction {
public:
virtual ASTConsumer *CreateASTConsumer(CompilerInstance &CI, StringRef file) {
return new S2EInstrumentationASTConsumer(&CI); // pass CI pointer to ASTConsumer
}
};
int main(int argc, const char **argv) {
// Parse the command-line args passed to your code
CommonOptionsParser op(argc, argv);
// Create a new Clang Tool instance (a LibTooling environment)
ClangTool Tool(op.getCompilations(), op.getSourcePathList());
languageOptions.GNUMode = 1;
languageOptions.CXXExceptions = 1;
languageOptions.RTTI = 1;
languageOptions.Bool = 1;
languageOptions.CPlusPlus = 1;
// Run the Clang Tool, creating a new FrontendAction (explained below)
//int result = Tool.run(newFrontendActionFactory<RewriteMacrosAction>());
int result = Tool.run(newFrontendActionFactory<S2EInstrumentationFrontendAction>());
//result = Tool.run(newFrontendActionFactory<S2EInstrumentationFrontendAction>());
// Print out the rewritten source code ("rewriter" is a global var.)
rewriter.getEditBuffer(rewriter.getSourceMgr().getMainFileID()).write(errs());
//rewriter.getEditBuffer(rewriter.getSourceMgr().getMainFileID()).write(outs());
return result;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment