OLD | NEW |
(Empty) | |
| 1 //===- SandboxIndirectCalls.cpp - Add CFI to indirect function calls-------===// |
| 2 // |
| 3 // The LLVM Compiler Infrastructure |
| 4 // |
| 5 // This file is distributed under the University of Illinois Open Source |
| 6 // License. See LICENSE.TXT for details. |
| 7 // |
| 8 //===----------------------------------------------------------------------===// |
| 9 // |
| 10 // XXX |
| 11 // |
| 12 //===----------------------------------------------------------------------===// |
| 13 |
| 14 #include "llvm/IR/Constants.h" |
| 15 // #include "llvm/IR/Function.h" |
| 16 #include "llvm/IR/Instructions.h" |
| 17 // #include "llvm/IR/Intrinsics.h" |
| 18 #include "llvm/IR/Module.h" |
| 19 // #include "llvm/IR/Type.h" |
| 20 #include "llvm/Pass.h" |
| 21 #include "llvm/Support/raw_ostream.h" |
| 22 #include "llvm/Transforms/NaCl.h" |
| 23 |
| 24 using namespace llvm; |
| 25 |
| 26 namespace { |
| 27 // This is a ModulePass so that XXX... |
| 28 class SandboxIndirectCalls : public ModulePass { |
| 29 public: |
| 30 static char ID; // Pass identification, replacement for typeid |
| 31 SandboxIndirectCalls() : ModulePass(ID) { |
| 32 initializeSandboxIndirectCallsPass(*PassRegistry::getPassRegistry()); |
| 33 } |
| 34 |
| 35 virtual bool runOnModule(Module &M); |
| 36 }; |
| 37 } |
| 38 |
| 39 char SandboxIndirectCalls::ID = 0; |
| 40 INITIALIZE_PASS(SandboxIndirectCalls, "sandbox-indirect-calls", |
| 41 "Add CFI to indirect function calls", |
| 42 false, false) |
| 43 |
| 44 bool SandboxIndirectCalls::runOnModule(Module &M) { |
| 45 Type *I32 = Type::getInt32Ty(M.getContext()); |
| 46 Type *IntPtrType = I32; // XXX |
| 47 PointerType *PtrType = Type::getInt8Ty(M.getContext())->getPointerTo(); |
| 48 |
| 49 SmallVector<Constant *, 20> FuncTable; |
| 50 // Reserve index 0. |
| 51 FuncTable.push_back(ConstantPointerNull::get(PtrType)); |
| 52 |
| 53 // Build a function table out of address-taken functions. |
| 54 for (Module::iterator Func = M.begin(), E = M.end(); Func != E; ++Func) { |
| 55 // Look for address-taking references to the function. |
| 56 SmallVector<User *, 10> Users; |
| 57 for (Value::use_iterator U = Func->use_begin(), E = Func->use_end(); |
| 58 U != E; ++U) { |
| 59 if (CallInst *Call = dyn_cast<CallInst>(*U)) { |
| 60 // In PNaCl's normal form, a function referenced by a CallInst |
| 61 // can only appear as the callee, not an argument. |
| 62 if (U.getOperandNo() != Call->getNumArgOperands()) { |
| 63 errs() << "Value: " << **U << "\n"; |
| 64 report_fatal_error("SandboxIndirectCalls: Bad function reference"); |
| 65 } |
| 66 } else { |
| 67 // In PNaCl's normal form, all other references are PtrToInt |
| 68 // instructions or ConstantExprs. |
| 69 if (!(isa<PtrToIntInst>(*U) || |
| 70 (isa<ConstantExpr>(*U) && |
| 71 cast<ConstantExpr>(*U)->getOpcode() == Instruction::PtrToInt))) { |
| 72 errs() << "Value: " << **U << "\n"; |
| 73 report_fatal_error("SandboxIndirectCalls: Bad function reference"); |
| 74 } |
| 75 Users.push_back(*U); |
| 76 } |
| 77 } |
| 78 |
| 79 // If the function is address-taken, allocate it an ID by adding |
| 80 // it to the function table. |
| 81 if (!Users.empty()) { |
| 82 Value *FuncIndex = ConstantInt::get(IntPtrType, FuncTable.size()); |
| 83 // XXX: Remove bitcast when we use multiple tables. |
| 84 FuncTable.push_back(ConstantExpr::getBitCast(Func, PtrType)); |
| 85 |
| 86 for (SmallVectorImpl<User *>::iterator U = Users.begin(), E = Users.end(); |
| 87 U != E; ++U) { |
| 88 (*U)->replaceAllUsesWith(FuncIndex); |
| 89 // XXX: assumes cast is only used once. |
| 90 if (Instruction *Inst = dyn_cast<PtrToIntInst>(*U)) |
| 91 Inst->eraseFromParent(); |
| 92 } |
| 93 } |
| 94 } |
| 95 |
| 96 Constant *TableArray = |
| 97 ConstantArray::get(ArrayType::get(PtrType, FuncTable.size()), FuncTable); |
| 98 Value *FuncTableGV = new GlobalVariable( |
| 99 M, TableArray->getType(), /*isConstant=*/true, |
| 100 GlobalVariable::InternalLinkage, TableArray, |
| 101 "__sfi_function_table"); |
| 102 |
| 103 // Convert indirect function call instructions. |
| 104 for (Module::iterator Func = M.begin(), E = M.end(); Func != E; ++Func) { |
| 105 for (Function::iterator BB = Func->begin(), E = Func->end(); |
| 106 BB != E; |
| 107 ++BB) { |
| 108 for (BasicBlock::iterator Inst = BB->begin(), E = BB->end(); |
| 109 Inst != E; ++Inst) { |
| 110 assert(!isa<InvokeInst>(Inst)); |
| 111 if (CallInst *Call = dyn_cast<CallInst>(Inst)) { |
| 112 Value *Callee = Call->getCalledValue(); |
| 113 if (!isa<Function>(Callee)) { |
| 114 // assert... |
| 115 IntToPtrInst *Cast = cast<IntToPtrInst>(Callee); |
| 116 Value *FuncIndex = Cast->getOperand(0); |
| 117 |
| 118 Value *Indexes[] = { |
| 119 ConstantInt::get(I32, 0), |
| 120 FuncIndex |
| 121 }; |
| 122 Value *Ptr = GetElementPtrInst::Create( |
| 123 FuncTableGV, Indexes, "func_gep", Call); |
| 124 Value *FuncPtr = new LoadInst(Ptr, "func", Call); |
| 125 // XXX: Remove bitcast when we use multiple tables. |
| 126 Value *Bitcast = new BitCastInst(FuncPtr, Cast->getType(), |
| 127 "func_bc", Call); |
| 128 Call->setCalledFunction(Bitcast); |
| 129 // XXX: assumes cast is only used once. |
| 130 Cast->eraseFromParent(); |
| 131 } |
| 132 } |
| 133 } |
| 134 } |
| 135 } |
| 136 |
| 137 return true; |
| 138 } |
| 139 |
| 140 ModulePass *llvm::createSandboxIndirectCallsPass() { |
| 141 return new SandboxIndirectCalls(); |
| 142 } |
OLD | NEW |