OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "chrome_frame/dll_redirector.h" | 5 #include "chrome_frame/dll_redirector.h" |
6 | 6 |
7 #include <aclapi.h> | 7 #include <aclapi.h> |
8 #include <atlbase.h> | 8 #include <atlbase.h> |
9 #include <atlsecurity.h> | 9 #include <atlsecurity.h> |
10 #include <sddl.h> | 10 #include <sddl.h> |
(...skipping 105 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
116 } | 116 } |
117 | 117 |
118 return success; | 118 return success; |
119 } | 119 } |
120 | 120 |
121 | 121 |
122 bool DllRedirector::RegisterAsFirstCFModule() { | 122 bool DllRedirector::RegisterAsFirstCFModule() { |
123 DCHECK(first_module_handle_ == NULL); | 123 DCHECK(first_module_handle_ == NULL); |
124 | 124 |
125 // Build our own file version outside of the lock: | 125 // Build our own file version outside of the lock: |
126 scoped_ptr<Version> our_version(GetCurrentModuleVersion()); | 126 scoped_ptr<base::Version> our_version(GetCurrentModuleVersion()); |
127 | 127 |
128 // We sadly can't use the autolock here since we want to have a timeout. | 128 // We sadly can't use the autolock here since we want to have a timeout. |
129 // Be careful not to return while holding the lock. Also, attempt to do as | 129 // Be careful not to return while holding the lock. Also, attempt to do as |
130 // little as possible while under this lock. | 130 // little as possible while under this lock. |
131 | 131 |
132 bool lock_acquired = false; | 132 bool lock_acquired = false; |
133 CSecurityAttributes sec_attr; | 133 CSecurityAttributes sec_attr; |
134 if (base::win::GetVersion() >= base::win::VERSION_VISTA && | 134 if (base::win::GetVersion() >= base::win::VERSION_VISTA && |
135 BuildSecurityAttributesForLock(&sec_attr)) { | 135 BuildSecurityAttributesForLock(&sec_attr)) { |
136 // On vista and above, we need to explicitly allow low integrity access | 136 // On vista and above, we need to explicitly allow low integrity access |
(...skipping 43 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
180 if (created_beacon) { | 180 if (created_beacon) { |
181 dll_version_.swap(our_version); | 181 dll_version_.swap(our_version); |
182 | 182 |
183 lstrcpynA(reinterpret_cast<char*>(shared_memory_->memory()), | 183 lstrcpynA(reinterpret_cast<char*>(shared_memory_->memory()), |
184 dll_version_->GetString().c_str(), | 184 dll_version_->GetString().c_str(), |
185 std::min(kSharedMemorySize, | 185 std::min(kSharedMemorySize, |
186 dll_version_->GetString().length() + 1)); | 186 dll_version_->GetString().length() + 1)); |
187 } else { | 187 } else { |
188 char buffer[kSharedMemorySize] = {0}; | 188 char buffer[kSharedMemorySize] = {0}; |
189 memcpy(buffer, shared_memory_->memory(), kSharedMemorySize - 1); | 189 memcpy(buffer, shared_memory_->memory(), kSharedMemorySize - 1); |
190 dll_version_.reset(new Version(buffer)); | 190 dll_version_.reset(new base::Version(buffer)); |
191 | 191 |
192 if (!dll_version_->IsValid() || | 192 if (!dll_version_->IsValid() || |
193 dll_version_->Equals(*our_version.get())) { | 193 dll_version_->Equals(*our_version.get())) { |
194 // If we either couldn't parse a valid version out of the shared | 194 // If we either couldn't parse a valid version out of the shared |
195 // memory or we did parse a version and it is the same as our own, | 195 // memory or we did parse a version and it is the same as our own, |
196 // then pretend we're first in to avoid trying to load any other DLLs. | 196 // then pretend we're first in to avoid trying to load any other DLLs. |
197 dll_version_.reset(our_version.release()); | 197 dll_version_.reset(our_version.release()); |
198 created_beacon = true; | 198 created_beacon = true; |
199 } | 199 } |
200 } | 200 } |
(...skipping 30 matching lines...) Expand all Loading... |
231 if (first_module_handle) { | 231 if (first_module_handle) { |
232 proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( | 232 proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( |
233 GetProcAddress(first_module_handle, "DllGetClassObject")); | 233 GetProcAddress(first_module_handle, "DllGetClassObject")); |
234 DPLOG_IF(ERROR, !proc_ptr) << "DllRedirector: Could not get address of " | 234 DPLOG_IF(ERROR, !proc_ptr) << "DllRedirector: Could not get address of " |
235 "DllGetClassObject from first loaded module."; | 235 "DllGetClassObject from first loaded module."; |
236 } | 236 } |
237 | 237 |
238 return proc_ptr; | 238 return proc_ptr; |
239 } | 239 } |
240 | 240 |
241 Version* DllRedirector::GetCurrentModuleVersion() { | 241 base::Version* DllRedirector::GetCurrentModuleVersion() { |
242 scoped_ptr<FileVersionInfo> file_version_info( | 242 scoped_ptr<FileVersionInfo> file_version_info( |
243 FileVersionInfo::CreateFileVersionInfoForCurrentModule()); | 243 FileVersionInfo::CreateFileVersionInfoForCurrentModule()); |
244 DCHECK(file_version_info.get()); | 244 DCHECK(file_version_info.get()); |
245 | 245 |
246 scoped_ptr<Version> current_version; | 246 scoped_ptr<base::Version> current_version; |
247 if (file_version_info.get()) { | 247 if (file_version_info.get()) { |
248 current_version.reset( | 248 current_version.reset( |
249 new Version(WideToASCII(file_version_info->file_version()))); | 249 new base::Version(WideToASCII(file_version_info->file_version()))); |
250 DCHECK(current_version->IsValid()); | 250 DCHECK(current_version->IsValid()); |
251 } | 251 } |
252 | 252 |
253 return current_version.release(); | 253 return current_version.release(); |
254 } | 254 } |
255 | 255 |
256 HMODULE DllRedirector::GetFirstModule() { | 256 HMODULE DllRedirector::GetFirstModule() { |
257 DCHECK(dll_version_.get()) | 257 DCHECK(dll_version_.get()) |
258 << "Error: Did you call RegisterAsFirstCFModule() first?"; | 258 << "Error: Did you call RegisterAsFirstCFModule() first?"; |
259 | 259 |
260 if (first_module_handle_ == NULL) { | 260 if (first_module_handle_ == NULL) { |
261 first_module_handle_ = LoadVersionedModule(dll_version_.get()); | 261 first_module_handle_ = LoadVersionedModule(dll_version_.get()); |
262 } | 262 } |
263 | 263 |
264 if (first_module_handle_ == reinterpret_cast<HMODULE>(&__ImageBase)) { | 264 if (first_module_handle_ == reinterpret_cast<HMODULE>(&__ImageBase)) { |
265 NOTREACHED() << "Should not be loading own version."; | 265 NOTREACHED() << "Should not be loading own version."; |
266 first_module_handle_ = NULL; | 266 first_module_handle_ = NULL; |
267 } | 267 } |
268 | 268 |
269 return first_module_handle_; | 269 return first_module_handle_; |
270 } | 270 } |
271 | 271 |
272 HMODULE DllRedirector::LoadVersionedModule(Version* version) { | 272 HMODULE DllRedirector::LoadVersionedModule(base::Version* version) { |
273 DCHECK(version); | 273 DCHECK(version); |
274 | 274 |
275 HMODULE hmodule = NULL; | 275 HMODULE hmodule = NULL; |
276 wchar_t system_buffer[MAX_PATH]; | 276 wchar_t system_buffer[MAX_PATH]; |
277 HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase); | 277 HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase); |
278 system_buffer[0] = 0; | 278 system_buffer[0] = 0; |
279 if (GetModuleFileName(this_module, system_buffer, | 279 if (GetModuleFileName(this_module, system_buffer, |
280 arraysize(system_buffer)) != 0) { | 280 arraysize(system_buffer)) != 0) { |
281 base::FilePath module_path(system_buffer); | 281 base::FilePath module_path(system_buffer); |
282 | 282 |
283 // For a module located in | 283 // For a module located in |
284 // Foo\XXXXXXXXX\<module>.dll, load | 284 // Foo\XXXXXXXXX\<module>.dll, load |
285 // Foo\<version>\<module>.dll: | 285 // Foo\<version>\<module>.dll: |
286 base::FilePath module_name = module_path.BaseName(); | 286 base::FilePath module_name = module_path.BaseName(); |
287 module_path = module_path.DirName() | 287 module_path = module_path.DirName() |
288 .DirName() | 288 .DirName() |
289 .Append(ASCIIToWide(version->GetString())) | 289 .Append(ASCIIToWide(version->GetString())) |
290 .Append(module_name); | 290 .Append(module_name); |
291 | 291 |
292 hmodule = LoadLibrary(module_path.value().c_str()); | 292 hmodule = LoadLibrary(module_path.value().c_str()); |
293 if (hmodule == NULL) { | 293 if (hmodule == NULL) { |
294 DPLOG(ERROR) << "Could not load reported module version " | 294 DPLOG(ERROR) << "Could not load reported module version " |
295 << version->GetString(); | 295 << version->GetString(); |
296 } | 296 } |
297 } else { | 297 } else { |
298 DPLOG(FATAL) << "Failed to get module file name"; | 298 DPLOG(FATAL) << "Failed to get module file name"; |
299 } | 299 } |
300 return hmodule; | 300 return hmodule; |
301 } | 301 } |
OLD | NEW |