Index: chrome_frame/protocol_sink_wrap.cc |
=================================================================== |
--- chrome_frame/protocol_sink_wrap.cc (revision 48838) |
+++ chrome_frame/protocol_sink_wrap.cc (working copy) |
@@ -13,6 +13,8 @@ |
#include "base/singleton.h" |
#include "base/string_util.h" |
+#include "chrome_frame/bind_context_info.h" |
+#include "chrome_frame/function_stub.h" |
#include "chrome_frame/utils.h" |
// BINDSTATUS_SERVER_MIMETYPEAVAILABLE == 54. Introduced in IE 8, so |
@@ -29,640 +31,595 @@ |
static const int kInternetProtocolReadIndex = 9; |
static const int kInternetProtocolStartExIndex = 13; |
-// TODO(ananta) |
-// We should avoid duplicate VTable declarations. |
-BEGIN_VTABLE_PATCHES(IInternetProtocol) |
- VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) |
- VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) |
-END_VTABLE_PATCHES() |
-BEGIN_VTABLE_PATCHES(IInternetProtocolSecure) |
- VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) |
- VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) |
-END_VTABLE_PATCHES() |
+// IInternetProtocol/Ex patches. |
+STDMETHODIMP Hook_Start(InternetProtocol_Start_Fn orig_start, |
+ IInternetProtocol* protocol, |
+ LPCWSTR url, |
+ IInternetProtocolSink* prot_sink, |
+ IInternetBindInfo* bind_info, |
+ DWORD flags, |
+ HANDLE_PTR reserved); |
-BEGIN_VTABLE_PATCHES(IInternetProtocolEx) |
- VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) |
- VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) |
- VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx) |
+STDMETHODIMP Hook_StartEx(InternetProtocol_StartEx_Fn orig_start_ex, |
+ IInternetProtocolEx* protocol, |
+ IUri* uri, |
+ IInternetProtocolSink* prot_sink, |
+ IInternetBindInfo* bind_info, |
+ DWORD flags, |
+ HANDLE_PTR reserved); |
+ |
+STDMETHODIMP Hook_Read(InternetProtocol_Read_Fn orig_read, |
+ IInternetProtocol* protocol, |
+ void* buffer, |
+ ULONG size, |
+ ULONG* size_read); |
+ |
+///////////////////////////////////////////////////////////////////////////// |
+BEGIN_VTABLE_PATCHES(CTransaction) |
+ VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, Hook_Start) |
+ VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, Hook_Read) |
END_VTABLE_PATCHES() |
-BEGIN_VTABLE_PATCHES(IInternetProtocolExSecure) |
- VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) |
- VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) |
- VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx) |
+BEGIN_VTABLE_PATCHES(CTransaction2) |
+ VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, Hook_StartEx) |
END_VTABLE_PATCHES() |
// |
// ProtocolSinkWrap implementation |
-// |
// Static map initialization |
-ProtocolSinkWrap::ProtocolSinkMap ProtocolSinkWrap::sink_map_; |
-CComAutoCriticalSection ProtocolSinkWrap::sink_map_lock_; |
+ProtData::ProtocolDataMap ProtData::datamap_; |
+Lock ProtData::datamap_lock_; |
-ProtocolSinkWrap::ProtocolSinkWrap() |
- : protocol_(NULL), renderer_type_(UNDETERMINED), |
- buffer_size_(0), buffer_pos_(0), is_saved_result_(false), |
- result_code_(0), result_error_(0), report_data_recursiveness_(0), |
- determining_renderer_type_(false) { |
- memset(buffer_, 0, arraysize(buffer_)); |
+ProtocolSinkWrap::ProtocolSinkWrap() { |
+ DLOG(INFO) << __FUNCTION__ << StringPrintf(" 0x%08X", this); |
} |
ProtocolSinkWrap::~ProtocolSinkWrap() { |
- // This object may be destroyed before Initialize is called. |
- if (protocol_ != NULL) { |
- CComCritSecLock<CComAutoCriticalSection> lock(sink_map_lock_); |
- DCHECK(sink_map_.end() != sink_map_.find(protocol_)); |
- sink_map_.erase(protocol_); |
- protocol_ = NULL; |
- } |
- DLOG(INFO) << "ProtocolSinkWrap: active sinks: " << sink_map_.size(); |
+ DLOG(INFO) << __FUNCTION__ << StringPrintf(" 0x%08X", this); |
} |
-bool ProtocolSinkWrap::PatchProtocolHandlers() { |
- HRESULT hr = PatchProtocolMethods(CLSID_HttpProtocol, |
- IInternetProtocol_PatchInfo, |
- IInternetProtocolEx_PatchInfo); |
- if (FAILED(hr)) { |
- NOTREACHED() << "Failed to patch IInternetProtocol interface." |
- << " Error: " << hr; |
- return false; |
- } |
- |
- hr = PatchProtocolMethods(CLSID_HttpSProtocol, |
- IInternetProtocolSecure_PatchInfo, |
- IInternetProtocolExSecure_PatchInfo); |
- if (FAILED(hr)) { |
- NOTREACHED() << "Failed to patch IInternetProtocol secure interface." |
- << " Error: " << hr; |
- return false; |
- } |
- |
- return true; |
+ScopedComPtr<IInternetProtocolSink> ProtocolSinkWrap::CreateNewSink( |
+ IInternetProtocolSink* sink, ProtData* data) { |
+ DCHECK(sink != NULL); |
+ DCHECK(data != NULL); |
+ CComObject<ProtocolSinkWrap>* new_sink = NULL; |
+ CComObject<ProtocolSinkWrap>::CreateInstance(&new_sink); |
+ new_sink->delegate_ = sink; |
+ new_sink->prot_data_ = data; |
+ return ScopedComPtr<IInternetProtocolSink>(new_sink); |
} |
-void ProtocolSinkWrap::UnpatchProtocolHandlers() { |
- vtable_patch::UnpatchInterfaceMethods(IInternetProtocol_PatchInfo); |
- vtable_patch::UnpatchInterfaceMethods(IInternetProtocolEx_PatchInfo); |
- vtable_patch::UnpatchInterfaceMethods(IInternetProtocolSecure_PatchInfo); |
- vtable_patch::UnpatchInterfaceMethods(IInternetProtocolExSecure_PatchInfo); |
+// IInternetProtocolSink methods |
+STDMETHODIMP ProtocolSinkWrap::Switch(PROTOCOLDATA* protocol_data) { |
+ HRESULT hr = E_FAIL; |
+ if (delegate_) |
+ hr = delegate_->Switch(protocol_data); |
+ return hr; |
} |
-HRESULT ProtocolSinkWrap::CreateProtocolHandlerInstance( |
- const CLSID& clsid, IInternetProtocol** protocol) { |
- if (!protocol) { |
- return E_INVALIDARG; |
- } |
+STDMETHODIMP ProtocolSinkWrap::ReportProgress(ULONG status_code, |
+ LPCWSTR status_text) { |
+ DLOG(INFO) << "ProtocolSinkWrap::ReportProgress: " |
+ << BindStatus2Str(status_code) |
+ << " Status: " << (status_text ? status_text : L""); |
- HMODULE module = ::GetModuleHandle(kUrlMonDllName); |
- if (!module) { |
- NOTREACHED() << "urlmon is not yet loaded. Error: " << GetLastError(); |
- return E_FAIL; |
- } |
+ HRESULT hr = prot_data_->ReportProgress(delegate_, status_code, status_text); |
+ return hr; |
+} |
- typedef HRESULT (WINAPI* DllGetClassObject_Fn)(REFCLSID, REFIID, LPVOID*); |
- DllGetClassObject_Fn fn = reinterpret_cast<DllGetClassObject_Fn>( |
- ::GetProcAddress(module, "DllGetClassObject")); |
- if (!fn) { |
- NOTREACHED() << "DllGetClassObject not found in urlmon.dll"; |
- return E_FAIL; |
- } |
+STDMETHODIMP ProtocolSinkWrap::ReportData(DWORD flags, ULONG progress, |
+ ULONG max_progress) { |
+ DCHECK(delegate_); |
+ DLOG(INFO) << "ProtocolSinkWrap::ReportData: " << Bscf2Str(flags) << |
+ " progress: " << progress << " progress_max: " << max_progress; |
- ScopedComPtr<IClassFactory> protocol_class_factory; |
- HRESULT hr = fn(clsid, IID_IClassFactory, |
- reinterpret_cast<LPVOID*>(protocol_class_factory.Receive())); |
- if (FAILED(hr)) { |
- NOTREACHED() << "DllGetclassObject failed. Error: " << hr; |
- return hr; |
- } |
+ HRESULT hr = prot_data_->ReportData(delegate_, flags, progress, max_progress); |
+ return hr; |
+} |
- ScopedComPtr<IInternetProtocol> handler_instance; |
- hr = protocol_class_factory->CreateInstance(NULL, IID_IInternetProtocol, |
- reinterpret_cast<void**>(handler_instance.Receive())); |
- if (FAILED(hr)) { |
- NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol." |
- << " Error: " << hr; |
- } else { |
- *protocol = handler_instance.Detach(); |
- } |
+STDMETHODIMP ProtocolSinkWrap::ReportResult(HRESULT result, DWORD error, |
+ LPCWSTR result_text) { |
+ DLOG(INFO) << "ProtocolSinkWrap::ReportResult: result: " << result << |
+ " error: " << error << " Text: " << (result_text ? result_text : L""); |
+ DCHECK_NE(UNDETERMINED, prot_data_->renderer_type()); |
+ HRESULT hr = E_FAIL; |
+ if (delegate_) |
+ hr = delegate_->ReportResult(result, error, result_text); |
+ |
return hr; |
} |
-HRESULT ProtocolSinkWrap::PatchProtocolMethods( |
- const CLSID& clsid_protocol, |
- vtable_patch::MethodPatchInfo* protocol_patch_info, |
- vtable_patch::MethodPatchInfo* protocol_ex_patch_info) { |
- if (!protocol_patch_info || !protocol_ex_patch_info) { |
- return E_INVALIDARG; |
- } |
- ScopedComPtr<IInternetProtocol> http_protocol; |
- HRESULT hr = CreateProtocolHandlerInstance(clsid_protocol, |
- http_protocol.Receive()); |
- if (FAILED(hr)) { |
- NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol." |
- << " Error: " << hr; |
- return false; |
+// Helpers |
+ScopedComPtr<IBindCtx> BindCtxFromIBindInfo(IInternetBindInfo* bind_info) { |
+ LPOLESTR bind_ctx_string = NULL; |
+ ULONG count; |
+ ScopedComPtr<IBindCtx> bind_ctx; |
+ bind_info->GetBindString(BINDSTRING_PTR_BIND_CONTEXT, &bind_ctx_string, 1, |
+ &count); |
+ if (bind_ctx_string) { |
+ IBindCtx* pbc = reinterpret_cast<IBindCtx*>(StringToInt(bind_ctx_string)); |
+ bind_ctx.Attach(pbc); |
+ CoTaskMemFree(bind_ctx_string); |
} |
- ScopedComPtr<IInternetProtocolEx> ipex; |
- ipex.QueryFrom(http_protocol); |
- if (ipex) { |
- hr = vtable_patch::PatchInterfaceMethods(ipex, protocol_ex_patch_info); |
- } else { |
- hr = vtable_patch::PatchInterfaceMethods(http_protocol, |
- protocol_patch_info); |
- } |
- return hr; |
+ return bind_ctx; |
} |
-// IInternetProtocol/Ex method implementation. |
-HRESULT ProtocolSinkWrap::OnStart(InternetProtocol_Start_Fn orig_start, |
- IInternetProtocol* protocol, LPCWSTR url, IInternetProtocolSink* prot_sink, |
- IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { |
- DCHECK(orig_start); |
- DLOG_IF(INFO, url != NULL) << "OnStart: " << url; |
+bool ShouldWrapSink(IInternetProtocolSink* sink, const wchar_t* url) { |
+ // TODO(stoyan): check the url scheme for http/https. |
+ ScopedComPtr<IHttpNegotiate> http_negotiate; |
+ HRESULT hr = DoQueryService(GUID_NULL, sink, http_negotiate.Receive()); |
+ if (http_negotiate && !IsSubFrameRequest(http_negotiate)) |
+ return true; |
- ScopedComPtr<IInternetProtocolSink> sink_to_use(MaybeWrapSink(protocol, |
- prot_sink, url)); |
- return orig_start(protocol, url, sink_to_use, bind_info, flags, reserved); |
+ return false; |
} |
-HRESULT ProtocolSinkWrap::OnStartEx(InternetProtocol_StartEx_Fn orig_start_ex, |
- IInternetProtocolEx* protocol, IUri* uri, IInternetProtocolSink* prot_sink, |
- IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { |
- DCHECK(orig_start_ex); |
+// High level helpers |
+bool IsCFRequest(IBindCtx* pbc) { |
+ ScopedComPtr<BindContextInfo> info; |
+ BindContextInfo::FromBindContext(pbc, info.Receive()); |
+ DCHECK(info); |
+ if (info && info->chrome_request()) |
+ return true; |
- ScopedBstr url; |
- uri->GetPropertyBSTR(Uri_PROPERTY_ABSOLUTE_URI, url.Receive(), 0); |
- DLOG_IF(INFO, url != NULL) << "OnStartEx: " << url; |
+ return false; |
+} |
- ScopedComPtr<IInternetProtocolSink> sink_to_use(MaybeWrapSink(protocol, |
- prot_sink, url)); |
- return orig_start_ex(protocol, uri, sink_to_use, bind_info, flags, reserved); |
+void PutProtData(IBindCtx* pbc, ProtData* data) { |
+ ScopedComPtr<BindContextInfo> info; |
+ BindContextInfo::FromBindContext(pbc, info.Receive()); |
+ if (info) |
+ info->set_prot_data(data); |
} |
-HRESULT ProtocolSinkWrap::OnRead(InternetProtocol_Read_Fn orig_read, |
- IInternetProtocol* protocol, void* buffer, ULONG size, ULONG* size_read) { |
- DCHECK(orig_read); |
+bool IsTextHtml(const wchar_t* status_text) { |
+ if (!status_text) |
+ return false; |
+ size_t status_text_length = lstrlenW(status_text); |
+ const wchar_t* status_text_end = status_text + |
+ std::min(status_text_length, arraysize(kTextHtmlMimeType) - 1); |
+ bool is_text_html = LowerCaseEqualsASCII(status_text, status_text_end, |
+ kTextHtmlMimeType); |
+ return is_text_html; |
+} |
- scoped_refptr<ProtocolSinkWrap> instance = |
- ProtocolSinkWrap::InstanceFromProtocol(protocol); |
- HRESULT hr; |
- if (instance) { |
- DCHECK(instance->protocol_ == protocol); |
- hr = instance->OnReadImpl(buffer, size, size_read, orig_read); |
- } else { |
- hr = orig_read(protocol, buffer, size, size_read); |
+RendererType DetermineRendererType(void* buffer, DWORD size, bool last_chance) { |
+ RendererType type = UNDETERMINED; |
+ if (last_chance) |
+ type = OTHER; |
+ |
+ std::wstring html_contents; |
+ // TODO(joshia): detect and handle different content encodings |
+ UTF8ToWide(reinterpret_cast<char*>(buffer), size, &html_contents); |
+ |
+ // Note that document_contents_ may have NULL characters in it. While |
+ // browsers may handle this properly, we don't and will stop scanning |
+ // for the XUACompat content value if we encounter one. |
+ std::wstring xua_compat_content; |
+ UtilGetXUACompatContentValue(html_contents, &xua_compat_content); |
+ if (StrStrI(xua_compat_content.c_str(), kChromeContentPrefix)) { |
+ type = CHROME; |
} |
- return hr; |
+ return type; |
} |
-bool ProtocolSinkWrap::Initialize(IInternetProtocol* protocol, |
- IInternetProtocolSink* original_sink, const wchar_t* url) { |
- DCHECK(original_sink); |
- delegate_ = original_sink; |
- protocol_ = protocol; |
- if (url) |
- url_ = url; |
+// ProtData |
+ProtData::ProtData(IInternetProtocol* protocol, |
+ InternetProtocol_Read_Fn read_fun, const wchar_t* url) |
+ : has_suggested_mime_type_(false), has_server_mime_type_(false), |
+ report_data_received_(false), buffer_size_(0), buffer_pos_(0), |
+ renderer_type_(UNDETERMINED), protocol_(protocol), read_fun_(read_fun), |
+ url_(url) { |
+ memset(buffer_, 0, arraysize(buffer_)); |
+ DLOG(INFO) << __FUNCTION__ << " " << this; |
- CComCritSecLock<CComAutoCriticalSection> lock(sink_map_lock_); |
- DCHECK(sink_map_.end() == sink_map_.find(protocol)); |
- sink_map_[protocol] = this; |
- DLOG(INFO) << "ProtocolSinkWrap: active sinks: " << sink_map_.size(); |
- return true; |
+ // Add to map. |
+ AutoLock lock(datamap_lock_); |
+ DCHECK(datamap_.end() == datamap_.find(protocol_)); |
+ datamap_[protocol] = this; |
} |
-// IInternetProtocolSink methods |
-STDMETHODIMP ProtocolSinkWrap::Switch(PROTOCOLDATA* protocol_data) { |
- HRESULT hr = E_FAIL; |
- if (delegate_) |
- hr = delegate_->Switch(protocol_data); |
- return hr; |
+ProtData::~ProtData() { |
+ DLOG(INFO) << __FUNCTION__ << " " << this; |
+ |
+ // Remove from map. |
+ AutoLock lock(datamap_lock_); |
+ DCHECK(datamap_.end() != datamap_.find(protocol_)); |
+ datamap_.erase(protocol_); |
} |
-STDMETHODIMP ProtocolSinkWrap::ReportProgress(ULONG status_code, |
- LPCWSTR status_text) { |
- DLOG(INFO) << "ProtocolSinkWrap::ReportProgress: Code:" << status_code << |
- " Text: " << (status_text ? status_text : L""); |
- if (!delegate_) { |
- return E_FAIL; |
+HRESULT ProtData::Read(void* buffer, ULONG size, ULONG* size_read) { |
+ if (renderer_type_ == UNDETERMINED) { |
+ return E_PENDING; |
} |
- if ((BINDSTATUS_MIMETYPEAVAILABLE == status_code) || |
- (BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE == status_code)) { |
- // If we have a MIMETYPE and that MIMETYPE is not "text/html". we don't |
- // want to do anything with this. |
- if (status_text) { |
- size_t status_text_length = lstrlenW(status_text); |
- const wchar_t* status_text_end = status_text + std::min( |
- status_text_length, arraysize(kTextHtmlMimeType) - 1); |
- if (!LowerCaseEqualsASCII(status_text, status_text_end, |
- kTextHtmlMimeType)) { |
- renderer_type_ = OTHER; |
- } |
+ |
+ const ULONG bytes_available = buffer_size_ - buffer_pos_; |
+ const ULONG bytes_to_copy = std::min(bytes_available, size); |
+ if (bytes_to_copy) { |
+ // Copy from the local buffer. |
+ memcpy(buffer, buffer_ + buffer_pos_, bytes_to_copy); |
+ *size_read = bytes_to_copy; |
+ buffer_pos_ += bytes_to_copy; |
+ |
+ HRESULT hr = S_OK; |
+ ULONG new_data = 0; |
+ if (size > bytes_available) { |
+ // User buffer is greater than what we have. |
+ buffer = reinterpret_cast<uint8*>(buffer) + bytes_to_copy; |
+ size -= bytes_to_copy; |
+ hr = read_fun_(protocol_, buffer, size, &new_data); |
} |
- } |
- HRESULT hr = S_OK; |
- if (delegate_ && renderer_type_ != CHROME) { |
- hr = delegate_->ReportProgress(status_code, status_text); |
+ if (size_read) |
+ *size_read = bytes_to_copy + new_data; |
+ return hr; |
} |
- return hr; |
+ |
+ return read_fun_(protocol_, buffer, size, size_read); |
} |
-STDMETHODIMP ProtocolSinkWrap::ReportData(DWORD flags, ULONG progress, |
- ULONG max_progress) { |
- DCHECK(protocol_); |
- DCHECK(delegate_); |
- DLOG(INFO) << "ProtocolSinkWrap::ReportData: flags: " << flags << |
- " progress: " << progress << " progress_max: " << max_progress; |
- scoped_refptr<ProtocolSinkWrap> self_ref(this); |
+HRESULT ProtData::ReportProgress(IInternetProtocolSink* delegate, |
+ ULONG status_code, LPCWSTR status_text) { |
+ switch (status_code) { |
+ case BINDSTATUS_DIRECTBIND: |
+ renderer_type_ = OTHER; |
+ break; |
- // Maintain a stack depth to make a determination. ReportData is called |
- // recursively in IE8. If the request can be served in a single Read, the |
- // situation ends up like this: |
- // orig_prot |
- // |--> ProtocolSinkWrap::ReportData (BSCF_FIRSTDATANOTIFICATION) |
- // |--> orig_prot->Read(...) - 1st read - S_OK and data |
- // |--> ProtocolSinkWrap::ReportData (BSCF_LASTDATANOTIFICATION) |
- // |--> orig_prot->Read(...) - 2nd read S_FALSE, 0 bytes |
- // |
- // Inner call returns S_FALSE and no data. We try to make a determination |
- // of render type then and incorrectly set it to 'OTHER' as we don't have |
- // any data yet. However, we can make a determination in the context of |
- // outer ReportData since the first read will return S_OK with data. Then |
- // the next Read in the loop will return S_FALSE and we will enter the |
- // determination logic. |
+ case BINDSTATUS_REDIRECTING: |
+ url_.empty(); |
+ if (status_text) |
+ url_ = status_text; |
+ break; |
- // NOTE: We use the report_data_recursiveness_ variable to detect situations |
- // in which calls to ReportData are re-entrant (such as when the entire |
- // contents of a page fit inside a single packet). In these cases, we |
- // don't care about re-entrant calls beyond the second, and so we compare |
- // report_data_recursiveness_ inside the while loop, making sure we skip |
- // what would otherwise be spurious calls to ReportProgress(). |
- report_data_recursiveness_++; |
+ case BINDSTATUS_SERVER_MIMETYPEAVAILABLE: |
+ has_server_mime_type_ = true; |
+ SaveSuggestedMimeType(status_text); |
+ return S_OK; |
- HRESULT hr = S_OK; |
- if (is_undetermined()) { |
- CheckAndReportChromeMimeTypeForRequest(); |
+ // TODO(stoyan): BINDSTATUS_RAWMIMETYPE |
+ case BINDSTATUS_MIMETYPEAVAILABLE: |
+ case BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE: |
+ SaveSuggestedMimeType(status_text); |
+ return S_OK; |
} |
- // we call original only if the renderer type is other |
- if (renderer_type() == OTHER) { |
- hr = delegate_->ReportData(flags, progress, max_progress); |
+ return delegate->ReportProgress(status_code, status_text); |
+} |
- if (is_saved_result_) { |
- is_saved_result_ = false; |
- delegate_->ReportResult(result_code_, result_error_, |
- result_text_.c_str()); |
+HRESULT ProtData::ReportData(IInternetProtocolSink* delegate, |
+ DWORD flags, ULONG progress, ULONG max_progress) { |
+ if (renderer_type_ != UNDETERMINED) { |
+ return delegate->ReportData(flags, progress, max_progress); |
+ } |
+ |
+ // Do these checks only once. |
+ if (!report_data_received_) { |
+ report_data_received_ = true; |
+ |
+ DLOG_IF(INFO, (flags & BSCF_FIRSTDATANOTIFICATION) == 0) << |
+ "BUGBUG: BSCF_FIRSTDATANOTIFICATION is not set properly!"; |
+ |
+ |
+ // We check here, instead in ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE) |
+ // to be safe when following multiple redirects.? |
+ if (!IsTextHtml(suggested_mime_type_)) { |
+ renderer_type_ = OTHER; |
+ FireSugestedMimeType(delegate); |
+ return delegate->ReportData(flags, progress, max_progress); |
} |
+ |
+ if (!url_.empty() && IsOptInUrl(url_.c_str())) { |
+ // TODO(stoyan): We may attempt to remove ourselves from the bind context. |
+ renderer_type_ = CHROME; |
+ delegate->ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE, kChromeMimeType); |
+ return delegate->ReportData(flags, progress, max_progress); |
+ } |
} |
- report_data_recursiveness_--; |
- return hr; |
-} |
+ HRESULT hr = FillBuffer(); |
-STDMETHODIMP ProtocolSinkWrap::ReportResult(HRESULT result, DWORD error, |
- LPCWSTR result_text) { |
- DLOG(INFO) << "ProtocolSinkWrap::ReportResult: result: " << result << |
- " error: " << error << " Text: " << (result_text ? result_text : L""); |
+ bool last_chance = false; |
+ if (hr == S_OK || hr == S_FALSE) { |
+ last_chance = true; |
+ } |
- // If this request failed, we don't want to have anything to do with this. |
- if (FAILED(result)) |
- renderer_type_ = OTHER; |
+ renderer_type_ = DetermineRendererType(buffer_, buffer_size_, last_chance); |
- // if we are still not sure about the renderer type, cache the result, |
- // othewise urlmon will get confused about getting reported about a |
- // success result for which it never received any data. |
- if (is_undetermined()) { |
- is_saved_result_ = true; |
- result_code_ = result; |
- result_error_ = error; |
- if (result_text) |
- result_text_ = result_text; |
+ if (renderer_type_ == UNDETERMINED) { |
+ // do not report anything, we need more data. |
return S_OK; |
} |
- HRESULT hr = E_FAIL; |
- if (delegate_) |
- hr = delegate_->ReportResult(result, error, result_text); |
+ if (renderer_type_ == CHROME) { |
+ DLOG(INFO) << "Forwarding BINDSTATUS_MIMETYPEAVAILABLE " |
+ << kChromeMimeType; |
+ delegate->ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE, kChromeMimeType); |
+ } |
- return hr; |
-} |
+ if (renderer_type_ == OTHER) { |
+ FireSugestedMimeType(delegate); |
+ } |
-// IInternetBindInfoEx |
-STDMETHODIMP ProtocolSinkWrap::GetBindInfo(DWORD* flags, |
- BINDINFO* bind_info_ret) { |
- ScopedComPtr<IInternetBindInfo> bind_info; |
- HRESULT hr = bind_info.QueryFrom(delegate_); |
- if (bind_info) |
- hr = bind_info->GetBindInfo(flags, bind_info_ret); |
- return hr; |
-} |
+ // This is the first data notification we forward. |
+ flags |= BSCF_FIRSTDATANOTIFICATION; |
-STDMETHODIMP ProtocolSinkWrap::GetBindString(ULONG string_type, |
- LPOLESTR* string_array, ULONG array_size, ULONG* size_returned) { |
- ScopedComPtr<IInternetBindInfo> bind_info; |
- HRESULT hr = bind_info.QueryFrom(delegate_); |
- if (bind_info) |
- hr = bind_info->GetBindString(string_type, string_array, |
- array_size, size_returned); |
- return hr; |
-} |
+ if (hr == S_FALSE) { |
+ flags |= (BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE); |
+ } |
-STDMETHODIMP ProtocolSinkWrap::GetBindInfoEx(DWORD* flags, BINDINFO* bind_info, |
- DWORD* bindf2, DWORD* reserved) { |
- ScopedComPtr<IInternetBindInfoEx> bind_info_ex; |
- HRESULT hr = bind_info_ex.QueryFrom(delegate_); |
- if (bind_info_ex) |
- hr = bind_info_ex->GetBindInfoEx(flags, bind_info, bindf2, reserved); |
- return hr; |
+ return delegate->ReportData(flags, progress, max_progress); |
} |
-// IServiceProvider |
-STDMETHODIMP ProtocolSinkWrap::QueryService(REFGUID service_guid, |
- REFIID riid, void** service) { |
- ScopedComPtr<IServiceProvider> service_provider; |
- HRESULT hr = service_provider.QueryFrom(delegate_); |
- if (service_provider) |
- hr = service_provider->QueryService(service_guid, riid, service); |
- return hr; |
+void ProtData::UpdateUrl(const wchar_t* url) { |
+ url_ = url; |
} |
-// IAuthenticate |
-STDMETHODIMP ProtocolSinkWrap::Authenticate(HWND* window, |
- LPWSTR* user_name, LPWSTR* password) { |
- ScopedComPtr<IAuthenticate> authenticate; |
- HRESULT hr = authenticate.QueryFrom(delegate_); |
- if (authenticate) |
- hr = authenticate->Authenticate(window, user_name, password); |
- return hr; |
-} |
+// S_FALSE - EOF |
+// S_OK - buffer fully filled |
+// E_PENDING - some data added to buffer, but buffer is not yet full |
+// E_XXXX - some other error. |
+HRESULT ProtData::FillBuffer() { |
+ HRESULT hr_read = S_OK; |
-// IInternetProtocolEx |
-STDMETHODIMP ProtocolSinkWrap::Start(LPCWSTR url, |
- IInternetProtocolSink *protocol_sink, IInternetBindInfo* bind_info, |
- DWORD flags, HANDLE_PTR reserved) { |
- ScopedComPtr<IInternetProtocolRoot> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->Start(url, protocol_sink, bind_info, flags, reserved); |
- return hr; |
-} |
+ while ((hr_read == S_OK) && (buffer_size_ < kMaxContentSniffLength)) { |
+ ULONG size_read = 0; |
+ hr_read = read_fun_(protocol_, buffer_ + buffer_size_, |
+ kMaxContentSniffLength - buffer_size_, &size_read); |
+ buffer_size_ += size_read; |
+ } |
-STDMETHODIMP ProtocolSinkWrap::Continue(PROTOCOLDATA* protocol_data) { |
- ScopedComPtr<IInternetProtocolRoot> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->Continue(protocol_data); |
- return hr; |
+ return hr_read; |
} |
-STDMETHODIMP ProtocolSinkWrap::Abort(HRESULT reason, DWORD options) { |
- ScopedComPtr<IInternetProtocolRoot> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->Abort(reason, options); |
- return hr; |
+void ProtData::SaveSuggestedMimeType(LPCWSTR status_text) { |
+ has_suggested_mime_type_ = true; |
+ suggested_mime_type_.Allocate(status_text); |
} |
-STDMETHODIMP ProtocolSinkWrap::Terminate(DWORD options) { |
- ScopedComPtr<IInternetProtocolRoot> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->Terminate(options); |
- return hr; |
-} |
+void ProtData::FireSugestedMimeType(IInternetProtocolSink* delegate) { |
+ if (has_server_mime_type_) { |
+ DLOG(INFO) << "Forwarding BINDSTATUS_SERVER_MIMETYPEAVAILABLE " |
+ << suggested_mime_type_; |
+ delegate->ReportProgress(BINDSTATUS_SERVER_MIMETYPEAVAILABLE, |
+ suggested_mime_type_); |
+ return; |
+ } |
-STDMETHODIMP ProtocolSinkWrap::Suspend() { |
- ScopedComPtr<IInternetProtocolRoot> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->Suspend(); |
- return hr; |
+ if (has_suggested_mime_type_) { |
+ DLOG(INFO) << "Forwarding BINDSTATUS_MIMETYPEAVAILABLE " |
+ << suggested_mime_type_; |
+ delegate->ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE, |
+ suggested_mime_type_); |
+ } |
} |
-STDMETHODIMP ProtocolSinkWrap::Resume() { |
- ScopedComPtr<IInternetProtocolRoot> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->Resume(); |
- return hr; |
+scoped_refptr<ProtData> ProtData::DataFromProtocol( |
+ IInternetProtocol* protocol) { |
+ scoped_refptr<ProtData> instance; |
+ AutoLock lock(datamap_lock_); |
+ ProtocolDataMap::iterator it = datamap_.find(protocol); |
+ if (datamap_.end() != it) |
+ instance = it->second; |
+ return instance; |
} |
-STDMETHODIMP ProtocolSinkWrap::Read(void *buffer, ULONG size, |
- ULONG* size_read) { |
- ScopedComPtr<IInternetProtocol> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->Read(buffer, size, size_read); |
- return hr; |
-} |
+// IInternetProtocol/Ex hooks. |
+STDMETHODIMP Hook_Start(InternetProtocol_Start_Fn orig_start, |
+ IInternetProtocol* protocol, LPCWSTR url, IInternetProtocolSink* prot_sink, |
+ IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { |
+ DCHECK(orig_start); |
+ if (!url || !prot_sink || !bind_info) |
+ return E_INVALIDARG; |
+ DLOG_IF(INFO, url != NULL) << "OnStart: " << url << PiFlags2Str(flags); |
-STDMETHODIMP ProtocolSinkWrap::Seek(LARGE_INTEGER move, DWORD origin, |
- ULARGE_INTEGER* new_pos) { |
- ScopedComPtr<IInternetProtocol> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->Seek(move, origin, new_pos); |
- return hr; |
-} |
+ ScopedComPtr<IBindCtx> bind_ctx = BindCtxFromIBindInfo(bind_info); |
+ if (!bind_ctx) { |
+ // MSHTML sometimes takes a short path, skips the creation of |
+ // moniker and binding, by directly grabbing protocol from InternetSession |
+ DLOG(INFO) << "DirectBind for " << url; |
+ return orig_start(protocol, url, prot_sink, bind_info, flags, reserved); |
+ } |
-STDMETHODIMP ProtocolSinkWrap::LockRequest(DWORD options) { |
- ScopedComPtr<IInternetProtocol> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->LockRequest(options); |
- return hr; |
-} |
+ if (IsCFRequest(bind_ctx)) { |
+ return orig_start(protocol, url, prot_sink, bind_info, flags, reserved); |
+ } |
-STDMETHODIMP ProtocolSinkWrap::UnlockRequest() { |
- ScopedComPtr<IInternetProtocol> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->UnlockRequest(); |
- return hr; |
-} |
+ scoped_refptr<ProtData> prot_data = ProtData::DataFromProtocol(protocol); |
+ if (prot_data) { |
+ DLOG(INFO) << "Found existing ProtData!"; |
+ prot_data->UpdateUrl(url); |
+ ScopedComPtr<IInternetProtocolSink> new_sink = |
+ ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data); |
+ return orig_start(protocol, url, new_sink, bind_info, flags, reserved); |
+ } |
-STDMETHODIMP ProtocolSinkWrap::StartEx(IUri* uri, |
- IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, |
- DWORD flags, HANDLE_PTR reserved) { |
- ScopedComPtr<IInternetProtocolEx> protocol; |
- HRESULT hr = protocol.QueryFrom(delegate_); |
- if (protocol) |
- hr = protocol->StartEx(uri, protocol_sink, bind_info, flags, reserved); |
- return hr; |
-} |
+ if (!ShouldWrapSink(prot_sink, url)) { |
+ return orig_start(protocol, url, prot_sink, bind_info, flags, reserved); |
+ } |
-// IInternetPriority |
-STDMETHODIMP ProtocolSinkWrap::SetPriority(LONG priority) { |
- ScopedComPtr<IInternetPriority> internet_priority; |
- HRESULT hr = internet_priority.QueryFrom(delegate_); |
- if (internet_priority) |
- hr = internet_priority->SetPriority(priority); |
- return hr; |
-} |
+ // Fresh request. |
+ InternetProtocol_Read_Fn read_fun = reinterpret_cast<InternetProtocol_Read_Fn> |
+ (CTransaction_PatchInfo[1].stub_->argument()); |
+ prot_data = new ProtData(protocol, read_fun, url); |
+ PutProtData(bind_ctx, prot_data); |
-STDMETHODIMP ProtocolSinkWrap::GetPriority(LONG* priority) { |
- ScopedComPtr<IInternetPriority> internet_priority; |
- HRESULT hr = internet_priority.QueryFrom(delegate_); |
- if (internet_priority) |
- hr = internet_priority->GetPriority(priority); |
- return hr; |
+ ScopedComPtr<IInternetProtocolSink> new_sink = |
+ ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data); |
+ return orig_start(protocol, url, new_sink, bind_info, flags, reserved); |
} |
-// IWrappedProtocol |
-STDMETHODIMP ProtocolSinkWrap::GetWrapperCode(LONG *code, DWORD_PTR reserved) { |
- ScopedComPtr<IWrappedProtocol> wrapped_protocol; |
- HRESULT hr = wrapped_protocol.QueryFrom(delegate_); |
- if (wrapped_protocol) |
- hr = wrapped_protocol->GetWrapperCode(code, reserved); |
- return hr; |
+STDMETHODIMP Hook_StartEx(InternetProtocol_StartEx_Fn orig_start_ex, |
+ IInternetProtocolEx* protocol, IUri* uri, IInternetProtocolSink* prot_sink, |
+ IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { |
+ DCHECK(orig_start_ex); |
+ if (!uri || !prot_sink || !bind_info) |
+ return E_INVALIDARG; |
+ |
+ ScopedBstr url; |
+ uri->GetPropertyBSTR(Uri_PROPERTY_ABSOLUTE_URI, url.Receive(), 0); |
+ DLOG_IF(INFO, url != NULL) << "OnStartEx: " << url << PiFlags2Str(flags); |
+ |
+ ScopedComPtr<IBindCtx> bind_ctx = BindCtxFromIBindInfo(bind_info); |
+ if (!bind_ctx) { |
+ // MSHTML sometimes takes a short path, skips the creation of |
+ // moniker and binding, by directly grabbing protocol from InternetSession. |
+ DLOG(INFO) << "DirectBind for " << url; |
+ return orig_start_ex(protocol, uri, prot_sink, bind_info, flags, reserved); |
+ } |
+ |
+ if (IsCFRequest(bind_ctx)) { |
+ return orig_start_ex(protocol, uri, prot_sink, bind_info, flags, reserved); |
+ } |
+ |
+ scoped_refptr<ProtData> prot_data = ProtData::DataFromProtocol(protocol); |
+ if (prot_data) { |
+ DLOG(INFO) << "Found existing ProtData!"; |
+ prot_data->UpdateUrl(url); |
+ ScopedComPtr<IInternetProtocolSink> new_sink = |
+ ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data); |
+ return orig_start_ex(protocol, uri, new_sink, bind_info, flags, reserved); |
+ } |
+ |
+ if (!ShouldWrapSink(prot_sink, url)) { |
+ return orig_start_ex(protocol, uri, prot_sink, bind_info, flags, reserved); |
+ } |
+ |
+ // Fresh request. |
+ InternetProtocol_Read_Fn read_fun = reinterpret_cast<InternetProtocol_Read_Fn> |
+ (CTransaction_PatchInfo[1].stub_->argument()); |
+ prot_data = new ProtData(protocol, read_fun, url); |
+ PutProtData(bind_ctx, prot_data); |
+ |
+ ScopedComPtr<IInternetProtocolSink> new_sink = |
+ ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data); |
+ return orig_start_ex(protocol, uri, new_sink, bind_info, flags, reserved); |
} |
+STDMETHODIMP Hook_Read(InternetProtocol_Read_Fn orig_read, |
+ IInternetProtocol* protocol, void* buffer, ULONG size, ULONG* size_read) { |
+ DCHECK(orig_read); |
+ scoped_refptr<ProtData> prot_data = ProtData::DataFromProtocol(protocol); |
+ if (!prot_data) { |
+ return orig_read(protocol, buffer, size, size_read); |
+ } |
-// public IUriContainer |
-STDMETHODIMP ProtocolSinkWrap::GetIUri(IUri** uri) { |
- ScopedComPtr<IUriContainer> uri_container; |
- HRESULT hr = uri_container.QueryFrom(delegate_); |
- if (uri_container) |
- hr = uri_container->GetIUri(uri); |
+ HRESULT hr = prot_data->Read(buffer, size, size_read); |
return hr; |
} |
-// Protected helpers |
+// Patching / Hooking code. |
+class FakeProtocol : public CComObjectRootEx<CComSingleThreadModel>, |
+ public IInternetProtocol { |
+ public: |
+ BEGIN_COM_MAP(FakeProtocol) |
+ COM_INTERFACE_ENTRY(IInternetProtocol) |
+ COM_INTERFACE_ENTRY(IInternetProtocolRoot) |
+ END_COM_MAP() |
-void ProtocolSinkWrap::DetermineRendererType() { |
- if (is_undetermined()) { |
- if (IsOptInUrl(url_.c_str())) { |
- renderer_type_ = CHROME; |
- } else { |
- std::wstring xua_compat_content; |
- // Note that document_contents_ may have NULL characters in it. While |
- // browsers may handle this properly, we don't and will stop scanning for |
- // the XUACompat content value if we encounter one. |
- DCHECK(buffer_size_ < arraysize(buffer_)); |
- buffer_[buffer_size_] = 0; |
- std::wstring html_contents; |
- // TODO(joshia): detect and handle different content encodings |
- UTF8ToWide(buffer_, buffer_size_, &html_contents); |
- UtilGetXUACompatContentValue(html_contents, &xua_compat_content); |
- if (StrStrI(xua_compat_content.c_str(), kChromeContentPrefix)) { |
- renderer_type_ = CHROME; |
- } else { |
- renderer_type_ = OTHER; |
- } |
- } |
+ STDMETHOD(Start)(LPCWSTR url, IInternetProtocolSink *protocol_sink, |
+ IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { |
+ transaction_.QueryFrom(protocol_sink); |
+ // Return some unusual error code. |
+ return INET_E_INVALID_CERTIFICATE; |
} |
-} |
-HRESULT ProtocolSinkWrap::CheckAndReportChromeMimeTypeForRequest() { |
- if (!is_undetermined()) |
- return S_OK; |
+ STDMETHOD(Continue)(PROTOCOLDATA* protocol_data) { return S_OK; } |
+ STDMETHOD(Abort)(HRESULT reason, DWORD options) { return S_OK; } |
+ STDMETHOD(Terminate)(DWORD options) { return S_OK; } |
+ STDMETHOD(Suspend)() { return S_OK; } |
+ STDMETHOD(Resume)() { return S_OK; } |
+ STDMETHOD(Read)(void *buffer, ULONG size, ULONG* size_read) { return S_OK; } |
+ STDMETHOD(Seek)(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_pos) |
+ { return S_OK; } |
+ STDMETHOD(LockRequest)(DWORD options) { return S_OK; } |
+ STDMETHOD(UnlockRequest)() { return S_OK; } |
- // This function could get invoked recursively in the context of |
- // IInternetProtocol::Read. Check for the same and bail. |
- if (determining_renderer_type_) |
- return S_OK; |
+ ScopedComPtr<IInternetProtocol> transaction_; |
+}; |
- determining_renderer_type_ = true; |
+struct FakeFactory : public IClassFactory, |
+ public CComObjectRootEx<CComSingleThreadModel> { |
+ BEGIN_COM_MAP(FakeFactory) |
+ COM_INTERFACE_ENTRY(IClassFactory) |
+ END_COM_MAP() |
- HRESULT hr_read = S_OK; |
- while (hr_read == S_OK) { |
- ULONG size_read = 0; |
- hr_read = protocol_->Read(buffer_ + buffer_size_, |
- kMaxContentSniffLength - buffer_size_, &size_read); |
- buffer_size_ += size_read; |
+ STDMETHOD(CreateInstance)(IUnknown *pUnkOuter, REFIID riid, void **ppvObj) { |
+ if (pUnkOuter) |
+ return CLASS_E_NOAGGREGATION; |
+ HRESULT hr = obj_->QueryInterface(riid, ppvObj); |
+ return hr; |
+ } |
- // Attempt to determine the renderer type if we have received |
- // sufficient data. Do not attempt this when we are called recursively. |
- if (report_data_recursiveness_ < 2 && (S_FALSE == hr_read) || |
- (buffer_size_ >= kMaxContentSniffLength)) { |
- DetermineRendererType(); |
- if (renderer_type() == CHROME) { |
- // Workaround for IE 8 and "nosniff". See: |
- // http://blogs.msdn.com/ie/archive/2008/09/02/ie8-security-part-vi-beta-2-update.aspx |
- delegate_->ReportProgress( |
- BINDSTATUS_SERVER_MIMETYPEAVAILABLE, kChromeMimeType); |
- // For IE < 8. |
- delegate_->ReportProgress( |
- BINDSTATUS_MIMETYPEAVAILABLE, kChromeMimeType); |
+ STDMETHOD(LockServer)(BOOL fLock) { |
+ return S_OK; |
+ } |
- delegate_->ReportProgress( |
- BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, kChromeMimeType); |
+ IUnknown* obj_; |
+}; |
- delegate_->ReportData( |
- BSCF_FIRSTDATANOTIFICATION, 0, 0); |
+static void HookTransactionVtable(IInternetProtocol* p) { |
+ ScopedComPtr<IInternetProtocolEx> ex; |
+ ex.QueryFrom(p); |
- delegate_->ReportData( |
- BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE, 0, 0); |
- } |
- break; |
- } |
+ HRESULT hr = vtable_patch::PatchInterfaceMethods(p, CTransaction_PatchInfo); |
+ if (hr == S_OK && ex) { |
+ vtable_patch::PatchInterfaceMethods(ex.get(), CTransaction2_PatchInfo); |
} |
- |
- determining_renderer_type_ = false; |
- return hr_read; |
} |
-HRESULT ProtocolSinkWrap::OnReadImpl(void* buffer, ULONG size, ULONG* size_read, |
- InternetProtocol_Read_Fn orig_read) { |
- // We want to switch the renderer to chrome, we cannot return any |
- // data now. |
- if (CHROME == renderer_type()) |
- return S_FALSE; |
- |
- // Serve data from our buffer first. |
- if (OTHER == renderer_type()) { |
- const ULONG bytes_to_copy = std::min(buffer_size_ - buffer_pos_, size); |
- if (bytes_to_copy) { |
- memcpy(buffer, buffer_ + buffer_pos_, bytes_to_copy); |
- *size_read = bytes_to_copy; |
- buffer_pos_ += bytes_to_copy; |
- return S_OK; |
- } |
+void TransactionHooks::InstallHooks() { |
+ if (IS_PATCHED(CTransaction)) { |
+ DLOG(WARNING) << __FUNCTION__ << " called more than once."; |
+ return; |
} |
- return orig_read(protocol_, buffer, size, size_read); |
-} |
+ CComObjectStackEx<FakeProtocol> prot; |
+ CComObjectStackEx<FakeFactory> factory; |
+ factory.obj_ = &prot; |
+ ScopedComPtr<IInternetSession> session; |
+ HRESULT hr = ::CoInternetGetSession(0, session.Receive(), 0); |
+ hr = session->RegisterNameSpace(&factory, CLSID_NULL, L"611", 0, 0, 0); |
+ DLOG_IF(FATAL, FAILED(hr)) << "Failed to register namespace"; |
+ if (hr != S_OK) |
+ return; |
-scoped_refptr<ProtocolSinkWrap> ProtocolSinkWrap::InstanceFromProtocol( |
- IInternetProtocol* protocol) { |
- CComCritSecLock<CComAutoCriticalSection> lock(sink_map_lock_); |
- scoped_refptr<ProtocolSinkWrap> instance; |
- ProtocolSinkMap::iterator it = sink_map_.find(protocol); |
- if (sink_map_.end() != it) |
- instance = it->second; |
- return instance; |
-} |
+ do { |
+ ScopedComPtr<IMoniker> mk; |
+ ScopedComPtr<IBindCtx> bc; |
+ ScopedComPtr<IStream> stream; |
+ hr = ::CreateAsyncBindCtxEx(0, 0, 0, 0, bc.Receive(), 0); |
+ DLOG_IF(FATAL, FAILED(hr)) << "CreateAsyncBindCtxEx failed " << hr; |
+ if (hr != S_OK) |
+ break; |
-ScopedComPtr<IInternetProtocolSink> ProtocolSinkWrap::MaybeWrapSink( |
- IInternetProtocol* protocol, IInternetProtocolSink* prot_sink, |
- const wchar_t* url) { |
- ScopedComPtr<IInternetProtocolSink> sink_to_use(prot_sink); |
+ hr = ::CreateURLMoniker(NULL, L"611://512", mk.Receive()); |
+ DLOG_IF(FATAL, FAILED(hr)) << "CreateURLMoniker failed " << hr; |
+ if (hr != S_OK) |
+ break; |
- // FYI: GUID_NULL doesn't work when the URL is being loaded from history. |
- // asking for IID_IHttpNegotiate as the service id works, but |
- // getting the IWebBrowser2 interface still doesn't work. |
- ScopedComPtr<IHttpNegotiate> http_negotiate; |
- HRESULT hr = DoQueryService(GUID_NULL, prot_sink, http_negotiate.Receive()); |
+ hr = mk->BindToStorage(bc, NULL, IID_IStream, |
+ reinterpret_cast<void**>(stream.Receive())); |
+ DLOG_IF(FATAL, hr != INET_E_INVALID_CERTIFICATE) << |
+ "BindToStorage failed " << hr; |
+ } while (0); |
- if (http_negotiate && !IsSubFrameRequest(http_negotiate)) { |
- CComObject<ProtocolSinkWrap>* wrap = NULL; |
- CComObject<ProtocolSinkWrap>::CreateInstance(&wrap); |
- DCHECK(wrap); |
- if (wrap) { |
- wrap->AddRef(); |
- if (wrap->Initialize(protocol, prot_sink, url)) { |
- sink_to_use = wrap; |
- } |
- wrap->Release(); |
- } |
+ hr = session->UnregisterNameSpace(&factory, L"611"); |
+ if (prot.transaction_) { |
+ HookTransactionVtable(prot.transaction_); |
+ // Explicit release, otherwise ~CComObjectStackEx will complain about |
+ // outstanding reference to us, because it runs before ~FakeProtocol |
+ prot.transaction_.Release(); |
} |
+} |
- return sink_to_use; |
+void TransactionHooks::RevertHooks() { |
+ vtable_patch::UnpatchInterfaceMethods(CTransaction_PatchInfo); |
+ vtable_patch::UnpatchInterfaceMethods(CTransaction2_PatchInfo); |
} |