OLD | NEW |
---|---|
(Empty) | |
1 // Copyright 2011 Google Inc. All Rights Reserved. | |
2 // Author: ajwong@google.com (Albert Wong) | |
3 | |
4 #include <iostream> | |
5 #include <string> | |
6 | |
7 #include "clang/AST/AST.h" | |
8 #include "clang/AST/ASTConsumer.h" | |
9 #include "clang/AST/Decl.h" | |
10 #include "clang/AST/DeclGroup.h" | |
11 #include "clang/AST/DeclVisitor.h" | |
12 #include "clang/AST/DeclarationName.h" | |
13 #include "clang/AST/ExprCXX.h" | |
14 #include "clang/AST/ParentMap.h" | |
15 #include "clang/AST/StmtVisitor.h" | |
16 #include "clang/AST/RecursiveASTVisitor.h" | |
17 #include "clang/AST/TypeLocVisitor.h" | |
18 #include "clang/Basic/Diagnostic.h" | |
19 #include "clang/Basic/IdentifierTable.h" | |
20 #include "clang/Basic/SourceManager.h" | |
21 #include "clang/Frontend/FrontendPluginRegistry.h" | |
22 #include "clang/Index/DeclReferenceMap.h" | |
23 #include "clang/Lex/Lexer.h" | |
24 #include "clang/Rewrite/ASTConsumers.h" | |
25 #include "clang/Rewrite/Rewriter.h" | |
26 #include "llvm/ADT/DenseSet.h" | |
27 #include "llvm/ADT/OwningPtr.h" | |
28 #include "llvm/ADT/SmallPtrSet.h" | |
29 #include "llvm/ADT/StringExtras.h" | |
30 #include "llvm/Support/MemoryBuffer.h" | |
31 #include "llvm/Support/raw_ostream.h" | |
32 #include "clang/Frontend/CompilerInstance.h" | |
33 #include "llvm/Support/raw_ostream.h" | |
34 | |
35 using namespace clang; | |
36 using namespace std; | |
37 | |
38 class PostTaskVisitor : public RecursiveASTVisitor<PostTaskVisitor> { | |
39 public: | |
40 PostTaskVisitor(ASTContext* context, Rewriter* rewriter) | |
41 : context_(context), rewriter_(rewriter) { | |
42 interesting_classes_.insert("MessageLoop"); | |
43 interesting_classes_.insert("MessageLoopProxy"); | |
44 } | |
45 | |
46 bool TraverseStmt(Stmt *S) { | |
47 // Catch the MessageLoop and MessageLoopProxy calls. | |
48 if (CXXMemberCallExpr* mce = dyn_cast_or_null<CXXMemberCallExpr>(S)) { | |
49 if (IsPostTaskExpr(mce)) { | |
50 if (mce->getNumArgs() < 2) { | |
51 llvm::errs() << "PostTask with less than 2 args?! Inconceivable!"; | |
Nico
2011/09/15 02:50:03
You can emit diagnostics instead, that way the com
awong
2011/09/16 02:50:41
Done.
| |
52 return false; | |
53 } | |
54 MaybeRewriteNewRunnableMethod(mce->getArgs()[1]); | |
55 } | |
56 } | |
57 return RecursiveASTVisitor<PostTaskVisitor>::TraverseStmt(S); | |
58 } | |
59 | |
60 private: | |
61 bool IsPostTaskExpr(CXXMemberCallExpr* call) { | |
62 CXXMethodDecl* method_decl = call->getMethodDecl()->getCanonicalDecl(); | |
63 if (kPostTaskName == method_decl->getNameAsString()) { | |
Nico
2011/09/15 02:50:03
Alternatively, you could use the Resolver stuff in
awong
2011/09/16 02:50:41
Interesting. That is cleaner. Added TODO.
| |
64 string classname = method_decl->getThisType(*context_) | |
65 .getBaseTypeIdentifier()->getName(); | |
66 llvm::outs() << "Found a posttask. Name of class: " | |
67 << classname | |
68 << "\n"; | |
69 if (interesting_classes_.find(classname) != interesting_classes_.end()) { | |
Nico
2011/09/15 02:50:03
nit: .count(classname) > 0
awong
2011/09/16 02:50:41
Done.
| |
70 return true; | |
71 } | |
72 } | |
73 return false; | |
74 } | |
75 | |
76 bool MaybeRewriteNewRunnableMethod(Expr* post_task_arg) { | |
77 // Strip implicit casts. | |
Nico
2011/09/15 02:50:03
Instead:
post_task_arg = post_task_arg->IgnoreImp
awong
2011/09/16 02:50:41
Oh very nice. Went with IgnoreImplicit()->IgnoreP
| |
78 if (ImplicitCastExpr* ice = dyn_cast<ImplicitCastExpr>(post_task_arg)) { | |
79 return MaybeRewriteNewRunnableMethod(ice->getSubExpr()); | |
80 } | |
81 | |
82 CallExpr* ce = dyn_cast<CallExpr>(post_task_arg); | |
83 if (!ce) return false; | |
84 FunctionDecl *fd = ce->getDirectCallee(); | |
85 if (!fd) return false; | |
86 if (kNewRunnableMethodName != | |
87 fd->getNameInfo().getName().getAsIdentifierInfo()->getName()) | |
88 return false; | |
89 | |
90 if (ce->getNumArgs() < 2) { | |
91 llvm::errs() << "NewRunnableMethod with less than 2 args?! Inconceivable! "; | |
92 return false; | |
93 } | |
94 | |
95 // Okay, here's where it gets fun. We need to | |
96 // (1) replace the NRM identifier text with base::Bind. | |
97 // (2) Swap positions of the first two arguments. | |
98 | |
99 // NewRunnableMethod -> base::Bind | |
100 rewriter_->ReplaceText(ce->getCallee()->getSourceRange(), "base::Bind"); | |
101 | |
102 // Swap the argument order. | |
103 Expr* arg1 = ce->getArgs()[0]; | |
104 Expr* arg2 = ce->getArgs()[1]; | |
105 bool failure = rewriter_->ReplaceStmt(arg1, arg2); | |
106 failure |= rewriter_->ReplaceStmt(arg2, arg1); | |
107 | |
108 // TODO(ajwong): | |
109 // (3) Check the RunnableMethodTraits for the second argument, and | |
110 // wrap base::Unretained() if it is declared with | |
111 // DISABLE_RUNNABLE_METHOD_REFCOUNT. | |
112 return failure; | |
113 } | |
114 | |
115 ASTContext* context_; | |
116 Rewriter* rewriter_; | |
117 | |
118 // Names for matching. | |
119 set<string> interesting_classes_; | |
120 static const char kPostTaskName[]; | |
121 static const char kNewRunnableMethodName[]; | |
122 }; | |
123 | |
124 const char PostTaskVisitor::kPostTaskName[] = "PostTask"; | |
125 const char PostTaskVisitor::kNewRunnableMethodName[] = "NewRunnableMethod"; | |
126 | |
127 class TaskToBindConsumer : public ASTConsumer { | |
128 public: | |
129 TaskToBindConsumer() { | |
130 } | |
131 | |
132 virtual void Initialize(ASTContext &context) { | |
133 rewriter_.setSourceMgr(context.getSourceManager(), context.getLangOptions()) ; | |
134 } | |
135 | |
136 virtual void HandleTranslationUnit(ASTContext &context) { | |
137 // modify calls | |
138 TranslationUnitDecl *translation_unit = context.getTranslationUnitDecl(); | |
139 for (DeclContext::decl_iterator it = translation_unit->decls_begin(), | |
140 end = translation_unit->decls_end(); it != end; ++it) { | |
Nico
2011/09/15 02:50:03
You could check
context.getSourceManager()->isF
awong
2011/09/16 02:50:41
Added a note. I think I'm going to try to make th
| |
141 PostTaskVisitor visitor(&context, &rewriter_); | |
142 visitor.TraverseDecl(*it); | |
143 } | |
144 | |
145 // Get the buffer corresponding to MainFileID. | |
146 // If we haven't changed it, then we are done. | |
147 if (rewriter_.buffer_begin() != rewriter_.buffer_end()) { | |
148 llvm::outs() << "Src file changed.\n"; | |
149 } else { | |
150 llvm::errs() << "No changes.\n"; | |
151 } | |
152 | |
153 for (Rewriter::buffer_iterator it = rewriter_.buffer_begin(); | |
154 it != rewriter_.buffer_end(); | |
155 ++it) { | |
156 llvm::outs() << std::string(it->second.begin(), it->second.end()); | |
157 } | |
158 | |
159 std::cout.flush(); | |
160 } | |
161 | |
162 private: | |
163 FileID source_file_; | |
164 Rewriter rewriter_; | |
165 }; | |
166 | |
167 class TaskToBindAction : public PluginASTAction { | |
168 protected: | |
169 ASTConsumer *CreateASTConsumer(CompilerInstance &CI, llvm::StringRef) { | |
170 return new TaskToBindConsumer(); | |
171 } | |
172 | |
173 bool ParseArgs(const CompilerInstance &CI, | |
174 const std::vector<std::string>& args) { | |
175 for (unsigned i = 0, e = args.size(); i != e; ++i) { | |
176 llvm::errs() << "TaskToBind arg = " << args[i] << "\n"; | |
177 | |
178 // Example error handling. | |
179 if (args[i] == "-an-error") { | |
180 Diagnostic &D = CI.getDiagnostics(); | |
181 unsigned DiagID = D.getCustomDiagID( | |
182 Diagnostic::Error, "invalid argument '" + args[i] + "'"); | |
183 D.Report(DiagID); | |
184 return false; | |
185 } | |
186 } | |
187 if (args.size() && args[0] == "help") | |
188 PrintHelp(llvm::errs()); | |
189 | |
190 return true; | |
191 } | |
192 void PrintHelp(llvm::raw_ostream& ros) { | |
193 ros << "Help for TaskToBind plugin goes here\n"; | |
194 } | |
195 | |
196 }; | |
197 | |
198 static FrontendPluginRegistry::Add<TaskToBindAction> | |
199 X("convert-postask", "rewrite PostTask(NRM) to PostTask(Bind)"); | |
OLD | NEW |