Skip to content

Instantly share code, notes, and snippets.

@eqvinox
Created February 2, 2023 00:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save eqvinox/9a6c7810e07e8c3373d2ade106678918 to your computer and use it in GitHub Desktop.
Save eqvinox/9a6c7810e07e8c3373d2ade106678918 to your computer and use it in GitHub Desktop.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// FRR printfrr extensions type-checking plugin
//
// printfrr("%pI4", addr); -> warn if addr isn't a "struct in_addr *"
//
// WORK IN PROGRESS
#include <unordered_map>
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/Type.h"
#include "clang/Basic/DiagnosticSema.h"
#include "clang/Basic/TokenKinds.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendPluginRegistry.h"
#include "clang/Parse/ParseAST.h"
#include "clang/Parse/Parser.h"
#include "clang/Sema/Sema.h"
#include "clang/Sema/SemaDiagnostic.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Path.h"
using namespace llvm;
namespace clang {
namespace frr_format {
/* types that are "opaque" at the system level are stuffed in here. Those
* either should be printf'd with a special specifier (e.g. size_t = %zd) or
* can't be printf'd without a cast (e.g. pid_t might be int128_t)
*/
llvm::SmallSet<llvm::StringRef, 16> terminal_typedefs;
/* squirrel away __attribute__(("printfrr", X, Y)) */
struct FRRFormatInfo {
uint64_t fmt;
uint64_t args;
bool has_args;
};
/* format_funcs.contains(decl) is faster than scanning through the decl */
std::unordered_map<const clang::FunctionDecl *, FRRFormatInfo> format_funcs;
/* __attribute__(("printfrr", format-string-idx, args-idx)) */
class FRRFormatAttrInfo : public ParsedAttrInfo {
public:
FRRFormatAttrInfo() {
NumArgs = 3;
static constexpr Spelling S[] = {
{ParsedAttr::AS_GNU, "frr_format"},
};
Spellings = S;
}
bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr,
const Decl *D) const override {
if (!isa<FunctionDecl>(D)) {
S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str)
<< Attr << "functions";
return false;
}
return true;
}
AttrHandling handleDeclAttribute(Sema &S, Decl *D,
const ParsedAttr &Attr) const override {
const auto *TheFunc = dyn_cast_or_null<FunctionDecl>(D);
FRRFormatInfo fi;
if (Attr.getNumArgs() != 3) {
unsigned ID = S.getDiagnostics().getCustomDiagID(
DiagnosticsEngine::Error,
"'frr_format' attribute requires 3 arguments");
S.Diag(Attr.getLoc(), ID);
return AttributeNotApplied;
}
auto *Arg0 = Attr.getArgAsExpr(0);
StringLiteral *Literal =
dyn_cast_or_null<StringLiteral>(Arg0->IgnoreParenCasts());
if (!Literal) {
unsigned ID = S.getDiagnostics().getCustomDiagID(
DiagnosticsEngine::Error, "first argument to the 'frr_format' "
"attribute must be a string literal");
S.Diag(Attr.getLoc(), ID);
return AttributeNotApplied;
}
IntegerLiteral *LitFmt = dyn_cast_or_null<IntegerLiteral>(Attr.getArgAsExpr(1)->IgnoreParenCasts());
if (!LitFmt) {
unsigned ID = S.getDiagnostics().getCustomDiagID(
DiagnosticsEngine::Error, "second argument to the 'frr_format' "
"attribute must be an integer literal");
S.Diag(Attr.getLoc(), ID);
return AttributeNotApplied;
}
IntegerLiteral *LitArgs = dyn_cast_or_null<IntegerLiteral>(Attr.getArgAsExpr(2)->IgnoreParenCasts());
if (!LitArgs) {
unsigned ID = S.getDiagnostics().getCustomDiagID(
DiagnosticsEngine::Error, "third argument to the 'frr_format' "
"attribute must be an integer literal");
S.Diag(Attr.getLoc(), ID);
return AttributeNotApplied;
}
/* TODO: reject invalid (negative) values */
fi.fmt = LitFmt->getValue().getLimitedValue() - 1;
uint64_t args = LitArgs->getValue().getLimitedValue();
fi.args = args - 1;
fi.has_args = args > 0;
format_funcs[TheFunc] = fi;
return AttributeApplied;
}
};
static ParsedAttrInfoRegistry::Add<FRRFormatAttrInfo> Z("frr_format", "FRR printf extensions attribute");
/* #pragma FRR printfrr_ext "%pI4" (struct in_addr *) */
class FRRFormatPragmaHandler : public PragmaHandler {
public:
FRRFormatPragmaHandler() : PragmaHandler("FRR") { }
void HandlePragma(Preprocessor &PP, PragmaIntroducer Introducer,
Token &PragmaTok) {
Token Tok;
StringRef name;
PP.Lex(Tok);
if (Tok.isNot(tok::identifier)) {
PP.Diag(Tok.getLocation(), diag::err_expected) << tok::identifier;
return;
}
name = Tok.getIdentifierInfo()->getName();
if (!name.equals("printfrr_ext")) {
PP.Diag(Tok.getLocation(), diag::err_expected) << tok::string_literal;
return;
}
std::string fmtspec;
if (!PP.LexStringLiteral(Tok, fmtspec, "pragma FRR printfrr_ext", true)) {
PP.Diag(Tok.getLocation(), diag::err_expected) << tok::string_literal;
return;
}
// PP.Lex(Tok); - implicit in LexStringLiteral
if (Tok.isNot(tok::l_paren)) {
PP.Diag(Tok.getLocation(), diag::err_expected) << tok::l_paren;
return;
}
llvm::errs() << "#pragma \"" << fmtspec << "\"\n";
for (PP.Lex(Tok); Tok.isNot(tok::r_paren); PP.Lex(Tok)) {
/* TODO: parse type? */
}
}
};
static PragmaHandlerRegistry::Add<FRRFormatPragmaHandler> Y("FRR", "FRR pragmas");
/* find printfrr calls */
class FRRFormatVisitor : public RecursiveASTVisitor<FRRFormatVisitor> {
public:
FRRFormatVisitor(ASTContext &context) : Context(context), Diags(context.getDiagnostics()) {
WarningInvalidFormatString = Diags.getCustomDiagID(
DiagnosticsEngine::Error,
"invalid format string");
WarningInvalidFormatSpecifier = Diags.getCustomDiagID(
DiagnosticsEngine::Warning,
"invalid format specifier '%0'");
}
struct flag_state {
unsigned L;
unsigned h;
unsigned j;
unsigned l;
unsigned q;
unsigned t;
unsigned z;
};
void check_int_arg(char specifier, flag_state &flags, const Expr *arg) {
const QualType QT = arg->getType();
const Type *T = QT.getTypePtr();
const Type *PT, *nextPT = T;
llvm::errs() << "intcheck: specifier %" << specifier << " got type: (";
// arg->print(llvm::errs());
QT.print(llvm::errs(), Context.getPrintingPolicy());
llvm::errs() << "):\n";
T->dump();
while (nextPT) {
PT = nextPT;
nextPT = NULL;
const TypedefType *tt = PT->getAs<TypedefType>();
if (tt) {
auto decl = tt->getDecl();
auto tdname = decl->getDeclName().getAsString();
llvm::errs() << "typedef type: " << tdname << "\n";
if (terminal_typedefs.contains(tdname)) {
llvm::errs() << " -- terminal\n";
break;
}
nextPT = decl->getUnderlyingType().getTypePtr();
}
}
llvm::errs() << "\n";
}
void check_printfrr(CallExpr *ce, const Expr *fe, const char *start, const char *end,
const Expr *const *argv, unsigned argc) {
unsigned argpos = 0;
const char *pos = start;
while (pos < end) {
flag_state flags = {};
pos = (const char *)memchr(pos, '%', end - pos);
if (!pos)
break;
pos++;
do {
if (pos == end || *pos == '\0') {
Diags.Report(fe->getExprLoc(), WarningInvalidFormatString);
return;
}
static const char intspecs[] = "dDiuoOuUxX";
static const char terminal[] = "cCaAeEfFgGmpsS";
switch (*pos++) {
case ' ':
case '#':
case '-':
case '+':
case '\'':
case '.':
case '0' ... '9':
continue;
case '*':
/* check int arg */
continue;
case 'L': flags.L++; continue;
case 'h': flags.h++; continue;
case 'j': flags.j++; continue;
case 'l': flags.l++; continue;
case 'q': flags.q++; continue;
case 't': flags.t++; continue;
case 'z': flags.z++; continue;
case 'm':
/* no arg for %m */
break;
case 'p':
/* TODO: format check %p extensions */
argpos++;
break;
case 'd':
case 'i':
/* TODO: format check %d/%i extensions */
/* FALLTHRU */
default:
if (strchr(intspecs, pos[-1])) {
check_int_arg(pos[-1], flags, argv[argpos]);
argpos++;
/* check base int types */
} else if (strchr(terminal, pos[-1])) {
argpos++;
/* skip checking these */
} else {
char tmp[2] = { pos[-1], '\0' };
Diags.Report(fe->getExprLoc(), WarningInvalidFormatSpecifier)
<< tmp;
}
}
break;
} while (1);
}
}
bool VisitCallExpr(CallExpr *ce) {
const FunctionDecl *callee = ce->getDirectCallee();
FRRFormatInfo fi;
while (callee) {
auto it = format_funcs.find(callee);
if (it != format_funcs.end()) {
fi = it->second;
break;
}
/* printfrr attribute may have been applied to earlier decl */
callee = callee->getPreviousDecl();
}
if (!callee)
return true;
unsigned argc = ce->getNumArgs();
if (argc <= fi.fmt) {
Diags.Report(ce->getRParenLoc(), WarningInvalidFormatString);
return true;
}
const Expr *fmtexpr = ce->getArg(fi.fmt);
const auto *fmtlit = dyn_cast<StringLiteral>(fmtexpr->IgnoreParenImpCasts());
if (!fmtlit) {
//Diags.Report(fmtexpr->getExprLoc(), WarningInvalidFormatString);
return true;
}
StringRef fmtstrref = fmtlit->getString();
const char *fmtbytes = fmtstrref.data();
const ConstantArrayType *T =
Context.getAsConstantArrayType(fmtlit->getType());
assert(T && "String literal not of constant array type!");
size_t fmtsize = T->getSize().getZExtValue();
const Expr *const *args = NULL;
if (fi.has_args) {
args = ce->getArgs();
args += fi.args;
argc -= fi.args;
} else
argc = 0;
check_printfrr(ce, fmtexpr, fmtbytes, fmtbytes + fmtsize, args, argc);
return true;
}
private:
ASTContext &Context;
DiagnosticsEngine &Diags;
unsigned WarningInvalidFormatString;
unsigned WarningInvalidFormatSpecifier;
};
/* setup plumbing */
class FRRFormatConsumer : public ASTConsumer {
CompilerInstance &Instance;
public:
FRRFormatConsumer(CompilerInstance &Instance)
: Instance(Instance) {
// CI.getPreprocessor().addPPCallbacks(std::make_unique<FRRFormatPPCallbacks>());
}
void Initialize(ASTContext &Context) override {
llvm::errs() << "Consumer Initialize\n";
}
void HandleTranslationUnit(ASTContext& context) override {
llvm::errs() << "Consumer Handle TU\n";
FRRFormatVisitor Visitor(context);
Visitor.TraverseDecl(context.getTranslationUnitDecl());
}
};
class FRRFormatPluginAction : public PluginASTAction {
public:
std::unique_ptr<clang::ASTConsumer>
CreateASTConsumer(clang::CompilerInstance &CI, StringRef InFile) override {
terminal_typedefs.insert("atomic_size_t");
terminal_typedefs.insert("atomic_ssize_t");
terminal_typedefs.insert("size_t");
terminal_typedefs.insert("ssize_t");
terminal_typedefs.insert("ptrdiff_t");
terminal_typedefs.insert("pid_t");
terminal_typedefs.insert("uid_t");
terminal_typedefs.insert("gid_t");
terminal_typedefs.insert("time_t");
return std::make_unique<FRRFormatConsumer>(CI);
}
PluginASTAction::ActionType getActionType() override {
return AddAfterMainAction;
}
bool ParseArgs(const CompilerInstance &CI,
const std::vector<std::string> &Args) override {
return true;
}
};
} // namespace frr_format
} // namespace clang
volatile int FRRFormatPluginAnchorSource = 0;
static clang::FrontendPluginRegistry::Add<
clang::frr_format::FRRFormatPluginAction>
X("frr-format", "check FRR printf extensions");
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment