Index: content/browser/renderer_host/websocket_host.cc |
diff --git a/content/browser/renderer_host/websocket_host.cc b/content/browser/renderer_host/websocket_host.cc |
index ed398f8d1e4457fa4c5776d5494d6f2df39fe8e7..c5eb0e7ef64dafa9b268371bd240a2f6dd36ce49 100644 |
--- a/content/browser/renderer_host/websocket_host.cc |
+++ b/content/browser/renderer_host/websocket_host.cc |
@@ -5,13 +5,17 @@ |
#include "content/browser/renderer_host/websocket_host.h" |
#include "base/basictypes.h" |
+#include "base/memory/weak_ptr.h" |
#include "base/strings/string_util.h" |
#include "content/browser/renderer_host/websocket_dispatcher_host.h" |
+#include "content/browser/ssl/ssl_error_handler.h" |
+#include "content/browser/ssl/ssl_manager.h" |
#include "content/common/websocket_messages.h" |
#include "ipc/ipc_message_macros.h" |
#include "net/http/http_request_headers.h" |
#include "net/http/http_response_headers.h" |
#include "net/http/http_util.h" |
+#include "net/ssl/ssl_info.h" |
#include "net/websockets/websocket_channel.h" |
#include "net/websockets/websocket_event_interface.h" |
#include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode |
@@ -80,7 +84,9 @@ ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) { |
// renderer or child process via WebSocketDispatcherHost. |
class WebSocketEventHandler : public net::WebSocketEventInterface { |
public: |
- WebSocketEventHandler(WebSocketDispatcherHost* dispatcher, int routing_id); |
+ WebSocketEventHandler(WebSocketDispatcherHost* dispatcher, |
+ int routing_id, |
+ int render_frame_id); |
virtual ~WebSocketEventHandler(); |
// net::WebSocketEventInterface implementation |
@@ -102,18 +108,50 @@ class WebSocketEventHandler : public net::WebSocketEventInterface { |
scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE; |
virtual ChannelState OnFinishOpeningHandshake( |
scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE; |
+ virtual ChannelState OnSSLCertificateError( |
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, |
+ const GURL& url, |
+ const net::SSLInfo& ssl_info, |
+ bool fatal) OVERRIDE; |
private: |
+ class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate { |
+ public: |
+ SSLErrorHandlerDelegate( |
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks); |
+ virtual ~SSLErrorHandlerDelegate(); |
+ |
+ base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr(); |
+ |
+ // SSLErrorHandler::Delegate methods |
+ virtual void CancelSSLRequest(const GlobalRequestID& id, |
+ int error, |
+ const net::SSLInfo* ssl_info) OVERRIDE; |
+ virtual void ContinueSSLRequest(const GlobalRequestID& id) OVERRIDE; |
+ |
+ private: |
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_; |
+ base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_; |
+ |
+ DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate); |
+ }; |
+ |
WebSocketDispatcherHost* const dispatcher_; |
const int routing_id_; |
+ const int render_frame_id_; |
+ scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_; |
DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler); |
}; |
WebSocketEventHandler::WebSocketEventHandler( |
WebSocketDispatcherHost* dispatcher, |
- int routing_id) |
- : dispatcher_(dispatcher), routing_id_(routing_id) {} |
+ int routing_id, |
+ int render_frame_id) |
+ : dispatcher_(dispatcher), |
+ routing_id_(routing_id), |
+ render_frame_id_(render_frame_id) { |
+} |
WebSocketEventHandler::~WebSocketEventHandler() { |
DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_; |
@@ -227,18 +265,67 @@ ChannelState WebSocketEventHandler::OnFinishOpeningHandshake( |
response_to_pass)); |
} |
+ChannelState WebSocketEventHandler::OnSSLCertificateError( |
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, |
+ const GURL& url, |
+ const net::SSLInfo& ssl_info, |
+ bool fatal) { |
+ DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError" |
+ << " routing_id=" << routing_id_ << " url=" << url.spec() |
+ << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal; |
+ ssl_error_handler_delegate_.reset( |
+ new SSLErrorHandlerDelegate(callbacks.Pass())); |
+ // We don't need request_id to be unique so just make a fake one. |
+ GlobalRequestID request_id(-1, -1); |
+ SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(), |
+ request_id, |
+ ResourceType::SUB_RESOURCE, |
+ url, |
+ dispatcher_->render_process_id(), |
+ render_frame_id_, |
+ ssl_info, |
+ fatal); |
+ // The above method is always asynchronous. |
+ return WebSocketEventInterface::CHANNEL_ALIVE; |
+} |
+ |
+WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate( |
+ scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks) |
+ : callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {} |
+ |
+WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {} |
+ |
+base::WeakPtr<SSLErrorHandler::Delegate> |
+WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() { |
+ return weak_ptr_factory_.GetWeakPtr(); |
+} |
+ |
+void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest( |
+ const GlobalRequestID& id, |
+ int error, |
+ const net::SSLInfo* ssl_info) { |
+ DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest" |
+ << " error=" << error |
+ << " cert_status=" << (ssl_info ? ssl_info->cert_status |
+ : static_cast<net::CertStatus>(-1)); |
+ callbacks_->CancelSSLRequest(error, ssl_info); |
+} |
+ |
+void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest( |
+ const GlobalRequestID& id) { |
+ DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest"; |
+ callbacks_->ContinueSSLRequest(); |
+} |
+ |
} // namespace |
WebSocketHost::WebSocketHost(int routing_id, |
WebSocketDispatcherHost* dispatcher, |
net::URLRequestContext* url_request_context) |
- : routing_id_(routing_id) { |
+ : dispatcher_(dispatcher), |
+ url_request_context_(url_request_context), |
+ routing_id_(routing_id) { |
DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id; |
- |
- scoped_ptr<net::WebSocketEventInterface> event_interface( |
- new WebSocketEventHandler(dispatcher, routing_id)); |
- channel_.reset( |
- new net::WebSocketChannel(event_interface.Pass(), url_request_context)); |
} |
WebSocketHost::~WebSocketHost() {} |
@@ -258,15 +345,19 @@ bool WebSocketHost::OnMessageReceived(const IPC::Message& message) { |
void WebSocketHost::OnAddChannelRequest( |
const GURL& socket_url, |
const std::vector<std::string>& requested_protocols, |
- const url::Origin& origin) { |
+ const url::Origin& origin, |
+ int render_frame_id) { |
DVLOG(3) << "WebSocketHost::OnAddChannelRequest" |
<< " routing_id=" << routing_id_ << " socket_url=\"" << socket_url |
<< "\" requested_protocols=\"" |
<< JoinString(requested_protocols, ", ") << "\" origin=\"" |
<< origin.string() << "\""; |
yhirano
2014/06/03 05:44:51
DCHECK(!channel_)
Adam Rice
2014/06/03 06:36:50
Done. I also changed the if() conditions in the ot
|
- channel_->SendAddChannelRequest( |
- socket_url, requested_protocols, origin); |
+ scoped_ptr<net::WebSocketEventInterface> event_interface( |
+ new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id)); |
+ channel_.reset( |
+ new net::WebSocketChannel(event_interface.Pass(), url_request_context_)); |
+ channel_->SendAddChannelRequest(socket_url, requested_protocols, origin); |
} |
void WebSocketHost::OnSendFrame(bool fin, |
@@ -276,14 +367,16 @@ void WebSocketHost::OnSendFrame(bool fin, |
<< " routing_id=" << routing_id_ << " fin=" << fin |
<< " type=" << type << " data is " << data.size() << " bytes"; |
- channel_->SendFrame(fin, MessageTypeToOpCode(type), data); |
+ if (channel_) |
+ channel_->SendFrame(fin, MessageTypeToOpCode(type), data); |
} |
void WebSocketHost::OnFlowControl(int64 quota) { |
DVLOG(3) << "WebSocketHost::OnFlowControl" |
<< " routing_id=" << routing_id_ << " quota=" << quota; |
- channel_->SendFlowControl(quota); |
+ if (channel_) |
+ channel_->SendFlowControl(quota); |
} |
void WebSocketHost::OnDropChannel(bool was_clean, |
@@ -293,8 +386,10 @@ void WebSocketHost::OnDropChannel(bool was_clean, |
<< " routing_id=" << routing_id_ << " was_clean=" << was_clean |
<< " code=" << code << " reason=\"" << reason << "\""; |
- // TODO(yhirano): Handle |was_clean| appropriately. |
- channel_->StartClosingHandshake(code, reason); |
+ if (channel_) { |
+ // TODO(yhirano): Handle |was_clean| appropriately. |
+ channel_->StartClosingHandshake(code, reason); |
+ } |
} |
} // namespace content |