OLD | NEW |
(Empty) | |
| 1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. |
| 4 |
| 5 #include "net/websockets/websocket_throttle.h" |
| 6 |
| 7 #if defined(OS_WIN) |
| 8 #include <ws2tcpip.h> |
| 9 #else |
| 10 #include <netdb.h> |
| 11 #endif |
| 12 |
| 13 #include <string> |
| 14 |
| 15 #include "base/message_loop.h" |
| 16 #include "base/ref_counted.h" |
| 17 #include "base/singleton.h" |
| 18 #include "base/string_util.h" |
| 19 #include "net/base/io_buffer.h" |
| 20 #include "net/socket_stream/socket_stream.h" |
| 21 |
| 22 namespace net { |
| 23 |
| 24 static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) { |
| 25 switch (addrinfo->ai_family) { |
| 26 case AF_INET: { |
| 27 const struct sockaddr_in* const addr = |
| 28 reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr); |
| 29 return StringPrintf("%d:%s", |
| 30 addrinfo->ai_family, |
| 31 HexEncode(&addr->sin_addr, 4).c_str()); |
| 32 } |
| 33 case AF_INET6: { |
| 34 const struct sockaddr_in6* const addr6 = |
| 35 reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr); |
| 36 return StringPrintf("%d:%s", |
| 37 addrinfo->ai_family, |
| 38 HexEncode(&addr6->sin6_addr, |
| 39 sizeof(addr6->sin6_addr)).c_str()); |
| 40 } |
| 41 default: |
| 42 return StringPrintf("%d:%s", |
| 43 addrinfo->ai_family, |
| 44 HexEncode(addrinfo->ai_addr, |
| 45 addrinfo->ai_addrlen).c_str()); |
| 46 } |
| 47 } |
| 48 |
| 49 // State for WebSocket protocol on each SocketStream. |
| 50 // This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName. |
| 51 // This is alive between connection starts and handshake is finished. |
| 52 // In this class, it doesn't check actual handshake finishes, but only checks |
| 53 // end of header is found in read data. |
| 54 class WebSocketThrottle::WebSocketState : public SocketStream::UserData { |
| 55 public: |
| 56 explicit WebSocketState(const AddressList& addrs) |
| 57 : address_list_(addrs), |
| 58 callback_(NULL), |
| 59 waiting_(false), |
| 60 handshake_finished_(false), |
| 61 buffer_(NULL) { |
| 62 } |
| 63 ~WebSocketState() {} |
| 64 |
| 65 int OnStartOpenConnection(CompletionCallback* callback) { |
| 66 DCHECK(!callback_); |
| 67 if (!waiting_) |
| 68 return OK; |
| 69 callback_ = callback; |
| 70 return ERR_IO_PENDING; |
| 71 } |
| 72 |
| 73 int OnRead(const char* data, int len, CompletionCallback* callback) { |
| 74 DCHECK(!waiting_); |
| 75 DCHECK(!callback_); |
| 76 DCHECK(!handshake_finished_); |
| 77 static const int kBufferSize = 8129; |
| 78 |
| 79 if (!buffer_) { |
| 80 // Fast path. |
| 81 int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0); |
| 82 if (eoh > 0) { |
| 83 handshake_finished_ = true; |
| 84 return OK; |
| 85 } |
| 86 buffer_ = new GrowableIOBuffer(); |
| 87 buffer_->SetCapacity(kBufferSize); |
| 88 } else { |
| 89 if (buffer_->RemainingCapacity() < len) { |
| 90 if (!buffer_->SetCapacity(buffer_->capacity() + kBufferSize)) { |
| 91 // TODO(ukai): Check more correctly. |
| 92 // Seek to the last CR or LF and reduce memory usage. |
| 93 LOG(ERROR) << "Too large headers? capacity=" << buffer_->capacity(); |
| 94 handshake_finished_ = true; |
| 95 return OK; |
| 96 } |
| 97 } |
| 98 } |
| 99 memcpy(buffer_->data(), data, len); |
| 100 buffer_->set_offset(buffer_->offset() + len); |
| 101 |
| 102 int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(), |
| 103 buffer_->offset(), 0); |
| 104 handshake_finished_ = (eoh > 0); |
| 105 return OK; |
| 106 } |
| 107 |
| 108 const AddressList& address_list() const { return address_list_; } |
| 109 void SetWaiting() { waiting_ = true; } |
| 110 bool IsWaiting() const { return waiting_; } |
| 111 bool HandshakeFinished() const { return handshake_finished_; } |
| 112 void Wakeup() { |
| 113 waiting_ = false; |
| 114 // We wrap |callback_| to keep this alive while this is released. |
| 115 scoped_refptr<CompletionCallbackRunner> runner = |
| 116 new CompletionCallbackRunner(callback_); |
| 117 callback_ = NULL; |
| 118 MessageLoopForIO::current()->PostTask( |
| 119 FROM_HERE, |
| 120 NewRunnableMethod(runner.get(), |
| 121 &CompletionCallbackRunner::Run)); |
| 122 } |
| 123 |
| 124 static const char* kKeyName; |
| 125 |
| 126 private: |
| 127 class CompletionCallbackRunner |
| 128 : public base::RefCountedThreadSafe<CompletionCallbackRunner> { |
| 129 public: |
| 130 explicit CompletionCallbackRunner(CompletionCallback* callback) |
| 131 : callback_(callback) { |
| 132 DCHECK(callback_); |
| 133 } |
| 134 virtual ~CompletionCallbackRunner() {} |
| 135 void Run() { |
| 136 callback_->Run(OK); |
| 137 } |
| 138 private: |
| 139 CompletionCallback* callback_; |
| 140 |
| 141 DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner); |
| 142 }; |
| 143 |
| 144 const AddressList& address_list_; |
| 145 |
| 146 CompletionCallback* callback_; |
| 147 // True if waiting another websocket connection is established. |
| 148 // False if the websocket is performing handshaking. |
| 149 bool waiting_; |
| 150 |
| 151 // True if the websocket handshake is completed. |
| 152 // If true, it will be removed from queue and deleted from the SocketStream |
| 153 // UserData soon. |
| 154 bool handshake_finished_; |
| 155 |
| 156 // Buffer for read data to check handshake response message. |
| 157 scoped_refptr<GrowableIOBuffer> buffer_; |
| 158 |
| 159 DISALLOW_COPY_AND_ASSIGN(WebSocketState); |
| 160 }; |
| 161 |
| 162 const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState"; |
| 163 |
| 164 WebSocketThrottle::WebSocketThrottle() { |
| 165 SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this); |
| 166 SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this); |
| 167 } |
| 168 |
| 169 WebSocketThrottle::~WebSocketThrottle() { |
| 170 DCHECK(queue_.empty()); |
| 171 DCHECK(addr_map_.empty()); |
| 172 } |
| 173 |
| 174 int WebSocketThrottle::OnStartOpenConnection( |
| 175 SocketStream* socket, CompletionCallback* callback) { |
| 176 WebSocketState* state = new WebSocketState(socket->address_list()); |
| 177 PutInQueue(socket, state); |
| 178 return state->OnStartOpenConnection(callback); |
| 179 } |
| 180 |
| 181 int WebSocketThrottle::OnRead(SocketStream* socket, |
| 182 const char* data, int len, |
| 183 CompletionCallback* callback) { |
| 184 WebSocketState* state = static_cast<WebSocketState*>( |
| 185 socket->GetUserData(WebSocketState::kKeyName)); |
| 186 // If no state, handshake was already completed. Do nothing. |
| 187 if (!state) |
| 188 return OK; |
| 189 |
| 190 int result = state->OnRead(data, len, callback); |
| 191 if (state->HandshakeFinished()) { |
| 192 RemoveFromQueue(socket, state); |
| 193 WakeupSocketIfNecessary(); |
| 194 } |
| 195 return result; |
| 196 } |
| 197 |
| 198 int WebSocketThrottle::OnWrite(SocketStream* socket, |
| 199 const char* data, int len, |
| 200 CompletionCallback* callback) { |
| 201 // Do nothing. |
| 202 return OK; |
| 203 } |
| 204 |
| 205 void WebSocketThrottle::OnClose(SocketStream* socket) { |
| 206 WebSocketState* state = static_cast<WebSocketState*>( |
| 207 socket->GetUserData(WebSocketState::kKeyName)); |
| 208 if (!state) |
| 209 return; |
| 210 RemoveFromQueue(socket, state); |
| 211 WakeupSocketIfNecessary(); |
| 212 } |
| 213 |
| 214 void WebSocketThrottle::PutInQueue(SocketStream* socket, |
| 215 WebSocketState* state) { |
| 216 socket->SetUserData(WebSocketState::kKeyName, state); |
| 217 queue_.push_back(state); |
| 218 const AddressList& address_list = socket->address_list(); |
| 219 for (const struct addrinfo* addrinfo = address_list.head(); |
| 220 addrinfo != NULL; |
| 221 addrinfo = addrinfo->ai_next) { |
| 222 std::string addrkey = AddrinfoToHashkey(addrinfo); |
| 223 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); |
| 224 if (iter == addr_map_.end()) { |
| 225 ConnectingQueue* queue = new ConnectingQueue(); |
| 226 queue->push_back(state); |
| 227 addr_map_[addrkey] = queue; |
| 228 } else { |
| 229 iter->second->push_back(state); |
| 230 state->SetWaiting(); |
| 231 } |
| 232 } |
| 233 } |
| 234 |
| 235 void WebSocketThrottle::RemoveFromQueue(SocketStream* socket, |
| 236 WebSocketState* state) { |
| 237 const AddressList& address_list = socket->address_list(); |
| 238 for (const struct addrinfo* addrinfo = address_list.head(); |
| 239 addrinfo != NULL; |
| 240 addrinfo = addrinfo->ai_next) { |
| 241 std::string addrkey = AddrinfoToHashkey(addrinfo); |
| 242 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); |
| 243 DCHECK(iter != addr_map_.end()); |
| 244 ConnectingQueue* queue = iter->second; |
| 245 DCHECK(state == queue->front()); |
| 246 queue->pop_front(); |
| 247 if (queue->empty()) |
| 248 addr_map_.erase(iter); |
| 249 } |
| 250 for (ConnectingQueue::iterator iter = queue_.begin(); |
| 251 iter != queue_.end(); |
| 252 ++iter) { |
| 253 if (*iter == state) { |
| 254 queue_.erase(iter); |
| 255 break; |
| 256 } |
| 257 } |
| 258 socket->SetUserData(WebSocketState::kKeyName, NULL); |
| 259 } |
| 260 |
| 261 void WebSocketThrottle::WakeupSocketIfNecessary() { |
| 262 for (ConnectingQueue::iterator iter = queue_.begin(); |
| 263 iter != queue_.end(); |
| 264 ++iter) { |
| 265 WebSocketState* state = *iter; |
| 266 if (!state->IsWaiting()) |
| 267 continue; |
| 268 |
| 269 bool should_wakeup = true; |
| 270 const AddressList& address_list = state->address_list(); |
| 271 for (const struct addrinfo* addrinfo = address_list.head(); |
| 272 addrinfo != NULL; |
| 273 addrinfo = addrinfo->ai_next) { |
| 274 std::string addrkey = AddrinfoToHashkey(addrinfo); |
| 275 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); |
| 276 DCHECK(iter != addr_map_.end()); |
| 277 ConnectingQueue* queue = iter->second; |
| 278 if (state != queue->front()) { |
| 279 should_wakeup = false; |
| 280 break; |
| 281 } |
| 282 } |
| 283 if (should_wakeup) |
| 284 state->Wakeup(); |
| 285 } |
| 286 } |
| 287 |
| 288 /* static */ |
| 289 void WebSocketThrottle::Init() { |
| 290 Singleton<WebSocketThrottle>::get(); |
| 291 } |
| 292 |
| 293 } // namespace net |
OLD | NEW |