| OLD | NEW |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "chrome_frame/dll_redirector.h" | 5 #include "chrome_frame/dll_redirector.h" |
| 6 | 6 |
| 7 #include <aclapi.h> | 7 #include <aclapi.h> |
| 8 #include <atlbase.h> | 8 #include <atlbase.h> |
| 9 #include <atlsecurity.h> | 9 #include <atlsecurity.h> |
| 10 #include <sddl.h> | 10 #include <sddl.h> |
| (...skipping 106 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 117 } | 117 } |
| 118 | 118 |
| 119 return success; | 119 return success; |
| 120 } | 120 } |
| 121 | 121 |
| 122 | 122 |
| 123 bool DllRedirector::RegisterAsFirstCFModule() { | 123 bool DllRedirector::RegisterAsFirstCFModule() { |
| 124 DCHECK(first_module_handle_ == NULL); | 124 DCHECK(first_module_handle_ == NULL); |
| 125 | 125 |
| 126 // Build our own file version outside of the lock: | 126 // Build our own file version outside of the lock: |
| 127 scoped_ptr<base::Version> our_version(GetCurrentModuleVersion()); | 127 scoped_ptr<Version> our_version(GetCurrentModuleVersion()); |
| 128 | 128 |
| 129 // We sadly can't use the autolock here since we want to have a timeout. | 129 // We sadly can't use the autolock here since we want to have a timeout. |
| 130 // Be careful not to return while holding the lock. Also, attempt to do as | 130 // Be careful not to return while holding the lock. Also, attempt to do as |
| 131 // little as possible while under this lock. | 131 // little as possible while under this lock. |
| 132 | 132 |
| 133 bool lock_acquired = false; | 133 bool lock_acquired = false; |
| 134 CSecurityAttributes sec_attr; | 134 CSecurityAttributes sec_attr; |
| 135 if (base::win::GetVersion() >= base::win::VERSION_VISTA && | 135 if (base::win::GetVersion() >= base::win::VERSION_VISTA && |
| 136 BuildSecurityAttributesForLock(&sec_attr)) { | 136 BuildSecurityAttributesForLock(&sec_attr)) { |
| 137 // On vista and above, we need to explicitly allow low integrity access | 137 // On vista and above, we need to explicitly allow low integrity access |
| (...skipping 43 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 181 if (created_beacon) { | 181 if (created_beacon) { |
| 182 dll_version_.swap(our_version); | 182 dll_version_.swap(our_version); |
| 183 | 183 |
| 184 lstrcpynA(reinterpret_cast<char*>(shared_memory_->memory()), | 184 lstrcpynA(reinterpret_cast<char*>(shared_memory_->memory()), |
| 185 dll_version_->GetString().c_str(), | 185 dll_version_->GetString().c_str(), |
| 186 std::min(kSharedMemorySize, | 186 std::min(kSharedMemorySize, |
| 187 dll_version_->GetString().length() + 1)); | 187 dll_version_->GetString().length() + 1)); |
| 188 } else { | 188 } else { |
| 189 char buffer[kSharedMemorySize] = {0}; | 189 char buffer[kSharedMemorySize] = {0}; |
| 190 memcpy(buffer, shared_memory_->memory(), kSharedMemorySize - 1); | 190 memcpy(buffer, shared_memory_->memory(), kSharedMemorySize - 1); |
| 191 dll_version_.reset(new base::Version(buffer)); | 191 dll_version_.reset(new Version(buffer)); |
| 192 | 192 |
| 193 if (!dll_version_->IsValid() || | 193 if (!dll_version_->IsValid() || |
| 194 dll_version_->Equals(*our_version.get())) { | 194 dll_version_->Equals(*our_version.get())) { |
| 195 // If we either couldn't parse a valid version out of the shared | 195 // If we either couldn't parse a valid version out of the shared |
| 196 // memory or we did parse a version and it is the same as our own, | 196 // memory or we did parse a version and it is the same as our own, |
| 197 // then pretend we're first in to avoid trying to load any other DLLs. | 197 // then pretend we're first in to avoid trying to load any other DLLs. |
| 198 dll_version_.reset(our_version.release()); | 198 dll_version_.reset(our_version.release()); |
| 199 created_beacon = true; | 199 created_beacon = true; |
| 200 } | 200 } |
| 201 } | 201 } |
| (...skipping 30 matching lines...) Expand all Loading... |
| 232 if (first_module_handle) { | 232 if (first_module_handle) { |
| 233 proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( | 233 proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>( |
| 234 GetProcAddress(first_module_handle, "DllGetClassObject")); | 234 GetProcAddress(first_module_handle, "DllGetClassObject")); |
| 235 DPLOG_IF(ERROR, !proc_ptr) << "DllRedirector: Could not get address of " | 235 DPLOG_IF(ERROR, !proc_ptr) << "DllRedirector: Could not get address of " |
| 236 "DllGetClassObject from first loaded module."; | 236 "DllGetClassObject from first loaded module."; |
| 237 } | 237 } |
| 238 | 238 |
| 239 return proc_ptr; | 239 return proc_ptr; |
| 240 } | 240 } |
| 241 | 241 |
| 242 base::Version* DllRedirector::GetCurrentModuleVersion() { | 242 Version* DllRedirector::GetCurrentModuleVersion() { |
| 243 scoped_ptr<FileVersionInfo> file_version_info( | 243 scoped_ptr<FileVersionInfo> file_version_info( |
| 244 FileVersionInfo::CreateFileVersionInfoForCurrentModule()); | 244 FileVersionInfo::CreateFileVersionInfoForCurrentModule()); |
| 245 DCHECK(file_version_info.get()); | 245 DCHECK(file_version_info.get()); |
| 246 | 246 |
| 247 scoped_ptr<base::Version> current_version; | 247 scoped_ptr<Version> current_version; |
| 248 if (file_version_info.get()) { | 248 if (file_version_info.get()) { |
| 249 current_version.reset( | 249 current_version.reset( |
| 250 new base::Version(WideToASCII(file_version_info->file_version()))); | 250 new Version(WideToASCII(file_version_info->file_version()))); |
| 251 DCHECK(current_version->IsValid()); | 251 DCHECK(current_version->IsValid()); |
| 252 } | 252 } |
| 253 | 253 |
| 254 return current_version.release(); | 254 return current_version.release(); |
| 255 } | 255 } |
| 256 | 256 |
| 257 HMODULE DllRedirector::GetFirstModule() { | 257 HMODULE DllRedirector::GetFirstModule() { |
| 258 DCHECK(dll_version_.get()) | 258 DCHECK(dll_version_.get()) |
| 259 << "Error: Did you call RegisterAsFirstCFModule() first?"; | 259 << "Error: Did you call RegisterAsFirstCFModule() first?"; |
| 260 | 260 |
| 261 if (first_module_handle_ == NULL) { | 261 if (first_module_handle_ == NULL) { |
| 262 first_module_handle_ = LoadVersionedModule(dll_version_.get()); | 262 first_module_handle_ = LoadVersionedModule(dll_version_.get()); |
| 263 } | 263 } |
| 264 | 264 |
| 265 if (first_module_handle_ == reinterpret_cast<HMODULE>(&__ImageBase)) { | 265 if (first_module_handle_ == reinterpret_cast<HMODULE>(&__ImageBase)) { |
| 266 NOTREACHED() << "Should not be loading own version."; | 266 NOTREACHED() << "Should not be loading own version."; |
| 267 first_module_handle_ = NULL; | 267 first_module_handle_ = NULL; |
| 268 } | 268 } |
| 269 | 269 |
| 270 return first_module_handle_; | 270 return first_module_handle_; |
| 271 } | 271 } |
| 272 | 272 |
| 273 HMODULE DllRedirector::LoadVersionedModule(base::Version* version) { | 273 HMODULE DllRedirector::LoadVersionedModule(Version* version) { |
| 274 DCHECK(version); | 274 DCHECK(version); |
| 275 | 275 |
| 276 HMODULE hmodule = NULL; | 276 HMODULE hmodule = NULL; |
| 277 wchar_t system_buffer[MAX_PATH]; | 277 wchar_t system_buffer[MAX_PATH]; |
| 278 HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase); | 278 HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase); |
| 279 system_buffer[0] = 0; | 279 system_buffer[0] = 0; |
| 280 if (GetModuleFileName(this_module, system_buffer, | 280 if (GetModuleFileName(this_module, system_buffer, |
| 281 arraysize(system_buffer)) != 0) { | 281 arraysize(system_buffer)) != 0) { |
| 282 base::FilePath module_path(system_buffer); | 282 base::FilePath module_path(system_buffer); |
| 283 | 283 |
| 284 // For a module located in | 284 // For a module located in |
| 285 // Foo\XXXXXXXXX\<module>.dll, load | 285 // Foo\XXXXXXXXX\<module>.dll, load |
| 286 // Foo\<version>\<module>.dll: | 286 // Foo\<version>\<module>.dll: |
| 287 base::FilePath module_name = module_path.BaseName(); | 287 base::FilePath module_name = module_path.BaseName(); |
| 288 module_path = module_path.DirName() | 288 module_path = module_path.DirName() |
| 289 .DirName() | 289 .DirName() |
| 290 .Append(base::ASCIIToWide(version->GetString())) | 290 .Append(base::ASCIIToWide(version->GetString())) |
| 291 .Append(module_name); | 291 .Append(module_name); |
| 292 | 292 |
| 293 hmodule = LoadLibrary(module_path.value().c_str()); | 293 hmodule = LoadLibrary(module_path.value().c_str()); |
| 294 if (hmodule == NULL) { | 294 if (hmodule == NULL) { |
| 295 DPLOG(ERROR) << "Could not load reported module version " | 295 DPLOG(ERROR) << "Could not load reported module version " |
| 296 << version->GetString(); | 296 << version->GetString(); |
| 297 } | 297 } |
| 298 } else { | 298 } else { |
| 299 DPLOG(FATAL) << "Failed to get module file name"; | 299 DPLOG(FATAL) << "Failed to get module file name"; |
| 300 } | 300 } |
| 301 return hmodule; | 301 return hmodule; |
| 302 } | 302 } |
| OLD | NEW |