OLD | NEW |
1 // Copyright 2014 The Chromium Authors. All rights reserved. | 1 // Copyright 2014 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "remoting/host/gnubby_auth_handler_posix.h" | 5 #include "remoting/host/gnubby_auth_handler_posix.h" |
6 | 6 |
7 #include <unistd.h> | 7 #include <unistd.h> |
8 #include <utility> | 8 #include <utility> |
9 #include <vector> | |
10 | 9 |
11 #include "base/bind.h" | 10 #include "base/bind.h" |
12 #include "base/file_util.h" | 11 #include "base/file_util.h" |
13 #include "base/json/json_reader.h" | 12 #include "base/json/json_reader.h" |
14 #include "base/json/json_writer.h" | 13 #include "base/json/json_writer.h" |
15 #include "base/lazy_instance.h" | 14 #include "base/lazy_instance.h" |
16 #include "base/stl_util.h" | 15 #include "base/stl_util.h" |
17 #include "base/values.h" | 16 #include "base/values.h" |
18 #include "net/socket/unix_domain_socket_posix.h" | 17 #include "net/socket/unix_domain_socket_posix.h" |
19 #include "remoting/base/logging.h" | 18 #include "remoting/base/logging.h" |
20 #include "remoting/host/gnubby_util.h" | 19 #include "remoting/host/gnubby_socket.h" |
21 #include "remoting/proto/control.pb.h" | 20 #include "remoting/proto/control.pb.h" |
22 #include "remoting/protocol/client_stub.h" | 21 #include "remoting/protocol/client_stub.h" |
23 | 22 |
24 namespace remoting { | 23 namespace remoting { |
25 | 24 |
26 namespace { | 25 namespace { |
27 | 26 |
28 const int kMaxRequestLength = 4096; | |
29 | |
30 const char kConnectionId[] = "connectionId"; | 27 const char kConnectionId[] = "connectionId"; |
31 const char kControlMessage[] = "control"; | 28 const char kControlMessage[] = "control"; |
32 const char kControlOption[] = "option"; | 29 const char kControlOption[] = "option"; |
33 const char kDataMessage[] = "data"; | 30 const char kDataMessage[] = "data"; |
| 31 const char kDataPayload[] = "data"; |
| 32 const char kErrorMessage[] = "error"; |
34 const char kGnubbyAuthMessage[] = "gnubby-auth"; | 33 const char kGnubbyAuthMessage[] = "gnubby-auth"; |
35 const char kGnubbyAuthV1[] = "auth-v1"; | 34 const char kGnubbyAuthV1[] = "auth-v1"; |
36 const char kJSONMessage[] = "jsonMessage"; | |
37 const char kMessageType[] = "type"; | 35 const char kMessageType[] = "type"; |
38 | 36 |
39 // The name of the socket to listen for gnubby requests on. | 37 // The name of the socket to listen for gnubby requests on. |
40 base::LazyInstance<base::FilePath>::Leaky g_gnubby_socket_name = | 38 base::LazyInstance<base::FilePath>::Leaky g_gnubby_socket_name = |
41 LAZY_INSTANCE_INITIALIZER; | 39 LAZY_INSTANCE_INITIALIZER; |
42 | 40 |
43 // STL predicate to match by a StreamListenSocket pointer. | 41 // STL predicate to match by a StreamListenSocket pointer. |
44 class CompareSocket { | 42 class CompareSocket { |
45 public: | 43 public: |
46 explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {} | 44 explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {} |
47 | 45 |
48 bool operator()(const std::pair<int, net::StreamListenSocket*> element) | 46 bool operator()(const std::pair<int, GnubbySocket*> element) const { |
49 const { | 47 return element.second->IsSocket(socket_); |
50 return socket_ == element.second; | |
51 } | 48 } |
52 | 49 |
53 private: | 50 private: |
54 net::StreamListenSocket* socket_; | 51 net::StreamListenSocket* socket_; |
55 }; | 52 }; |
56 | 53 |
57 // Socket authentication function that only allows connections from callers with | 54 // Socket authentication function that only allows connections from callers with |
58 // the current uid. | 55 // the current uid. |
59 bool MatchUid(uid_t user_id, gid_t) { | 56 bool MatchUid(uid_t user_id, gid_t) { |
60 bool allowed = user_id == getuid(); | 57 bool allowed = user_id == getuid(); |
61 if (!allowed) | 58 if (!allowed) |
62 HOST_LOG << "Refused socket connection from uid " << user_id; | 59 HOST_LOG << "Refused socket connection from uid " << user_id; |
63 return allowed; | 60 return allowed; |
64 } | 61 } |
65 | 62 |
66 // Returns the request data length from the first four data bytes. | 63 // Returns the command code (the first byte of the data) if it exists, or -1 if |
67 int GetRequestLength(const char* data) { | 64 // the data is empty. |
68 return ((data[0] & 255) << 24) + ((data[1] & 255) << 16) + | 65 unsigned int GetCommandCode(const std::string& data) { |
69 ((data[2] & 255) << 8) + (data[3] & 255) + 4; | 66 return data.empty() ? -1 : static_cast<unsigned int>(data[0]); |
70 } | 67 } |
71 | 68 |
72 // Returns true if the request data is complete (has at least as many bytes as | 69 // Creates a string of byte data from a ListValue of numbers. Returns true if |
73 // indicated by the size in the first four bytes plus four for the first bytes). | 70 // all of the list elements are numbers. |
74 bool IsRequestComplete(const char* data, int data_len) { | 71 bool ConvertListValueToString(base::ListValue* bytes, std::string* out) { |
75 if (data_len < 4) | 72 out->clear(); |
76 return false; | |
77 return GetRequestLength(data) <= data_len; | |
78 } | |
79 | 73 |
80 // Returns true if the request data size is bigger than the threshold. | 74 unsigned int byte_count = bytes->GetSize(); |
81 bool IsRequestTooLarge(const char* data, int data_len, int max_len) { | 75 if (byte_count != 0) { |
82 if (data_len < 4) | 76 out->reserve(byte_count); |
83 return false; | 77 for (unsigned int i = 0; i < byte_count; i++) { |
84 return GetRequestLength(data) > max_len; | 78 int value; |
| 79 if (!bytes->GetInteger(i, &value)) |
| 80 return false; |
| 81 out->push_back(static_cast<char>(value)); |
| 82 } |
| 83 } |
| 84 return true; |
85 } | 85 } |
86 | 86 |
87 } // namespace | 87 } // namespace |
88 | 88 |
89 GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix( | 89 GnubbyAuthHandlerPosix::GnubbyAuthHandlerPosix( |
90 protocol::ClientStub* client_stub) | 90 protocol::ClientStub* client_stub) |
91 : client_stub_(client_stub), last_connection_id_(0) { | 91 : client_stub_(client_stub), last_connection_id_(0) { |
92 DCHECK(client_stub_); | 92 DCHECK(client_stub_); |
93 } | 93 } |
94 | 94 |
(...skipping 27 matching lines...) Expand all Loading... |
122 | 122 |
123 if (type == kControlMessage) { | 123 if (type == kControlMessage) { |
124 std::string option; | 124 std::string option; |
125 if (client_message->GetString(kControlOption, &option) && | 125 if (client_message->GetString(kControlOption, &option) && |
126 option == kGnubbyAuthV1) { | 126 option == kGnubbyAuthV1) { |
127 CreateAuthorizationSocket(); | 127 CreateAuthorizationSocket(); |
128 } else { | 128 } else { |
129 LOG(ERROR) << "Invalid gnubby-auth control option"; | 129 LOG(ERROR) << "Invalid gnubby-auth control option"; |
130 } | 130 } |
131 } else if (type == kDataMessage) { | 131 } else if (type == kDataMessage) { |
132 int connection_id; | 132 ActiveSockets::iterator iter = GetSocketForMessage(client_message); |
133 std::string json_message; | 133 if (iter != active_sockets_.end()) { |
134 if (client_message->GetInteger(kConnectionId, &connection_id) && | 134 base::ListValue* bytes; |
135 client_message->GetString(kJSONMessage, &json_message)) { | 135 std::string response; |
136 ActiveSockets::iterator iter = active_sockets_.find(connection_id); | 136 if (client_message->GetList(kDataPayload, &bytes) && |
137 if (iter != active_sockets_.end()) { | 137 ConvertListValueToString(bytes, &response)) { |
138 HOST_LOG << "Sending gnubby response"; | 138 HOST_LOG << "Sending gnubby response: " << GetCommandCode(response); |
139 | 139 iter->second->SendResponse(response); |
140 std::string response; | |
141 GetGnubbyResponseFromJson(json_message, &response); | |
142 iter->second->Send(response); | |
143 } else { | 140 } else { |
144 LOG(ERROR) << "Received gnubby-auth data for unknown connection"; | 141 LOG(ERROR) << "Invalid gnubby data"; |
| 142 SendErrorAndCloseActiveSocket(iter); |
145 } | 143 } |
146 } else { | 144 } else { |
147 LOG(ERROR) << "Invalid gnubby-auth data message"; | 145 LOG(ERROR) << "Unknown gnubby-auth data connection"; |
| 146 } |
| 147 } else if (type == kErrorMessage) { |
| 148 ActiveSockets::iterator iter = GetSocketForMessage(client_message); |
| 149 if (iter != active_sockets_.end()) { |
| 150 HOST_LOG << "Sending gnubby error"; |
| 151 SendErrorAndCloseActiveSocket(iter); |
| 152 } else { |
| 153 LOG(ERROR) << "Unknown gnubby-auth error connection"; |
148 } | 154 } |
149 } else { | 155 } else { |
150 LOG(ERROR) << "Unknown gnubby-auth message type: " << type; | 156 LOG(ERROR) << "Unknown gnubby-auth message type: " << type; |
151 } | 157 } |
152 } | 158 } |
153 } | 159 } |
154 | 160 |
155 void GnubbyAuthHandlerPosix::DeliverHostDataMessage(int connection_id, | 161 void GnubbyAuthHandlerPosix::DeliverHostDataMessage( |
156 const std::string& data) | 162 int connection_id, |
157 const { | 163 const std::string& data) const { |
158 DCHECK(CalledOnValidThread()); | 164 DCHECK(CalledOnValidThread()); |
159 | 165 |
160 base::DictionaryValue request; | 166 base::DictionaryValue request; |
161 request.SetString(kMessageType, kDataMessage); | 167 request.SetString(kMessageType, kDataMessage); |
162 request.SetInteger(kConnectionId, connection_id); | 168 request.SetInteger(kConnectionId, connection_id); |
163 request.SetString(kJSONMessage, data); | 169 |
| 170 base::ListValue* bytes = new base::ListValue(); |
| 171 for (std::string::const_iterator i = data.begin(); i != data.end(); ++i) { |
| 172 bytes->AppendInteger(static_cast<unsigned char>(*i)); |
| 173 } |
| 174 request.Set(kDataPayload, bytes); |
164 | 175 |
165 std::string request_json; | 176 std::string request_json; |
166 if (!base::JSONWriter::Write(&request, &request_json)) { | 177 if (!base::JSONWriter::Write(&request, &request_json)) { |
167 LOG(ERROR) << "Failed to create request json"; | 178 LOG(ERROR) << "Failed to create request json"; |
168 return; | 179 return; |
169 } | 180 } |
170 | 181 |
171 protocol::ExtensionMessage message; | 182 protocol::ExtensionMessage message; |
172 message.set_type(kGnubbyAuthMessage); | 183 message.set_type(kGnubbyAuthMessage); |
173 message.set_data(request_json); | 184 message.set_data(request_json); |
174 | 185 |
175 client_stub_->DeliverHostMessage(message); | 186 client_stub_->DeliverHostMessage(message); |
176 } | 187 } |
177 | 188 |
178 bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting( | 189 bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting( |
179 net::StreamListenSocket* socket) const { | 190 net::StreamListenSocket* socket) const { |
180 return std::find_if(active_sockets_.begin(), | 191 return std::find_if(active_sockets_.begin(), |
181 active_sockets_.end(), | 192 active_sockets_.end(), |
182 CompareSocket(socket)) != active_sockets_.end(); | 193 CompareSocket(socket)) != active_sockets_.end(); |
183 } | 194 } |
184 | 195 |
| 196 int GnubbyAuthHandlerPosix::GetConnectionIdForTesting( |
| 197 net::StreamListenSocket* socket) const { |
| 198 ActiveSockets::const_iterator iter = std::find_if( |
| 199 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
| 200 return iter->first; |
| 201 } |
| 202 |
| 203 GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting( |
| 204 net::StreamListenSocket* socket) const { |
| 205 ActiveSockets::const_iterator iter = std::find_if( |
| 206 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
| 207 return iter->second; |
| 208 } |
| 209 |
185 void GnubbyAuthHandlerPosix::DidAccept( | 210 void GnubbyAuthHandlerPosix::DidAccept( |
186 net::StreamListenSocket* server, | 211 net::StreamListenSocket* server, |
187 scoped_ptr<net::StreamListenSocket> socket) { | 212 scoped_ptr<net::StreamListenSocket> socket) { |
188 DCHECK(CalledOnValidThread()); | 213 DCHECK(CalledOnValidThread()); |
189 | 214 |
190 active_sockets_[++last_connection_id_] = socket.release(); | 215 int connection_id = ++last_connection_id_; |
| 216 active_sockets_[connection_id] = |
| 217 new GnubbySocket(socket.Pass(), |
| 218 base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut, |
| 219 base::Unretained(this), |
| 220 connection_id)); |
191 } | 221 } |
192 | 222 |
193 void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket, | 223 void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket, |
194 const char* data, | 224 const char* data, |
195 int len) { | 225 int len) { |
196 DCHECK(CalledOnValidThread()); | 226 DCHECK(CalledOnValidThread()); |
197 | 227 |
198 ActiveSockets::iterator socket_iter = std::find_if( | 228 ActiveSockets::iterator iter = std::find_if( |
199 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); | 229 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
200 if (socket_iter != active_sockets_.end()) { | 230 if (iter != active_sockets_.end()) { |
201 int connection_id = socket_iter->first; | 231 GnubbySocket* gnubby_socket = iter->second; |
202 | 232 gnubby_socket->AddRequestData(data, len); |
203 ActiveRequests::iterator request_iter = | 233 if (gnubby_socket->IsRequestTooLarge()) { |
204 active_requests_.find(connection_id); | 234 SendErrorAndCloseActiveSocket(iter); |
205 if (request_iter != active_requests_.end()) { | 235 } else if (gnubby_socket->IsRequestComplete()) { |
206 std::vector<char>& saved_vector = request_iter->second; | 236 std::string request_data; |
207 if (IsRequestTooLarge( | 237 gnubby_socket->GetAndClearRequestData(&request_data); |
208 saved_vector.data(), saved_vector.size(), kMaxRequestLength)) { | 238 ProcessGnubbyRequest(iter->first, request_data); |
209 // We can't close a StreamListenSocket; throw away everything but the | |
210 // size bytes. | |
211 saved_vector.resize(4); | |
212 return; | |
213 } | |
214 saved_vector.insert(saved_vector.end(), data, data + len); | |
215 | |
216 if (IsRequestComplete(saved_vector.data(), saved_vector.size())) { | |
217 ProcessGnubbyRequest( | |
218 connection_id, saved_vector.data(), saved_vector.size()); | |
219 active_requests_.erase(request_iter); | |
220 } | |
221 } else if (IsRequestComplete(data, len)) { | |
222 ProcessGnubbyRequest(connection_id, data, len); | |
223 } else { | |
224 if (IsRequestTooLarge(data, len, kMaxRequestLength)) { | |
225 // Only save the size bytes. | |
226 active_requests_[connection_id] = std::vector<char>(data, data + 4); | |
227 } else { | |
228 active_requests_[connection_id] = std::vector<char>(data, data + len); | |
229 } | |
230 } | 239 } |
| 240 } else { |
| 241 LOG(ERROR) << "Received data for unknown connection"; |
231 } | 242 } |
232 } | 243 } |
233 | 244 |
234 void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) { | 245 void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) { |
235 DCHECK(CalledOnValidThread()); | 246 DCHECK(CalledOnValidThread()); |
236 | 247 |
237 ActiveSockets::iterator iter = std::find_if( | 248 ActiveSockets::iterator iter = std::find_if( |
238 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); | 249 active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
239 if (iter != active_sockets_.end()) { | 250 if (iter != active_sockets_.end()) { |
240 active_requests_.erase(iter->first); | |
241 | |
242 delete iter->second; | 251 delete iter->second; |
243 active_sockets_.erase(iter); | 252 active_sockets_.erase(iter); |
244 } | 253 } |
245 } | 254 } |
246 | 255 |
247 void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() { | 256 void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() { |
248 DCHECK(CalledOnValidThread()); | 257 DCHECK(CalledOnValidThread()); |
249 | 258 |
250 if (!g_gnubby_socket_name.Get().empty()) { | 259 if (!g_gnubby_socket_name.Get().empty()) { |
251 // If the file already exists, a socket in use error is returned. | 260 // If the file already exists, a socket in use error is returned. |
252 base::DeleteFile(g_gnubby_socket_name.Get(), false); | 261 base::DeleteFile(g_gnubby_socket_name.Get(), false); |
253 | 262 |
254 HOST_LOG << "Listening for gnubby requests on " | 263 HOST_LOG << "Listening for gnubby requests on " |
255 << g_gnubby_socket_name.Get().value(); | 264 << g_gnubby_socket_name.Get().value(); |
256 | 265 |
257 auth_socket_ = net::UnixDomainSocket::CreateAndListen( | 266 auth_socket_ = net::UnixDomainSocket::CreateAndListen( |
258 g_gnubby_socket_name.Get().value(), this, base::Bind(MatchUid)); | 267 g_gnubby_socket_name.Get().value(), this, base::Bind(MatchUid)); |
259 if (!auth_socket_.get()) { | 268 if (!auth_socket_.get()) { |
260 LOG(ERROR) << "Failed to open socket for gnubby requests"; | 269 LOG(ERROR) << "Failed to open socket for gnubby requests"; |
261 } | 270 } |
262 } else { | 271 } else { |
263 HOST_LOG << "No gnubby socket name specified"; | 272 HOST_LOG << "No gnubby socket name specified"; |
264 } | 273 } |
265 } | 274 } |
266 | 275 |
267 void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(int connection_id, | 276 void GnubbyAuthHandlerPosix::ProcessGnubbyRequest( |
268 const char* data, | 277 int connection_id, |
269 int data_len) { | 278 const std::string& request_data) { |
270 std::string json; | 279 HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data); |
271 if (GetJsonFromGnubbyRequest(data, data_len, &json)) { | 280 DeliverHostDataMessage(connection_id, request_data); |
272 HOST_LOG << "Received gnubby request"; | 281 } |
273 DeliverHostDataMessage(connection_id, json); | 282 |
274 } else { | 283 GnubbyAuthHandlerPosix::ActiveSockets::iterator |
275 LOG(ERROR) << "Could not decode gnubby request"; | 284 GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue* message) { |
| 285 int connection_id; |
| 286 if (message->GetInteger(kConnectionId, &connection_id)) { |
| 287 return active_sockets_.find(connection_id); |
276 } | 288 } |
| 289 return active_sockets_.end(); |
| 290 } |
| 291 |
| 292 void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket( |
| 293 const ActiveSockets::iterator& iter) { |
| 294 iter->second->SendSshError(); |
| 295 |
| 296 delete iter->second; |
| 297 active_sockets_.erase(iter); |
| 298 } |
| 299 |
| 300 void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) { |
| 301 HOST_LOG << "Gnubby request timed out"; |
| 302 ActiveSockets::iterator iter = active_sockets_.find(connection_id); |
| 303 if (iter != active_sockets_.end()) |
| 304 SendErrorAndCloseActiveSocket(iter); |
277 } | 305 } |
278 | 306 |
279 } // namespace remoting | 307 } // namespace remoting |
OLD | NEW |