Index: tools/clang/TaskToBind/TaskToBind.cpp |
diff --git a/tools/clang/TaskToBind/TaskToBind.cpp b/tools/clang/TaskToBind/TaskToBind.cpp |
new file mode 100644 |
index 0000000000000000000000000000000000000000..51770845e2942f42d691c55de8e0c0d55550679b |
--- /dev/null |
+++ b/tools/clang/TaskToBind/TaskToBind.cpp |
@@ -0,0 +1,199 @@ |
+// Copyright 2011 Google Inc. All Rights Reserved. |
+// Author: ajwong@google.com (Albert Wong) |
+ |
+#include <iostream> |
+#include <string> |
+ |
+#include "clang/AST/AST.h" |
+#include "clang/AST/ASTConsumer.h" |
+#include "clang/AST/Decl.h" |
+#include "clang/AST/DeclGroup.h" |
+#include "clang/AST/DeclVisitor.h" |
+#include "clang/AST/DeclarationName.h" |
+#include "clang/AST/ExprCXX.h" |
+#include "clang/AST/ParentMap.h" |
+#include "clang/AST/StmtVisitor.h" |
+#include "clang/AST/RecursiveASTVisitor.h" |
+#include "clang/AST/TypeLocVisitor.h" |
+#include "clang/Basic/Diagnostic.h" |
+#include "clang/Basic/IdentifierTable.h" |
+#include "clang/Basic/SourceManager.h" |
+#include "clang/Frontend/FrontendPluginRegistry.h" |
+#include "clang/Index/DeclReferenceMap.h" |
+#include "clang/Lex/Lexer.h" |
+#include "clang/Rewrite/ASTConsumers.h" |
+#include "clang/Rewrite/Rewriter.h" |
+#include "llvm/ADT/DenseSet.h" |
+#include "llvm/ADT/OwningPtr.h" |
+#include "llvm/ADT/SmallPtrSet.h" |
+#include "llvm/ADT/StringExtras.h" |
+#include "llvm/Support/MemoryBuffer.h" |
+#include "llvm/Support/raw_ostream.h" |
+#include "clang/Frontend/CompilerInstance.h" |
+#include "llvm/Support/raw_ostream.h" |
+ |
+using namespace clang; |
+using namespace std; |
+ |
+class PostTaskVisitor : public RecursiveASTVisitor<PostTaskVisitor> { |
+ public: |
+ PostTaskVisitor(ASTContext* context, Rewriter* rewriter) |
+ : context_(context), rewriter_(rewriter) { |
+ interesting_classes_.insert("MessageLoop"); |
+ interesting_classes_.insert("MessageLoopProxy"); |
+ } |
+ |
+ bool TraverseStmt(Stmt *S) { |
+ // Catch the MessageLoop and MessageLoopProxy calls. |
+ if (CXXMemberCallExpr* mce = dyn_cast_or_null<CXXMemberCallExpr>(S)) { |
+ if (IsPostTaskExpr(mce)) { |
+ if (mce->getNumArgs() < 2) { |
+ llvm::errs() << "PostTask with less than 2 args?! Inconceivable!"; |
Nico
2011/09/15 02:50:03
You can emit diagnostics instead, that way the com
awong
2011/09/16 02:50:41
Done.
|
+ return false; |
+ } |
+ MaybeRewriteNewRunnableMethod(mce->getArgs()[1]); |
+ } |
+ } |
+ return RecursiveASTVisitor<PostTaskVisitor>::TraverseStmt(S); |
+ } |
+ |
+ private: |
+ bool IsPostTaskExpr(CXXMemberCallExpr* call) { |
+ CXXMethodDecl* method_decl = call->getMethodDecl()->getCanonicalDecl(); |
+ if (kPostTaskName == method_decl->getNameAsString()) { |
Nico
2011/09/15 02:50:03
Alternatively, you could use the Resolver stuff in
awong
2011/09/16 02:50:41
Interesting. That is cleaner. Added TODO.
|
+ string classname = method_decl->getThisType(*context_) |
+ .getBaseTypeIdentifier()->getName(); |
+ llvm::outs() << "Found a posttask. Name of class: " |
+ << classname |
+ << "\n"; |
+ if (interesting_classes_.find(classname) != interesting_classes_.end()) { |
Nico
2011/09/15 02:50:03
nit: .count(classname) > 0
awong
2011/09/16 02:50:41
Done.
|
+ return true; |
+ } |
+ } |
+ return false; |
+ } |
+ |
+ bool MaybeRewriteNewRunnableMethod(Expr* post_task_arg) { |
+ // Strip implicit casts. |
Nico
2011/09/15 02:50:03
Instead:
post_task_arg = post_task_arg->IgnoreImp
awong
2011/09/16 02:50:41
Oh very nice. Went with IgnoreImplicit()->IgnoreP
|
+ if (ImplicitCastExpr* ice = dyn_cast<ImplicitCastExpr>(post_task_arg)) { |
+ return MaybeRewriteNewRunnableMethod(ice->getSubExpr()); |
+ } |
+ |
+ CallExpr* ce = dyn_cast<CallExpr>(post_task_arg); |
+ if (!ce) return false; |
+ FunctionDecl *fd = ce->getDirectCallee(); |
+ if (!fd) return false; |
+ if (kNewRunnableMethodName != |
+ fd->getNameInfo().getName().getAsIdentifierInfo()->getName()) |
+ return false; |
+ |
+ if (ce->getNumArgs() < 2) { |
+ llvm::errs() << "NewRunnableMethod with less than 2 args?! Inconceivable!"; |
+ return false; |
+ } |
+ |
+ // Okay, here's where it gets fun. We need to |
+ // (1) replace the NRM identifier text with base::Bind. |
+ // (2) Swap positions of the first two arguments. |
+ |
+ // NewRunnableMethod -> base::Bind |
+ rewriter_->ReplaceText(ce->getCallee()->getSourceRange(), "base::Bind"); |
+ |
+ // Swap the argument order. |
+ Expr* arg1 = ce->getArgs()[0]; |
+ Expr* arg2 = ce->getArgs()[1]; |
+ bool failure = rewriter_->ReplaceStmt(arg1, arg2); |
+ failure |= rewriter_->ReplaceStmt(arg2, arg1); |
+ |
+ // TODO(ajwong): |
+ // (3) Check the RunnableMethodTraits for the second argument, and |
+ // wrap base::Unretained() if it is declared with |
+ // DISABLE_RUNNABLE_METHOD_REFCOUNT. |
+ return failure; |
+ } |
+ |
+ ASTContext* context_; |
+ Rewriter* rewriter_; |
+ |
+ // Names for matching. |
+ set<string> interesting_classes_; |
+ static const char kPostTaskName[]; |
+ static const char kNewRunnableMethodName[]; |
+}; |
+ |
+const char PostTaskVisitor::kPostTaskName[] = "PostTask"; |
+const char PostTaskVisitor::kNewRunnableMethodName[] = "NewRunnableMethod"; |
+ |
+class TaskToBindConsumer : public ASTConsumer { |
+public: |
+ TaskToBindConsumer() { |
+ } |
+ |
+ virtual void Initialize(ASTContext &context) { |
+ rewriter_.setSourceMgr(context.getSourceManager(), context.getLangOptions()); |
+ } |
+ |
+ virtual void HandleTranslationUnit(ASTContext &context) { |
+ // modify calls |
+ TranslationUnitDecl *translation_unit = context.getTranslationUnitDecl(); |
+ for (DeclContext::decl_iterator it = translation_unit->decls_begin(), |
+ end = translation_unit->decls_end(); it != end; ++it) { |
Nico
2011/09/15 02:50:03
You could check
context.getSourceManager()->isF
awong
2011/09/16 02:50:41
Added a note. I think I'm going to try to make th
|
+ PostTaskVisitor visitor(&context, &rewriter_); |
+ visitor.TraverseDecl(*it); |
+ } |
+ |
+ // Get the buffer corresponding to MainFileID. |
+ // If we haven't changed it, then we are done. |
+ if (rewriter_.buffer_begin() != rewriter_.buffer_end()) { |
+ llvm::outs() << "Src file changed.\n"; |
+ } else { |
+ llvm::errs() << "No changes.\n"; |
+ } |
+ |
+ for (Rewriter::buffer_iterator it = rewriter_.buffer_begin(); |
+ it != rewriter_.buffer_end(); |
+ ++it) { |
+ llvm::outs() << std::string(it->second.begin(), it->second.end()); |
+ } |
+ |
+ std::cout.flush(); |
+ } |
+ |
+private: |
+ FileID source_file_; |
+ Rewriter rewriter_; |
+}; |
+ |
+class TaskToBindAction : public PluginASTAction { |
+protected: |
+ ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) { |
+ return new TaskToBindConsumer(); |
+ } |
+ |
+ bool ParseArgs(const CompilerInstance &CI, |
+ const std::vector<std::string>& args) { |
+ for (unsigned i = 0, e = args.size(); i != e; ++i) { |
+ llvm::errs() << "TaskToBind arg = " << args[i] << "\n"; |
+ |
+ // Example error handling. |
+ if (args[i] == "-an-error") { |
+ Diagnostic &D = CI.getDiagnostics(); |
+ unsigned DiagID = D.getCustomDiagID( |
+ Diagnostic::Error, "invalid argument '" + args[i] + "'"); |
+ D.Report(DiagID); |
+ return false; |
+ } |
+ } |
+ if (args.size() && args[0] == "help") |
+ PrintHelp(llvm::errs()); |
+ |
+ return true; |
+ } |
+ void PrintHelp(llvm::raw_ostream& ros) { |
+ ros << "Help for TaskToBind plugin goes here\n"; |
+ } |
+ |
+}; |
+ |
+static FrontendPluginRegistry::Add<TaskToBindAction> |
+X("convert-postask", "rewrite PostTask(NRM) to PostTask(Bind)"); |