Skip to content

Instantly share code, notes, and snippets.

@lhchavez
Created January 3, 2017 14:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lhchavez/b0d27de062d01b6ec224f63bef38a694 to your computer and use it in GitHub Desktop.
Save lhchavez/b0d27de062d01b6ec224f63bef38a694 to your computer and use it in GitHub Desktop.
WIP of an IWYU ClangTool
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Path.h"
#include <memory>
#include <unordered_map>
#include <vector>
using namespace clang;
using namespace clang::tooling;
using namespace llvm;
struct TypeSourceInformation {
std::string TypeName;
FullSourceLoc FirstUseLoc;
FullSourceLoc DeclarationLoc;
FullSourceLoc FirstFullUseLoc;
bool NeedFullType;
};
struct MacroExpansionInformation {
std::string MacroName;
FullSourceLoc FirstUseloc;
FullSourceLoc DeclarationLoc;
};
struct IncludeInformation {
std::string FileName;
bool IsAngled;
bool FromMainFile;
bool FromMainHeaderFile;
FullSourceLoc IncludeLoc;
};
// Apply a custom category to all command-line options so that they are the
// only ones displayed.
static llvm::cl::OptionCategory MyToolCategory("clang-iwyu options");
static cl::extrahelp CommonHelp(CommonOptionsParser::HelpMessage);
static const NamedDecl *getNamedDecl(QualType QT) {
if (QT->isPointerType() || QT->isReferenceType())
return getNamedDecl(QT->getPointeeType());
else if (QT->isArrayType())
return getNamedDecl(QT->castAsArrayTypeUnsafe()->getElementType());
else if (QT->isRecordType())
return QT->getAs<RecordType>()->getDecl();
else if (QT->isEnumeralType())
return QT->getAs<EnumType>()->getDecl();
else if (QT->getTypeClass() == Type::Typedef)
return QT->getAs<TypedefType>()->getDecl();
const ElaboratedType *Elab = dyn_cast<ElaboratedType>(QT.getTypePtr());
if (Elab)
return getNamedDecl(Elab->getNamedType());
const TemplateSpecializationType *TST =
dyn_cast<TemplateSpecializationType>(QT.getTypePtr());
if (TST) {
TemplateName TN = TST->getTemplateName();
if (TST->isTypeAlias())
return getNamedDecl(TST->getAliasedType());
else if (TST->isSugared())
return getNamedDecl(TST->desugar());
else if (TN.getAsTemplateDecl())
return TN.getAsTemplateDecl();
}
return nullptr;
}
class ClangIwyuVisitor : public RecursiveASTVisitor<ClangIwyuVisitor> {
public:
ClangIwyuVisitor(ASTContext *Context,
std::unordered_map<std::string, TypeSourceInformation>
*const TypeSourceMap)
: Context(Context), TypeSourceMap(TypeSourceMap) {}
bool VisitTypeLoc(TypeLoc TL) {
if (!Context->getSourceManager().isInMainFile(TL.getLocStart()))
return true;
const NamedDecl *ND = getNamedDecl(TL.getType());
if (!ND)
return true;
AddQualType(ND, TL.getLocStart());
return true;
}
bool TraverseFunctionDecl(FunctionDecl *Decl) {
return TraverseFunctionHelper(
Decl, &RecursiveASTVisitor<ClangIwyuVisitor>::TraverseFunctionDecl);
}
bool TraverseCXXMethodDecl(CXXMethodDecl *Decl) {
return TraverseFunctionHelper(
Decl, &RecursiveASTVisitor<ClangIwyuVisitor>::TraverseCXXMethodDecl);
}
bool TraverseCXXConstructorDecl(CXXConstructorDecl *Decl) {
return TraverseFunctionHelper(
Decl,
&RecursiveASTVisitor<ClangIwyuVisitor>::TraverseCXXConstructorDecl);
}
bool TraverseCXXConversionDecl(CXXConversionDecl *Decl) {
return TraverseFunctionHelper(
Decl,
&RecursiveASTVisitor<ClangIwyuVisitor>::TraverseCXXConversionDecl);
}
bool TraverseCXXDestructorDecl(CXXDestructorDecl *Decl) {
return TraverseFunctionHelper(
Decl,
&RecursiveASTVisitor<ClangIwyuVisitor>::TraverseCXXDestructorDecl);
}
bool VisitValueDecl(const ValueDecl *VD) {
if (InEmptyFunctionDecl)
return true;
SourceManager &SM = Context->getSourceManager();
if (!SM.isInMainFile(VD->getLocStart()))
return true;
const NamedDecl *ND = getNamedDecl(VD->getType());
if (!ND)
return true;
TypeSourceInformation *TSI = AddQualType(ND, VD->getLocStart());
if (!TSI) {
FullSourceLoc UseLoc = Context->getFullLoc(VD->getLocStart());
llvm::outs() << "Could not find type at "
<< (UseLoc.isValid() ? SM.getFilename(UseLoc) : "<unk>")
<< ":"
<< (UseLoc.isValid() ? UseLoc.getSpellingLineNumber() : -1)
<< "\n";
return true;
}
if (TSI->NeedFullType)
return true;
QualType QT = VD->getType();
if (QT->isPointerType() || QT->isReferenceType())
return true;
TSI->FirstFullUseLoc = Context->getFullLoc(VD->getLocStart());
TSI->NeedFullType = true;
return true;
}
bool VisitUnaryDeref(const UnaryOperator *Op) {
return VisitExprHelper(Op->getSubExpr());
}
bool VisitBinPtrMemD(const BinaryOperator *Op) {
return VisitExprHelper(Op->getLHS());
}
bool VisitBinPtrMemI(const BinaryOperator *Op) {
return VisitExprHelper(Op->getLHS());
}
private:
bool VisitExprHelper(const Expr *E) {
SourceManager &SM = Context->getSourceManager();
if (!SM.isInMainFile(E->getLocStart()))
return true;
const NamedDecl *ND = getNamedDecl(E->getType());
if (!ND)
return true;
TypeSourceInformation *TSI = AddQualType(ND, E->getLocStart());
if (!TSI) {
FullSourceLoc UseLoc = Context->getFullLoc(E->getLocStart());
llvm::outs() << "Could not find type at "
<< (UseLoc.isValid() ? SM.getFilename(UseLoc) : "<unk>")
<< ":"
<< (UseLoc.isValid() ? UseLoc.getSpellingLineNumber() : -1)
<< "\n";
return true;
}
if (TSI->NeedFullType)
return true;
QualType QT = E->getType();
if (QT->isPointerType() || QT->isReferenceType())
return true;
TSI->FirstFullUseLoc = Context->getFullLoc(E->getLocStart());
TSI->NeedFullType = true;
return true;
}
template <typename T, typename Functor>
bool TraverseFunctionHelper(T *Func, Functor Callback) {
SourceManager &SM = Context->getSourceManager();
if (!SM.isInMainFile(Func->getLocStart()))
return (this->*Callback)(Func);
if (!Func->hasBody())
InEmptyFunctionDecl = true;
bool ReturnValue = (this->*Callback)(Func);
InEmptyFunctionDecl = false;
return ReturnValue;
}
TypeSourceInformation *AddQualType(const NamedDecl *ND,
SourceLocation LocStart) {
std::string QualTypeName = ND->getQualifiedNameAsString();
if (TypeSourceMap->find(QualTypeName) != TypeSourceMap->end())
return &(*TypeSourceMap)[QualTypeName];
FullSourceLoc UseLoc = Context->getFullLoc(LocStart);
FullSourceLoc DeclLoc = Context->getFullLoc(ND->getLocStart());
if (!UseLoc.isValid() || !DeclLoc.isValid())
return nullptr;
TypeSourceMap->insert(std::make_pair(
QualTypeName, TypeSourceInformation{QualTypeName, UseLoc, DeclLoc,
FullSourceLoc(), false}));
return &(*TypeSourceMap)[QualTypeName];
}
bool InEmptyFunctionDecl = false;
ASTContext *const Context;
std::unordered_map<std::string, TypeSourceInformation> *const TypeSourceMap;
const FileID MainFileID;
};
class ClangIwyuConsumer : public ASTConsumer {
public:
ClangIwyuConsumer(ASTContext *Context,
std::unordered_map<std::string, TypeSourceInformation>
*const TypeSourceMap)
: Visitor(Context, TypeSourceMap) {}
void HandleTranslationUnit(ASTContext &Context) override {
Visitor.TraverseDecl(Context.getTranslationUnitDecl());
}
private:
ClangIwyuVisitor Visitor;
};
class IncludeMacroCallbacks : public PPCallbacks {
public:
IncludeMacroCallbacks(
SourceManager *SourceMgr, ASTContext *const Context,
std::unordered_map<std::string, MacroExpansionInformation>
*const MacroExpansionMap,
std::vector<IncludeInformation> *const IncludeList)
: SourceMgr(SourceMgr), Context(Context),
MacroExpansionMap(MacroExpansionMap), IncludeList(IncludeList),
MainFileID(SourceMgr->getMainFileID()) {
SourceLocation Loc = SourceMgr->getLocForStartOfFile(MainFileID);
SmallString<256> CanonicalNameBuf(SourceMgr->getFilename(Loc));
llvm::sys::fs::make_absolute(CanonicalNameBuf);
llvm::sys::path::native(CanonicalNameBuf);
llvm::sys::path::remove_dots(CanonicalNameBuf, /* remove_dot_dot */ true);
llvm::sys::path::replace_extension(CanonicalNameBuf, ".h");
HeaderFilename = CanonicalNameBuf.str();
}
void FileChanged(SourceLocation Loc, FileChangeReason Reason,
SrcMgr::CharacteristicKind FileType,
FileID PrevFID = FileID()) override {
if (Reason != EnterFile)
return;
SmallString<256> CanonicalNameBuf(SourceMgr->getFilename(Loc));
llvm::sys::fs::make_absolute(CanonicalNameBuf);
llvm::sys::path::native(CanonicalNameBuf);
llvm::sys::path::remove_dots(CanonicalNameBuf, /* remove_dot_dot */ true);
std::string Filename = CanonicalNameBuf.str();
if (Filename == HeaderFilename)
MainHeaderFileID = SourceMgr->getFileID(Loc);
}
void MacroExpands(const Token &MacroNameTok, const MacroDefinition &MD,
SourceRange Range, const MacroArgs *Args) override {
if (!SourceMgr->isInMainFile(MacroNameTok.getLocation()))
return;
if (MD.getMacroInfo()->isBuiltinMacro())
return;
std::string MacroName = MacroNameTok.getIdentifierInfo()->getName();
if (MacroExpansionMap->find(MacroName) != MacroExpansionMap->end())
return;
FullSourceLoc UseLoc = Context->getFullLoc(MacroNameTok.getLocation());
FullSourceLoc DeclLoc =
Context->getFullLoc(MD.getLocalDirective()->getLocation());
if (!UseLoc.isValid() || !DeclLoc.isValid())
return;
MacroExpansionMap->insert(std::make_pair(
MacroName, MacroExpansionInformation{MacroName, UseLoc, DeclLoc}));
}
void InclusionDirective(SourceLocation HashLoc, const Token &IncludeTok,
StringRef FileName, bool IsAngled,
CharSourceRange FilenameRange, const FileEntry *File,
StringRef SearchPath, StringRef RelativePath,
const clang::Module *Imported) override {
FileID FID = SourceMgr->getFileID(HashLoc);
if (FID != MainFileID && FID != MainHeaderFileID)
return;
FullSourceLoc IncludeLoc = Context->getFullLoc(HashLoc);
if (!IncludeLoc.isValid())
return;
IncludeList->emplace_back(IncludeInformation{
FileName, IsAngled, FID == MainFileID,
MainFileID != MainHeaderFileID && FID == MainHeaderFileID, IncludeLoc});
}
private:
SourceManager *const SourceMgr;
ASTContext *const Context;
std::unordered_map<std::string, MacroExpansionInformation>
*const MacroExpansionMap;
std::vector<IncludeInformation> *const IncludeList;
std::string HeaderFilename;
const FileID MainFileID;
FileID MainHeaderFileID;
};
class ClangIwyuAction : public ASTFrontendAction {
public:
void EndSourceFileAction() override {
SourceManager &SM = getCompilerInstance().getSourceManager();
for (const auto &Include : IncludeList) {
llvm::outs() << "include " << Include.FileName
<< (Include.FromMainFile ? " from main file" : "")
<< (Include.FromMainHeaderFile ? " from header file" : "")
<< "\n";
}
for (const auto &ME : MacroExpansionMap) {
llvm::outs() << "macro " << ME.first << " from "
<< SM.getFilename(ME.second.DeclarationLoc) << ":"
<< ME.second.DeclarationLoc.getSpellingLineNumber() << "\n";
}
for (const auto &TS : TypeSourceMap) {
llvm::outs() << TS.first << " " << TS.second.NeedFullType
<< SM.getFilename(TS.second.DeclarationLoc) << ":"
<< TS.second.DeclarationLoc.getSpellingLineNumber() << "\n";
if (TS.second.NeedFullType) {
llvm::outs() << "\tBlame " << SM.getFilename(TS.second.FirstFullUseLoc)
<< ":" << TS.second.FirstFullUseLoc.getSpellingLineNumber()
<< "\n";
}
}
}
std::unique_ptr<ASTConsumer>
CreateASTConsumer(CompilerInstance &Compiler,
llvm::StringRef InFile) override {
TypeSourceMap.clear();
MacroExpansionMap.clear();
IncludeList.clear();
Compiler.getDiagnostics().setClient(new IgnoringDiagConsumer());
Preprocessor &PP = Compiler.getPreprocessor();
SourceManager &SM = Compiler.getSourceManager();
PP.addPPCallbacks(std::unique_ptr<PPCallbacks>(new IncludeMacroCallbacks(
&SM, &Compiler.getASTContext(), &MacroExpansionMap, &IncludeList)));
return std::unique_ptr<ASTConsumer>(
new ClangIwyuConsumer(&Compiler.getASTContext(), &TypeSourceMap));
}
private:
std::unordered_map<std::string, TypeSourceInformation> TypeSourceMap;
std::unordered_map<std::string, MacroExpansionInformation> MacroExpansionMap;
std::vector<IncludeInformation> IncludeList;
};
int main(int argc, const char **argv) {
// CommonOptionsParser constructor will parse arguments and create a
// CompilationDatabase. In case of error it will terminate the program.
CommonOptionsParser OptionsParser(argc, argv, MyToolCategory);
ClangTool Tool(OptionsParser.getCompilations(),
OptionsParser.getSourcePathList());
std::unique_ptr<FrontendActionFactory> FrontendFactory =
newFrontendActionFactory<ClangIwyuAction>();
return Tool.run(FrontendFactory.get());
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment