Index: mojo/public/cpp/bindings/lib/interface_endpoint_client.cc |
diff --git a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc |
index a9eee86b6aa27ec93664b1937875287e68996ef2..2eab43036a83a24b9bb24d99d4ee6ecbb41039b5 100644 |
--- a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc |
+++ b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc |
@@ -137,48 +137,43 @@ InterfaceEndpointClient::InterfaceEndpointClient( |
bool expect_sync_requests, |
scoped_refptr<base::SingleThreadTaskRunner> runner, |
uint32_t interface_version) |
- : handle_(std::move(handle)), |
+ : expect_sync_requests_(expect_sync_requests), |
+ handle_(std::move(handle)), |
incoming_receiver_(receiver), |
thunk_(this), |
filters_(&thunk_), |
- next_request_id_(1), |
- encountered_error_(false), |
task_runner_(std::move(runner)), |
control_message_proxy_(this), |
control_message_handler_(interface_version), |
weak_ptr_factory_(this) { |
DCHECK(handle_.is_valid()); |
- DCHECK(handle_.is_local()); |
// TODO(yzshen): the way to use validator (or message filter in general) |
// directly is a little awkward. |
if (payload_validator) |
filters_.Append(std::move(payload_validator)); |
- controller_ = handle_.group_controller()->AttachEndpointClient( |
- handle_, this, task_runner_); |
- if (expect_sync_requests) |
- controller_->AllowWokenUpBySyncWatchOnSameThread(); |
+ if (handle_.pending_association()) { |
+ handle_.SetAssociationEventHandler(base::Bind( |
+ &InterfaceEndpointClient::OnAssociationEvent, base::Unretained(this))); |
+ } else { |
+ InitControllerIfNecessary(); |
+ } |
} |
InterfaceEndpointClient::~InterfaceEndpointClient() { |
DCHECK(thread_checker_.CalledOnValidThread()); |
- if (handle_.is_valid()) |
+ if (controller_) |
handle_.group_controller()->DetachEndpointClient(handle_); |
} |
AssociatedGroup* InterfaceEndpointClient::associated_group() { |
if (!associated_group_) |
- associated_group_ = handle_.group_controller()->CreateAssociatedGroup(); |
+ associated_group_ = base::MakeUnique<AssociatedGroup>(handle_); |
return associated_group_.get(); |
} |
-uint32_t InterfaceEndpointClient::interface_id() const { |
- DCHECK(thread_checker_.CalledOnValidThread()); |
- return handle_.id(); |
-} |
- |
ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() { |
DCHECK(thread_checker_.CalledOnValidThread()); |
DCHECK(!has_pending_responders()); |
@@ -186,8 +181,13 @@ ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() { |
if (!handle_.is_valid()) |
return ScopedInterfaceEndpointHandle(); |
- controller_ = nullptr; |
- handle_.group_controller()->DetachEndpointClient(handle_); |
+ handle_.SetAssociationEventHandler( |
+ ScopedInterfaceEndpointHandle::AssociationEventCallback()); |
+ |
+ if (controller_) { |
+ controller_ = nullptr; |
+ handle_.group_controller()->DetachEndpointClient(handle_); |
+ } |
return std::move(handle_); |
} |
@@ -200,7 +200,8 @@ void InterfaceEndpointClient::AddFilter( |
void InterfaceEndpointClient::RaiseError() { |
DCHECK(thread_checker_.CalledOnValidThread()); |
- handle_.group_controller()->RaiseError(); |
+ if (!handle_.pending_association()) |
+ handle_.group_controller()->RaiseError(); |
} |
void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason, |
@@ -213,24 +214,40 @@ void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason, |
bool InterfaceEndpointClient::Accept(Message* message) { |
DCHECK(thread_checker_.CalledOnValidThread()); |
- DCHECK(controller_); |
DCHECK(!message->has_flag(Message::kFlagExpectsResponse)); |
+ DCHECK(!handle_.pending_association()); |
+ |
+ // This has to been done even if connection error has occurred. For example, |
+ // the message contains a pending associated request. The user may try to use |
+ // the corresponding associated interface pointer after sending this message. |
+ // That associated interface pointer has to join an associated group in order |
+ // to work properly. |
+ if (!message->associated_endpoint_handles()->empty()) |
+ message->SerializeAssociatedEndpointHandles(handle_.group_controller()); |
if (encountered_error_) |
return false; |
+ InitControllerIfNecessary(); |
+ |
return controller_->SendMessage(message); |
} |
bool InterfaceEndpointClient::AcceptWithResponder(Message* message, |
MessageReceiver* responder) { |
DCHECK(thread_checker_.CalledOnValidThread()); |
- DCHECK(controller_); |
DCHECK(message->has_flag(Message::kFlagExpectsResponse)); |
+ DCHECK(!handle_.pending_association()); |
+ |
+ // Please see comments in Accept(). |
+ if (!message->associated_endpoint_handles()->empty()) |
+ message->SerializeAssociatedEndpointHandles(handle_.group_controller()); |
if (encountered_error_) |
return false; |
+ InitControllerIfNecessary(); |
+ |
// Reserve 0 in case we want it to convey special meaning in the future. |
uint64_t request_id = next_request_id_++; |
if (request_id == 0) |
@@ -305,6 +322,42 @@ void InterfaceEndpointClient::NotifyError( |
} |
} |
+void InterfaceEndpointClient::QueryVersion( |
+ const base::Callback<void(uint32_t)>& callback) { |
+ control_message_proxy_.QueryVersion(callback); |
+} |
+ |
+void InterfaceEndpointClient::RequireVersion(uint32_t version) { |
+ control_message_proxy_.RequireVersion(version); |
+} |
+ |
+void InterfaceEndpointClient::FlushForTesting() { |
+ control_message_proxy_.FlushForTesting(); |
+} |
+ |
+void InterfaceEndpointClient::InitControllerIfNecessary() { |
+ if (controller_ || handle_.pending_association()) |
+ return; |
+ |
+ controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this, |
+ task_runner_); |
+ if (expect_sync_requests_) |
+ controller_->AllowWokenUpBySyncWatchOnSameThread(); |
+} |
+ |
+void InterfaceEndpointClient::OnAssociationEvent( |
+ ScopedInterfaceEndpointHandle::AssociationEvent event) { |
+ if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) { |
+ InitControllerIfNecessary(); |
+ } else if (event == |
+ ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) { |
+ task_runner_->PostTask(FROM_HERE, |
+ base::Bind(&InterfaceEndpointClient::NotifyError, |
+ weak_ptr_factory_.GetWeakPtr(), |
+ handle_.disconnect_reason())); |
+ } |
+} |
+ |
bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) { |
DCHECK_EQ(handle_.id(), message->interface_id()); |
DCHECK(!encountered_error_); |