OLD | NEW |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "google_apis/gcm/engine/connection_handler.h" | 5 #include "google_apis/gcm/engine/connection_handler.h" |
6 | 6 |
7 #include "base/message_loop/message_loop.h" | |
8 #include "google/protobuf/io/coded_stream.h" | |
9 #include "google_apis/gcm/base/mcs_util.h" | |
10 #include "google_apis/gcm/base/socket_stream.h" | |
11 #include "net/base/net_errors.h" | |
12 #include "net/socket/stream_socket.h" | |
13 | |
14 using namespace google::protobuf::io; | |
15 | |
16 namespace gcm { | 7 namespace gcm { |
17 | 8 |
18 namespace { | 9 ConnectionHandler::ConnectionHandler() { |
19 | |
20 // # of bytes a MCS version packet consumes. | |
21 const int kVersionPacketLen = 1; | |
22 // # of bytes a tag packet consumes. | |
23 const int kTagPacketLen = 1; | |
24 // Max # of bytes a length packet consumes. | |
25 const int kSizePacketLenMin = 1; | |
26 const int kSizePacketLenMax = 2; | |
27 | |
28 // The current MCS protocol version. | |
29 const int kMCSVersion = 38; | |
30 | |
31 } // namespace | |
32 | |
33 ConnectionHandler::ConnectionHandler(base::TimeDelta read_timeout) | |
34 : read_timeout_(read_timeout), | |
35 handshake_complete_(false), | |
36 message_tag_(0), | |
37 message_size_(0), | |
38 weak_ptr_factory_(this) { | |
39 } | 10 } |
40 | 11 |
41 ConnectionHandler::~ConnectionHandler() { | 12 ConnectionHandler::~ConnectionHandler() { |
42 } | 13 } |
43 | 14 |
44 void ConnectionHandler::Init( | |
45 scoped_ptr<net::StreamSocket> socket, | |
46 const google::protobuf::MessageLite& login_request, | |
47 const ProtoReceivedCallback& read_callback, | |
48 const ProtoSentCallback& write_callback, | |
49 const ConnectionChangedCallback& connection_callback) { | |
50 DCHECK(!read_callback.is_null()); | |
51 DCHECK(!write_callback.is_null()); | |
52 DCHECK(!connection_callback.is_null()); | |
53 | |
54 // Invalidate any previously outstanding reads. | |
55 weak_ptr_factory_.InvalidateWeakPtrs(); | |
56 | |
57 handshake_complete_ = false; | |
58 message_tag_ = 0; | |
59 message_size_ = 0; | |
60 socket_ = socket.Pass(); | |
61 input_stream_.reset(new SocketInputStream(socket_.get())); | |
62 output_stream_.reset(new SocketOutputStream(socket_.get())); | |
63 read_callback_ = read_callback; | |
64 write_callback_ = write_callback; | |
65 connection_callback_ = connection_callback; | |
66 | |
67 Login(login_request); | |
68 } | |
69 | |
70 bool ConnectionHandler::CanSendMessage() const { | |
71 return handshake_complete_ && output_stream_.get() && | |
72 output_stream_->GetState() == SocketOutputStream::EMPTY; | |
73 } | |
74 | |
75 void ConnectionHandler::SendMessage( | |
76 const google::protobuf::MessageLite& message) { | |
77 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); | |
78 DCHECK(handshake_complete_); | |
79 | |
80 { | |
81 CodedOutputStream coded_output_stream(output_stream_.get()); | |
82 DVLOG(1) << "Writing proto of size " << message.ByteSize(); | |
83 int tag = GetMCSProtoTag(message); | |
84 DCHECK_NE(tag, -1); | |
85 coded_output_stream.WriteRaw(&tag, 1); | |
86 coded_output_stream.WriteVarint32(message.ByteSize()); | |
87 message.SerializeToCodedStream(&coded_output_stream); | |
88 } | |
89 | |
90 if (output_stream_->Flush( | |
91 base::Bind(&ConnectionHandler::OnMessageSent, | |
92 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { | |
93 OnMessageSent(); | |
94 } | |
95 } | |
96 | |
97 void ConnectionHandler::Login( | |
98 const google::protobuf::MessageLite& login_request) { | |
99 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); | |
100 | |
101 const char version_byte[1] = {kMCSVersion}; | |
102 const char login_request_tag[1] = {kLoginRequestTag}; | |
103 { | |
104 CodedOutputStream coded_output_stream(output_stream_.get()); | |
105 coded_output_stream.WriteRaw(version_byte, 1); | |
106 coded_output_stream.WriteRaw(login_request_tag, 1); | |
107 coded_output_stream.WriteVarint32(login_request.ByteSize()); | |
108 login_request.SerializeToCodedStream(&coded_output_stream); | |
109 } | |
110 | |
111 if (output_stream_->Flush( | |
112 base::Bind(&ConnectionHandler::OnMessageSent, | |
113 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { | |
114 base::MessageLoop::current()->PostTask( | |
115 FROM_HERE, | |
116 base::Bind(&ConnectionHandler::OnMessageSent, | |
117 weak_ptr_factory_.GetWeakPtr())); | |
118 } | |
119 | |
120 read_timeout_timer_.Start(FROM_HERE, | |
121 read_timeout_, | |
122 base::Bind(&ConnectionHandler::OnTimeout, | |
123 weak_ptr_factory_.GetWeakPtr())); | |
124 WaitForData(MCS_VERSION_TAG_AND_SIZE); | |
125 } | |
126 | |
127 void ConnectionHandler::OnMessageSent() { | |
128 if (!output_stream_.get()) { | |
129 // The connection has already been closed. Just return. | |
130 DCHECK(!input_stream_.get()); | |
131 DCHECK(!read_timeout_timer_.IsRunning()); | |
132 return; | |
133 } | |
134 | |
135 if (output_stream_->GetState() != SocketOutputStream::EMPTY) { | |
136 int last_error = output_stream_->last_error(); | |
137 CloseConnection(); | |
138 // If the socket stream had an error, plumb it up, else plumb up FAILED. | |
139 if (last_error == net::OK) | |
140 last_error = net::ERR_FAILED; | |
141 connection_callback_.Run(last_error); | |
142 return; | |
143 } | |
144 | |
145 write_callback_.Run(); | |
146 } | |
147 | |
148 void ConnectionHandler::GetNextMessage() { | |
149 DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() || | |
150 SocketInputStream::READY == input_stream_->GetState()); | |
151 message_tag_ = 0; | |
152 message_size_ = 0; | |
153 | |
154 WaitForData(MCS_TAG_AND_SIZE); | |
155 } | |
156 | |
157 void ConnectionHandler::WaitForData(ProcessingState state) { | |
158 DVLOG(1) << "Waiting for MCS data: state == " << state; | |
159 | |
160 if (!input_stream_) { | |
161 // The connection has already been closed. Just return. | |
162 DCHECK(!output_stream_.get()); | |
163 DCHECK(!read_timeout_timer_.IsRunning()); | |
164 return; | |
165 } | |
166 | |
167 if (input_stream_->GetState() != SocketInputStream::EMPTY && | |
168 input_stream_->GetState() != SocketInputStream::READY) { | |
169 // An error occurred. | |
170 int last_error = output_stream_->last_error(); | |
171 CloseConnection(); | |
172 // If the socket stream had an error, plumb it up, else plumb up FAILED. | |
173 if (last_error == net::OK) | |
174 last_error = net::ERR_FAILED; | |
175 connection_callback_.Run(last_error); | |
176 return; | |
177 } | |
178 | |
179 // Used to determine whether a Socket::Read is necessary. | |
180 int min_bytes_needed = 0; | |
181 // Used to limit the size of the Socket::Read. | |
182 int max_bytes_needed = 0; | |
183 | |
184 switch(state) { | |
185 case MCS_VERSION_TAG_AND_SIZE: | |
186 min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin; | |
187 max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax; | |
188 break; | |
189 case MCS_TAG_AND_SIZE: | |
190 min_bytes_needed = kTagPacketLen + kSizePacketLenMin; | |
191 max_bytes_needed = kTagPacketLen + kSizePacketLenMax; | |
192 break; | |
193 case MCS_FULL_SIZE: | |
194 // If in this state, the minimum size packet length must already have been | |
195 // insufficient, so set both to the max length. | |
196 min_bytes_needed = kSizePacketLenMax; | |
197 max_bytes_needed = kSizePacketLenMax; | |
198 break; | |
199 case MCS_PROTO_BYTES: | |
200 read_timeout_timer_.Reset(); | |
201 // No variability in the message size, set both to the same. | |
202 min_bytes_needed = message_size_; | |
203 max_bytes_needed = message_size_; | |
204 break; | |
205 default: | |
206 NOTREACHED(); | |
207 } | |
208 DCHECK_GE(max_bytes_needed, min_bytes_needed); | |
209 | |
210 int byte_count = input_stream_->UnreadByteCount(); | |
211 if (min_bytes_needed - byte_count > 0 && | |
212 input_stream_->Refresh( | |
213 base::Bind(&ConnectionHandler::WaitForData, | |
214 weak_ptr_factory_.GetWeakPtr(), | |
215 state), | |
216 max_bytes_needed - byte_count) == net::ERR_IO_PENDING) { | |
217 return; | |
218 } | |
219 | |
220 // Check for refresh errors. | |
221 if (input_stream_->GetState() != SocketInputStream::READY) { | |
222 // An error occurred. | |
223 int last_error = output_stream_->last_error(); | |
224 CloseConnection(); | |
225 // If the socket stream had an error, plumb it up, else plumb up FAILED. | |
226 if (last_error == net::OK) | |
227 last_error = net::ERR_FAILED; | |
228 connection_callback_.Run(last_error); | |
229 return; | |
230 } | |
231 | |
232 // Received enough bytes, process them. | |
233 DVLOG(1) << "Processing MCS data: state == " << state; | |
234 switch(state) { | |
235 case MCS_VERSION_TAG_AND_SIZE: | |
236 OnGotVersion(); | |
237 break; | |
238 case MCS_TAG_AND_SIZE: | |
239 OnGotMessageTag(); | |
240 break; | |
241 case MCS_FULL_SIZE: | |
242 OnGotMessageSize(); | |
243 break; | |
244 case MCS_PROTO_BYTES: | |
245 OnGotMessageBytes(); | |
246 break; | |
247 default: | |
248 NOTREACHED(); | |
249 } | |
250 } | |
251 | |
252 void ConnectionHandler::OnGotVersion() { | |
253 uint8 version = 0; | |
254 { | |
255 CodedInputStream coded_input_stream(input_stream_.get()); | |
256 coded_input_stream.ReadRaw(&version, 1); | |
257 } | |
258 if (version < kMCSVersion) { | |
259 LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version); | |
260 connection_callback_.Run(net::ERR_FAILED); | |
261 return; | |
262 } | |
263 | |
264 input_stream_->RebuildBuffer(); | |
265 | |
266 // Process the LoginResponse message tag. | |
267 OnGotMessageTag(); | |
268 } | |
269 | |
270 void ConnectionHandler::OnGotMessageTag() { | |
271 if (input_stream_->GetState() != SocketInputStream::READY) { | |
272 LOG(ERROR) << "Failed to receive protobuf tag."; | |
273 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); | |
274 return; | |
275 } | |
276 | |
277 { | |
278 CodedInputStream coded_input_stream(input_stream_.get()); | |
279 coded_input_stream.ReadRaw(&message_tag_, 1); | |
280 } | |
281 | |
282 DVLOG(1) << "Received proto of type " | |
283 << static_cast<unsigned int>(message_tag_); | |
284 | |
285 if (!read_timeout_timer_.IsRunning()) { | |
286 read_timeout_timer_.Start(FROM_HERE, | |
287 read_timeout_, | |
288 base::Bind(&ConnectionHandler::OnTimeout, | |
289 weak_ptr_factory_.GetWeakPtr())); | |
290 } | |
291 OnGotMessageSize(); | |
292 } | |
293 | |
294 void ConnectionHandler::OnGotMessageSize() { | |
295 if (input_stream_->GetState() != SocketInputStream::READY) { | |
296 LOG(ERROR) << "Failed to receive message size."; | |
297 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); | |
298 return; | |
299 } | |
300 | |
301 bool need_another_byte = false; | |
302 int prev_byte_count = input_stream_->ByteCount(); | |
303 { | |
304 CodedInputStream coded_input_stream(input_stream_.get()); | |
305 if (!coded_input_stream.ReadVarint32(&message_size_)) | |
306 need_another_byte = true; | |
307 } | |
308 | |
309 if (need_another_byte) { | |
310 DVLOG(1) << "Expecting another message size byte."; | |
311 if (prev_byte_count >= kSizePacketLenMax) { | |
312 // Already had enough bytes, something else went wrong. | |
313 LOG(ERROR) << "Failed to process message size."; | |
314 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); | |
315 return; | |
316 } | |
317 // Back up by the amount read (should always be 1 byte). | |
318 int bytes_read = prev_byte_count - input_stream_->ByteCount(); | |
319 DCHECK_EQ(bytes_read, 1); | |
320 input_stream_->BackUp(bytes_read); | |
321 WaitForData(MCS_FULL_SIZE); | |
322 return; | |
323 } | |
324 | |
325 DVLOG(1) << "Proto size: " << message_size_; | |
326 | |
327 if (message_size_ > 0) | |
328 WaitForData(MCS_PROTO_BYTES); | |
329 else | |
330 OnGotMessageBytes(); | |
331 } | |
332 | |
333 void ConnectionHandler::OnGotMessageBytes() { | |
334 read_timeout_timer_.Stop(); | |
335 scoped_ptr<google::protobuf::MessageLite> protobuf( | |
336 BuildProtobufFromTag(message_tag_)); | |
337 // Messages with no content are valid; just use the default protobuf for | |
338 // that tag. | |
339 if (protobuf.get() && message_size_ == 0) { | |
340 base::MessageLoop::current()->PostTask( | |
341 FROM_HERE, | |
342 base::Bind(&ConnectionHandler::GetNextMessage, | |
343 weak_ptr_factory_.GetWeakPtr())); | |
344 read_callback_.Run(protobuf.Pass()); | |
345 return; | |
346 } | |
347 | |
348 if (!protobuf.get() || | |
349 input_stream_->GetState() != SocketInputStream::READY) { | |
350 LOG(ERROR) << "Failed to extract protobuf bytes of type " | |
351 << static_cast<unsigned int>(message_tag_); | |
352 protobuf.reset(); // Return a null pointer to denote an error. | |
353 read_callback_.Run(protobuf.Pass()); | |
354 return; | |
355 } | |
356 | |
357 { | |
358 CodedInputStream coded_input_stream(input_stream_.get()); | |
359 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { | |
360 NOTREACHED() << "Unable to parse GCM message of type " | |
361 << static_cast<unsigned int>(message_tag_); | |
362 protobuf.reset(); // Return a null pointer to denote an error. | |
363 read_callback_.Run(protobuf.Pass()); | |
364 return; | |
365 } | |
366 } | |
367 | |
368 input_stream_->RebuildBuffer(); | |
369 base::MessageLoop::current()->PostTask( | |
370 FROM_HERE, | |
371 base::Bind(&ConnectionHandler::GetNextMessage, | |
372 weak_ptr_factory_.GetWeakPtr())); | |
373 if (message_tag_ == kLoginResponseTag) { | |
374 if (handshake_complete_) { | |
375 LOG(ERROR) << "Unexpected login response."; | |
376 } else { | |
377 handshake_complete_ = true; | |
378 DVLOG(1) << "GCM Handshake complete."; | |
379 } | |
380 } | |
381 read_callback_.Run(protobuf.Pass()); | |
382 } | |
383 | |
384 void ConnectionHandler::OnTimeout() { | |
385 LOG(ERROR) << "Timed out waiting for GCM Protocol buffer."; | |
386 CloseConnection(); | |
387 connection_callback_.Run(net::ERR_TIMED_OUT); | |
388 } | |
389 | |
390 void ConnectionHandler::CloseConnection() { | |
391 DVLOG(1) << "Closing connection."; | |
392 read_callback_.Reset(); | |
393 write_callback_.Reset(); | |
394 read_timeout_timer_.Stop(); | |
395 socket_->Disconnect(); | |
396 input_stream_.reset(); | |
397 output_stream_.reset(); | |
398 weak_ptr_factory_.InvalidateWeakPtrs(); | |
399 } | |
400 | |
401 } // namespace gcm | 15 } // namespace gcm |
OLD | NEW |