OLD | NEW |
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include <htiframe.h> | 5 #include <htiframe.h> |
6 #include <mshtml.h> | 6 #include <mshtml.h> |
7 | 7 |
8 #include "chrome_frame/protocol_sink_wrap.h" | 8 #include "chrome_frame/protocol_sink_wrap.h" |
9 | 9 |
10 #include "base/logging.h" | 10 #include "base/logging.h" |
11 #include "base/registry.h" | 11 #include "base/registry.h" |
12 #include "base/scoped_bstr_win.h" | 12 #include "base/scoped_bstr_win.h" |
13 #include "base/singleton.h" | 13 #include "base/singleton.h" |
14 #include "base/string_util.h" | 14 #include "base/string_util.h" |
15 | 15 |
| 16 #include "chrome_frame/bind_context_info.h" |
| 17 #include "chrome_frame/function_stub.h" |
16 #include "chrome_frame/utils.h" | 18 #include "chrome_frame/utils.h" |
17 | 19 |
18 // BINDSTATUS_SERVER_MIMETYPEAVAILABLE == 54. Introduced in IE 8, so | 20 // BINDSTATUS_SERVER_MIMETYPEAVAILABLE == 54. Introduced in IE 8, so |
19 // not in everyone's headers yet. See: | 21 // not in everyone's headers yet. See: |
20 // http://msdn.microsoft.com/en-us/library/ms775133(VS.85,loband).aspx | 22 // http://msdn.microsoft.com/en-us/library/ms775133(VS.85,loband).aspx |
21 #ifndef BINDSTATUS_SERVER_MIMETYPEAVAILABLE | 23 #ifndef BINDSTATUS_SERVER_MIMETYPEAVAILABLE |
22 #define BINDSTATUS_SERVER_MIMETYPEAVAILABLE 54 | 24 #define BINDSTATUS_SERVER_MIMETYPEAVAILABLE 54 |
23 #endif | 25 #endif |
24 | 26 |
25 static const char kTextHtmlMimeType[] = "text/html"; | 27 static const char kTextHtmlMimeType[] = "text/html"; |
26 const wchar_t kUrlMonDllName[] = L"urlmon.dll"; | 28 const wchar_t kUrlMonDllName[] = L"urlmon.dll"; |
27 | 29 |
28 static const int kInternetProtocolStartIndex = 3; | 30 static const int kInternetProtocolStartIndex = 3; |
29 static const int kInternetProtocolReadIndex = 9; | 31 static const int kInternetProtocolReadIndex = 9; |
30 static const int kInternetProtocolStartExIndex = 13; | 32 static const int kInternetProtocolStartExIndex = 13; |
31 | 33 |
32 // TODO(ananta) | 34 |
33 // We should avoid duplicate VTable declarations. | 35 // IInternetProtocol/Ex patches. |
34 BEGIN_VTABLE_PATCHES(IInternetProtocol) | 36 STDMETHODIMP Hook_Start(InternetProtocol_Start_Fn orig_start, |
35 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) | 37 IInternetProtocol* protocol, |
36 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) | 38 LPCWSTR url, |
| 39 IInternetProtocolSink* prot_sink, |
| 40 IInternetBindInfo* bind_info, |
| 41 DWORD flags, |
| 42 HANDLE_PTR reserved); |
| 43 |
| 44 STDMETHODIMP Hook_StartEx(InternetProtocol_StartEx_Fn orig_start_ex, |
| 45 IInternetProtocolEx* protocol, |
| 46 IUri* uri, |
| 47 IInternetProtocolSink* prot_sink, |
| 48 IInternetBindInfo* bind_info, |
| 49 DWORD flags, |
| 50 HANDLE_PTR reserved); |
| 51 |
| 52 STDMETHODIMP Hook_Read(InternetProtocol_Read_Fn orig_read, |
| 53 IInternetProtocol* protocol, |
| 54 void* buffer, |
| 55 ULONG size, |
| 56 ULONG* size_read); |
| 57 |
| 58 ///////////////////////////////////////////////////////////////////////////// |
| 59 BEGIN_VTABLE_PATCHES(CTransaction) |
| 60 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, Hook_Start) |
| 61 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, Hook_Read) |
37 END_VTABLE_PATCHES() | 62 END_VTABLE_PATCHES() |
38 | 63 |
39 BEGIN_VTABLE_PATCHES(IInternetProtocolSecure) | 64 BEGIN_VTABLE_PATCHES(CTransaction2) |
40 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) | 65 VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, Hook_StartEx) |
41 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) | |
42 END_VTABLE_PATCHES() | |
43 | |
44 BEGIN_VTABLE_PATCHES(IInternetProtocolEx) | |
45 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) | |
46 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) | |
47 VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx) | |
48 END_VTABLE_PATCHES() | |
49 | |
50 BEGIN_VTABLE_PATCHES(IInternetProtocolExSecure) | |
51 VTABLE_PATCH_ENTRY(kInternetProtocolStartIndex, ProtocolSinkWrap::OnStart) | |
52 VTABLE_PATCH_ENTRY(kInternetProtocolReadIndex, ProtocolSinkWrap::OnRead) | |
53 VTABLE_PATCH_ENTRY(kInternetProtocolStartExIndex, ProtocolSinkWrap::OnStartEx) | |
54 END_VTABLE_PATCHES() | 66 END_VTABLE_PATCHES() |
55 | 67 |
56 // | 68 // |
57 // ProtocolSinkWrap implementation | 69 // ProtocolSinkWrap implementation |
58 // | |
59 | 70 |
60 // Static map initialization | 71 // Static map initialization |
61 ProtocolSinkWrap::ProtocolSinkMap ProtocolSinkWrap::sink_map_; | 72 ProtData::ProtocolDataMap ProtData::datamap_; |
62 CComAutoCriticalSection ProtocolSinkWrap::sink_map_lock_; | 73 Lock ProtData::datamap_lock_; |
63 | 74 |
64 ProtocolSinkWrap::ProtocolSinkWrap() | 75 ProtocolSinkWrap::ProtocolSinkWrap() { |
65 : protocol_(NULL), renderer_type_(UNDETERMINED), | 76 DLOG(INFO) << __FUNCTION__ << StringPrintf(" 0x%08X", this); |
66 buffer_size_(0), buffer_pos_(0), is_saved_result_(false), | |
67 result_code_(0), result_error_(0), report_data_recursiveness_(0), | |
68 determining_renderer_type_(false) { | |
69 memset(buffer_, 0, arraysize(buffer_)); | |
70 } | 77 } |
71 | 78 |
72 ProtocolSinkWrap::~ProtocolSinkWrap() { | 79 ProtocolSinkWrap::~ProtocolSinkWrap() { |
73 // This object may be destroyed before Initialize is called. | 80 DLOG(INFO) << __FUNCTION__ << StringPrintf(" 0x%08X", this); |
74 if (protocol_ != NULL) { | 81 } |
75 CComCritSecLock<CComAutoCriticalSection> lock(sink_map_lock_); | 82 |
76 DCHECK(sink_map_.end() != sink_map_.find(protocol_)); | 83 ScopedComPtr<IInternetProtocolSink> ProtocolSinkWrap::CreateNewSink( |
77 sink_map_.erase(protocol_); | 84 IInternetProtocolSink* sink, ProtData* data) { |
78 protocol_ = NULL; | 85 DCHECK(sink != NULL); |
79 } | 86 DCHECK(data != NULL); |
80 DLOG(INFO) << "ProtocolSinkWrap: active sinks: " << sink_map_.size(); | 87 CComObject<ProtocolSinkWrap>* new_sink = NULL; |
81 } | 88 CComObject<ProtocolSinkWrap>::CreateInstance(&new_sink); |
82 | 89 new_sink->delegate_ = sink; |
83 bool ProtocolSinkWrap::PatchProtocolHandlers() { | 90 new_sink->prot_data_ = data; |
84 HRESULT hr = PatchProtocolMethods(CLSID_HttpProtocol, | 91 return ScopedComPtr<IInternetProtocolSink>(new_sink); |
85 IInternetProtocol_PatchInfo, | |
86 IInternetProtocolEx_PatchInfo); | |
87 if (FAILED(hr)) { | |
88 NOTREACHED() << "Failed to patch IInternetProtocol interface." | |
89 << " Error: " << hr; | |
90 return false; | |
91 } | |
92 | |
93 hr = PatchProtocolMethods(CLSID_HttpSProtocol, | |
94 IInternetProtocolSecure_PatchInfo, | |
95 IInternetProtocolExSecure_PatchInfo); | |
96 if (FAILED(hr)) { | |
97 NOTREACHED() << "Failed to patch IInternetProtocol secure interface." | |
98 << " Error: " << hr; | |
99 return false; | |
100 } | |
101 | |
102 return true; | |
103 } | |
104 | |
105 void ProtocolSinkWrap::UnpatchProtocolHandlers() { | |
106 vtable_patch::UnpatchInterfaceMethods(IInternetProtocol_PatchInfo); | |
107 vtable_patch::UnpatchInterfaceMethods(IInternetProtocolEx_PatchInfo); | |
108 vtable_patch::UnpatchInterfaceMethods(IInternetProtocolSecure_PatchInfo); | |
109 vtable_patch::UnpatchInterfaceMethods(IInternetProtocolExSecure_PatchInfo); | |
110 } | |
111 | |
112 HRESULT ProtocolSinkWrap::CreateProtocolHandlerInstance( | |
113 const CLSID& clsid, IInternetProtocol** protocol) { | |
114 if (!protocol) { | |
115 return E_INVALIDARG; | |
116 } | |
117 | |
118 HMODULE module = ::GetModuleHandle(kUrlMonDllName); | |
119 if (!module) { | |
120 NOTREACHED() << "urlmon is not yet loaded. Error: " << GetLastError(); | |
121 return E_FAIL; | |
122 } | |
123 | |
124 typedef HRESULT (WINAPI* DllGetClassObject_Fn)(REFCLSID, REFIID, LPVOID*); | |
125 DllGetClassObject_Fn fn = reinterpret_cast<DllGetClassObject_Fn>( | |
126 ::GetProcAddress(module, "DllGetClassObject")); | |
127 if (!fn) { | |
128 NOTREACHED() << "DllGetClassObject not found in urlmon.dll"; | |
129 return E_FAIL; | |
130 } | |
131 | |
132 ScopedComPtr<IClassFactory> protocol_class_factory; | |
133 HRESULT hr = fn(clsid, IID_IClassFactory, | |
134 reinterpret_cast<LPVOID*>(protocol_class_factory.Receive())); | |
135 if (FAILED(hr)) { | |
136 NOTREACHED() << "DllGetclassObject failed. Error: " << hr; | |
137 return hr; | |
138 } | |
139 | |
140 ScopedComPtr<IInternetProtocol> handler_instance; | |
141 hr = protocol_class_factory->CreateInstance(NULL, IID_IInternetProtocol, | |
142 reinterpret_cast<void**>(handler_instance.Receive())); | |
143 if (FAILED(hr)) { | |
144 NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol." | |
145 << " Error: " << hr; | |
146 } else { | |
147 *protocol = handler_instance.Detach(); | |
148 } | |
149 | |
150 return hr; | |
151 } | |
152 | |
153 HRESULT ProtocolSinkWrap::PatchProtocolMethods( | |
154 const CLSID& clsid_protocol, | |
155 vtable_patch::MethodPatchInfo* protocol_patch_info, | |
156 vtable_patch::MethodPatchInfo* protocol_ex_patch_info) { | |
157 if (!protocol_patch_info || !protocol_ex_patch_info) { | |
158 return E_INVALIDARG; | |
159 } | |
160 | |
161 ScopedComPtr<IInternetProtocol> http_protocol; | |
162 HRESULT hr = CreateProtocolHandlerInstance(clsid_protocol, | |
163 http_protocol.Receive()); | |
164 if (FAILED(hr)) { | |
165 NOTREACHED() << "ClassFactory::CreateInstance failed for InternetProtocol." | |
166 << " Error: " << hr; | |
167 return false; | |
168 } | |
169 | |
170 ScopedComPtr<IInternetProtocolEx> ipex; | |
171 ipex.QueryFrom(http_protocol); | |
172 if (ipex) { | |
173 hr = vtable_patch::PatchInterfaceMethods(ipex, protocol_ex_patch_info); | |
174 } else { | |
175 hr = vtable_patch::PatchInterfaceMethods(http_protocol, | |
176 protocol_patch_info); | |
177 } | |
178 return hr; | |
179 } | |
180 | |
181 // IInternetProtocol/Ex method implementation. | |
182 HRESULT ProtocolSinkWrap::OnStart(InternetProtocol_Start_Fn orig_start, | |
183 IInternetProtocol* protocol, LPCWSTR url, IInternetProtocolSink* prot_sink, | |
184 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { | |
185 DCHECK(orig_start); | |
186 DLOG_IF(INFO, url != NULL) << "OnStart: " << url; | |
187 | |
188 ScopedComPtr<IInternetProtocolSink> sink_to_use(MaybeWrapSink(protocol, | |
189 prot_sink, url)); | |
190 return orig_start(protocol, url, sink_to_use, bind_info, flags, reserved); | |
191 } | |
192 | |
193 HRESULT ProtocolSinkWrap::OnStartEx(InternetProtocol_StartEx_Fn orig_start_ex, | |
194 IInternetProtocolEx* protocol, IUri* uri, IInternetProtocolSink* prot_sink, | |
195 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { | |
196 DCHECK(orig_start_ex); | |
197 | |
198 ScopedBstr url; | |
199 uri->GetPropertyBSTR(Uri_PROPERTY_ABSOLUTE_URI, url.Receive(), 0); | |
200 DLOG_IF(INFO, url != NULL) << "OnStartEx: " << url; | |
201 | |
202 ScopedComPtr<IInternetProtocolSink> sink_to_use(MaybeWrapSink(protocol, | |
203 prot_sink, url)); | |
204 return orig_start_ex(protocol, uri, sink_to_use, bind_info, flags, reserved); | |
205 } | |
206 | |
207 HRESULT ProtocolSinkWrap::OnRead(InternetProtocol_Read_Fn orig_read, | |
208 IInternetProtocol* protocol, void* buffer, ULONG size, ULONG* size_read) { | |
209 DCHECK(orig_read); | |
210 | |
211 scoped_refptr<ProtocolSinkWrap> instance = | |
212 ProtocolSinkWrap::InstanceFromProtocol(protocol); | |
213 HRESULT hr; | |
214 if (instance) { | |
215 DCHECK(instance->protocol_ == protocol); | |
216 hr = instance->OnReadImpl(buffer, size, size_read, orig_read); | |
217 } else { | |
218 hr = orig_read(protocol, buffer, size, size_read); | |
219 } | |
220 | |
221 return hr; | |
222 } | |
223 | |
224 bool ProtocolSinkWrap::Initialize(IInternetProtocol* protocol, | |
225 IInternetProtocolSink* original_sink, const wchar_t* url) { | |
226 DCHECK(original_sink); | |
227 delegate_ = original_sink; | |
228 protocol_ = protocol; | |
229 if (url) | |
230 url_ = url; | |
231 | |
232 CComCritSecLock<CComAutoCriticalSection> lock(sink_map_lock_); | |
233 DCHECK(sink_map_.end() == sink_map_.find(protocol)); | |
234 sink_map_[protocol] = this; | |
235 DLOG(INFO) << "ProtocolSinkWrap: active sinks: " << sink_map_.size(); | |
236 return true; | |
237 } | 92 } |
238 | 93 |
239 // IInternetProtocolSink methods | 94 // IInternetProtocolSink methods |
240 STDMETHODIMP ProtocolSinkWrap::Switch(PROTOCOLDATA* protocol_data) { | 95 STDMETHODIMP ProtocolSinkWrap::Switch(PROTOCOLDATA* protocol_data) { |
241 HRESULT hr = E_FAIL; | 96 HRESULT hr = E_FAIL; |
242 if (delegate_) | 97 if (delegate_) |
243 hr = delegate_->Switch(protocol_data); | 98 hr = delegate_->Switch(protocol_data); |
244 return hr; | 99 return hr; |
245 } | 100 } |
246 | 101 |
247 STDMETHODIMP ProtocolSinkWrap::ReportProgress(ULONG status_code, | 102 STDMETHODIMP ProtocolSinkWrap::ReportProgress(ULONG status_code, |
248 LPCWSTR status_text) { | 103 LPCWSTR status_text) { |
249 DLOG(INFO) << "ProtocolSinkWrap::ReportProgress: Code:" << status_code << | 104 DLOG(INFO) << "ProtocolSinkWrap::ReportProgress: " |
250 " Text: " << (status_text ? status_text : L""); | 105 << BindStatus2Str(status_code) |
251 if (!delegate_) { | 106 << " Status: " << (status_text ? status_text : L""); |
252 return E_FAIL; | 107 |
253 } | 108 HRESULT hr = prot_data_->ReportProgress(delegate_, status_code, status_text); |
254 if ((BINDSTATUS_MIMETYPEAVAILABLE == status_code) || | |
255 (BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE == status_code)) { | |
256 // If we have a MIMETYPE and that MIMETYPE is not "text/html". we don't | |
257 // want to do anything with this. | |
258 if (status_text) { | |
259 size_t status_text_length = lstrlenW(status_text); | |
260 const wchar_t* status_text_end = status_text + std::min( | |
261 status_text_length, arraysize(kTextHtmlMimeType) - 1); | |
262 if (!LowerCaseEqualsASCII(status_text, status_text_end, | |
263 kTextHtmlMimeType)) { | |
264 renderer_type_ = OTHER; | |
265 } | |
266 } | |
267 } | |
268 | |
269 HRESULT hr = S_OK; | |
270 if (delegate_ && renderer_type_ != CHROME) { | |
271 hr = delegate_->ReportProgress(status_code, status_text); | |
272 } | |
273 return hr; | 109 return hr; |
274 } | 110 } |
275 | 111 |
276 STDMETHODIMP ProtocolSinkWrap::ReportData(DWORD flags, ULONG progress, | 112 STDMETHODIMP ProtocolSinkWrap::ReportData(DWORD flags, ULONG progress, |
277 ULONG max_progress) { | 113 ULONG max_progress) { |
278 DCHECK(protocol_); | |
279 DCHECK(delegate_); | 114 DCHECK(delegate_); |
280 DLOG(INFO) << "ProtocolSinkWrap::ReportData: flags: " << flags << | 115 DLOG(INFO) << "ProtocolSinkWrap::ReportData: " << Bscf2Str(flags) << |
281 " progress: " << progress << " progress_max: " << max_progress; | 116 " progress: " << progress << " progress_max: " << max_progress; |
282 | 117 |
283 scoped_refptr<ProtocolSinkWrap> self_ref(this); | 118 HRESULT hr = prot_data_->ReportData(delegate_, flags, progress, max_progress); |
284 | |
285 // Maintain a stack depth to make a determination. ReportData is called | |
286 // recursively in IE8. If the request can be served in a single Read, the | |
287 // situation ends up like this: | |
288 // orig_prot | |
289 // |--> ProtocolSinkWrap::ReportData (BSCF_FIRSTDATANOTIFICATION) | |
290 // |--> orig_prot->Read(...) - 1st read - S_OK and data | |
291 // |--> ProtocolSinkWrap::ReportData (BSCF_LASTDATANOTIFICATION) | |
292 // |--> orig_prot->Read(...) - 2nd read S_FALSE, 0 bytes | |
293 // | |
294 // Inner call returns S_FALSE and no data. We try to make a determination | |
295 // of render type then and incorrectly set it to 'OTHER' as we don't have | |
296 // any data yet. However, we can make a determination in the context of | |
297 // outer ReportData since the first read will return S_OK with data. Then | |
298 // the next Read in the loop will return S_FALSE and we will enter the | |
299 // determination logic. | |
300 | |
301 // NOTE: We use the report_data_recursiveness_ variable to detect situations | |
302 // in which calls to ReportData are re-entrant (such as when the entire | |
303 // contents of a page fit inside a single packet). In these cases, we | |
304 // don't care about re-entrant calls beyond the second, and so we compare | |
305 // report_data_recursiveness_ inside the while loop, making sure we skip | |
306 // what would otherwise be spurious calls to ReportProgress(). | |
307 report_data_recursiveness_++; | |
308 | |
309 HRESULT hr = S_OK; | |
310 if (is_undetermined()) { | |
311 CheckAndReportChromeMimeTypeForRequest(); | |
312 } | |
313 | |
314 // we call original only if the renderer type is other | |
315 if (renderer_type() == OTHER) { | |
316 hr = delegate_->ReportData(flags, progress, max_progress); | |
317 | |
318 if (is_saved_result_) { | |
319 is_saved_result_ = false; | |
320 delegate_->ReportResult(result_code_, result_error_, | |
321 result_text_.c_str()); | |
322 } | |
323 } | |
324 | |
325 report_data_recursiveness_--; | |
326 return hr; | 119 return hr; |
327 } | 120 } |
328 | 121 |
329 STDMETHODIMP ProtocolSinkWrap::ReportResult(HRESULT result, DWORD error, | 122 STDMETHODIMP ProtocolSinkWrap::ReportResult(HRESULT result, DWORD error, |
330 LPCWSTR result_text) { | 123 LPCWSTR result_text) { |
331 DLOG(INFO) << "ProtocolSinkWrap::ReportResult: result: " << result << | 124 DLOG(INFO) << "ProtocolSinkWrap::ReportResult: result: " << result << |
332 " error: " << error << " Text: " << (result_text ? result_text : L""); | 125 " error: " << error << " Text: " << (result_text ? result_text : L""); |
333 | 126 DCHECK_NE(UNDETERMINED, prot_data_->renderer_type()); |
334 // If this request failed, we don't want to have anything to do with this. | |
335 if (FAILED(result)) | |
336 renderer_type_ = OTHER; | |
337 | |
338 // if we are still not sure about the renderer type, cache the result, | |
339 // othewise urlmon will get confused about getting reported about a | |
340 // success result for which it never received any data. | |
341 if (is_undetermined()) { | |
342 is_saved_result_ = true; | |
343 result_code_ = result; | |
344 result_error_ = error; | |
345 if (result_text) | |
346 result_text_ = result_text; | |
347 return S_OK; | |
348 } | |
349 | 127 |
350 HRESULT hr = E_FAIL; | 128 HRESULT hr = E_FAIL; |
351 if (delegate_) | 129 if (delegate_) |
352 hr = delegate_->ReportResult(result, error, result_text); | 130 hr = delegate_->ReportResult(result, error, result_text); |
353 | 131 |
354 return hr; | 132 return hr; |
355 } | 133 } |
356 | 134 |
357 // IInternetBindInfoEx | 135 |
358 STDMETHODIMP ProtocolSinkWrap::GetBindInfo(DWORD* flags, | 136 // Helpers |
359 BINDINFO* bind_info_ret) { | 137 ScopedComPtr<IBindCtx> BindCtxFromIBindInfo(IInternetBindInfo* bind_info) { |
360 ScopedComPtr<IInternetBindInfo> bind_info; | 138 LPOLESTR bind_ctx_string = NULL; |
361 HRESULT hr = bind_info.QueryFrom(delegate_); | 139 ULONG count; |
362 if (bind_info) | 140 ScopedComPtr<IBindCtx> bind_ctx; |
363 hr = bind_info->GetBindInfo(flags, bind_info_ret); | 141 bind_info->GetBindString(BINDSTRING_PTR_BIND_CONTEXT, &bind_ctx_string, 1, |
364 return hr; | 142 &count); |
365 } | 143 if (bind_ctx_string) { |
366 | 144 IBindCtx* pbc = reinterpret_cast<IBindCtx*>(StringToInt(bind_ctx_string)); |
367 STDMETHODIMP ProtocolSinkWrap::GetBindString(ULONG string_type, | 145 bind_ctx.Attach(pbc); |
368 LPOLESTR* string_array, ULONG array_size, ULONG* size_returned) { | 146 CoTaskMemFree(bind_ctx_string); |
369 ScopedComPtr<IInternetBindInfo> bind_info; | 147 } |
370 HRESULT hr = bind_info.QueryFrom(delegate_); | 148 |
371 if (bind_info) | 149 return bind_ctx; |
372 hr = bind_info->GetBindString(string_type, string_array, | 150 } |
373 array_size, size_returned); | 151 |
374 return hr; | 152 bool ShouldWrapSink(IInternetProtocolSink* sink, const wchar_t* url) { |
375 } | 153 // TODO(stoyan): check the url scheme for http/https. |
376 | 154 ScopedComPtr<IHttpNegotiate> http_negotiate; |
377 STDMETHODIMP ProtocolSinkWrap::GetBindInfoEx(DWORD* flags, BINDINFO* bind_info, | 155 HRESULT hr = DoQueryService(GUID_NULL, sink, http_negotiate.Receive()); |
378 DWORD* bindf2, DWORD* reserved) { | 156 if (http_negotiate && !IsSubFrameRequest(http_negotiate)) |
379 ScopedComPtr<IInternetBindInfoEx> bind_info_ex; | 157 return true; |
380 HRESULT hr = bind_info_ex.QueryFrom(delegate_); | 158 |
381 if (bind_info_ex) | 159 return false; |
382 hr = bind_info_ex->GetBindInfoEx(flags, bind_info, bindf2, reserved); | 160 } |
383 return hr; | 161 |
384 } | 162 // High level helpers |
385 | 163 bool IsCFRequest(IBindCtx* pbc) { |
386 // IServiceProvider | 164 ScopedComPtr<BindContextInfo> info; |
387 STDMETHODIMP ProtocolSinkWrap::QueryService(REFGUID service_guid, | 165 BindContextInfo::FromBindContext(pbc, info.Receive()); |
388 REFIID riid, void** service) { | 166 DCHECK(info); |
389 ScopedComPtr<IServiceProvider> service_provider; | 167 if (info && info->chrome_request()) |
390 HRESULT hr = service_provider.QueryFrom(delegate_); | 168 return true; |
391 if (service_provider) | 169 |
392 hr = service_provider->QueryService(service_guid, riid, service); | 170 return false; |
393 return hr; | 171 } |
394 } | 172 |
395 | 173 void PutProtData(IBindCtx* pbc, ProtData* data) { |
396 // IAuthenticate | 174 ScopedComPtr<BindContextInfo> info; |
397 STDMETHODIMP ProtocolSinkWrap::Authenticate(HWND* window, | 175 BindContextInfo::FromBindContext(pbc, info.Receive()); |
398 LPWSTR* user_name, LPWSTR* password) { | 176 if (info) |
399 ScopedComPtr<IAuthenticate> authenticate; | 177 info->set_prot_data(data); |
400 HRESULT hr = authenticate.QueryFrom(delegate_); | 178 } |
401 if (authenticate) | 179 |
402 hr = authenticate->Authenticate(window, user_name, password); | 180 bool IsTextHtml(const wchar_t* status_text) { |
403 return hr; | 181 if (!status_text) |
404 } | 182 return false; |
405 | 183 size_t status_text_length = lstrlenW(status_text); |
406 // IInternetProtocolEx | 184 const wchar_t* status_text_end = status_text + |
407 STDMETHODIMP ProtocolSinkWrap::Start(LPCWSTR url, | 185 std::min(status_text_length, arraysize(kTextHtmlMimeType) - 1); |
408 IInternetProtocolSink *protocol_sink, IInternetBindInfo* bind_info, | 186 bool is_text_html = LowerCaseEqualsASCII(status_text, status_text_end, |
409 DWORD flags, HANDLE_PTR reserved) { | 187 kTextHtmlMimeType); |
410 ScopedComPtr<IInternetProtocolRoot> protocol; | 188 return is_text_html; |
411 HRESULT hr = protocol.QueryFrom(delegate_); | 189 } |
412 if (protocol) | 190 |
413 hr = protocol->Start(url, protocol_sink, bind_info, flags, reserved); | 191 RendererType DetermineRendererType(void* buffer, DWORD size, bool last_chance) { |
414 return hr; | 192 RendererType type = UNDETERMINED; |
415 } | 193 if (last_chance) |
416 | 194 type = OTHER; |
417 STDMETHODIMP ProtocolSinkWrap::Continue(PROTOCOLDATA* protocol_data) { | 195 |
418 ScopedComPtr<IInternetProtocolRoot> protocol; | 196 std::wstring html_contents; |
419 HRESULT hr = protocol.QueryFrom(delegate_); | 197 // TODO(joshia): detect and handle different content encodings |
420 if (protocol) | 198 UTF8ToWide(reinterpret_cast<char*>(buffer), size, &html_contents); |
421 hr = protocol->Continue(protocol_data); | 199 |
422 return hr; | 200 // Note that document_contents_ may have NULL characters in it. While |
423 } | 201 // browsers may handle this properly, we don't and will stop scanning |
424 | 202 // for the XUACompat content value if we encounter one. |
425 STDMETHODIMP ProtocolSinkWrap::Abort(HRESULT reason, DWORD options) { | 203 std::wstring xua_compat_content; |
426 ScopedComPtr<IInternetProtocolRoot> protocol; | 204 UtilGetXUACompatContentValue(html_contents, &xua_compat_content); |
427 HRESULT hr = protocol.QueryFrom(delegate_); | 205 if (StrStrI(xua_compat_content.c_str(), kChromeContentPrefix)) { |
428 if (protocol) | 206 type = CHROME; |
429 hr = protocol->Abort(reason, options); | 207 } |
430 return hr; | 208 |
431 } | 209 return type; |
432 | 210 } |
433 STDMETHODIMP ProtocolSinkWrap::Terminate(DWORD options) { | 211 |
434 ScopedComPtr<IInternetProtocolRoot> protocol; | 212 // ProtData |
435 HRESULT hr = protocol.QueryFrom(delegate_); | 213 ProtData::ProtData(IInternetProtocol* protocol, |
436 if (protocol) | 214 InternetProtocol_Read_Fn read_fun, const wchar_t* url) |
437 hr = protocol->Terminate(options); | 215 : has_suggested_mime_type_(false), has_server_mime_type_(false), |
438 return hr; | 216 report_data_received_(false), buffer_size_(0), buffer_pos_(0), |
439 } | 217 renderer_type_(UNDETERMINED), protocol_(protocol), read_fun_(read_fun), |
440 | 218 url_(url) { |
441 STDMETHODIMP ProtocolSinkWrap::Suspend() { | 219 memset(buffer_, 0, arraysize(buffer_)); |
442 ScopedComPtr<IInternetProtocolRoot> protocol; | 220 DLOG(INFO) << __FUNCTION__ << " " << this; |
443 HRESULT hr = protocol.QueryFrom(delegate_); | 221 |
444 if (protocol) | 222 // Add to map. |
445 hr = protocol->Suspend(); | 223 AutoLock lock(datamap_lock_); |
446 return hr; | 224 DCHECK(datamap_.end() == datamap_.find(protocol_)); |
447 } | 225 datamap_[protocol] = this; |
448 | 226 } |
449 STDMETHODIMP ProtocolSinkWrap::Resume() { | 227 |
450 ScopedComPtr<IInternetProtocolRoot> protocol; | 228 ProtData::~ProtData() { |
451 HRESULT hr = protocol.QueryFrom(delegate_); | 229 DLOG(INFO) << __FUNCTION__ << " " << this; |
452 if (protocol) | 230 |
453 hr = protocol->Resume(); | 231 // Remove from map. |
454 return hr; | 232 AutoLock lock(datamap_lock_); |
455 } | 233 DCHECK(datamap_.end() != datamap_.find(protocol_)); |
456 | 234 datamap_.erase(protocol_); |
457 STDMETHODIMP ProtocolSinkWrap::Read(void *buffer, ULONG size, | 235 } |
458 ULONG* size_read) { | 236 |
459 ScopedComPtr<IInternetProtocol> protocol; | 237 HRESULT ProtData::Read(void* buffer, ULONG size, ULONG* size_read) { |
460 HRESULT hr = protocol.QueryFrom(delegate_); | 238 if (renderer_type_ == UNDETERMINED) { |
461 if (protocol) | 239 return E_PENDING; |
462 hr = protocol->Read(buffer, size, size_read); | 240 } |
463 return hr; | 241 |
464 } | 242 const ULONG bytes_available = buffer_size_ - buffer_pos_; |
465 | 243 const ULONG bytes_to_copy = std::min(bytes_available, size); |
466 STDMETHODIMP ProtocolSinkWrap::Seek(LARGE_INTEGER move, DWORD origin, | 244 if (bytes_to_copy) { |
467 ULARGE_INTEGER* new_pos) { | 245 // Copy from the local buffer. |
468 ScopedComPtr<IInternetProtocol> protocol; | 246 memcpy(buffer, buffer_ + buffer_pos_, bytes_to_copy); |
469 HRESULT hr = protocol.QueryFrom(delegate_); | 247 *size_read = bytes_to_copy; |
470 if (protocol) | 248 buffer_pos_ += bytes_to_copy; |
471 hr = protocol->Seek(move, origin, new_pos); | 249 |
472 return hr; | 250 HRESULT hr = S_OK; |
473 } | 251 ULONG new_data = 0; |
474 | 252 if (size > bytes_available) { |
475 STDMETHODIMP ProtocolSinkWrap::LockRequest(DWORD options) { | 253 // User buffer is greater than what we have. |
476 ScopedComPtr<IInternetProtocol> protocol; | 254 buffer = reinterpret_cast<uint8*>(buffer) + bytes_to_copy; |
477 HRESULT hr = protocol.QueryFrom(delegate_); | 255 size -= bytes_to_copy; |
478 if (protocol) | 256 hr = read_fun_(protocol_, buffer, size, &new_data); |
479 hr = protocol->LockRequest(options); | 257 } |
480 return hr; | 258 |
481 } | 259 if (size_read) |
482 | 260 *size_read = bytes_to_copy + new_data; |
483 STDMETHODIMP ProtocolSinkWrap::UnlockRequest() { | 261 return hr; |
484 ScopedComPtr<IInternetProtocol> protocol; | 262 } |
485 HRESULT hr = protocol.QueryFrom(delegate_); | 263 |
486 if (protocol) | 264 return read_fun_(protocol_, buffer, size, size_read); |
487 hr = protocol->UnlockRequest(); | 265 } |
488 return hr; | 266 |
489 } | 267 |
490 | 268 HRESULT ProtData::ReportProgress(IInternetProtocolSink* delegate, |
491 STDMETHODIMP ProtocolSinkWrap::StartEx(IUri* uri, | 269 ULONG status_code, LPCWSTR status_text) { |
492 IInternetProtocolSink* protocol_sink, IInternetBindInfo* bind_info, | 270 switch (status_code) { |
493 DWORD flags, HANDLE_PTR reserved) { | 271 case BINDSTATUS_DIRECTBIND: |
494 ScopedComPtr<IInternetProtocolEx> protocol; | 272 renderer_type_ = OTHER; |
495 HRESULT hr = protocol.QueryFrom(delegate_); | 273 break; |
496 if (protocol) | 274 |
497 hr = protocol->StartEx(uri, protocol_sink, bind_info, flags, reserved); | 275 case BINDSTATUS_REDIRECTING: |
498 return hr; | 276 url_.empty(); |
499 } | 277 if (status_text) |
500 | 278 url_ = status_text; |
501 // IInternetPriority | 279 break; |
502 STDMETHODIMP ProtocolSinkWrap::SetPriority(LONG priority) { | 280 |
503 ScopedComPtr<IInternetPriority> internet_priority; | 281 case BINDSTATUS_SERVER_MIMETYPEAVAILABLE: |
504 HRESULT hr = internet_priority.QueryFrom(delegate_); | 282 has_server_mime_type_ = true; |
505 if (internet_priority) | 283 SaveSuggestedMimeType(status_text); |
506 hr = internet_priority->SetPriority(priority); | 284 return S_OK; |
507 return hr; | 285 |
508 } | 286 // TODO(stoyan): BINDSTATUS_RAWMIMETYPE |
509 | 287 case BINDSTATUS_MIMETYPEAVAILABLE: |
510 STDMETHODIMP ProtocolSinkWrap::GetPriority(LONG* priority) { | 288 case BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE: |
511 ScopedComPtr<IInternetPriority> internet_priority; | 289 SaveSuggestedMimeType(status_text); |
512 HRESULT hr = internet_priority.QueryFrom(delegate_); | 290 return S_OK; |
513 if (internet_priority) | 291 } |
514 hr = internet_priority->GetPriority(priority); | 292 |
515 return hr; | 293 return delegate->ReportProgress(status_code, status_text); |
516 } | 294 } |
517 | 295 |
518 // IWrappedProtocol | 296 HRESULT ProtData::ReportData(IInternetProtocolSink* delegate, |
519 STDMETHODIMP ProtocolSinkWrap::GetWrapperCode(LONG *code, DWORD_PTR reserved) { | 297 DWORD flags, ULONG progress, ULONG max_progress) { |
520 ScopedComPtr<IWrappedProtocol> wrapped_protocol; | 298 if (renderer_type_ != UNDETERMINED) { |
521 HRESULT hr = wrapped_protocol.QueryFrom(delegate_); | 299 return delegate->ReportData(flags, progress, max_progress); |
522 if (wrapped_protocol) | 300 } |
523 hr = wrapped_protocol->GetWrapperCode(code, reserved); | 301 |
524 return hr; | 302 // Do these checks only once. |
525 } | 303 if (!report_data_received_) { |
526 | 304 report_data_received_ = true; |
527 | 305 |
528 // public IUriContainer | 306 DLOG_IF(INFO, (flags & BSCF_FIRSTDATANOTIFICATION) == 0) << |
529 STDMETHODIMP ProtocolSinkWrap::GetIUri(IUri** uri) { | 307 "BUGBUG: BSCF_FIRSTDATANOTIFICATION is not set properly!"; |
530 ScopedComPtr<IUriContainer> uri_container; | 308 |
531 HRESULT hr = uri_container.QueryFrom(delegate_); | 309 |
532 if (uri_container) | 310 // We check here, instead in ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE) |
533 hr = uri_container->GetIUri(uri); | 311 // to be safe when following multiple redirects.? |
534 return hr; | 312 if (!IsTextHtml(suggested_mime_type_)) { |
535 } | 313 renderer_type_ = OTHER; |
536 | 314 FireSugestedMimeType(delegate); |
537 // Protected helpers | 315 return delegate->ReportData(flags, progress, max_progress); |
538 | 316 } |
539 void ProtocolSinkWrap::DetermineRendererType() { | 317 |
540 if (is_undetermined()) { | 318 if (!url_.empty() && IsOptInUrl(url_.c_str())) { |
541 if (IsOptInUrl(url_.c_str())) { | 319 // TODO(stoyan): We may attempt to remove ourselves from the bind context. |
542 renderer_type_ = CHROME; | 320 renderer_type_ = CHROME; |
543 } else { | 321 delegate->ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE, kChromeMimeType); |
544 std::wstring xua_compat_content; | 322 return delegate->ReportData(flags, progress, max_progress); |
545 // Note that document_contents_ may have NULL characters in it. While | |
546 // browsers may handle this properly, we don't and will stop scanning for | |
547 // the XUACompat content value if we encounter one. | |
548 DCHECK(buffer_size_ < arraysize(buffer_)); | |
549 buffer_[buffer_size_] = 0; | |
550 std::wstring html_contents; | |
551 // TODO(joshia): detect and handle different content encodings | |
552 UTF8ToWide(buffer_, buffer_size_, &html_contents); | |
553 UtilGetXUACompatContentValue(html_contents, &xua_compat_content); | |
554 if (StrStrI(xua_compat_content.c_str(), kChromeContentPrefix)) { | |
555 renderer_type_ = CHROME; | |
556 } else { | |
557 renderer_type_ = OTHER; | |
558 } | |
559 } | 323 } |
560 } | 324 } |
561 } | 325 |
562 | 326 HRESULT hr = FillBuffer(); |
563 HRESULT ProtocolSinkWrap::CheckAndReportChromeMimeTypeForRequest() { | 327 |
564 if (!is_undetermined()) | 328 bool last_chance = false; |
| 329 if (hr == S_OK || hr == S_FALSE) { |
| 330 last_chance = true; |
| 331 } |
| 332 |
| 333 renderer_type_ = DetermineRendererType(buffer_, buffer_size_, last_chance); |
| 334 |
| 335 if (renderer_type_ == UNDETERMINED) { |
| 336 // do not report anything, we need more data. |
565 return S_OK; | 337 return S_OK; |
566 | 338 } |
567 // This function could get invoked recursively in the context of | 339 |
568 // IInternetProtocol::Read. Check for the same and bail. | 340 if (renderer_type_ == CHROME) { |
569 if (determining_renderer_type_) | 341 DLOG(INFO) << "Forwarding BINDSTATUS_MIMETYPEAVAILABLE " |
570 return S_OK; | 342 << kChromeMimeType; |
571 | 343 delegate->ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE, kChromeMimeType); |
572 determining_renderer_type_ = true; | 344 } |
573 | 345 |
| 346 if (renderer_type_ == OTHER) { |
| 347 FireSugestedMimeType(delegate); |
| 348 } |
| 349 |
| 350 // This is the first data notification we forward. |
| 351 flags |= BSCF_FIRSTDATANOTIFICATION; |
| 352 |
| 353 if (hr == S_FALSE) { |
| 354 flags |= (BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE); |
| 355 } |
| 356 |
| 357 return delegate->ReportData(flags, progress, max_progress); |
| 358 } |
| 359 |
| 360 void ProtData::UpdateUrl(const wchar_t* url) { |
| 361 url_ = url; |
| 362 } |
| 363 |
| 364 // S_FALSE - EOF |
| 365 // S_OK - buffer fully filled |
| 366 // E_PENDING - some data added to buffer, but buffer is not yet full |
| 367 // E_XXXX - some other error. |
| 368 HRESULT ProtData::FillBuffer() { |
574 HRESULT hr_read = S_OK; | 369 HRESULT hr_read = S_OK; |
575 while (hr_read == S_OK) { | 370 |
| 371 while ((hr_read == S_OK) && (buffer_size_ < kMaxContentSniffLength)) { |
576 ULONG size_read = 0; | 372 ULONG size_read = 0; |
577 hr_read = protocol_->Read(buffer_ + buffer_size_, | 373 hr_read = read_fun_(protocol_, buffer_ + buffer_size_, |
578 kMaxContentSniffLength - buffer_size_, &size_read); | 374 kMaxContentSniffLength - buffer_size_, &size_read); |
579 buffer_size_ += size_read; | 375 buffer_size_ += size_read; |
580 | 376 } |
581 // Attempt to determine the renderer type if we have received | 377 |
582 // sufficient data. Do not attempt this when we are called recursively. | |
583 if (report_data_recursiveness_ < 2 && (S_FALSE == hr_read) || | |
584 (buffer_size_ >= kMaxContentSniffLength)) { | |
585 DetermineRendererType(); | |
586 if (renderer_type() == CHROME) { | |
587 // Workaround for IE 8 and "nosniff". See: | |
588 // http://blogs.msdn.com/ie/archive/2008/09/02/ie8-security-part-vi-beta
-2-update.aspx | |
589 delegate_->ReportProgress( | |
590 BINDSTATUS_SERVER_MIMETYPEAVAILABLE, kChromeMimeType); | |
591 // For IE < 8. | |
592 delegate_->ReportProgress( | |
593 BINDSTATUS_MIMETYPEAVAILABLE, kChromeMimeType); | |
594 | |
595 delegate_->ReportProgress( | |
596 BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, kChromeMimeType); | |
597 | |
598 delegate_->ReportData( | |
599 BSCF_FIRSTDATANOTIFICATION, 0, 0); | |
600 | |
601 delegate_->ReportData( | |
602 BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE, 0, 0); | |
603 } | |
604 break; | |
605 } | |
606 } | |
607 | |
608 determining_renderer_type_ = false; | |
609 return hr_read; | 378 return hr_read; |
610 } | 379 } |
611 | 380 |
612 HRESULT ProtocolSinkWrap::OnReadImpl(void* buffer, ULONG size, ULONG* size_read, | 381 void ProtData::SaveSuggestedMimeType(LPCWSTR status_text) { |
613 InternetProtocol_Read_Fn orig_read) { | 382 has_suggested_mime_type_ = true; |
614 // We want to switch the renderer to chrome, we cannot return any | 383 suggested_mime_type_.Allocate(status_text); |
615 // data now. | 384 } |
616 if (CHROME == renderer_type()) | 385 |
617 return S_FALSE; | 386 void ProtData::FireSugestedMimeType(IInternetProtocolSink* delegate) { |
618 | 387 if (has_server_mime_type_) { |
619 // Serve data from our buffer first. | 388 DLOG(INFO) << "Forwarding BINDSTATUS_SERVER_MIMETYPEAVAILABLE " |
620 if (OTHER == renderer_type()) { | 389 << suggested_mime_type_; |
621 const ULONG bytes_to_copy = std::min(buffer_size_ - buffer_pos_, size); | 390 delegate->ReportProgress(BINDSTATUS_SERVER_MIMETYPEAVAILABLE, |
622 if (bytes_to_copy) { | 391 suggested_mime_type_); |
623 memcpy(buffer, buffer_ + buffer_pos_, bytes_to_copy); | 392 return; |
624 *size_read = bytes_to_copy; | 393 } |
625 buffer_pos_ += bytes_to_copy; | 394 |
626 return S_OK; | 395 if (has_suggested_mime_type_) { |
627 } | 396 DLOG(INFO) << "Forwarding BINDSTATUS_MIMETYPEAVAILABLE " |
628 } | 397 << suggested_mime_type_; |
629 | 398 delegate->ReportProgress(BINDSTATUS_MIMETYPEAVAILABLE, |
630 return orig_read(protocol_, buffer, size, size_read); | 399 suggested_mime_type_); |
631 } | 400 } |
632 | 401 } |
633 scoped_refptr<ProtocolSinkWrap> ProtocolSinkWrap::InstanceFromProtocol( | 402 |
| 403 scoped_refptr<ProtData> ProtData::DataFromProtocol( |
634 IInternetProtocol* protocol) { | 404 IInternetProtocol* protocol) { |
635 CComCritSecLock<CComAutoCriticalSection> lock(sink_map_lock_); | 405 scoped_refptr<ProtData> instance; |
636 scoped_refptr<ProtocolSinkWrap> instance; | 406 AutoLock lock(datamap_lock_); |
637 ProtocolSinkMap::iterator it = sink_map_.find(protocol); | 407 ProtocolDataMap::iterator it = datamap_.find(protocol); |
638 if (sink_map_.end() != it) | 408 if (datamap_.end() != it) |
639 instance = it->second; | 409 instance = it->second; |
640 return instance; | 410 return instance; |
641 } | 411 } |
642 | 412 |
643 ScopedComPtr<IInternetProtocolSink> ProtocolSinkWrap::MaybeWrapSink( | 413 // IInternetProtocol/Ex hooks. |
644 IInternetProtocol* protocol, IInternetProtocolSink* prot_sink, | 414 STDMETHODIMP Hook_Start(InternetProtocol_Start_Fn orig_start, |
645 const wchar_t* url) { | 415 IInternetProtocol* protocol, LPCWSTR url, IInternetProtocolSink* prot_sink, |
646 ScopedComPtr<IInternetProtocolSink> sink_to_use(prot_sink); | 416 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { |
647 | 417 DCHECK(orig_start); |
648 // FYI: GUID_NULL doesn't work when the URL is being loaded from history. | 418 if (!url || !prot_sink || !bind_info) |
649 // asking for IID_IHttpNegotiate as the service id works, but | 419 return E_INVALIDARG; |
650 // getting the IWebBrowser2 interface still doesn't work. | 420 DLOG_IF(INFO, url != NULL) << "OnStart: " << url << PiFlags2Str(flags); |
651 ScopedComPtr<IHttpNegotiate> http_negotiate; | 421 |
652 HRESULT hr = DoQueryService(GUID_NULL, prot_sink, http_negotiate.Receive()); | 422 ScopedComPtr<IBindCtx> bind_ctx = BindCtxFromIBindInfo(bind_info); |
653 | 423 if (!bind_ctx) { |
654 if (http_negotiate && !IsSubFrameRequest(http_negotiate)) { | 424 // MSHTML sometimes takes a short path, skips the creation of |
655 CComObject<ProtocolSinkWrap>* wrap = NULL; | 425 // moniker and binding, by directly grabbing protocol from InternetSession |
656 CComObject<ProtocolSinkWrap>::CreateInstance(&wrap); | 426 DLOG(INFO) << "DirectBind for " << url; |
657 DCHECK(wrap); | 427 return orig_start(protocol, url, prot_sink, bind_info, flags, reserved); |
658 if (wrap) { | 428 } |
659 wrap->AddRef(); | 429 |
660 if (wrap->Initialize(protocol, prot_sink, url)) { | 430 if (IsCFRequest(bind_ctx)) { |
661 sink_to_use = wrap; | 431 return orig_start(protocol, url, prot_sink, bind_info, flags, reserved); |
662 } | 432 } |
663 wrap->Release(); | 433 |
664 } | 434 scoped_refptr<ProtData> prot_data = ProtData::DataFromProtocol(protocol); |
665 } | 435 if (prot_data) { |
666 | 436 DLOG(INFO) << "Found existing ProtData!"; |
667 return sink_to_use; | 437 prot_data->UpdateUrl(url); |
668 } | 438 ScopedComPtr<IInternetProtocolSink> new_sink = |
| 439 ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data); |
| 440 return orig_start(protocol, url, new_sink, bind_info, flags, reserved); |
| 441 } |
| 442 |
| 443 if (!ShouldWrapSink(prot_sink, url)) { |
| 444 return orig_start(protocol, url, prot_sink, bind_info, flags, reserved); |
| 445 } |
| 446 |
| 447 // Fresh request. |
| 448 InternetProtocol_Read_Fn read_fun = reinterpret_cast<InternetProtocol_Read_Fn> |
| 449 (CTransaction_PatchInfo[1].stub_->argument()); |
| 450 prot_data = new ProtData(protocol, read_fun, url); |
| 451 PutProtData(bind_ctx, prot_data); |
| 452 |
| 453 ScopedComPtr<IInternetProtocolSink> new_sink = |
| 454 ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data); |
| 455 return orig_start(protocol, url, new_sink, bind_info, flags, reserved); |
| 456 } |
| 457 |
| 458 STDMETHODIMP Hook_StartEx(InternetProtocol_StartEx_Fn orig_start_ex, |
| 459 IInternetProtocolEx* protocol, IUri* uri, IInternetProtocolSink* prot_sink, |
| 460 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { |
| 461 DCHECK(orig_start_ex); |
| 462 if (!uri || !prot_sink || !bind_info) |
| 463 return E_INVALIDARG; |
| 464 |
| 465 ScopedBstr url; |
| 466 uri->GetPropertyBSTR(Uri_PROPERTY_ABSOLUTE_URI, url.Receive(), 0); |
| 467 DLOG_IF(INFO, url != NULL) << "OnStartEx: " << url << PiFlags2Str(flags); |
| 468 |
| 469 ScopedComPtr<IBindCtx> bind_ctx = BindCtxFromIBindInfo(bind_info); |
| 470 if (!bind_ctx) { |
| 471 // MSHTML sometimes takes a short path, skips the creation of |
| 472 // moniker and binding, by directly grabbing protocol from InternetSession. |
| 473 DLOG(INFO) << "DirectBind for " << url; |
| 474 return orig_start_ex(protocol, uri, prot_sink, bind_info, flags, reserved); |
| 475 } |
| 476 |
| 477 if (IsCFRequest(bind_ctx)) { |
| 478 return orig_start_ex(protocol, uri, prot_sink, bind_info, flags, reserved); |
| 479 } |
| 480 |
| 481 scoped_refptr<ProtData> prot_data = ProtData::DataFromProtocol(protocol); |
| 482 if (prot_data) { |
| 483 DLOG(INFO) << "Found existing ProtData!"; |
| 484 prot_data->UpdateUrl(url); |
| 485 ScopedComPtr<IInternetProtocolSink> new_sink = |
| 486 ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data); |
| 487 return orig_start_ex(protocol, uri, new_sink, bind_info, flags, reserved); |
| 488 } |
| 489 |
| 490 if (!ShouldWrapSink(prot_sink, url)) { |
| 491 return orig_start_ex(protocol, uri, prot_sink, bind_info, flags, reserved); |
| 492 } |
| 493 |
| 494 // Fresh request. |
| 495 InternetProtocol_Read_Fn read_fun = reinterpret_cast<InternetProtocol_Read_Fn> |
| 496 (CTransaction_PatchInfo[1].stub_->argument()); |
| 497 prot_data = new ProtData(protocol, read_fun, url); |
| 498 PutProtData(bind_ctx, prot_data); |
| 499 |
| 500 ScopedComPtr<IInternetProtocolSink> new_sink = |
| 501 ProtocolSinkWrap::CreateNewSink(prot_sink, prot_data); |
| 502 return orig_start_ex(protocol, uri, new_sink, bind_info, flags, reserved); |
| 503 } |
| 504 |
| 505 STDMETHODIMP Hook_Read(InternetProtocol_Read_Fn orig_read, |
| 506 IInternetProtocol* protocol, void* buffer, ULONG size, ULONG* size_read) { |
| 507 DCHECK(orig_read); |
| 508 scoped_refptr<ProtData> prot_data = ProtData::DataFromProtocol(protocol); |
| 509 if (!prot_data) { |
| 510 return orig_read(protocol, buffer, size, size_read); |
| 511 } |
| 512 |
| 513 HRESULT hr = prot_data->Read(buffer, size, size_read); |
| 514 return hr; |
| 515 } |
| 516 |
| 517 // Patching / Hooking code. |
| 518 class FakeProtocol : public CComObjectRootEx<CComSingleThreadModel>, |
| 519 public IInternetProtocol { |
| 520 public: |
| 521 BEGIN_COM_MAP(FakeProtocol) |
| 522 COM_INTERFACE_ENTRY(IInternetProtocol) |
| 523 COM_INTERFACE_ENTRY(IInternetProtocolRoot) |
| 524 END_COM_MAP() |
| 525 |
| 526 STDMETHOD(Start)(LPCWSTR url, IInternetProtocolSink *protocol_sink, |
| 527 IInternetBindInfo* bind_info, DWORD flags, HANDLE_PTR reserved) { |
| 528 transaction_.QueryFrom(protocol_sink); |
| 529 // Return some unusual error code. |
| 530 return INET_E_INVALID_CERTIFICATE; |
| 531 } |
| 532 |
| 533 STDMETHOD(Continue)(PROTOCOLDATA* protocol_data) { return S_OK; } |
| 534 STDMETHOD(Abort)(HRESULT reason, DWORD options) { return S_OK; } |
| 535 STDMETHOD(Terminate)(DWORD options) { return S_OK; } |
| 536 STDMETHOD(Suspend)() { return S_OK; } |
| 537 STDMETHOD(Resume)() { return S_OK; } |
| 538 STDMETHOD(Read)(void *buffer, ULONG size, ULONG* size_read) { return S_OK; } |
| 539 STDMETHOD(Seek)(LARGE_INTEGER move, DWORD origin, ULARGE_INTEGER* new_pos) |
| 540 { return S_OK; } |
| 541 STDMETHOD(LockRequest)(DWORD options) { return S_OK; } |
| 542 STDMETHOD(UnlockRequest)() { return S_OK; } |
| 543 |
| 544 ScopedComPtr<IInternetProtocol> transaction_; |
| 545 }; |
| 546 |
| 547 struct FakeFactory : public IClassFactory, |
| 548 public CComObjectRootEx<CComSingleThreadModel> { |
| 549 BEGIN_COM_MAP(FakeFactory) |
| 550 COM_INTERFACE_ENTRY(IClassFactory) |
| 551 END_COM_MAP() |
| 552 |
| 553 STDMETHOD(CreateInstance)(IUnknown *pUnkOuter, REFIID riid, void **ppvObj) { |
| 554 if (pUnkOuter) |
| 555 return CLASS_E_NOAGGREGATION; |
| 556 HRESULT hr = obj_->QueryInterface(riid, ppvObj); |
| 557 return hr; |
| 558 } |
| 559 |
| 560 STDMETHOD(LockServer)(BOOL fLock) { |
| 561 return S_OK; |
| 562 } |
| 563 |
| 564 IUnknown* obj_; |
| 565 }; |
| 566 |
| 567 static void HookTransactionVtable(IInternetProtocol* p) { |
| 568 ScopedComPtr<IInternetProtocolEx> ex; |
| 569 ex.QueryFrom(p); |
| 570 |
| 571 HRESULT hr = vtable_patch::PatchInterfaceMethods(p, CTransaction_PatchInfo); |
| 572 if (hr == S_OK && ex) { |
| 573 vtable_patch::PatchInterfaceMethods(ex.get(), CTransaction2_PatchInfo); |
| 574 } |
| 575 } |
| 576 |
| 577 void TransactionHooks::InstallHooks() { |
| 578 if (IS_PATCHED(CTransaction)) { |
| 579 DLOG(WARNING) << __FUNCTION__ << " called more than once."; |
| 580 return; |
| 581 } |
| 582 |
| 583 CComObjectStackEx<FakeProtocol> prot; |
| 584 CComObjectStackEx<FakeFactory> factory; |
| 585 factory.obj_ = &prot; |
| 586 ScopedComPtr<IInternetSession> session; |
| 587 HRESULT hr = ::CoInternetGetSession(0, session.Receive(), 0); |
| 588 hr = session->RegisterNameSpace(&factory, CLSID_NULL, L"611", 0, 0, 0); |
| 589 DLOG_IF(FATAL, FAILED(hr)) << "Failed to register namespace"; |
| 590 if (hr != S_OK) |
| 591 return; |
| 592 |
| 593 do { |
| 594 ScopedComPtr<IMoniker> mk; |
| 595 ScopedComPtr<IBindCtx> bc; |
| 596 ScopedComPtr<IStream> stream; |
| 597 hr = ::CreateAsyncBindCtxEx(0, 0, 0, 0, bc.Receive(), 0); |
| 598 DLOG_IF(FATAL, FAILED(hr)) << "CreateAsyncBindCtxEx failed " << hr; |
| 599 if (hr != S_OK) |
| 600 break; |
| 601 |
| 602 hr = ::CreateURLMoniker(NULL, L"611://512", mk.Receive()); |
| 603 DLOG_IF(FATAL, FAILED(hr)) << "CreateURLMoniker failed " << hr; |
| 604 if (hr != S_OK) |
| 605 break; |
| 606 |
| 607 hr = mk->BindToStorage(bc, NULL, IID_IStream, |
| 608 reinterpret_cast<void**>(stream.Receive())); |
| 609 DLOG_IF(FATAL, hr != INET_E_INVALID_CERTIFICATE) << |
| 610 "BindToStorage failed " << hr; |
| 611 } while (0); |
| 612 |
| 613 hr = session->UnregisterNameSpace(&factory, L"611"); |
| 614 if (prot.transaction_) { |
| 615 HookTransactionVtable(prot.transaction_); |
| 616 // Explicit release, otherwise ~CComObjectStackEx will complain about |
| 617 // outstanding reference to us, because it runs before ~FakeProtocol |
| 618 prot.transaction_.Release(); |
| 619 } |
| 620 } |
| 621 |
| 622 void TransactionHooks::RevertHooks() { |
| 623 vtable_patch::UnpatchInterfaceMethods(CTransaction_PatchInfo); |
| 624 vtable_patch::UnpatchInterfaceMethods(CTransaction2_PatchInfo); |
| 625 } |
OLD | NEW |