Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(84)

Side by Side Diff: chrome_frame/protocol_sink_wrap.cc

Issue 2620001: A new way of hooking internet protocols. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src/
Patch Set: '' Created 10 years, 6 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch | Annotate | Revision Log
« no previous file with comments | « chrome_frame/protocol_sink_wrap.h ('k') | chrome_frame/test/test_mock_with_web_server.cc » ('j') | no next file with comments »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
OLDNEW
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. 1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be 2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file. 3 // found in the LICENSE file.
4 4
5 #include <htiframe.h> 5 #include <htiframe.h>
6 #include <mshtml.h> 6 #include <mshtml.h>
7 7
8 #include "chrome_frame/protocol_sink_wrap.h" 8 #include "chrome_frame/protocol_sink_wrap.h"
9 9
10 #include "base/logging.h" 10 #include "base/logging.h"
11 #include "base/registry.h" 11 #include "base/registry.h"
12 #include "base/scoped_bstr_win.h" 12 #include "base/scoped_bstr_win.h"
13 #include "base/singleton.h" 13 #include "base/singleton.h"
14 #include "base/string_util.h" 14 #include "base/string_util.h"
15 15
16 #include "chrome_frame/bind_context_info.h"
17 #include "chrome_frame/function_stub.h"
16 #include "chrome_frame/utils.h" 18 #include "chrome_frame/utils.h"
17 19
18 // BINDSTATUS_SERVER_MIMETYPEAVAILABLE == 54. Introduced in IE 8, so 20 // BINDSTATUS_SERVER_MIMETYPEAVAILABLE == 54. Introduced in IE 8, so
19 // not in everyone's headers yet. See: 21 // not in everyone's headers yet. See:
20 // http://msdn.microsoft.com/en-us/library/ms775133(VS.85,loband).aspx 22 // http://msdn.microsoft.com/en-us/library/ms775133(VS.85,loband).aspx
21 #ifndef BINDSTATUS_SERVER_MIMETYPEAVAILABLE 23 #ifndef BINDSTATUS_SERVER_MIMETYPEAVAILABLE
22 #define BINDSTATUS_SERVER_MIMETYPEAVAILABLE 54 24 #define BINDSTATUS_SERVER_MIMETYPEAVAILABLE 54
23 #endif 25 #endif
24 26
25 static const char kTextHtmlMimeType[] = "text/html"; 27 static const char kTextHtmlMimeType[] = "text/html";
26 const wchar_t kUrlMonDllName[] = L"urlmon.dll"; 28 const wchar_t kUrlMonDllName[] = L"urlmon.dll";
27 29
28 static const int kInternetProtocolStartIndex = 3; 30 static const int kInternetProtocolStartIndex = 3;
29 static const int kInternetProtocolReadIndex = 9; 31 static const int kInternetProtocolReadIndex = 9;
30 static const int kInternetProtocolStartExIndex = 13; 32 static const int kInternetProtocolStartExIndex = 13;
31 33
32 // TODO(ananta) 34
33 // We should avoid duplicate VTable declarations. 35 // IInternetProtocol/Ex patches.
34 BEGIN_VTABLE_PATCHES(IInternetProtocol) 36 STDMETHODIMP Hook_Start(InternetProtocol_Start_Fn orig_start,
35 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) 37 IInternetProtocol* protocol,
36 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) 38 LPCWSTR url,
39 IInternetProtocolSink* prot_sink,
40 IInternetBindInfo* bind_info,
41 DWORD flags,
42 HANDLE_PTR reserved);
43
44 STDMETHODIMP Hook_StartEx(InternetProtocol_StartEx_Fn orig_start_ex,
45 IInternetProtocolEx* protocol,
46 IUri* uri,
47 IInternetProtocolSink* prot_sink,
48 IInternetBindInfo* bind_info,
49 DWORD flags,
50 HANDLE_PTR reserved);
51
52 STDMETHODIMP Hook_Read(InternetProtocol_Read_Fn orig_read,
53 IInternetProtocol* protocol,
54 void* buffer,
55 ULONG size,
56 ULONG* size_read);
57
58 /////////////////////////////////////////////////////////////////////////////
59 BEGIN_VTABLE_PATCHES(CTransaction)
60 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, Hook_Start)
61 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, Hook_Read)
37 END_VTABLE_PATCHES() 62 END_VTABLE_PATCHES()
38 63
39 BEGIN_VTABLE_PATCHES(IInternetProtocolSecure) 64 BEGIN_VTABLE_PATCHES(CTransaction2)
40 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) 65 VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, Hook_StartEx)
41 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead)
42 END_VTABLE_PATCHES()
43
44 BEGIN_VTABLE_PATCHES(IInternetProtocolEx)
45 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart)
46 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead)
47 VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx)
48 END_VTABLE_PATCHES()
49
50 BEGIN_VTABLE_PATCHES(IInternetProtocolExSecure)
51 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart)
52 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead)
53 VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx)
54 END_VTABLE_PATCHES() 66 END_VTABLE_PATCHES()
55 67
56 // 68 //
57 // ProtocolSinkWrap implementation 69 // ProtocolSinkWrap implementation
58 //
59 70
60 // Static map initialization 71 // Static map initialization
61 ProtocolSinkWrap::ProtocolSinkMap ProtocolSinkWrap::sink_map_; 72 ProtData::ProtocolDataMap ProtData::datamap_;
62 CComAutoCriticalSection ProtocolSinkWrap::sink_map_lock_; 73 Lock ProtData::datamap_lock_;
63 74
64 ProtocolSinkWrap::ProtocolSinkWrap() 75 ProtocolSinkWrap::ProtocolSinkWrap() {
65 : protocol_(NULL), renderer_type_(UNDETERMINED), 76 DLOG(INFO) << __FUNCTION__ << StringPrintf(" 0x%08X", this);
66 buffer_size_(0), buffer_pos_(0), is_saved_result_(false),
67 result_code_(0), result_error_(0), report_data_recursiveness_(0),
68 determining_renderer_type_(false) {
69 memset(buffer_, 0, arraysize(buffer_));
70 } 77 }
71 78
72 ProtocolSinkWrap::~ProtocolSinkWrap() { 79 ProtocolSinkWrap::~ProtocolSinkWrap() {
73 // This object may be destroyed before Initialize is called. 80 DLOG(INFO) << __FUNCTION__ << StringPrintf(" 0x%08X", this);
74 if (protocol_ != NULL) { 81 }
75 CComCritSecLock<CComAutoCriticalSection> lock(sink_map_lock_); 82
76 DCHECK(sink_map_.end() != sink_map_.find(protocol_)); 83 ScopedComPtr<IInternetProtocolSink> ProtocolSinkWrap::CreateNewSink(
77 sink_map_.erase(protocol_); 84 IInternetProtocolSink* sink, ProtData* data) {
78 protocol_ = NULL; 85 DCHECK(sink != NULL);
79 } 86 DCHECK(data != NULL);
80 DLOG(INFO) << "ProtocolSinkWrap: active sinks: " << sink_map_.size(); 87 CComObject<ProtocolSinkWrap>* new_sink = NULL;
81 } 88 CComObject<ProtocolSinkWrap>::CreateInstance(&new_sink);
82 89 new_sink->delegate_ = sink;
83 bool ProtocolSinkWrap::PatchProtocolHandlers() { 90 new_sink->prot_data_ = data;
84 HRESULT hr = PatchProtocolMethods(CLSID_HttpProtocol, 91 return ScopedComPtr<IInternetProtocolSink>(new_sink);
85 IInternetProtocol_PatchInfo,
86 IInternetProtocolEx_PatchInfo);
87 if (FAILED(hr)) {
88 NOTREACHED() << "Failed to patch IInternetProtocol interface."
89 << " Error: " << hr;
90 return false;
91 }
92
93 hr = PatchProtocolMethods(CLSID_HttpSProtocol,
94 IInternetProtocolSecure_PatchInfo,
95 IInternetProtocolExSecure_PatchInfo);
96 if (FAILED(hr)) {
97 NOTREACHED() << "Failed to patch IInternetProtocol secure interface."
98 << " Error: " << hr;
99 return false;
100 }
101
102 return true;
103 }
104
105 void ProtocolSinkWrap::UnpatchProtocolHandlers() {
106 vtable_patch::UnpatchInterfaceMethods(IInternetProtocol_PatchInfo);
107 vtable_patch::UnpatchInterfaceMethods(IInternetProtocolEx_PatchInfo);
108 vtable_patch::UnpatchInterfaceMethods(IInternetProtocolSecure_PatchInfo);
109 vtable_patch::UnpatchInterfaceMethods(IInternetProtocolExSecure_PatchInfo);
110 }
111
112 HRESULT ProtocolSinkWrap::CreateProtocolHandlerInstance(
113 const CLSID& clsid, IInternetProtocol** protocol) {
114 if (!protocol) {
115 return E_INVALIDARG;
116 }
117
118 HMODULE module = ::GetModuleHandle(kUrlMonDllName);
119 if (!module) {
120 NOTREACHED() << "urlmon is not yet loaded. Error: " << GetLastError();
121 return E_FAIL;
122 }
123
124 typedef HRESULT (WINAPI* DllGetClassObject_Fn)(REFCLSID, REFIID, LPVOID*);
125 DllGetClassObject_Fn fn = reinterpret_cast<DllGetClassObject_Fn>(
126 ::GetProcAddress(module, "DllGetClassObject"));
127 if (!fn) {
128 NOTREACHED() << "DllGetClassObject not found in urlmon.dll";
129 return E_FAIL;
130 }
131
132 ScopedComPtr<IClassFactory> protocol_class_factory;
133 HRESULT hr = fn(clsid, IID_IClassFactory,
134 reinterpret_cast<LPVOID*>(protocol_class_factory.Receive()));
135 if (FAILED(hr)) {
136 NOTREACHED() << "DllGetclassObject failed. Error: " << hr;
137 return hr;
138 }
139
140 ScopedComPtr<IInternetProtocol> handler_instance;
141 hr = protocol_class_factory->CreateInstance(NULL, IID_IInternetProtocol,
142 reinterpret_cast<void**>(handler_instance.Receive()));
143 if (FAILED(hr)) {
144 NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol."
145 << " Error: " << hr;
146 } else {
147 *protocol = handler_instance.Detach();
148 }
149
150 return hr;
151 }
152
153 HRESULT ProtocolSinkWrap::PatchProtocolMethods(
154 const CLSID& clsid_protocol,
155 vtable_patch::MethodPatchInfo* protocol_patch_info,
156 vtable_patch::MethodPatchInfo* protocol_ex_patch_info) {
157 if (!protocol_patch_info || !protocol_ex_patch_info) {
158 return E_INVALIDARG;
159 }
160
161 ScopedComPtr<IInternetProtocol> http_protocol;
162 HRESULT hr = CreateProtocolHandlerInstance(clsid_protocol,
163 http_protocol.Receive());
164 if (FAILED(hr)) {
165 NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol."
166 << " Error: " << hr;
167 return false;
168 }
169
170 ScopedComPtr<IInternetProtocolEx> ipex;
171 ipex.QueryFrom(http_protocol);
172 if (ipex) {
173 hr = vtable_patch::PatchInterfaceMethods(ipex, protocol_ex_patch_info);
174 } else {
175 hr = vtable_patch::PatchInterfaceMethods(http_protocol,
176 protocol_patch_info);
177 }
178 return hr;
179 }
180
181 // IInternetProtocol/Ex method implementation.
182 HRESULT ProtocolSinkWrap::OnStart(InternetProtocol_Start_Fn orig_start,
183 IInternetProtocol* protocol, LPCWSTR url, IInternetProtocolSink* prot_sink,
184 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {
185 DCHECK(orig_start);
186 DLOG_IF(INFO, url != NULL) << "OnStart: " << url;
187
188 ScopedComPtr<IInternetProtocolSink> sink_to_use(MaybeWrapSink(protocol,
189 prot_sink, url));
190 return orig_start(protocol, url, sink_to_use, bind_info, flags, reserved);
191 }
192
193 HRESULT ProtocolSinkWrap::OnStartEx(InternetProtocol_StartEx_Fn orig_start_ex,
194 IInternetProtocolEx* protocol, IUri* uri, IInternetProtocolSink* prot_sink,
195 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {
196 DCHECK(orig_start_ex);
197
198 ScopedBstr url;
199 uri->GetPropertyBSTR(Uri_PROPERTY_ABSOLUTE_URI, url.Receive(), 0);
200 DLOG_IF(INFO, url != NULL) << "OnStartEx: " << url;
201
202 ScopedComPtr<IInternetProtocolSink> sink_to_use(MaybeWrapSink(protocol,
203 prot_sink, url));
204 return orig_start_ex(protocol, uri, sink_to_use, bind_info, flags, reserved);
205 }
206
207 HRESULT ProtocolSinkWrap::OnRead(InternetProtocol_Read_Fn orig_read,
208 IInternetProtocol* protocol, void* buffer, ULONG size, ULONG* size_read) {
209 DCHECK(orig_read);
210
211 scoped_refptr<ProtocolSinkWrap> instance =
212 ProtocolSinkWrap::InstanceFromProtocol(protocol);
213 HRESULT hr;
214 if (instance) {
215 DCHECK(instance->protocol_ == protocol);
216 hr = instance->OnReadImpl(buffer, size, size_read, orig_read);
217 } else {
218 hr = orig_read(protocol, buffer, size, size_read);
219 }
220
221 return hr;
222 }
223
224 bool ProtocolSinkWrap::Initialize(IInternetProtocol* protocol,
225 IInternetProtocolSink* original_sink, const wchar_t* url) {
226 DCHECK(original_sink);
227 delegate_ = original_sink;
228 protocol_ = protocol;
229 if (url)
230 url_ = url;
231
232 CComCritSecLock<CComAutoCriticalSection> lock(sink_map_lock_);
233 DCHECK(sink_map_.end() == sink_map_.find(protocol));
234 sink_map_[protocol] = this;
235 DLOG(INFO) << "ProtocolSinkWrap: active sinks: " << sink_map_.size();
236 return true;
237 } 92 }
238 93
239 // IInternetProtocolSink methods 94 // IInternetProtocolSink methods
240 STDMETHODIMP ProtocolSinkWrap::Switch(PROTOCOLDATA* protocol_data) { 95 STDMETHODIMP ProtocolSinkWrap::Switch(PROTOCOLDATA* protocol_data) {
241 HRESULT hr = E_FAIL; 96 HRESULT hr = E_FAIL;
242 if (delegate_) 97 if (delegate_)
243 hr = delegate_->Switch(protocol_data); 98 hr = delegate_->Switch(protocol_data);
244 return hr; 99 return hr;
245 } 100 }
246 101
247 STDMETHODIMP ProtocolSinkWrap::ReportProgress(ULONG status_code, 102 STDMETHODIMP ProtocolSinkWrap::ReportProgress(ULONG status_code,
248 LPCWSTR status_text) { 103 LPCWSTR status_text) {
249 DLOG(INFO) << "ProtocolSinkWrap::ReportProgress: Code:" << status_code << 104 DLOG(INFO) << "ProtocolSinkWrap::ReportProgress: "
250 " Text: " << (status_text ? status_text : L""); 105 << BindStatus2Str(status_code)
251 if (!delegate_) { 106 << " Status: " << (status_text ? status_text : L"");
252 return E_FAIL; 107
253 } 108 HRESULT hr = prot_data_->ReportProgress(delegate_, status_code, status_text);
254 if ((BINDSTATUS_MIMETYPEAVAILABLE == status_code) ||
255 (BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE == status_code)) {
256 // If we have a MIMETYPE and that MIMETYPE is not "text/html". we don't
257 // want to do anything with this.
258 if (status_text) {
259 size_t status_text_length = lstrlenW(status_text);
260 const wchar_t* status_text_end = status_text + std::min(
261 status_text_length, arraysize(kTextHtmlMimeType) - 1);
262 if (!LowerCaseEqualsASCII(status_text, status_text_end,
263 kTextHtmlMimeType)) {
264 renderer_type_ = OTHER;
265 }
266 }
267 }
268
269 HRESULT hr = S_OK;
270 if (delegate_ && renderer_type_ != CHROME) {
271 hr = delegate_->ReportProgress(status_code, status_text);
272 }
273 return hr; 109 return hr;
274 } 110 }
275 111
276 STDMETHODIMP ProtocolSinkWrap::ReportData(DWORD flags, ULONG progress, 112 STDMETHODIMP ProtocolSinkWrap::ReportData(DWORD flags, ULONG progress,
277 ULONG max_progress) { 113 ULONG max_progress) {
278 DCHECK(protocol_);
279 DCHECK(delegate_); 114 DCHECK(delegate_);
280 DLOG(INFO) << "ProtocolSinkWrap::ReportData: flags: " << flags << 115 DLOG(INFO) << "ProtocolSinkWrap::ReportData: " << Bscf2Str(flags) <<
281 " progress: " << progress << " progress_max: " << max_progress; 116 " progress: " << progress << " progress_max: " << max_progress;
282 117
283 scoped_refptr<ProtocolSinkWrap> self_ref(this); 118 HRESULT hr = prot_data_->ReportData(delegate_, flags, progress, max_progress);
284
285 // Maintain a stack depth to make a determination. ReportData is called
286 // recursively in IE8. If the request can be served in a single Read, the
287 // situation ends up like this:
288 // orig_prot
289 // |--> ProtocolSinkWrap::ReportData (BSCF_FIRSTDATANOTIFICATION)
290 // |--> orig_prot->Read(...) - 1st read - S_OK and data
291 // |--> ProtocolSinkWrap::ReportData (BSCF_LASTDATANOTIFICATION)
292 // |--> orig_prot->Read(...) - 2nd read S_FALSE, 0 bytes
293 //
294 // Inner call returns S_FALSE and no data. We try to make a determination
295 // of render type then and incorrectly set it to 'OTHER' as we don't have
296 // any data yet. However, we can make a determination in the context of
297 // outer ReportData since the first read will return S_OK with data. Then
298 // the next Read in the loop will return S_FALSE and we will enter the
299 // determination logic.
300
301 // NOTE: We use the report_data_recursiveness_ variable to detect situations
302 // in which calls to ReportData are re-entrant (such as when the entire
303 // contents of a page fit inside a single packet). In these cases, we
304 // don't care about re-entrant calls beyond the second, and so we compare
305 // report_data_recursiveness_ inside the while loop, making sure we skip
306 // what would otherwise be spurious calls to ReportProgress().
307 report_data_recursiveness_++;
308
309 HRESULT hr = S_OK;
310 if (is_undetermined()) {
311 CheckAndReportChromeMimeTypeForRequest();
312 }
313
314 // we call original only if the renderer type is other
315 if (renderer_type() == OTHER) {
316 hr = delegate_->ReportData(flags, progress, max_progress);
317
318 if (is_saved_result_) {
319 is_saved_result_ = false;
320 delegate_->ReportResult(result_code_, result_error_,
321 result_text_.c_str());
322 }
323 }
324
325 report_data_recursiveness_--;
326 return hr; 119 return hr;
327 } 120 }
328 121
329 STDMETHODIMP ProtocolSinkWrap::ReportResult(HRESULT result, DWORD error, 122 STDMETHODIMP ProtocolSinkWrap::ReportResult(HRESULT result, DWORD error,
330 LPCWSTR result_text) { 123 LPCWSTR result_text) {
331 DLOG(INFO) << "ProtocolSinkWrap::ReportResult: result: " << result << 124 DLOG(INFO) << "ProtocolSinkWrap::ReportResult: result: " << result <<
332 " error: " << error << " Text: " << (result_text ? result_text : L""); 125 " error: " << error << " Text: " << (result_text ? result_text : L"");
333 126 DCHECK_NE(UNDETERMINED, prot_data_->renderer_type());
334 // If this request failed, we don't want to have anything to do with this.
335 if (FAILED(result))
336 renderer_type_ = OTHER;
337
338 // if we are still not sure about the renderer type, cache the result,
339 // othewise urlmon will get confused about getting reported about a
340 // success result for which it never received any data.
341 if (is_undetermined()) {
342 is_saved_result_ = true;
343 result_code_ = result;
344 result_error_ = error;
345 if (result_text)
346 result_text_ = result_text;
347 return S_OK;
348 }
349 127
350 HRESULT hr = E_FAIL; 128 HRESULT hr = E_FAIL;
351 if (delegate_) 129 if (delegate_)
352 hr = delegate_->ReportResult(result, error, result_text); 130 hr = delegate_->ReportResult(result, error, result_text);
353 131
354 return hr; 132 return hr;
355 } 133 }
356 134
357 // IInternetBindInfoEx 135
358 STDMETHODIMP ProtocolSinkWrap::GetBindInfo(DWORD* flags, 136 // Helpers
359 BINDINFO* bind_info_ret) { 137 ScopedComPtr<IBindCtx> BindCtxFromIBindInfo(IInternetBindInfo* bind_info) {
360 ScopedComPtr<IInternetBindInfo> bind_info; 138 LPOLESTR bind_ctx_string = NULL;
361 HRESULT hr = bind_info.QueryFrom(delegate_); 139 ULONG count;
362 if (bind_info) 140 ScopedComPtr<IBindCtx> bind_ctx;
363 hr = bind_info->GetBindInfo(flags, bind_info_ret); 141 bind_info->GetBindString(BINDSTRING_PTR_BIND_CONTEXT, &bind_ctx_string, 1,
364 return hr; 142 &count);
365 } 143 if (bind_ctx_string) {
366 144 IBindCtx* pbc = reinterpret_cast<IBindCtx*>(StringToInt(bind_ctx_string));
367 STDMETHODIMP ProtocolSinkWrap::GetBindString(ULONG string_type, 145 bind_ctx.Attach(pbc);
368 LPOLESTR* string_array, ULONG array_size, ULONG* size_returned) { 146 CoTaskMemFree(bind_ctx_string);
369 ScopedComPtr<IInternetBindInfo> bind_info; 147 }
370 HRESULT hr = bind_info.QueryFrom(delegate_); 148
371 if (bind_info) 149 return bind_ctx;
372 hr = bind_info->GetBindString(string_type, string_array, 150 }
373 array_size, size_returned); 151
374 return hr; 152 bool ShouldWrapSink(IInternetProtocolSink* sink, const wchar_t* url) {
375 } 153 // TODO(stoyan): check the url scheme for http/https.
376 154 ScopedComPtr<IHttpNegotiate> http_negotiate;
377 STDMETHODIMP ProtocolSinkWrap::GetBindInfoEx(DWORD* flags, BINDINFO* bind_info, 155 HRESULT hr = DoQueryService(GUID_NULL, sink, http_negotiate.Receive());
378 DWORD* bindf2, DWORD* reserved) { 156 if (http_negotiate && !IsSubFrameRequest(http_negotiate))
379 ScopedComPtr<IInternetBindInfoEx> bind_info_ex; 157 return true;
380 HRESULT hr = bind_info_ex.QueryFrom(delegate_); 158
381 if (bind_info_ex) 159 return false;
382 hr = bind_info_ex->GetBindInfoEx(flags, bind_info, bindf2, reserved); 160 }
383 return hr; 161
384 } 162 // High level helpers
385 163 bool IsCFRequest(IBindCtx* pbc) {
386 // IServiceProvider 164 ScopedComPtr<BindContextInfo> info;
387 STDMETHODIMP ProtocolSinkWrap::QueryService(REFGUID service_guid, 165 BindContextInfo::FromBindContext(pbc, info.Receive());
388 REFIID riid, void** service) { 166 DCHECK(info);
389 ScopedComPtr<IServiceProvider> service_provider; 167 if (info && info->chrome_request())
390 HRESULT hr = service_provider.QueryFrom(delegate_); 168 return true;
391 if (service_provider) 169
392 hr = service_provider->QueryService(service_guid, riid, service); 170 return false;
393 return hr; 171 }
394 } 172
395 173 void PutProtData(IBindCtx* pbc, ProtData* data) {
396 // IAuthenticate 174 ScopedComPtr<BindContextInfo> info;
397 STDMETHODIMP ProtocolSinkWrap::Authenticate(HWND* window, 175 BindContextInfo::FromBindContext(pbc, info.Receive());
398 LPWSTR* user_name, LPWSTR* password) { 176 if (info)
399 ScopedComPtr<IAuthenticate> authenticate; 177 info->set_prot_data(data);
400 HRESULT hr = authenticate.QueryFrom(delegate_); 178 }
401 if (authenticate) 179
402 hr = authenticate->Authenticate(window, user_name, password); 180 bool IsTextHtml(const wchar_t* status_text) {
403 return hr; 181 if (!status_text)
404 } 182 return false;
405 183 size_t status_text_length = lstrlenW(status_text);
406 // IInternetProtocolEx 184 const wchar_t* status_text_end = status_text +
407 STDMETHODIMP ProtocolSinkWrap::Start(LPCWSTR url, 185 std::min(status_text_length, arraysize(kTextHtmlMimeType) - 1);
408 IInternetProtocolSink *protocol_sink, IInternetBindInfo* bind_info, 186 bool is_text_html = LowerCaseEqualsASCII(status_text, status_text_end,
409 DWORD flags, HANDLE_PTR reserved) { 187 kTextHtmlMimeType);
410 ScopedComPtr<IInternetProtocolRoot> protocol; 188 return is_text_html;
411 HRESULT hr = protocol.QueryFrom(delegate_); 189 }
412 if (protocol) 190
413 hr = protocol->Start(url, protocol_sink, bind_info, flags, reserved); 191 RendererType DetermineRendererType(void* buffer, DWORD size, bool last_chance) {
414 return hr; 192 RendererType type = UNDETERMINED;
415 } 193 if (last_chance)
416 194 type = OTHER;
417 STDMETHODIMP ProtocolSinkWrap::Continue(PROTOCOLDATA* protocol_data) { 195
418 ScopedComPtr<IInternetProtocolRoot> protocol; 196 std::wstring html_contents;
419 HRESULT hr = protocol.QueryFrom(delegate_); 197 // TODO(joshia): detect and handle different content encodings
420 if (protocol) 198 UTF8ToWide(reinterpret_cast<char*>(buffer), size, &html_contents);
421 hr = protocol->Continue(protocol_data); 199
422 return hr; 200 // Note that document_contents_ may have NULL characters in it. While
423 } 201 // browsers may handle this properly, we don't and will stop scanning
424 202 // for the XUACompat content value if we encounter one.
425 STDMETHODIMP ProtocolSinkWrap::Abort(HRESULT reason, DWORD options) { 203 std::wstring xua_compat_content;
426 ScopedComPtr<IInternetProtocolRoot> protocol; 204 UtilGetXUACompatContentValue(html_contents, &xua_compat_content);
427 HRESULT hr = protocol.QueryFrom(delegate_); 205 if (StrStrI(xua_compat_content.c_str(), kChromeContentPrefix)) {
428 if (protocol) 206 type = CHROME;
429 hr = protocol->Abort(reason, options); 207 }
430 return hr; 208
431 } 209 return type;
432 210 }
433 STDMETHODIMP ProtocolSinkWrap::Terminate(DWORD options) { 211
434 ScopedComPtr<IInternetProtocolRoot> protocol; 212 // ProtData
435 HRESULT hr = protocol.QueryFrom(delegate_); 213 ProtData::ProtData(IInternetProtocol* protocol,
436 if (protocol) 214 InternetProtocol_Read_Fn read_fun, const wchar_t* url)
437 hr = protocol->Terminate(options); 215 : has_suggested_mime_type_(false), has_server_mime_type_(false),
438 return hr; 216 report_data_received_(false), buffer_size_(0), buffer_pos_(0),
439 } 217 renderer_type_(UNDETERMINED), protocol_(protocol), read_fun_(read_fun),
440 218 url_(url) {
441 STDMETHODIMP ProtocolSinkWrap::Suspend() { 219 memset(buffer_, 0, arraysize(buffer_));
442 ScopedComPtr<IInternetProtocolRoot> protocol; 220 DLOG(INFO) << __FUNCTION__ << " " << this;
443 HRESULT hr = protocol.QueryFrom(delegate_); 221
444 if (protocol) 222 // Add to map.
445 hr = protocol->Suspend(); 223 AutoLock lock(datamap_lock_);
446 return hr; 224 DCHECK(datamap_.end() == datamap_.find(protocol_));
447 } 225 datamap_[protocol] = this;
448 226 }
449 STDMETHODIMP ProtocolSinkWrap::Resume() { 227
450 ScopedComPtr<IInternetProtocolRoot> protocol; 228 ProtData::~ProtData() {
451 HRESULT hr = protocol.QueryFrom(delegate_); 229 DLOG(INFO) << __FUNCTION__ << " " << this;
452 if (protocol) 230
453 hr = protocol->Resume(); 231 // Remove from map.
454 return hr; 232 AutoLock lock(datamap_lock_);
455 } 233 DCHECK(datamap_.end() != datamap_.find(protocol_));
456 234 datamap_.erase(protocol_);
457 STDMETHODIMP ProtocolSinkWrap::Read(void *buffer, ULONG size, 235 }
458 ULONG* size_read) { 236
459 ScopedComPtr<IInternetProtocol> protocol; 237 HRESULT ProtData::Read(void* buffer, ULONG size, ULONG* size_read) {
460 HRESULT hr = protocol.QueryFrom(delegate_); 238 if (renderer_type_ == UNDETERMINED) {
461 if (protocol) 239 return E_PENDING;
462 hr = protocol->Read(buffer, size, size_read); 240 }
463 return hr; 241
464 } 242 const ULONG bytes_available = buffer_size_ - buffer_pos_;
465 243 const ULONG bytes_to_copy = std::min(bytes_available, size);
466 STDMETHODIMP ProtocolSinkWrap::Seek(LARGE_INTEGER move, DWORD origin, 244 if (bytes_to_copy) {
467 ULARGE_INTEGER* new_pos) { 245 // Copy from the local buffer.
468 ScopedComPtr<IInternetProtocol> protocol; 246 memcpy(buffer, buffer_ + buffer_pos_, bytes_to_copy);
469 HRESULT hr = protocol.QueryFrom(delegate_); 247 *size_read = bytes_to_copy;
470 if (protocol) 248 buffer_pos_ += bytes_to_copy;
471 hr = protocol->Seek(move, origin, new_pos); 249
472 return hr; 250 HRESULT hr = S_OK;
473 } 251 ULONG new_data = 0;
474 252 if (size > bytes_available) {
475 STDMETHODIMP ProtocolSinkWrap::LockRequest(DWORD options) { 253 // User buffer is greater than what we have.
476 ScopedComPtr<IInternetProtocol> protocol; 254 buffer = reinterpret_cast<uint8*>(buffer) + bytes_to_copy;
477 HRESULT hr = protocol.QueryFrom(delegate_); 255 size -= bytes_to_copy;
478 if (protocol) 256 hr = read_fun_(protocol_, buffer, size, &new_data);
479 hr = protocol->LockRequest(options); 257 }
480 return hr; 258
481 } 259 if (size_read)
482 260 *size_read = bytes_to_copy + new_data;
483 STDMETHODIMP ProtocolSinkWrap::UnlockRequest() { 261 return hr;
484 ScopedComPtr<IInternetProtocol> protocol; 262 }
485 HRESULT hr = protocol.QueryFrom(delegate_); 263
486 if (protocol) 264 return read_fun_(protocol_, buffer, size, size_read);
487 hr = protocol->UnlockRequest(); 265 }
488 return hr; 266
489 } 267
490 268 HRESULT ProtData::ReportProgress(IInternetProtocolSink* delegate,
491 STDMETHODIMP ProtocolSinkWrap::StartEx(IUri* uri, 269 ULONG status_code, LPCWSTR status_text) {
492 IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, 270 switch (status_code) {
493 DWORD flags, HANDLE_PTR reserved) { 271 case BINDSTATUS_DIRECTBIND:
494 ScopedComPtr<IInternetProtocolEx> protocol; 272 renderer_type_ = OTHER;
495 HRESULT hr = protocol.QueryFrom(delegate_); 273 break;
496 if (protocol) 274
497 hr = protocol->StartEx(uri, protocol_sink, bind_info, flags, reserved); 275 case BINDSTATUS_REDIRECTING:
498 return hr; 276 url_.empty();
499 } 277 if (status_text)
500 278 url_ = status_text;
501 // IInternetPriority 279 break;
502 STDMETHODIMP ProtocolSinkWrap::SetPriority(LONG priority) { 280
503 ScopedComPtr<IInternetPriority> internet_priority; 281 case BINDSTATUS_SERVER_MIMETYPEAVAILABLE:
504 HRESULT hr = internet_priority.QueryFrom(delegate_); 282 has_server_mime_type_ = true;
505 if (internet_priority) 283 SaveSuggestedMimeType(status_text);
506 hr = internet_priority->SetPriority(priority); 284 return S_OK;
507 return hr; 285
508 } 286 // TODO(stoyan): BINDSTATUS_RAWMIMETYPE
509 287 case BINDSTATUS_MIMETYPEAVAILABLE:
510 STDMETHODIMP ProtocolSinkWrap::GetPriority(LONG* priority) { 288 case BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE:
511 ScopedComPtr<IInternetPriority> internet_priority; 289 SaveSuggestedMimeType(status_text);
512 HRESULT hr = internet_priority.QueryFrom(delegate_); 290 return S_OK;
513 if (internet_priority) 291 }
514 hr = internet_priority->GetPriority(priority); 292
515 return hr; 293 return delegate->ReportProgress(status_code, status_text);
516 } 294 }
517 295
518 // IWrappedProtocol 296 HRESULT ProtData::ReportData(IInternetProtocolSink* delegate,
519 STDMETHODIMP ProtocolSinkWrap::GetWrapperCode(LONG *code, DWORD_PTR reserved) { 297 DWORD flags, ULONG progress, ULONG max_progress) {
520 ScopedComPtr<IWrappedProtocol> wrapped_protocol; 298 if (renderer_type_ != UNDETERMINED) {
521 HRESULT hr = wrapped_protocol.QueryFrom(delegate_); 299 return delegate->ReportData(flags, progress, max_progress);
522 if (wrapped_protocol) 300 }
523 hr = wrapped_protocol->GetWrapperCode(code, reserved); 301
524 return hr; 302 // Do these checks only once.
525 } 303 if (!report_data_received_) {
526 304 report_data_received_ = true;
527 305
528 // public IUriContainer 306 DLOG_IF(INFO, (flags & BSCF_FIRSTDATANOTIFICATION) == 0) <<
529 STDMETHODIMP ProtocolSinkWrap::GetIUri(IUri** uri) { 307 "BUGBUG: BSCF_FIRSTDATANOTIFICATION is not set properly!";
530 ScopedComPtr<IUriContainer> uri_container; 308
531 HRESULT hr = uri_container.QueryFrom(delegate_); 309
532 if (uri_container) 310 // We check here, instead in ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE)
533 hr = uri_container->GetIUri(uri); 311 // to be safe when following multiple redirects.?
534 return hr; 312 if (!IsTextHtml(suggested_mime_type_)) {
535 } 313 renderer_type_ = OTHER;
536 314 FireSugestedMimeType(delegate);
537 // Protected helpers 315 return delegate->ReportData(flags, progress, max_progress);
538 316 }
539 void ProtocolSinkWrap::DetermineRendererType() { 317
540 if (is_undetermined()) { 318 if (!url_.empty() && IsOptInUrl(url_.c_str())) {
541 if (IsOptInUrl(url_.c_str())) { 319 // TODO(stoyan): We may attempt to remove ourselves from the bind context.
542 renderer_type_ = CHROME; 320 renderer_type_ = CHROME;
543 } else { 321 delegate->ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE, kChromeMimeType);
544 std::wstring xua_compat_content; 322 return delegate->ReportData(flags, progress, max_progress);
545 // Note that document_contents_ may have NULL characters in it. While
546 // browsers may handle this properly, we don't and will stop scanning for
547 // the XUACompat content value if we encounter one.
548 DCHECK(buffer_size_ < arraysize(buffer_));
549 buffer_[buffer_size_] = 0;
550 std::wstring html_contents;
551 // TODO(joshia): detect and handle different content encodings
552 UTF8ToWide(buffer_, buffer_size_, &html_contents);
553 UtilGetXUACompatContentValue(html_contents, &xua_compat_content);
554 if (StrStrI(xua_compat_content.c_str(), kChromeContentPrefix)) {
555 renderer_type_ = CHROME;
556 } else {
557 renderer_type_ = OTHER;
558 }
559 } 323 }
560 } 324 }
561 } 325
562 326 HRESULT hr = FillBuffer();
563 HRESULT ProtocolSinkWrap::CheckAndReportChromeMimeTypeForRequest() { 327
564 if (!is_undetermined()) 328 bool last_chance = false;
329 if (hr == S_OK || hr == S_FALSE) {
330 last_chance = true;
331 }
332
333 renderer_type_ = DetermineRendererType(buffer_, buffer_size_, last_chance);
334
335 if (renderer_type_ == UNDETERMINED) {
336 // do not report anything, we need more data.
565 return S_OK; 337 return S_OK;
566 338 }
567 // This function could get invoked recursively in the context of 339
568 // IInternetProtocol::Read. Check for the same and bail. 340 if (renderer_type_ == CHROME) {
569 if (determining_renderer_type_) 341 DLOG(INFO) << "Forwarding BINDSTATUS_MIMETYPEAVAILABLE "
570 return S_OK; 342 << kChromeMimeType;
571 343 delegate->ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE, kChromeMimeType);
572 determining_renderer_type_ = true; 344 }
573 345
346 if (renderer_type_ == OTHER) {
347 FireSugestedMimeType(delegate);
348 }
349
350 // This is the first data notification we forward.
351 flags |= BSCF_FIRSTDATANOTIFICATION;
352
353 if (hr == S_FALSE) {
354 flags |= (BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE);
355 }
356
357 return delegate->ReportData(flags, progress, max_progress);
358 }
359
360 void ProtData::UpdateUrl(const wchar_t* url) {
361 url_ = url;
362 }
363
364 // S_FALSE - EOF
365 // S_OK - buffer fully filled
366 // E_PENDING - some data added to buffer, but buffer is not yet full
367 // E_XXXX - some other error.
368 HRESULT ProtData::FillBuffer() {
574 HRESULT hr_read = S_OK; 369 HRESULT hr_read = S_OK;
575 while (hr_read == S_OK) { 370
371 while ((hr_read == S_OK) && (buffer_size_ < kMaxContentSniffLength)) {
576 ULONG size_read = 0; 372 ULONG size_read = 0;
577 hr_read = protocol_->Read(buffer_ + buffer_size_, 373 hr_read = read_fun_(protocol_, buffer_ + buffer_size_,
578 kMaxContentSniffLength - buffer_size_, &size_read); 374 kMaxContentSniffLength - buffer_size_, &size_read);
579 buffer_size_ += size_read; 375 buffer_size_ += size_read;
580 376 }
581 // Attempt to determine the renderer type if we have received 377
582 // sufficient data. Do not attempt this when we are called recursively.
583 if (report_data_recursiveness_ < 2 && (S_FALSE == hr_read) ||
584 (buffer_size_ >= kMaxContentSniffLength)) {
585 DetermineRendererType();
586 if (renderer_type() == CHROME) {
587 // Workaround for IE 8 and "nosniff". See:
588 // http://blogs.msdn.com/ie/archive/2008/09/02/ie8-security-part-vi-beta -2-update.aspx
589 delegate_->ReportProgress(
590 BINDSTATUS_SERVER_MIMETYPEAVAILABLE, kChromeMimeType);
591 // For IE < 8.
592 delegate_->ReportProgress(
593 BINDSTATUS_MIMETYPEAVAILABLE, kChromeMimeType);
594
595 delegate_->ReportProgress(
596 BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, kChromeMimeType);
597
598 delegate_->ReportData(
599 BSCF_FIRSTDATANOTIFICATION, 0, 0);
600
601 delegate_->ReportData(
602 BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE, 0, 0);
603 }
604 break;
605 }
606 }
607
608 determining_renderer_type_ = false;
609 return hr_read; 378 return hr_read;
610 } 379 }
611 380
612 HRESULT ProtocolSinkWrap::OnReadImpl(void* buffer, ULONG size, ULONG* size_read, 381 void ProtData::SaveSuggestedMimeType(LPCWSTR status_text) {
613 InternetProtocol_Read_Fn orig_read) { 382 has_suggested_mime_type_ = true;
614 // We want to switch the renderer to chrome, we cannot return any 383 suggested_mime_type_.Allocate(status_text);
615 // data now. 384 }
616 if (CHROME == renderer_type()) 385
617 return S_FALSE; 386 void ProtData::FireSugestedMimeType(IInternetProtocolSink* delegate) {
618 387 if (has_server_mime_type_) {
619 // Serve data from our buffer first. 388 DLOG(INFO) << "Forwarding BINDSTATUS_SERVER_MIMETYPEAVAILABLE "
620 if (OTHER == renderer_type()) { 389 << suggested_mime_type_;
621 const ULONG bytes_to_copy = std::min(buffer_size_ - buffer_pos_, size); 390 delegate->ReportProgress(BINDSTATUS_SERVER_MIMETYPEAVAILABLE,
622 if (bytes_to_copy) { 391 suggested_mime_type_);
623 memcpy(buffer, buffer_ + buffer_pos_, bytes_to_copy); 392 return;
624 *size_read = bytes_to_copy; 393 }
625 buffer_pos_ += bytes_to_copy; 394
626 return S_OK; 395 if (has_suggested_mime_type_) {
627 } 396 DLOG(INFO) << "Forwarding BINDSTATUS_MIMETYPEAVAILABLE "
628 } 397 << suggested_mime_type_;
629 398 delegate->ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE,
630 return orig_read(protocol_, buffer, size, size_read); 399 suggested_mime_type_);
631 } 400 }
632 401 }
633 scoped_refptr<ProtocolSinkWrap> ProtocolSinkWrap::InstanceFromProtocol( 402
403 scoped_refptr<ProtData> ProtData::DataFromProtocol(
634 IInternetProtocol* protocol) { 404 IInternetProtocol* protocol) {
635 CComCritSecLock<CComAutoCriticalSection> lock(sink_map_lock_); 405 scoped_refptr<ProtData> instance;
636 scoped_refptr<ProtocolSinkWrap> instance; 406 AutoLock lock(datamap_lock_);
637 ProtocolSinkMap::iterator it = sink_map_.find(protocol); 407 ProtocolDataMap::iterator it = datamap_.find(protocol);
638 if (sink_map_.end() != it) 408 if (datamap_.end() != it)
639 instance = it->second; 409 instance = it->second;
640 return instance; 410 return instance;
641 } 411 }
642 412
643 ScopedComPtr<IInternetProtocolSink> ProtocolSinkWrap::MaybeWrapSink( 413 // IInternetProtocol/Ex hooks.
644 IInternetProtocol* protocol, IInternetProtocolSink* prot_sink, 414 STDMETHODIMP Hook_Start(InternetProtocol_Start_Fn orig_start,
645 const wchar_t* url) { 415 IInternetProtocol* protocol, LPCWSTR url, IInternetProtocolSink* prot_sink,
646 ScopedComPtr<IInternetProtocolSink> sink_to_use(prot_sink); 416 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {
647 417 DCHECK(orig_start);
648 // FYI: GUID_NULL doesn't work when the URL is being loaded from history. 418 if (!url || !prot_sink || !bind_info)
649 // asking for IID_IHttpNegotiate as the service id works, but 419 return E_INVALIDARG;
650 // getting the IWebBrowser2 interface still doesn't work. 420 DLOG_IF(INFO, url != NULL) << "OnStart: " << url << PiFlags2Str(flags);
651 ScopedComPtr<IHttpNegotiate> http_negotiate; 421
652 HRESULT hr = DoQueryService(GUID_NULL, prot_sink, http_negotiate.Receive()); 422 ScopedComPtr<IBindCtx> bind_ctx = BindCtxFromIBindInfo(bind_info);
653 423 if (!bind_ctx) {
654 if (http_negotiate && !IsSubFrameRequest(http_negotiate)) { 424 // MSHTML sometimes takes a short path, skips the creation of
655 CComObject<ProtocolSinkWrap>* wrap = NULL; 425 // moniker and binding, by directly grabbing protocol from InternetSession
656 CComObject<ProtocolSinkWrap>::CreateInstance(&wrap); 426 DLOG(INFO) << "DirectBind for " << url;
657 DCHECK(wrap); 427 return orig_start(protocol, url, prot_sink, bind_info, flags, reserved);
658 if (wrap) { 428 }
659 wrap->AddRef(); 429
660 if (wrap->Initialize(protocol, prot_sink, url)) { 430 if (IsCFRequest(bind_ctx)) {
661 sink_to_use = wrap; 431 return orig_start(protocol, url, prot_sink, bind_info, flags, reserved);
662 } 432 }
663 wrap->Release(); 433
664 } 434 scoped_refptr<ProtData> prot_data = ProtData::DataFromProtocol(protocol);
665 } 435 if (prot_data) {
666 436 DLOG(INFO) << "Found existing ProtData!";
667 return sink_to_use; 437 prot_data->UpdateUrl(url);
668 } 438 ScopedComPtr<IInternetProtocolSink> new_sink =
439 ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data);
440 return orig_start(protocol, url, new_sink, bind_info, flags, reserved);
441 }
442
443 if (!ShouldWrapSink(prot_sink, url)) {
444 return orig_start(protocol, url, prot_sink, bind_info, flags, reserved);
445 }
446
447 // Fresh request.
448 InternetProtocol_Read_Fn read_fun = reinterpret_cast<InternetProtocol_Read_Fn>
449 (CTransaction_PatchInfo[1].stub_->argument());
450 prot_data = new ProtData(protocol, read_fun, url);
451 PutProtData(bind_ctx, prot_data);
452
453 ScopedComPtr<IInternetProtocolSink> new_sink =
454 ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data);
455 return orig_start(protocol, url, new_sink, bind_info, flags, reserved);
456 }
457
458 STDMETHODIMP Hook_StartEx(InternetProtocol_StartEx_Fn orig_start_ex,
459 IInternetProtocolEx* protocol, IUri* uri, IInternetProtocolSink* prot_sink,
460 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {
461 DCHECK(orig_start_ex);
462 if (!uri || !prot_sink || !bind_info)
463 return E_INVALIDARG;
464
465 ScopedBstr url;
466 uri->GetPropertyBSTR(Uri_PROPERTY_ABSOLUTE_URI, url.Receive(), 0);
467 DLOG_IF(INFO, url != NULL) << "OnStartEx: " << url << PiFlags2Str(flags);
468
469 ScopedComPtr<IBindCtx> bind_ctx = BindCtxFromIBindInfo(bind_info);
470 if (!bind_ctx) {
471 // MSHTML sometimes takes a short path, skips the creation of
472 // moniker and binding, by directly grabbing protocol from InternetSession.
473 DLOG(INFO) << "DirectBind for " << url;
474 return orig_start_ex(protocol, uri, prot_sink, bind_info, flags, reserved);
475 }
476
477 if (IsCFRequest(bind_ctx)) {
478 return orig_start_ex(protocol, uri, prot_sink, bind_info, flags, reserved);
479 }
480
481 scoped_refptr<ProtData> prot_data = ProtData::DataFromProtocol(protocol);
482 if (prot_data) {
483 DLOG(INFO) << "Found existing ProtData!";
484 prot_data->UpdateUrl(url);
485 ScopedComPtr<IInternetProtocolSink> new_sink =
486 ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data);
487 return orig_start_ex(protocol, uri, new_sink, bind_info, flags, reserved);
488 }
489
490 if (!ShouldWrapSink(prot_sink, url)) {
491 return orig_start_ex(protocol, uri, prot_sink, bind_info, flags, reserved);
492 }
493
494 // Fresh request.
495 InternetProtocol_Read_Fn read_fun = reinterpret_cast<InternetProtocol_Read_Fn>
496 (CTransaction_PatchInfo[1].stub_->argument());
497 prot_data = new ProtData(protocol, read_fun, url);
498 PutProtData(bind_ctx, prot_data);
499
500 ScopedComPtr<IInternetProtocolSink> new_sink =
501 ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data);
502 return orig_start_ex(protocol, uri, new_sink, bind_info, flags, reserved);
503 }
504
505 STDMETHODIMP Hook_Read(InternetProtocol_Read_Fn orig_read,
506 IInternetProtocol* protocol, void* buffer, ULONG size, ULONG* size_read) {
507 DCHECK(orig_read);
508 scoped_refptr<ProtData> prot_data = ProtData::DataFromProtocol(protocol);
509 if (!prot_data) {
510 return orig_read(protocol, buffer, size, size_read);
511 }
512
513 HRESULT hr = prot_data->Read(buffer, size, size_read);
514 return hr;
515 }
516
517 // Patching / Hooking code.
518 class FakeProtocol : public CComObjectRootEx<CComSingleThreadModel>,
519 public IInternetProtocol {
520 public:
521 BEGIN_COM_MAP(FakeProtocol)
522 COM_INTERFACE_ENTRY(IInternetProtocol)
523 COM_INTERFACE_ENTRY(IInternetProtocolRoot)
524 END_COM_MAP()
525
526 STDMETHOD(Start)(LPCWSTR url, IInternetProtocolSink *protocol_sink,
527 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) {
528 transaction_.QueryFrom(protocol_sink);
529 // Return some unusual error code.
530 return INET_E_INVALID_CERTIFICATE;
531 }
532
533 STDMETHOD(Continue)(PROTOCOLDATA* protocol_data) { return S_OK; }
534 STDMETHOD(Abort)(HRESULT reason, DWORD options) { return S_OK; }
535 STDMETHOD(Terminate)(DWORD options) { return S_OK; }
536 STDMETHOD(Suspend)() { return S_OK; }
537 STDMETHOD(Resume)() { return S_OK; }
538 STDMETHOD(Read)(void *buffer, ULONG size, ULONG* size_read) { return S_OK; }
539 STDMETHOD(Seek)(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_pos)
540 { return S_OK; }
541 STDMETHOD(LockRequest)(DWORD options) { return S_OK; }
542 STDMETHOD(UnlockRequest)() { return S_OK; }
543
544 ScopedComPtr<IInternetProtocol> transaction_;
545 };
546
547 struct FakeFactory : public IClassFactory,
548 public CComObjectRootEx<CComSingleThreadModel> {
549 BEGIN_COM_MAP(FakeFactory)
550 COM_INTERFACE_ENTRY(IClassFactory)
551 END_COM_MAP()
552
553 STDMETHOD(CreateInstance)(IUnknown *pUnkOuter, REFIID riid, void **ppvObj) {
554 if (pUnkOuter)
555 return CLASS_E_NOAGGREGATION;
556 HRESULT hr = obj_->QueryInterface(riid, ppvObj);
557 return hr;
558 }
559
560 STDMETHOD(LockServer)(BOOL fLock) {
561 return S_OK;
562 }
563
564 IUnknown* obj_;
565 };
566
567 static void HookTransactionVtable(IInternetProtocol* p) {
568 ScopedComPtr<IInternetProtocolEx> ex;
569 ex.QueryFrom(p);
570
571 HRESULT hr = vtable_patch::PatchInterfaceMethods(p, CTransaction_PatchInfo);
572 if (hr == S_OK && ex) {
573 vtable_patch::PatchInterfaceMethods(ex.get(), CTransaction2_PatchInfo);
574 }
575 }
576
577 void TransactionHooks::InstallHooks() {
578 if (IS_PATCHED(CTransaction)) {
579 DLOG(WARNING) << __FUNCTION__ << " called more than once.";
580 return;
581 }
582
583 CComObjectStackEx<FakeProtocol> prot;
584 CComObjectStackEx<FakeFactory> factory;
585 factory.obj_ = &prot;
586 ScopedComPtr<IInternetSession> session;
587 HRESULT hr = ::CoInternetGetSession(0, session.Receive(), 0);
588 hr = session->RegisterNameSpace(&factory, CLSID_NULL, L"611", 0, 0, 0);
589 DLOG_IF(FATAL, FAILED(hr)) << "Failed to register namespace";
590 if (hr != S_OK)
591 return;
592
593 do {
594 ScopedComPtr<IMoniker> mk;
595 ScopedComPtr<IBindCtx> bc;
596 ScopedComPtr<IStream> stream;
597 hr = ::CreateAsyncBindCtxEx(0, 0, 0, 0, bc.Receive(), 0);
598 DLOG_IF(FATAL, FAILED(hr)) << "CreateAsyncBindCtxEx failed " << hr;
599 if (hr != S_OK)
600 break;
601
602 hr = ::CreateURLMoniker(NULL, L"611://512", mk.Receive());
603 DLOG_IF(FATAL, FAILED(hr)) << "CreateURLMoniker failed " << hr;
604 if (hr != S_OK)
605 break;
606
607 hr = mk->BindToStorage(bc, NULL, IID_IStream,
608 reinterpret_cast<void**>(stream.Receive()));
609 DLOG_IF(FATAL, hr != INET_E_INVALID_CERTIFICATE) <<
610 "BindToStorage failed " << hr;
611 } while (0);
612
613 hr = session->UnregisterNameSpace(&factory, L"611");
614 if (prot.transaction_) {
615 HookTransactionVtable(prot.transaction_);
616 // Explicit release, otherwise ~CComObjectStackEx will complain about
617 // outstanding reference to us, because it runs before ~FakeProtocol
618 prot.transaction_.Release();
619 }
620 }
621
622 void TransactionHooks::RevertHooks() {
623 vtable_patch::UnpatchInterfaceMethods(CTransaction_PatchInfo);
624 vtable_patch::UnpatchInterfaceMethods(CTransaction2_PatchInfo);
625 }
OLDNEW
« no previous file with comments | « chrome_frame/protocol_sink_wrap.h ('k') | chrome_frame/test/test_mock_with_web_server.cc » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698