Index: tools/clang/rewrite_scoped_refptr/RewriteScopedRefptr.cpp |
diff --git a/tools/clang/rewrite_scoped_refptr/RewriteScopedRefptr.cpp b/tools/clang/rewrite_scoped_refptr/RewriteScopedRefptr.cpp |
index 3f83cec2fc556186a8af728bfe0b16ee4f9b9e2b..1365aaf8d1a15fdc3968ffa89c5518b6e75916a0 100644 |
--- a/tools/clang/rewrite_scoped_refptr/RewriteScopedRefptr.cpp |
+++ b/tools/clang/rewrite_scoped_refptr/RewriteScopedRefptr.cpp |
@@ -155,6 +155,47 @@ void GetRewriterCallback::run(const MatchFinder::MatchResult& result) { |
replacements_->insert(Replacement(*result.SourceManager, range, text)); |
} |
+class VarRewriterCallback : public MatchFinder::MatchCallback { |
+ public: |
+ explicit VarRewriterCallback(Replacements* replacements) |
+ : replacements_(replacements) {} |
+ virtual void run(const MatchFinder::MatchResult& result) override; |
+ |
+ private: |
+ Replacements* const replacements_; |
+}; |
+ |
+void VarRewriterCallback::run(const MatchFinder::MatchResult& result) { |
+ const clang::CXXMemberCallExpr* const implicit_call = |
+ result.Nodes.getNodeAs<clang::CXXMemberCallExpr>("call"); |
+ const clang::DeclaratorDecl* const var_decl = |
+ result.Nodes.getNodeAs<clang::DeclaratorDecl>("var"); |
+ |
+ if (!implicit_call || !var_decl) |
+ return; |
+ |
+ const clang::TypeSourceInfo* tsi = var_decl->getTypeSourceInfo(); |
+ |
+ clang::CharSourceRange range = clang::CharSourceRange::getTokenRange( |
+ result.SourceManager->getSpellingLoc(tsi->getTypeLoc().getBeginLoc()), |
+ result.SourceManager->getSpellingLoc(tsi->getTypeLoc().getEndLoc())); |
+ if (!range.isValid()) |
+ return; |
+ |
+ std::string text = clang::Lexer::getSourceText( |
+ range, *result.SourceManager, result.Context->getLangOpts()); |
+ if (text.empty()) |
+ return; |
+ text.erase(text.rfind('*')); |
+ |
+ std::string replacement_text("scoped_refptr<"); |
+ replacement_text += text; |
+ replacement_text += ">"; |
+ |
+ replacements_->insert( |
+ Replacement(*result.SourceManager, range, replacement_text)); |
+} |
+ |
} // namespace |
static llvm::cl::extrahelp common_help(CommonOptionsParser::HelpMessage); |
@@ -166,25 +207,41 @@ int main(int argc, const char* argv[]) { |
options.getSourcePathList()); |
MatchFinder match_finder; |
+ Replacements replacements; |
// Finds all calls to conversion operator member function. This catches calls |
// to "operator T*", "operator Testable", and "operator bool" equally. |
- StatementMatcher overloaded_call_matcher = memberCallExpr( |
+ auto base_matcher = memberCallExpr( |
thisPointerType(recordDecl(isSameOrDerivedFrom("::scoped_refptr"), |
isTemplateInstantiation())), |
- callee(conversionDecl()), |
- on(id("arg", expr()))); |
+ callee(conversionDecl())); |
+ |
+ // The heuristic for whether or not a conversion is 'unsafe'. An unsafe |
+ // conversion is one where a temporary scoped_refptr<T> is converted to |
+ // another type. The matcher provides an exception for a temporary |
+ // scoped_refptr that is the result of an operator call. In this case, assume |
+ // that it's the result of an iterator dereference, and the container itself |
+ // retains the necessary reference, since this is a common idiom to see in |
+ // loop bodies. |
+ auto is_unsafe_conversion = |
+ bindTemporaryExpr(unless(has(operatorCallExpr()))); |
+ |
+ auto safe_conversion_matcher = memberCallExpr( |
+ base_matcher, on(id("arg", expr(unless(is_unsafe_conversion))))); |
+ |
+ auto unsafe_conversion_matcher = |
+ memberCallExpr(base_matcher, on(id("arg", is_unsafe_conversion))); |
// This catches both user-defined conversions (eg: "operator bool") and |
// standard conversion sequence (C++03 13.3.3.1.1), such as converting a |
// pointer to a bool. |
- StatementMatcher implicit_to_bool = |
+ auto implicit_to_bool = |
implicitCastExpr(hasImplicitDestinationType(isBoolean())); |
// Avoid converting calls to of "operator Testable" -> "bool" and calls of |
// "operator T*" -> "bool". |
- StatementMatcher bool_conversion_matcher = hasParent(expr( |
- anyOf(expr(implicit_to_bool), expr(hasParent(expr(implicit_to_bool)))))); |
+ auto bool_conversion_matcher = hasParent( |
+ expr(anyOf(implicit_to_bool, expr(hasParent(implicit_to_bool))))); |
// Find all calls to an operator overload that do NOT (ultimately) result in |
// being cast to a bool - eg: where it's being converted to T* and rewrite |
@@ -193,29 +250,23 @@ int main(int argc, const char* argv[]) { |
// All bool conversions will be handled with the Testable trick, but that |
// can only be used once "operator T*" is removed, since otherwise it leaves |
// the call ambiguous. |
- Replacements get_replacements; |
- GetRewriterCallback get_callback(&get_replacements); |
- match_finder.addMatcher(id("call", expr(overloaded_call_matcher)), |
- &get_callback); |
- |
-#if 0 |
- // Finds all temporary scoped_refptr<T>'s being assigned to a T*. Note that |
- // this will result in two callbacks--both the above callback to append get() |
- // and this callback will match. |
+ GetRewriterCallback get_callback(&replacements); |
+ match_finder.addMatcher(id("call", safe_conversion_matcher), &get_callback); |
+ |
+ // Find temporary scoped_refptr<T>'s being unsafely assigned to a T*. |
+ VarRewriterCallback var_callback(&replacements); |
match_finder.addMatcher( |
id("var", |
- varDecl(hasInitializer(ignoringImpCasts( |
- id("call", expr(overloaded_call_matcher)))), |
+ varDecl(hasInitializer(ignoringImpCasts(exprWithCleanups( |
+ has(id("call", unsafe_conversion_matcher))))), |
hasType(pointerType()))), |
- &callback); |
+ &var_callback); |
match_finder.addMatcher( |
- binaryOperator( |
- hasOperatorName("="), |
- hasLHS(declRefExpr(to(id("var", varDecl(hasType(pointerType())))))), |
- hasRHS(ignoringParenImpCasts( |
- id("call", expr(overloaded_call_matcher))))), |
- &callback); |
-#endif |
+ constructorDecl(forEachConstructorInitializer(allOf( |
+ withInitializer(ignoringImpCasts( |
+ exprWithCleanups(has(id("call", unsafe_conversion_matcher))))), |
+ forField(id("var", fieldDecl(hasType(pointerType()))))))), |
+ &var_callback); |
std::unique_ptr<clang::tooling::FrontendActionFactory> factory = |
clang::tooling::newFrontendActionFactory(&match_finder); |
@@ -225,7 +276,7 @@ int main(int argc, const char* argv[]) { |
// Serialization format is documented in tools/clang/scripts/run_tool.py |
llvm::outs() << "==== BEGIN EDITS ====\n"; |
- for (const auto& r : get_replacements) { |
+ for (const auto& r : replacements) { |
std::string replacement_text = r.getReplacementText().str(); |
std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0'); |
llvm::outs() << "r:" << r.getFilePath() << ":" << r.getOffset() << ":" |