OLD | NEW |
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/websockets/websocket_throttle.h" | 5 #include "net/websockets/websocket_throttle.h" |
6 | 6 |
7 #include <string> | 7 #include <string> |
8 | 8 |
9 #include "base/message_loop.h" | 9 #include "base/message_loop.h" |
10 #include "base/ref_counted.h" | 10 #include "base/ref_counted.h" |
11 #include "base/singleton.h" | 11 #include "base/singleton.h" |
12 #include "base/string_util.h" | 12 #include "base/string_util.h" |
13 #include "net/base/io_buffer.h" | 13 #include "net/base/io_buffer.h" |
14 #include "net/base/sys_addrinfo.h" | 14 #include "net/base/sys_addrinfo.h" |
15 #include "net/socket_stream/socket_stream.h" | 15 #include "net/socket_stream/socket_stream.h" |
| 16 #include "net/websockets/websocket_job.h" |
16 | 17 |
17 namespace net { | 18 namespace net { |
18 | 19 |
19 static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) { | 20 static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) { |
20 switch (addrinfo->ai_family) { | 21 switch (addrinfo->ai_family) { |
21 case AF_INET: { | 22 case AF_INET: { |
22 const struct sockaddr_in* const addr = | 23 const struct sockaddr_in* const addr = |
23 reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr); | 24 reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr); |
24 return StringPrintf("%d:%s", | 25 return StringPrintf("%d:%s", |
25 addrinfo->ai_family, | 26 addrinfo->ai_family, |
26 HexEncode(&addr->sin_addr, 4).c_str()); | 27 HexEncode(&addr->sin_addr, 4).c_str()); |
27 } | 28 } |
28 case AF_INET6: { | 29 case AF_INET6: { |
29 const struct sockaddr_in6* const addr6 = | 30 const struct sockaddr_in6* const addr6 = |
30 reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr); | 31 reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr); |
31 return StringPrintf("%d:%s", | 32 return StringPrintf("%d:%s", |
32 addrinfo->ai_family, | 33 addrinfo->ai_family, |
33 HexEncode(&addr6->sin6_addr, | 34 HexEncode(&addr6->sin6_addr, |
34 sizeof(addr6->sin6_addr)).c_str()); | 35 sizeof(addr6->sin6_addr)).c_str()); |
35 } | 36 } |
36 default: | 37 default: |
37 return StringPrintf("%d:%s", | 38 return StringPrintf("%d:%s", |
38 addrinfo->ai_family, | 39 addrinfo->ai_family, |
39 HexEncode(addrinfo->ai_addr, | 40 HexEncode(addrinfo->ai_addr, |
40 addrinfo->ai_addrlen).c_str()); | 41 addrinfo->ai_addrlen).c_str()); |
41 } | 42 } |
42 } | 43 } |
43 | 44 |
44 // State for WebSocket protocol on each SocketStream. | |
45 // This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName. | |
46 // This is alive between connection starts and handshake is finished. | |
47 // In this class, it doesn't check actual handshake finishes, but only checks | |
48 // end of header is found in read data. | |
49 class WebSocketThrottle::WebSocketState : public SocketStream::UserData { | |
50 public: | |
51 explicit WebSocketState(const AddressList& addrs) | |
52 : address_list_(addrs), | |
53 callback_(NULL), | |
54 waiting_(false), | |
55 handshake_finished_(false), | |
56 buffer_(NULL) { | |
57 } | |
58 ~WebSocketState() {} | |
59 | |
60 int OnStartOpenConnection(CompletionCallback* callback) { | |
61 DCHECK(!callback_); | |
62 if (!waiting_) | |
63 return OK; | |
64 callback_ = callback; | |
65 return ERR_IO_PENDING; | |
66 } | |
67 | |
68 int OnRead(const char* data, int len, CompletionCallback* callback) { | |
69 DCHECK(!waiting_); | |
70 DCHECK(!callback_); | |
71 DCHECK(!handshake_finished_); | |
72 static const int kBufferSize = 8129; | |
73 | |
74 if (!buffer_) { | |
75 // Fast path. | |
76 int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0); | |
77 if (eoh > 0) { | |
78 handshake_finished_ = true; | |
79 return OK; | |
80 } | |
81 buffer_ = new GrowableIOBuffer(); | |
82 buffer_->SetCapacity(kBufferSize); | |
83 } else if (buffer_->RemainingCapacity() < len) { | |
84 buffer_->SetCapacity(buffer_->capacity() + kBufferSize); | |
85 } | |
86 memcpy(buffer_->data(), data, len); | |
87 buffer_->set_offset(buffer_->offset() + len); | |
88 | |
89 int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(), | |
90 buffer_->offset(), 0); | |
91 handshake_finished_ = (eoh > 0); | |
92 return OK; | |
93 } | |
94 | |
95 const AddressList& address_list() const { return address_list_; } | |
96 void SetWaiting() { waiting_ = true; } | |
97 bool IsWaiting() const { return waiting_; } | |
98 bool HandshakeFinished() const { return handshake_finished_; } | |
99 void Wakeup() { | |
100 waiting_ = false; | |
101 // We wrap |callback_| to keep this alive while this is released. | |
102 scoped_refptr<CompletionCallbackRunner> runner = | |
103 new CompletionCallbackRunner(callback_); | |
104 callback_ = NULL; | |
105 MessageLoopForIO::current()->PostTask( | |
106 FROM_HERE, | |
107 NewRunnableMethod(runner.get(), | |
108 &CompletionCallbackRunner::Run)); | |
109 } | |
110 | |
111 static const char* kKeyName; | |
112 | |
113 private: | |
114 class CompletionCallbackRunner | |
115 : public base::RefCountedThreadSafe<CompletionCallbackRunner> { | |
116 public: | |
117 explicit CompletionCallbackRunner(CompletionCallback* callback) | |
118 : callback_(callback) { | |
119 DCHECK(callback_); | |
120 } | |
121 void Run() { | |
122 callback_->Run(OK); | |
123 } | |
124 private: | |
125 friend class base::RefCountedThreadSafe<CompletionCallbackRunner>; | |
126 | |
127 virtual ~CompletionCallbackRunner() {} | |
128 | |
129 CompletionCallback* callback_; | |
130 | |
131 DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner); | |
132 }; | |
133 | |
134 const AddressList& address_list_; | |
135 | |
136 CompletionCallback* callback_; | |
137 // True if waiting another websocket connection is established. | |
138 // False if the websocket is performing handshaking. | |
139 bool waiting_; | |
140 | |
141 // True if the websocket handshake is completed. | |
142 // If true, it will be removed from queue and deleted from the SocketStream | |
143 // UserData soon. | |
144 bool handshake_finished_; | |
145 | |
146 // Buffer for read data to check handshake response message. | |
147 scoped_refptr<GrowableIOBuffer> buffer_; | |
148 | |
149 DISALLOW_COPY_AND_ASSIGN(WebSocketState); | |
150 }; | |
151 | |
152 const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState"; | |
153 | |
154 WebSocketThrottle::WebSocketThrottle() { | 45 WebSocketThrottle::WebSocketThrottle() { |
155 SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this); | |
156 SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this); | |
157 } | 46 } |
158 | 47 |
159 WebSocketThrottle::~WebSocketThrottle() { | 48 WebSocketThrottle::~WebSocketThrottle() { |
160 DCHECK(queue_.empty()); | 49 DCHECK(queue_.empty()); |
161 DCHECK(addr_map_.empty()); | 50 DCHECK(addr_map_.empty()); |
162 } | 51 } |
163 | 52 |
164 int WebSocketThrottle::OnStartOpenConnection( | 53 void WebSocketThrottle::PutInQueue(WebSocketJob* job) { |
165 SocketStream* socket, CompletionCallback* callback) { | 54 queue_.push_back(job); |
166 WebSocketState* state = new WebSocketState(socket->address_list()); | 55 const AddressList& address_list = job->address_list(); |
167 PutInQueue(socket, state); | |
168 return state->OnStartOpenConnection(callback); | |
169 } | |
170 | |
171 int WebSocketThrottle::OnRead(SocketStream* socket, | |
172 const char* data, int len, | |
173 CompletionCallback* callback) { | |
174 WebSocketState* state = static_cast<WebSocketState*>( | |
175 socket->GetUserData(WebSocketState::kKeyName)); | |
176 // If no state, handshake was already completed. Do nothing. | |
177 if (!state) | |
178 return OK; | |
179 | |
180 int result = state->OnRead(data, len, callback); | |
181 if (state->HandshakeFinished()) { | |
182 RemoveFromQueue(socket, state); | |
183 WakeupSocketIfNecessary(); | |
184 } | |
185 return result; | |
186 } | |
187 | |
188 int WebSocketThrottle::OnWrite(SocketStream* socket, | |
189 const char* data, int len, | |
190 CompletionCallback* callback) { | |
191 // Do nothing. | |
192 return OK; | |
193 } | |
194 | |
195 void WebSocketThrottle::OnClose(SocketStream* socket) { | |
196 WebSocketState* state = static_cast<WebSocketState*>( | |
197 socket->GetUserData(WebSocketState::kKeyName)); | |
198 if (!state) | |
199 return; | |
200 RemoveFromQueue(socket, state); | |
201 WakeupSocketIfNecessary(); | |
202 } | |
203 | |
204 void WebSocketThrottle::PutInQueue(SocketStream* socket, | |
205 WebSocketState* state) { | |
206 socket->SetUserData(WebSocketState::kKeyName, state); | |
207 queue_.push_back(state); | |
208 const AddressList& address_list = socket->address_list(); | |
209 for (const struct addrinfo* addrinfo = address_list.head(); | 56 for (const struct addrinfo* addrinfo = address_list.head(); |
210 addrinfo != NULL; | 57 addrinfo != NULL; |
211 addrinfo = addrinfo->ai_next) { | 58 addrinfo = addrinfo->ai_next) { |
212 std::string addrkey = AddrinfoToHashkey(addrinfo); | 59 std::string addrkey = AddrinfoToHashkey(addrinfo); |
213 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); | 60 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); |
214 if (iter == addr_map_.end()) { | 61 if (iter == addr_map_.end()) { |
215 ConnectingQueue* queue = new ConnectingQueue(); | 62 ConnectingQueue* queue = new ConnectingQueue(); |
216 queue->push_back(state); | 63 queue->push_back(job); |
217 addr_map_[addrkey] = queue; | 64 addr_map_[addrkey] = queue; |
218 } else { | 65 } else { |
219 iter->second->push_back(state); | 66 iter->second->push_back(job); |
220 state->SetWaiting(); | 67 job->SetWaiting(); |
221 } | 68 } |
222 } | 69 } |
223 } | 70 } |
224 | 71 |
225 void WebSocketThrottle::RemoveFromQueue(SocketStream* socket, | 72 void WebSocketThrottle::RemoveFromQueue(WebSocketJob* job) { |
226 WebSocketState* state) { | 73 bool in_queue = false; |
227 const AddressList& address_list = socket->address_list(); | 74 for (ConnectingQueue::iterator iter = queue_.begin(); |
| 75 iter != queue_.end(); |
| 76 ++iter) { |
| 77 if (*iter == job) { |
| 78 queue_.erase(iter); |
| 79 in_queue = true; |
| 80 break; |
| 81 } |
| 82 } |
| 83 if (!in_queue) |
| 84 return; |
| 85 const AddressList& address_list = job->address_list(); |
228 for (const struct addrinfo* addrinfo = address_list.head(); | 86 for (const struct addrinfo* addrinfo = address_list.head(); |
229 addrinfo != NULL; | 87 addrinfo != NULL; |
230 addrinfo = addrinfo->ai_next) { | 88 addrinfo = addrinfo->ai_next) { |
231 std::string addrkey = AddrinfoToHashkey(addrinfo); | 89 std::string addrkey = AddrinfoToHashkey(addrinfo); |
232 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); | 90 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); |
233 DCHECK(iter != addr_map_.end()); | 91 DCHECK(iter != addr_map_.end()); |
234 ConnectingQueue* queue = iter->second; | 92 ConnectingQueue* queue = iter->second; |
235 DCHECK(state == queue->front()); | 93 // Job may not be front of queue when job is closed early while waiting. |
236 queue->pop_front(); | 94 for (ConnectingQueue::iterator iter = queue->begin(); |
| 95 iter != queue->end(); |
| 96 ++iter) { |
| 97 if (*iter == job) { |
| 98 queue->erase(iter); |
| 99 break; |
| 100 } |
| 101 } |
237 if (queue->empty()) { | 102 if (queue->empty()) { |
238 delete queue; | 103 delete queue; |
239 addr_map_.erase(iter); | 104 addr_map_.erase(iter); |
240 } | 105 } |
241 } | 106 } |
242 for (ConnectingQueue::iterator iter = queue_.begin(); | |
243 iter != queue_.end(); | |
244 ++iter) { | |
245 if (*iter == state) { | |
246 queue_.erase(iter); | |
247 break; | |
248 } | |
249 } | |
250 socket->SetUserData(WebSocketState::kKeyName, NULL); | |
251 } | 107 } |
252 | 108 |
253 void WebSocketThrottle::WakeupSocketIfNecessary() { | 109 void WebSocketThrottle::WakeupSocketIfNecessary() { |
254 for (ConnectingQueue::iterator iter = queue_.begin(); | 110 for (ConnectingQueue::iterator iter = queue_.begin(); |
255 iter != queue_.end(); | 111 iter != queue_.end(); |
256 ++iter) { | 112 ++iter) { |
257 WebSocketState* state = *iter; | 113 WebSocketJob* job = *iter; |
258 if (!state->IsWaiting()) | 114 if (!job->IsWaiting()) |
259 continue; | 115 continue; |
260 | 116 |
261 bool should_wakeup = true; | 117 bool should_wakeup = true; |
262 const AddressList& address_list = state->address_list(); | 118 const AddressList& address_list = job->address_list(); |
263 for (const struct addrinfo* addrinfo = address_list.head(); | 119 for (const struct addrinfo* addrinfo = address_list.head(); |
264 addrinfo != NULL; | 120 addrinfo != NULL; |
265 addrinfo = addrinfo->ai_next) { | 121 addrinfo = addrinfo->ai_next) { |
266 std::string addrkey = AddrinfoToHashkey(addrinfo); | 122 std::string addrkey = AddrinfoToHashkey(addrinfo); |
267 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); | 123 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); |
268 DCHECK(iter != addr_map_.end()); | 124 DCHECK(iter != addr_map_.end()); |
269 ConnectingQueue* queue = iter->second; | 125 ConnectingQueue* queue = iter->second; |
270 if (state != queue->front()) { | 126 if (job != queue->front()) { |
271 should_wakeup = false; | 127 should_wakeup = false; |
272 break; | 128 break; |
273 } | 129 } |
274 } | 130 } |
275 if (should_wakeup) | 131 if (should_wakeup) |
276 state->Wakeup(); | 132 job->Wakeup(); |
277 } | 133 } |
278 } | 134 } |
279 | 135 |
280 /* static */ | |
281 void WebSocketThrottle::Init() { | |
282 Singleton<WebSocketThrottle>::get(); | |
283 } | |
284 | |
285 } // namespace net | 136 } // namespace net |
OLD | NEW |