Index: lib/Transforms/NaCl/RewriteAggFuns.cpp |
diff --git a/lib/Transforms/NaCl/RewriteAggFuns.cpp b/lib/Transforms/NaCl/RewriteAggFuns.cpp |
new file mode 100644 |
index 0000000000000000000000000000000000000000..6d7fe30bf05c393c841562ad47354c020e6efb88 |
--- /dev/null |
+++ b/lib/Transforms/NaCl/RewriteAggFuns.cpp |
@@ -0,0 +1,389 @@ |
+//===- RewriteAggArgs.cpp - Rewrite functions that take aggragate arguments===// |
+// |
+// The LLVM Compiler Infrastructure |
+// |
+// This file is distributed under the University of Illinois Open Source |
+// License. See LICENSE.TXT for details. |
+// |
+//===----------------------------------------------------------------------===// |
+// |
+// These two simple passes convert functions with aggregate types in their |
+// signature into pointers. Neither pass tries to optimize what little generated |
+// IR they produce. |
+// TODO: vectors/structs/arrays smaller than 4 bytes can be represented within a |
+// `i32` and don't need to be put behind a pointer. |
+// |
+//===----------------------------------------------------------------------===// |
+ |
+#include "llvm/Pass.h" |
+#include "llvm/IR/Constants.h" |
+#include "llvm/IR/DataLayout.h" |
+#include "llvm/IR/DebugInfo.h" |
+#include "llvm/IR/Function.h" |
+#include "llvm/IR/Instructions.h" |
+#include "llvm/IR/Module.h" |
+#include "llvm/IR/Type.h" |
+#include "llvm/Support/Format.h" |
+#include "llvm/Transforms/NaCl.h" |
+#include "llvm/Transforms/Utils/Cloning.h" |
+#include "llvm/Transforms/Utils/ValueMapper.h" |
+ |
+using namespace llvm; |
+ |
+namespace { |
+/// Rewrite a function returning an aggregate type to a function with a sret |
+/// argument. |
+struct RewriteAggRet : public ModulePass { |
+ static char ID; |
+ RewriteAggRet() : ModulePass(ID) { |
+ initializeRewriteAggRetPass(*PassRegistry::getPassRegistry()); |
+ } |
+ |
+ inline bool needsRewrite(const Type *Ty) const { |
+ assert(Ty != nullptr); |
+ return Ty->isAggregateType() || Ty->isVectorTy(); |
+ } |
+ |
+ inline bool needsRewrite(const FunctionType *FTy) const { |
+ assert(FTy != nullptr); |
+ return needsRewrite(FTy->getReturnType()); |
+ } |
+ |
+ FunctionType *getNewFTy(FunctionType *FTy) const { |
+ assert(FTy != nullptr); |
+ SmallVector<Type *, 8> Args; |
+ Args.push_back(FTy->getReturnType()->getPointerTo()); |
+ for (auto *Ty : FTy->params()) { |
+ Args.push_back(Ty); |
+ } |
+ |
+ return FunctionType::get(Type::getVoidTy(FTy->getContext()), Args, |
+ FTy->isVarArg()); |
+ } |
+ |
+ bool runOnModule(Module &M) { |
+ bool Changed = false; |
+ SmallVector<Function *, 64> ToDelete; |
+ const DataLayout &DL = M.getDataLayout(); |
+ LLVMContext &C = M.getContext(); |
+ auto DISubprogramMap = makeSubprogramMap(M); |
+ |
+ for (Module::iterator I = M.begin(); I != M.end(); I++) { |
+ Function *OldF = I; |
+ FunctionType *OldFTy = OldF->getFunctionType(); |
+ |
+ FunctionType *NewFTy = OldFTy; |
+ Function *NewF = OldF; |
+ if (needsRewrite(OldFTy)) { |
+ ToDelete.push_back(OldF); |
+ ValueToValueMapTy VMap; |
+ |
+ NewFTy = getNewFTy(OldFTy); |
+ NewF = Function::Create(NewFTy, OldF->getLinkage(), ""); |
+ M.getFunctionList().insert(I, NewF); |
+ |
+ AttributeSet OldAttrs = OldF->getAttributes(); |
+ AttributeSet NewAttrs; |
+ NewAttrs = NewAttrs.addAttribute(C, AttributeSet::ReturnIndex + 1, |
+ Attribute::StructRet); |
+ NewAttrs = NewAttrs.addAttribute(C, AttributeSet::ReturnIndex + 1, |
+ Attribute::NonNull); |
+ NewAttrs = NewAttrs.addAttribute(C, AttributeSet::ReturnIndex + 1, |
+ Attribute::NoCapture); |
+ if (OldFTy->getReturnType()->isSized()) { |
+ NewAttrs = NewAttrs.addDereferenceableAttr( |
+ C, AttributeSet::ReturnIndex + 1, |
+ DL.getTypeAllocSize(OldFTy->getReturnType())); |
+ } |
+ // Move the old attributes to the right one index position: |
+ for (unsigned Index = 0; Index < OldF->arg_size(); Index++) { |
+ NewAttrs = |
+ NewAttrs.addAttributes(C, AttributeSet::ReturnIndex + 2 + Index, |
+ OldAttrs.getParamAttributes(Index)); |
+ } |
+ |
+ Function::arg_iterator SRetArg = NewF->arg_begin(); |
+ for (Function::arg_iterator OldArg = OldF->arg_begin(), |
+ NewArg = ++NewF->arg_begin(); |
+ OldArg != OldF->arg_end(); OldArg++, NewArg++) { |
+ NewArg->setName(OldArg->getName()); |
+ VMap[&*OldArg] = &*NewArg; |
+ } |
+ |
+ NewF->setAttributes(NewAttrs); |
+ |
+ SmallVector<ReturnInst *, 64> Returns; |
+ CloneFunctionInto(NewF, OldF, VMap, false, Returns); |
+ |
+ for (ReturnInst *Ret : Returns) { |
+ // Rewrite each to store to the `sret` argument. |
+ Value *V = Ret->getReturnValue(); |
+ CopyDebug(new StoreInst(V, SRetArg, Ret), Ret); |
+ ReturnInst::Create(C, nullptr, Ret); |
+ Ret->eraseFromParent(); |
+ } |
+ |
+ auto Found = DISubprogramMap.find(OldF); |
+ if (Found != DISubprogramMap.end()) |
+ Found->second->replaceFunction(NewF); |
+ |
+ NewF->takeName(OldF); |
+ OldF->replaceAllUsesWith( |
+ ConstantExpr::getPointerCast(NewF, OldFTy->getPointerTo())); |
+ } |
+ |
+ SmallVector<Instruction *, 64> ToDelete; |
+ for (BasicBlock &BB : *NewF) { |
+ for (Instruction &Inst : BB) { |
+ if (CallInst *Call = dyn_cast<CallInst>(&Inst)) { |
+ // Rewrite the call's return into a parameterized pointer. |
+ Value *Called = Call->getCalledValue(); |
+ FunctionType *CalledFTy = |
+ cast<FunctionType>(Call->getFunctionType()); |
+ if (!needsRewrite(CalledFTy)) { |
+ continue; |
+ } |
+ |
+ FunctionType *NewCalledFTy = getNewFTy(CalledFTy); |
+ |
+ AttributeSet OldAttrs = Call->getAttributes(); |
+ AttributeSet NewAttrs; |
+ NewAttrs = NewAttrs.addAttribute(C, AttributeSet::ReturnIndex + 1, |
+ Attribute::StructRet); |
+ NewAttrs = NewAttrs.addAttribute(C, AttributeSet::ReturnIndex + 1, |
+ Attribute::NonNull); |
+ NewAttrs = NewAttrs.addAttribute(C, AttributeSet::ReturnIndex + 1, |
+ Attribute::NoCapture); |
+ if (OldFTy->getReturnType()->isSized()) { |
+ NewAttrs = NewAttrs.addDereferenceableAttr( |
+ C, AttributeSet::ReturnIndex + 1, |
+ DL.getTypeAllocSize(CalledFTy->getReturnType())); |
+ } |
+ // Move the old attributes to the right one index position: |
+ for (unsigned Index = 0; Index < Call->getNumArgOperands(); |
+ Index++) { |
+ NewAttrs = NewAttrs.addAttributes( |
+ C, AttributeSet::ReturnIndex + 2 + Index, |
+ OldAttrs.getParamAttributes(Index)); |
+ } |
+ |
+ Instruction *Alloca = new AllocaInst( |
+ Call->getType(), nullptr, Call->getName() + ".ret-value", Call); |
+ CopyDebug(Alloca, Call); |
+ |
+ SmallVector<Value *, 8> Args; |
+ Args.push_back(Alloca); |
+ for (Value *Arg : Call->arg_operands()) { |
+ Args.push_back(Arg); |
+ } |
+ |
+ Instruction *BC = CastInst::CreatePointerCast( |
+ Called, NewCalledFTy->getPointerTo(), |
+ Called->getName() + ".fty-cast", Call); |
+ CopyDebug(BC, Call); |
+ CallInst *NewCall = CallInst::Create(BC, Args, "", Call); |
+ CopyDebug(NewCall, Call); |
+ NewCall->setAttributes(NewAttrs); |
+ |
+ Instruction *Load = new LoadInst(Alloca, "", Call); |
+ CopyDebug(Load, Call); |
+ Load->takeName(Call); |
+ Call->replaceAllUsesWith(Load); |
+ ToDelete.push_back(Call); |
+ } |
+ } |
+ } |
+ |
+ for (Instruction *V : ToDelete) { |
+ V->eraseFromParent(); |
+ } |
+ |
+ Changed = true; |
+ } |
+ |
+ for (Function *F : ToDelete) { |
+ F->eraseFromParent(); |
+ } |
+ |
+ return Changed; |
+ } |
+}; |
+ |
+struct RewriteAggArg : public ModulePass { |
+ static char ID; |
+ RewriteAggArg() : ModulePass(ID) { |
+ initializeRewriteAggArgPass(*PassRegistry::getPassRegistry()); |
+ } |
+ |
+ inline bool needsRewrite(const Type *Ty) const { |
+ return Ty->isAggregateType() || Ty->isVectorTy(); |
+ } |
+ |
+ bool needsRewrite(const FunctionType *FTy) const { |
+ for (const auto *Ty : FTy->params()) { |
+ if (needsRewrite(Ty)) { |
+ return true; |
+ } |
+ } |
+ |
+ return false; |
+ } |
+ |
+ FunctionType *getNewFTy(FunctionType *FTy) const { |
+ SmallVector<Type *, 8> Args; |
+ for (auto *Ty : FTy->params()) { |
+ if (needsRewrite(Ty)) { |
+ Args.push_back(Ty->getPointerTo()); |
+ } else { |
+ Args.push_back(Ty); |
+ } |
+ } |
+ |
+ return FunctionType::get(FTy->getReturnType(), Args, FTy->isVarArg()); |
+ } |
+ |
+ bool runOnModule(Module &M) override { |
+ bool Changed = false; |
+ SmallVector<Function *, 64> ToDelete; |
+ LLVMContext &C = M.getContext(); |
+ const DataLayout &DL = M.getDataLayout(); |
+ auto DISubprogramMap = makeSubprogramMap(M); |
+ |
+ for (Module::iterator I = M.begin(); I != M.end(); I++) { |
+ Function *OldF = I; |
+ FunctionType *OldFTy = OldF->getFunctionType(); |
+ Function *NewF = OldF; |
+ FunctionType *NewFTy = OldFTy; |
+ |
+ BasicBlock *Entry = nullptr; |
+ if (needsRewrite(OldFTy)) { |
+ ToDelete.push_back(OldF); |
+ ValueToValueMapTy VMap; |
+ |
+ NewFTy = getNewFTy(OldFTy); |
+ NewF = Function::Create(NewFTy, OldF->getLinkage(), ""); |
+ M.getFunctionList().insert(I, NewF); |
+ Entry = BasicBlock::Create(C, "agg-arg-loads", NewF); |
+ AttributeSet Attrs = OldF->getAttributes(); |
+ unsigned Index = AttributeSet::ReturnIndex + 1; |
+ for (Function::arg_iterator OldArg = OldF->arg_begin(), |
+ NewArg = NewF->arg_begin(); |
+ OldArg != OldF->arg_end(); OldArg++, NewArg++, Index++) { |
+ NewArg->setName(OldArg->getName()); |
+ if (needsRewrite(OldArg->getType())) { |
+ Attrs = Attrs.addAttribute(C, Index, Attribute::NonNull); |
+ Attrs = Attrs.addAttribute(C, Index, Attribute::NoCapture); |
+ if (OldArg->getType()->isSized()) { |
+ Attrs = Attrs.addDereferenceableAttr( |
+ C, Index, DL.getTypeAllocSize(OldArg->getType())); |
+ } |
+ Instruction *Load = |
+ new LoadInst(NewArg, NewArg->getName() + ".load", Entry); |
+ VMap[OldArg] = Load; |
+ } else { |
+ VMap[&*OldArg] = &*NewArg; |
+ } |
+ } |
+ NewF->setAttributes(Attrs); |
+ |
+ { |
+ SmallVector<ReturnInst *, 64> Returns; // unused. |
+ CloneFunctionInto(NewF, OldF, VMap, false, Returns); |
+ } |
+ |
+ auto Found = DISubprogramMap.find(OldF); |
+ if (Found != DISubprogramMap.end()) |
+ Found->second->replaceFunction(NewF); |
+ |
+ NewF->takeName(OldF); |
+ OldF->replaceAllUsesWith( |
+ ConstantExpr::getPointerCast(NewF, OldFTy->getPointerTo())); |
+ } |
+ |
+ SmallVector<Instruction *, 64> ToDelete; |
+ BasicBlock *OriginalEntry = nullptr; |
+ for (BasicBlock &BB : *NewF) { |
+ if (OriginalEntry == nullptr && &BB != Entry) { |
+ OriginalEntry = &BB; |
+ } |
+ for (Instruction &Inst : BB) { |
+ if (CallInst *Call = dyn_cast<CallInst>(&Inst)) { |
+ Value *Called = Call->getCalledValue(); |
+ FunctionType *CalledFTy = |
+ cast<FunctionType>(Call->getFunctionType()); |
+ if (!needsRewrite(CalledFTy)) { |
+ continue; |
+ } |
+ |
+ FunctionType *NewCalledFTy = getNewFTy(CalledFTy); |
+ |
+ // Rewrite the call's arguments into pointers. |
+ |
+ AttributeSet Attrs = Call->getAttributes(); |
+ SmallVector<Value *, 8> Args; |
+ unsigned Index = AttributeSet::ReturnIndex + 1; |
+ for (Value *Arg : Call->arg_operands()) { |
+ if (!needsRewrite(Arg->getType())) { |
+ Args.push_back(Arg); |
+ } else { |
+ Attrs = Attrs.addAttribute(C, Index, Attribute::NonNull); |
+ Attrs = Attrs.addAttribute(C, Index, Attribute::NoCapture); |
+ if (Arg->getType()->isSized()) { |
+ Attrs = Attrs.addDereferenceableAttr( |
+ C, Index, DL.getTypeAllocSize(Arg->getType())); |
+ } |
+ Instruction *Alloca = |
+ new AllocaInst(Arg->getType(), nullptr, |
+ Arg->getName() + ".call-store", Call); |
+ CopyDebug(Alloca, Call); |
+ CopyDebug(new StoreInst(Arg, Alloca, Call), Call); |
+ Args.push_back(Alloca); |
+ } |
+ Index++; |
+ } |
+ |
+ Instruction *BC = CastInst::CreatePointerCast( |
+ Called, NewCalledFTy->getPointerTo(), |
+ Called->getName() + ".fty-cast", Call); |
+ CopyDebug(BC, Call); |
+ CallInst *NewCall = CallInst::Create(BC, Args, "", Call); |
+ NewCall->takeName(Call); |
+ CopyDebug(NewCall, Call); |
+ NewCall->setAttributes(Attrs); |
+ Call->replaceAllUsesWith(NewCall); |
+ ToDelete.push_back(Call); |
+ } |
+ } |
+ } |
+ |
+ for (Instruction *V : ToDelete) { |
+ V->eraseFromParent(); |
+ } |
+ |
+ if (needsRewrite(OldFTy)) { |
+ BranchInst::Create(OriginalEntry, Entry); |
+ } |
+ } |
+ |
+ for (Function *F : ToDelete) { |
+ F->eraseFromParent(); |
+ Changed = true; |
+ } |
+ |
+ return Changed; |
+ } |
+}; |
+} |
+ |
+char RewriteAggRet::ID = 0; |
+INITIALIZE_PASS(RewriteAggRet, "rewrite-aggregate-returns", |
+ "Convert functions which return aggregate types to use `sret`", |
+ false, false) |
+ModulePass *llvm::createRewriteAggRetPass() { return new RewriteAggRet(); } |
+ |
+char RewriteAggArg::ID = 0; |
+INITIALIZE_PASS( |
+ RewriteAggArg, "rewrite-aggregate-arguments", |
+ "Convert functions which have aggregate arguments to use pointers", false, |
+ false) |
+ModulePass *llvm::createRewriteAggArgPass() { return new RewriteAggArg(); } |