OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "chrome_frame/dll_redirector.h" | 5 #include "chrome_frame/dll_redirector.h" |
6 | 6 |
7 #include <aclapi.h> | 7 #include <aclapi.h> |
8 #include <atlbase.h> | 8 #include <atlbase.h> |
9 #include <atlsecurity.h> | 9 #include <atlsecurity.h> |
10 #include <sddl.h> | 10 #include <sddl.h> |
(...skipping 106 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
117 } | 117 } |
118 | 118 |
119 return success; | 119 return success; |
120 } | 120 } |
121 | 121 |
122 | 122 |
123 bool DllRedirector::RegisterAsFirstCFModule() { | 123 bool DllRedirector::RegisterAsFirstCFModule() { |
124 DCHECK(first_module_handle_ == NULL); | 124 DCHECK(first_module_handle_ == NULL); |
125 | 125 |
126 // Build our own file version outside of the lock: | 126 // Build our own file version outside of the lock: |
127 scoped_ptr<base::Version> our_version(GetCurrentModuleVersion()); | 127 scoped_ptr<Version> our_version(GetCurrentModuleVersion()); |
128 | 128 |
129 // We sadly can't use the autolock here since we want to have a timeout. | 129 // We sadly can't use the autolock here since we want to have a timeout. |
130 // Be careful not to return while holding the lock. Also, attempt to do as | 130 // Be careful not to return while holding the lock. Also, attempt to do as |
131 // little as possible while under this lock. | 131 // little as possible while under this lock. |
132 | 132 |
133 bool lock_acquired = false; | 133 bool lock_acquired = false; |
134 CSecurityAttributes sec_attr; | 134 CSecurityAttributes sec_attr; |
135 if (base::win::GetVersion() >= base::win::VERSION_VISTA && | 135 if (base::win::GetVersion() >= base::win::VERSION_VISTA && |
136 BuildSecurityAttributesForLock(&sec_attr)) { | 136 BuildSecurityAttributesForLock(&sec_attr)) { |
137 // On vista and above, we need to explicitly allow low integrity access | 137 // On vista and above, we need to explicitly allow low integrity access |
(...skipping 43 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
181 if (created_beacon) { | 181 if (created_beacon) { |
182 dll_version_.swap(our_version); | 182 dll_version_.swap(our_version); |
183 | 183 |
184 lstrcpynA(reinterpret_cast<char*>(shared_memory_->memory()), | 184 lstrcpynA(reinterpret_cast<char*>(shared_memory_->memory()), |
185 dll_version_->GetString().c_str(), | 185 dll_version_->GetString().c_str(), |
186 std::min(kSharedMemorySize, | 186 std::min(kSharedMemorySize, |
187 dll_version_->GetString().length() + 1)); | 187 dll_version_->GetString().length() + 1)); |
188 } else { | 188 } else { |
189 char buffer[kSharedMemorySize] = {0}; | 189 char buffer[kSharedMemorySize] = {0}; |
190 memcpy(buffer, shared_memory_->memory(), kSharedMemorySize - 1); | 190 memcpy(buffer, shared_memory_->memory(), kSharedMemorySize - 1); |
191 dll_version_.reset(new base::Version(buffer)); | 191 dll_version_.reset(new Version(buffer)); |
192 | 192 |
193 if (!dll_version_->IsValid() || | 193 if (!dll_version_->IsValid() || |
194 dll_version_->Equals(*our_version.get())) { | 194 dll_version_->Equals(*our_version.get())) { |
195 // If we either couldn't parse a valid version out of the shared | 195 // If we either couldn't parse a valid version out of the shared |
196 // memory or we did parse a version and it is the same as our own, | 196 // memory or we did parse a version and it is the same as our own, |
197 // then pretend we're first in to avoid trying to load any other DLLs. | 197 // then pretend we're first in to avoid trying to load any other DLLs. |
198 dll_version_.reset(our_version.release()); | 198 dll_version_.reset(our_version.release()); |
199 created_beacon = true; | 199 created_beacon = true; |
200 } | 200 } |
201 } | 201 } |
(...skipping 30 matching lines...) Expand all Loading... |
232 if (first_module_handle) { | 232 if (first_module_handle) { |
233 proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( | 233 proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( |
234 GetProcAddress(first_module_handle, "DllGetClassObject")); | 234 GetProcAddress(first_module_handle, "DllGetClassObject")); |
235 DPLOG_IF(ERROR, !proc_ptr) << "DllRedirector: Could not get address of " | 235 DPLOG_IF(ERROR, !proc_ptr) << "DllRedirector: Could not get address of " |
236 "DllGetClassObject from first loaded module."; | 236 "DllGetClassObject from first loaded module."; |
237 } | 237 } |
238 | 238 |
239 return proc_ptr; | 239 return proc_ptr; |
240 } | 240 } |
241 | 241 |
242 base::Version* DllRedirector::GetCurrentModuleVersion() { | 242 Version* DllRedirector::GetCurrentModuleVersion() { |
243 scoped_ptr<FileVersionInfo> file_version_info( | 243 scoped_ptr<FileVersionInfo> file_version_info( |
244 FileVersionInfo::CreateFileVersionInfoForCurrentModule()); | 244 FileVersionInfo::CreateFileVersionInfoForCurrentModule()); |
245 DCHECK(file_version_info.get()); | 245 DCHECK(file_version_info.get()); |
246 | 246 |
247 scoped_ptr<base::Version> current_version; | 247 scoped_ptr<Version> current_version; |
248 if (file_version_info.get()) { | 248 if (file_version_info.get()) { |
249 current_version.reset( | 249 current_version.reset( |
250 new base::Version(WideToASCII(file_version_info->file_version()))); | 250 new Version(WideToASCII(file_version_info->file_version()))); |
251 DCHECK(current_version->IsValid()); | 251 DCHECK(current_version->IsValid()); |
252 } | 252 } |
253 | 253 |
254 return current_version.release(); | 254 return current_version.release(); |
255 } | 255 } |
256 | 256 |
257 HMODULE DllRedirector::GetFirstModule() { | 257 HMODULE DllRedirector::GetFirstModule() { |
258 DCHECK(dll_version_.get()) | 258 DCHECK(dll_version_.get()) |
259 << "Error: Did you call RegisterAsFirstCFModule() first?"; | 259 << "Error: Did you call RegisterAsFirstCFModule() first?"; |
260 | 260 |
261 if (first_module_handle_ == NULL) { | 261 if (first_module_handle_ == NULL) { |
262 first_module_handle_ = LoadVersionedModule(dll_version_.get()); | 262 first_module_handle_ = LoadVersionedModule(dll_version_.get()); |
263 } | 263 } |
264 | 264 |
265 if (first_module_handle_ == reinterpret_cast<HMODULE>(&__ImageBase)) { | 265 if (first_module_handle_ == reinterpret_cast<HMODULE>(&__ImageBase)) { |
266 NOTREACHED() << "Should not be loading own version."; | 266 NOTREACHED() << "Should not be loading own version."; |
267 first_module_handle_ = NULL; | 267 first_module_handle_ = NULL; |
268 } | 268 } |
269 | 269 |
270 return first_module_handle_; | 270 return first_module_handle_; |
271 } | 271 } |
272 | 272 |
273 HMODULE DllRedirector::LoadVersionedModule(base::Version* version) { | 273 HMODULE DllRedirector::LoadVersionedModule(Version* version) { |
274 DCHECK(version); | 274 DCHECK(version); |
275 | 275 |
276 HMODULE hmodule = NULL; | 276 HMODULE hmodule = NULL; |
277 wchar_t system_buffer[MAX_PATH]; | 277 wchar_t system_buffer[MAX_PATH]; |
278 HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase); | 278 HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase); |
279 system_buffer[0] = 0; | 279 system_buffer[0] = 0; |
280 if (GetModuleFileName(this_module, system_buffer, | 280 if (GetModuleFileName(this_module, system_buffer, |
281 arraysize(system_buffer)) != 0) { | 281 arraysize(system_buffer)) != 0) { |
282 base::FilePath module_path(system_buffer); | 282 base::FilePath module_path(system_buffer); |
283 | 283 |
284 // For a module located in | 284 // For a module located in |
285 // Foo\XXXXXXXXX\<module>.dll, load | 285 // Foo\XXXXXXXXX\<module>.dll, load |
286 // Foo\<version>\<module>.dll: | 286 // Foo\<version>\<module>.dll: |
287 base::FilePath module_name = module_path.BaseName(); | 287 base::FilePath module_name = module_path.BaseName(); |
288 module_path = module_path.DirName() | 288 module_path = module_path.DirName() |
289 .DirName() | 289 .DirName() |
290 .Append(base::ASCIIToWide(version->GetString())) | 290 .Append(base::ASCIIToWide(version->GetString())) |
291 .Append(module_name); | 291 .Append(module_name); |
292 | 292 |
293 hmodule = LoadLibrary(module_path.value().c_str()); | 293 hmodule = LoadLibrary(module_path.value().c_str()); |
294 if (hmodule == NULL) { | 294 if (hmodule == NULL) { |
295 DPLOG(ERROR) << "Could not load reported module version " | 295 DPLOG(ERROR) << "Could not load reported module version " |
296 << version->GetString(); | 296 << version->GetString(); |
297 } | 297 } |
298 } else { | 298 } else { |
299 DPLOG(FATAL) << "Failed to get module file name"; | 299 DPLOG(FATAL) << "Failed to get module file name"; |
300 } | 300 } |
301 return hmodule; | 301 return hmodule; |
302 } | 302 } |
OLD | NEW |