| OLD | NEW |
| 1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "google_apis/gcm/engine/connection_handler.h" | 5 #include "google_apis/gcm/engine/connection_handler.h" |
| 6 | 6 |
| 7 #include "base/message_loop/message_loop.h" | |
| 8 #include "google/protobuf/io/coded_stream.h" | |
| 9 #include "google_apis/gcm/base/mcs_util.h" | |
| 10 #include "google_apis/gcm/base/socket_stream.h" | |
| 11 #include "net/base/net_errors.h" | |
| 12 #include "net/socket/stream_socket.h" | |
| 13 | |
| 14 using namespace google::protobuf::io; | |
| 15 | |
| 16 namespace gcm { | 7 namespace gcm { |
| 17 | 8 |
| 18 namespace { | 9 ConnectionHandler::ConnectionHandler() { |
| 19 | |
| 20 // # of bytes a MCS version packet consumes. | |
| 21 const int kVersionPacketLen = 1; | |
| 22 // # of bytes a tag packet consumes. | |
| 23 const int kTagPacketLen = 1; | |
| 24 // Max # of bytes a length packet consumes. | |
| 25 const int kSizePacketLenMin = 1; | |
| 26 const int kSizePacketLenMax = 2; | |
| 27 | |
| 28 // The current MCS protocol version. | |
| 29 const int kMCSVersion = 38; | |
| 30 | |
| 31 } // namespace | |
| 32 | |
| 33 ConnectionHandler::ConnectionHandler(base::TimeDelta read_timeout) | |
| 34 : read_timeout_(read_timeout), | |
| 35 handshake_complete_(false), | |
| 36 message_tag_(0), | |
| 37 message_size_(0), | |
| 38 weak_ptr_factory_(this) { | |
| 39 } | 10 } |
| 40 | 11 |
| 41 ConnectionHandler::~ConnectionHandler() { | 12 ConnectionHandler::~ConnectionHandler() { |
| 42 } | 13 } |
| 43 | 14 |
| 44 void ConnectionHandler::Init( | |
| 45 scoped_ptr<net::StreamSocket> socket, | |
| 46 const google::protobuf::MessageLite& login_request, | |
| 47 const ProtoReceivedCallback& read_callback, | |
| 48 const ProtoSentCallback& write_callback, | |
| 49 const ConnectionChangedCallback& connection_callback) { | |
| 50 DCHECK(!read_callback.is_null()); | |
| 51 DCHECK(!write_callback.is_null()); | |
| 52 DCHECK(!connection_callback.is_null()); | |
| 53 | |
| 54 // Invalidate any previously outstanding reads. | |
| 55 weak_ptr_factory_.InvalidateWeakPtrs(); | |
| 56 | |
| 57 handshake_complete_ = false; | |
| 58 message_tag_ = 0; | |
| 59 message_size_ = 0; | |
| 60 socket_ = socket.Pass(); | |
| 61 input_stream_.reset(new SocketInputStream(socket_.get())); | |
| 62 output_stream_.reset(new SocketOutputStream(socket_.get())); | |
| 63 read_callback_ = read_callback; | |
| 64 write_callback_ = write_callback; | |
| 65 connection_callback_ = connection_callback; | |
| 66 | |
| 67 Login(login_request); | |
| 68 } | |
| 69 | |
| 70 bool ConnectionHandler::CanSendMessage() const { | |
| 71 return handshake_complete_ && output_stream_.get() && | |
| 72 output_stream_->GetState() == SocketOutputStream::EMPTY; | |
| 73 } | |
| 74 | |
| 75 void ConnectionHandler::SendMessage( | |
| 76 const google::protobuf::MessageLite& message) { | |
| 77 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); | |
| 78 DCHECK(handshake_complete_); | |
| 79 | |
| 80 { | |
| 81 CodedOutputStream coded_output_stream(output_stream_.get()); | |
| 82 DVLOG(1) << "Writing proto of size " << message.ByteSize(); | |
| 83 int tag = GetMCSProtoTag(message); | |
| 84 DCHECK_NE(tag, -1); | |
| 85 coded_output_stream.WriteRaw(&tag, 1); | |
| 86 coded_output_stream.WriteVarint32(message.ByteSize()); | |
| 87 message.SerializeToCodedStream(&coded_output_stream); | |
| 88 } | |
| 89 | |
| 90 if (output_stream_->Flush( | |
| 91 base::Bind(&ConnectionHandler::OnMessageSent, | |
| 92 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { | |
| 93 OnMessageSent(); | |
| 94 } | |
| 95 } | |
| 96 | |
| 97 void ConnectionHandler::Login( | |
| 98 const google::protobuf::MessageLite& login_request) { | |
| 99 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); | |
| 100 | |
| 101 const char version_byte[1] = {kMCSVersion}; | |
| 102 const char login_request_tag[1] = {kLoginRequestTag}; | |
| 103 { | |
| 104 CodedOutputStream coded_output_stream(output_stream_.get()); | |
| 105 coded_output_stream.WriteRaw(version_byte, 1); | |
| 106 coded_output_stream.WriteRaw(login_request_tag, 1); | |
| 107 coded_output_stream.WriteVarint32(login_request.ByteSize()); | |
| 108 login_request.SerializeToCodedStream(&coded_output_stream); | |
| 109 } | |
| 110 | |
| 111 if (output_stream_->Flush( | |
| 112 base::Bind(&ConnectionHandler::OnMessageSent, | |
| 113 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { | |
| 114 base::MessageLoop::current()->PostTask( | |
| 115 FROM_HERE, | |
| 116 base::Bind(&ConnectionHandler::OnMessageSent, | |
| 117 weak_ptr_factory_.GetWeakPtr())); | |
| 118 } | |
| 119 | |
| 120 read_timeout_timer_.Start(FROM_HERE, | |
| 121 read_timeout_, | |
| 122 base::Bind(&ConnectionHandler::OnTimeout, | |
| 123 weak_ptr_factory_.GetWeakPtr())); | |
| 124 WaitForData(MCS_VERSION_TAG_AND_SIZE); | |
| 125 } | |
| 126 | |
| 127 void ConnectionHandler::OnMessageSent() { | |
| 128 if (!output_stream_.get()) { | |
| 129 // The connection has already been closed. Just return. | |
| 130 DCHECK(!input_stream_.get()); | |
| 131 DCHECK(!read_timeout_timer_.IsRunning()); | |
| 132 return; | |
| 133 } | |
| 134 | |
| 135 if (output_stream_->GetState() != SocketOutputStream::EMPTY) { | |
| 136 int last_error = output_stream_->last_error(); | |
| 137 CloseConnection(); | |
| 138 // If the socket stream had an error, plumb it up, else plumb up FAILED. | |
| 139 if (last_error == net::OK) | |
| 140 last_error = net::ERR_FAILED; | |
| 141 connection_callback_.Run(last_error); | |
| 142 return; | |
| 143 } | |
| 144 | |
| 145 write_callback_.Run(); | |
| 146 } | |
| 147 | |
| 148 void ConnectionHandler::GetNextMessage() { | |
| 149 DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() || | |
| 150 SocketInputStream::READY == input_stream_->GetState()); | |
| 151 message_tag_ = 0; | |
| 152 message_size_ = 0; | |
| 153 | |
| 154 WaitForData(MCS_TAG_AND_SIZE); | |
| 155 } | |
| 156 | |
| 157 void ConnectionHandler::WaitForData(ProcessingState state) { | |
| 158 DVLOG(1) << "Waiting for MCS data: state == " << state; | |
| 159 | |
| 160 if (!input_stream_) { | |
| 161 // The connection has already been closed. Just return. | |
| 162 DCHECK(!output_stream_.get()); | |
| 163 DCHECK(!read_timeout_timer_.IsRunning()); | |
| 164 return; | |
| 165 } | |
| 166 | |
| 167 if (input_stream_->GetState() != SocketInputStream::EMPTY && | |
| 168 input_stream_->GetState() != SocketInputStream::READY) { | |
| 169 // An error occurred. | |
| 170 int last_error = output_stream_->last_error(); | |
| 171 CloseConnection(); | |
| 172 // If the socket stream had an error, plumb it up, else plumb up FAILED. | |
| 173 if (last_error == net::OK) | |
| 174 last_error = net::ERR_FAILED; | |
| 175 connection_callback_.Run(last_error); | |
| 176 return; | |
| 177 } | |
| 178 | |
| 179 // Used to determine whether a Socket::Read is necessary. | |
| 180 int min_bytes_needed = 0; | |
| 181 // Used to limit the size of the Socket::Read. | |
| 182 int max_bytes_needed = 0; | |
| 183 | |
| 184 switch(state) { | |
| 185 case MCS_VERSION_TAG_AND_SIZE: | |
| 186 min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin; | |
| 187 max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax; | |
| 188 break; | |
| 189 case MCS_TAG_AND_SIZE: | |
| 190 min_bytes_needed = kTagPacketLen + kSizePacketLenMin; | |
| 191 max_bytes_needed = kTagPacketLen + kSizePacketLenMax; | |
| 192 break; | |
| 193 case MCS_FULL_SIZE: | |
| 194 // If in this state, the minimum size packet length must already have been | |
| 195 // insufficient, so set both to the max length. | |
| 196 min_bytes_needed = kSizePacketLenMax; | |
| 197 max_bytes_needed = kSizePacketLenMax; | |
| 198 break; | |
| 199 case MCS_PROTO_BYTES: | |
| 200 read_timeout_timer_.Reset(); | |
| 201 // No variability in the message size, set both to the same. | |
| 202 min_bytes_needed = message_size_; | |
| 203 max_bytes_needed = message_size_; | |
| 204 break; | |
| 205 default: | |
| 206 NOTREACHED(); | |
| 207 } | |
| 208 DCHECK_GE(max_bytes_needed, min_bytes_needed); | |
| 209 | |
| 210 int byte_count = input_stream_->UnreadByteCount(); | |
| 211 if (min_bytes_needed - byte_count > 0 && | |
| 212 input_stream_->Refresh( | |
| 213 base::Bind(&ConnectionHandler::WaitForData, | |
| 214 weak_ptr_factory_.GetWeakPtr(), | |
| 215 state), | |
| 216 max_bytes_needed - byte_count) == net::ERR_IO_PENDING) { | |
| 217 return; | |
| 218 } | |
| 219 | |
| 220 // Check for refresh errors. | |
| 221 if (input_stream_->GetState() != SocketInputStream::READY) { | |
| 222 // An error occurred. | |
| 223 int last_error = output_stream_->last_error(); | |
| 224 CloseConnection(); | |
| 225 // If the socket stream had an error, plumb it up, else plumb up FAILED. | |
| 226 if (last_error == net::OK) | |
| 227 last_error = net::ERR_FAILED; | |
| 228 connection_callback_.Run(last_error); | |
| 229 return; | |
| 230 } | |
| 231 | |
| 232 // Received enough bytes, process them. | |
| 233 DVLOG(1) << "Processing MCS data: state == " << state; | |
| 234 switch(state) { | |
| 235 case MCS_VERSION_TAG_AND_SIZE: | |
| 236 OnGotVersion(); | |
| 237 break; | |
| 238 case MCS_TAG_AND_SIZE: | |
| 239 OnGotMessageTag(); | |
| 240 break; | |
| 241 case MCS_FULL_SIZE: | |
| 242 OnGotMessageSize(); | |
| 243 break; | |
| 244 case MCS_PROTO_BYTES: | |
| 245 OnGotMessageBytes(); | |
| 246 break; | |
| 247 default: | |
| 248 NOTREACHED(); | |
| 249 } | |
| 250 } | |
| 251 | |
| 252 void ConnectionHandler::OnGotVersion() { | |
| 253 uint8 version = 0; | |
| 254 { | |
| 255 CodedInputStream coded_input_stream(input_stream_.get()); | |
| 256 coded_input_stream.ReadRaw(&version, 1); | |
| 257 } | |
| 258 if (version < kMCSVersion) { | |
| 259 LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version); | |
| 260 connection_callback_.Run(net::ERR_FAILED); | |
| 261 return; | |
| 262 } | |
| 263 | |
| 264 input_stream_->RebuildBuffer(); | |
| 265 | |
| 266 // Process the LoginResponse message tag. | |
| 267 OnGotMessageTag(); | |
| 268 } | |
| 269 | |
| 270 void ConnectionHandler::OnGotMessageTag() { | |
| 271 if (input_stream_->GetState() != SocketInputStream::READY) { | |
| 272 LOG(ERROR) << "Failed to receive protobuf tag."; | |
| 273 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); | |
| 274 return; | |
| 275 } | |
| 276 | |
| 277 { | |
| 278 CodedInputStream coded_input_stream(input_stream_.get()); | |
| 279 coded_input_stream.ReadRaw(&message_tag_, 1); | |
| 280 } | |
| 281 | |
| 282 DVLOG(1) << "Received proto of type " | |
| 283 << static_cast<unsigned int>(message_tag_); | |
| 284 | |
| 285 if (!read_timeout_timer_.IsRunning()) { | |
| 286 read_timeout_timer_.Start(FROM_HERE, | |
| 287 read_timeout_, | |
| 288 base::Bind(&ConnectionHandler::OnTimeout, | |
| 289 weak_ptr_factory_.GetWeakPtr())); | |
| 290 } | |
| 291 OnGotMessageSize(); | |
| 292 } | |
| 293 | |
| 294 void ConnectionHandler::OnGotMessageSize() { | |
| 295 if (input_stream_->GetState() != SocketInputStream::READY) { | |
| 296 LOG(ERROR) << "Failed to receive message size."; | |
| 297 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); | |
| 298 return; | |
| 299 } | |
| 300 | |
| 301 bool need_another_byte = false; | |
| 302 int prev_byte_count = input_stream_->ByteCount(); | |
| 303 { | |
| 304 CodedInputStream coded_input_stream(input_stream_.get()); | |
| 305 if (!coded_input_stream.ReadVarint32(&message_size_)) | |
| 306 need_another_byte = true; | |
| 307 } | |
| 308 | |
| 309 if (need_another_byte) { | |
| 310 DVLOG(1) << "Expecting another message size byte."; | |
| 311 if (prev_byte_count >= kSizePacketLenMax) { | |
| 312 // Already had enough bytes, something else went wrong. | |
| 313 LOG(ERROR) << "Failed to process message size."; | |
| 314 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); | |
| 315 return; | |
| 316 } | |
| 317 // Back up by the amount read (should always be 1 byte). | |
| 318 int bytes_read = prev_byte_count - input_stream_->ByteCount(); | |
| 319 DCHECK_EQ(bytes_read, 1); | |
| 320 input_stream_->BackUp(bytes_read); | |
| 321 WaitForData(MCS_FULL_SIZE); | |
| 322 return; | |
| 323 } | |
| 324 | |
| 325 DVLOG(1) << "Proto size: " << message_size_; | |
| 326 | |
| 327 if (message_size_ > 0) | |
| 328 WaitForData(MCS_PROTO_BYTES); | |
| 329 else | |
| 330 OnGotMessageBytes(); | |
| 331 } | |
| 332 | |
| 333 void ConnectionHandler::OnGotMessageBytes() { | |
| 334 read_timeout_timer_.Stop(); | |
| 335 scoped_ptr<google::protobuf::MessageLite> protobuf( | |
| 336 BuildProtobufFromTag(message_tag_)); | |
| 337 // Messages with no content are valid; just use the default protobuf for | |
| 338 // that tag. | |
| 339 if (protobuf.get() && message_size_ == 0) { | |
| 340 base::MessageLoop::current()->PostTask( | |
| 341 FROM_HERE, | |
| 342 base::Bind(&ConnectionHandler::GetNextMessage, | |
| 343 weak_ptr_factory_.GetWeakPtr())); | |
| 344 read_callback_.Run(protobuf.Pass()); | |
| 345 return; | |
| 346 } | |
| 347 | |
| 348 if (!protobuf.get() || | |
| 349 input_stream_->GetState() != SocketInputStream::READY) { | |
| 350 LOG(ERROR) << "Failed to extract protobuf bytes of type " | |
| 351 << static_cast<unsigned int>(message_tag_); | |
| 352 protobuf.reset(); // Return a null pointer to denote an error. | |
| 353 read_callback_.Run(protobuf.Pass()); | |
| 354 return; | |
| 355 } | |
| 356 | |
| 357 { | |
| 358 CodedInputStream coded_input_stream(input_stream_.get()); | |
| 359 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { | |
| 360 NOTREACHED() << "Unable to parse GCM message of type " | |
| 361 << static_cast<unsigned int>(message_tag_); | |
| 362 protobuf.reset(); // Return a null pointer to denote an error. | |
| 363 read_callback_.Run(protobuf.Pass()); | |
| 364 return; | |
| 365 } | |
| 366 } | |
| 367 | |
| 368 input_stream_->RebuildBuffer(); | |
| 369 base::MessageLoop::current()->PostTask( | |
| 370 FROM_HERE, | |
| 371 base::Bind(&ConnectionHandler::GetNextMessage, | |
| 372 weak_ptr_factory_.GetWeakPtr())); | |
| 373 if (message_tag_ == kLoginResponseTag) { | |
| 374 if (handshake_complete_) { | |
| 375 LOG(ERROR) << "Unexpected login response."; | |
| 376 } else { | |
| 377 handshake_complete_ = true; | |
| 378 DVLOG(1) << "GCM Handshake complete."; | |
| 379 } | |
| 380 } | |
| 381 read_callback_.Run(protobuf.Pass()); | |
| 382 } | |
| 383 | |
| 384 void ConnectionHandler::OnTimeout() { | |
| 385 LOG(ERROR) << "Timed out waiting for GCM Protocol buffer."; | |
| 386 CloseConnection(); | |
| 387 connection_callback_.Run(net::ERR_TIMED_OUT); | |
| 388 } | |
| 389 | |
| 390 void ConnectionHandler::CloseConnection() { | |
| 391 DVLOG(1) << "Closing connection."; | |
| 392 read_callback_.Reset(); | |
| 393 write_callback_.Reset(); | |
| 394 read_timeout_timer_.Stop(); | |
| 395 socket_->Disconnect(); | |
| 396 input_stream_.reset(); | |
| 397 output_stream_.reset(); | |
| 398 weak_ptr_factory_.InvalidateWeakPtrs(); | |
| 399 } | |
| 400 | |
| 401 } // namespace gcm | 15 } // namespace gcm |
| OLD | NEW |