OLD | NEW |
1 // Copyright 2016 The Chromium Authors. All rights reserved. | 1 // Copyright 2016 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "remoting/host/security_key/gnubby_socket.h" | 5 #include "remoting/host/security_key/security_key_socket.h" |
6 | 6 |
7 #include <utility> | 7 #include <utility> |
8 | 8 |
9 #include "base/callback_helpers.h" | 9 #include "base/callback_helpers.h" |
10 #include "base/macros.h" | 10 #include "base/macros.h" |
11 #include "base/timer/timer.h" | 11 #include "base/timer/timer.h" |
12 #include "net/base/io_buffer.h" | 12 #include "net/base/io_buffer.h" |
13 #include "net/base/net_errors.h" | 13 #include "net/base/net_errors.h" |
14 #include "net/socket/stream_socket.h" | 14 #include "net/socket/stream_socket.h" |
15 | 15 |
16 namespace remoting { | 16 namespace remoting { |
17 | 17 |
18 namespace { | 18 namespace { |
19 | 19 |
20 const size_t kRequestSizeBytes = 4; | 20 const size_t kRequestSizeBytes = 4; |
21 const size_t kMaxRequestLength = 16384; | 21 const size_t kMaxRequestLength = 16384; |
22 const size_t kRequestReadBufferLength = kRequestSizeBytes + kMaxRequestLength; | 22 const size_t kRequestReadBufferLength = kRequestSizeBytes + kMaxRequestLength; |
23 | 23 |
24 // SSH Failure Code | 24 // SSH Failure Code |
25 const char kSshError[] = {0x05}; | 25 const char kSshError[] = {0x05}; |
26 | 26 |
27 } // namespace | 27 } // namespace |
28 | 28 |
29 GnubbySocket::GnubbySocket(std::unique_ptr<net::StreamSocket> socket, | 29 SecurityKeySocket::SecurityKeySocket(std::unique_ptr<net::StreamSocket> socket, |
30 base::TimeDelta timeout, | 30 base::TimeDelta timeout, |
31 const base::Closure& timeout_callback) | 31 const base::Closure& timeout_callback) |
32 : socket_(std::move(socket)), | 32 : socket_(std::move(socket)), |
33 read_completed_(false), | 33 read_completed_(false), |
34 read_buffer_(new net::IOBufferWithSize(kRequestReadBufferLength)) { | 34 read_buffer_(new net::IOBufferWithSize(kRequestReadBufferLength)) { |
35 timer_.reset(new base::Timer(false, false)); | 35 timer_.reset(new base::Timer(false, false)); |
36 timer_->Start(FROM_HERE, timeout, timeout_callback); | 36 timer_->Start(FROM_HERE, timeout, timeout_callback); |
37 } | 37 } |
38 | 38 |
39 GnubbySocket::~GnubbySocket() {} | 39 SecurityKeySocket::~SecurityKeySocket() {} |
40 | 40 |
41 bool GnubbySocket::GetAndClearRequestData(std::string* data_out) { | 41 bool SecurityKeySocket::GetAndClearRequestData(std::string* data_out) { |
42 DCHECK(thread_checker_.CalledOnValidThread()); | 42 DCHECK(thread_checker_.CalledOnValidThread()); |
43 DCHECK(read_completed_); | 43 DCHECK(read_completed_); |
44 | 44 |
45 if (!read_completed_) | 45 if (!read_completed_) |
46 return false; | 46 return false; |
47 if (!IsRequestComplete() || IsRequestTooLarge()) | 47 if (!IsRequestComplete() || IsRequestTooLarge()) |
48 return false; | 48 return false; |
49 // The request size is not part of the data; don't send it. | 49 // The request size is not part of the data; don't send it. |
50 data_out->assign(request_data_.begin() + kRequestSizeBytes, | 50 data_out->assign(request_data_.begin() + kRequestSizeBytes, |
51 request_data_.end()); | 51 request_data_.end()); |
52 request_data_.clear(); | 52 request_data_.clear(); |
53 return true; | 53 return true; |
54 } | 54 } |
55 | 55 |
56 void GnubbySocket::SendResponse(const std::string& response_data) { | 56 void SecurityKeySocket::SendResponse(const std::string& response_data) { |
57 DCHECK(thread_checker_.CalledOnValidThread()); | 57 DCHECK(thread_checker_.CalledOnValidThread()); |
58 DCHECK(!write_buffer_); | 58 DCHECK(!write_buffer_); |
59 | 59 |
60 std::string response_length_string = GetResponseLengthAsBytes(response_data); | 60 std::string response_length_string = GetResponseLengthAsBytes(response_data); |
61 int response_len = response_length_string.size() + response_data.size(); | 61 int response_len = response_length_string.size() + response_data.size(); |
62 std::unique_ptr<std::string> response( | 62 std::unique_ptr<std::string> response( |
63 new std::string(response_length_string + response_data)); | 63 new std::string(response_length_string + response_data)); |
64 write_buffer_ = new net::DrainableIOBuffer( | 64 write_buffer_ = new net::DrainableIOBuffer( |
65 new net::StringIOBuffer(std::move(response)), response_len); | 65 new net::StringIOBuffer(std::move(response)), response_len); |
66 DoWrite(); | 66 DoWrite(); |
67 } | 67 } |
68 | 68 |
69 void GnubbySocket::SendSshError() { | 69 void SecurityKeySocket::SendSshError() { |
70 DCHECK(thread_checker_.CalledOnValidThread()); | 70 DCHECK(thread_checker_.CalledOnValidThread()); |
71 | 71 |
72 SendResponse(std::string(kSshError, arraysize(kSshError))); | 72 SendResponse(std::string(kSshError, arraysize(kSshError))); |
73 } | 73 } |
74 | 74 |
75 void GnubbySocket::StartReadingRequest( | 75 void SecurityKeySocket::StartReadingRequest( |
76 const base::Closure& request_received_callback) { | 76 const base::Closure& request_received_callback) { |
77 DCHECK(thread_checker_.CalledOnValidThread()); | 77 DCHECK(thread_checker_.CalledOnValidThread()); |
78 DCHECK(request_received_callback_.is_null()); | 78 DCHECK(request_received_callback_.is_null()); |
79 | 79 |
80 request_received_callback_ = request_received_callback; | 80 request_received_callback_ = request_received_callback; |
81 DoRead(); | 81 DoRead(); |
82 } | 82 } |
83 | 83 |
84 void GnubbySocket::OnDataWritten(int result) { | 84 void SecurityKeySocket::OnDataWritten(int result) { |
85 DCHECK(thread_checker_.CalledOnValidThread()); | 85 DCHECK(thread_checker_.CalledOnValidThread()); |
86 DCHECK(write_buffer_); | 86 DCHECK(write_buffer_); |
87 | 87 |
88 if (result < 0) { | 88 if (result < 0) { |
89 LOG(ERROR) << "Error sending response: " << result; | 89 LOG(ERROR) << "Error sending response: " << result; |
90 return; | 90 return; |
91 } | 91 } |
92 ResetTimer(); | 92 ResetTimer(); |
93 write_buffer_->DidConsume(result); | 93 write_buffer_->DidConsume(result); |
94 DoWrite(); | 94 DoWrite(); |
95 } | 95 } |
96 | 96 |
97 void GnubbySocket::DoWrite() { | 97 void SecurityKeySocket::DoWrite() { |
98 DCHECK(thread_checker_.CalledOnValidThread()); | 98 DCHECK(thread_checker_.CalledOnValidThread()); |
99 DCHECK(write_buffer_); | 99 DCHECK(write_buffer_); |
100 | 100 |
101 if (!write_buffer_->BytesRemaining()) { | 101 if (!write_buffer_->BytesRemaining()) { |
102 write_buffer_ = nullptr; | 102 write_buffer_ = nullptr; |
103 return; | 103 return; |
104 } | 104 } |
105 int result = socket_->Write( | 105 int result = socket_->Write( |
106 write_buffer_.get(), write_buffer_->BytesRemaining(), | 106 write_buffer_.get(), write_buffer_->BytesRemaining(), |
107 base::Bind(&GnubbySocket::OnDataWritten, base::Unretained(this))); | 107 base::Bind(&SecurityKeySocket::OnDataWritten, base::Unretained(this))); |
108 if (result != net::ERR_IO_PENDING) | 108 if (result != net::ERR_IO_PENDING) |
109 OnDataWritten(result); | 109 OnDataWritten(result); |
110 } | 110 } |
111 | 111 |
112 void GnubbySocket::OnDataRead(int result) { | 112 void SecurityKeySocket::OnDataRead(int result) { |
113 DCHECK(thread_checker_.CalledOnValidThread()); | 113 DCHECK(thread_checker_.CalledOnValidThread()); |
114 | 114 |
115 if (result <= 0) { | 115 if (result <= 0) { |
116 if (result < 0) | 116 if (result < 0) |
117 LOG(ERROR) << "Error reading request: " << result; | 117 LOG(ERROR) << "Error reading request: " << result; |
118 read_completed_ = true; | 118 read_completed_ = true; |
119 base::ResetAndReturn(&request_received_callback_).Run(); | 119 base::ResetAndReturn(&request_received_callback_).Run(); |
120 return; | 120 return; |
121 } | 121 } |
122 | 122 |
123 ResetTimer(); | 123 ResetTimer(); |
124 request_data_.insert(request_data_.end(), read_buffer_->data(), | 124 request_data_.insert(request_data_.end(), read_buffer_->data(), |
125 read_buffer_->data() + result); | 125 read_buffer_->data() + result); |
126 if (IsRequestComplete()) { | 126 if (IsRequestComplete()) { |
127 read_completed_ = true; | 127 read_completed_ = true; |
128 base::ResetAndReturn(&request_received_callback_).Run(); | 128 base::ResetAndReturn(&request_received_callback_).Run(); |
129 return; | 129 return; |
130 } | 130 } |
131 | 131 |
132 DoRead(); | 132 DoRead(); |
133 } | 133 } |
134 | 134 |
135 void GnubbySocket::DoRead() { | 135 void SecurityKeySocket::DoRead() { |
136 DCHECK(thread_checker_.CalledOnValidThread()); | 136 DCHECK(thread_checker_.CalledOnValidThread()); |
137 | 137 |
138 int result = socket_->Read( | 138 int result = socket_->Read( |
139 read_buffer_.get(), kRequestReadBufferLength, | 139 read_buffer_.get(), kRequestReadBufferLength, |
140 base::Bind(&GnubbySocket::OnDataRead, base::Unretained(this))); | 140 base::Bind(&SecurityKeySocket::OnDataRead, base::Unretained(this))); |
141 if (result != net::ERR_IO_PENDING) | 141 if (result != net::ERR_IO_PENDING) |
142 OnDataRead(result); | 142 OnDataRead(result); |
143 } | 143 } |
144 | 144 |
145 bool GnubbySocket::IsRequestComplete() const { | 145 bool SecurityKeySocket::IsRequestComplete() const { |
146 DCHECK(thread_checker_.CalledOnValidThread()); | 146 DCHECK(thread_checker_.CalledOnValidThread()); |
147 | 147 |
148 if (request_data_.size() < kRequestSizeBytes) | 148 if (request_data_.size() < kRequestSizeBytes) |
149 return false; | 149 return false; |
150 return GetRequestLength() <= request_data_.size(); | 150 return GetRequestLength() <= request_data_.size(); |
151 } | 151 } |
152 | 152 |
153 bool GnubbySocket::IsRequestTooLarge() const { | 153 bool SecurityKeySocket::IsRequestTooLarge() const { |
154 DCHECK(thread_checker_.CalledOnValidThread()); | 154 DCHECK(thread_checker_.CalledOnValidThread()); |
155 | 155 |
156 if (request_data_.size() < kRequestSizeBytes) | 156 if (request_data_.size() < kRequestSizeBytes) |
157 return false; | 157 return false; |
158 return GetRequestLength() > kMaxRequestLength; | 158 return GetRequestLength() > kMaxRequestLength; |
159 } | 159 } |
160 | 160 |
161 size_t GnubbySocket::GetRequestLength() const { | 161 size_t SecurityKeySocket::GetRequestLength() const { |
162 DCHECK(request_data_.size() >= kRequestSizeBytes); | 162 DCHECK(request_data_.size() >= kRequestSizeBytes); |
163 | 163 |
164 return ((request_data_[0] & 255) << 24) + ((request_data_[1] & 255) << 16) + | 164 return ((request_data_[0] & 255) << 24) + ((request_data_[1] & 255) << 16) + |
165 ((request_data_[2] & 255) << 8) + (request_data_[3] & 255) + | 165 ((request_data_[2] & 255) << 8) + (request_data_[3] & 255) + |
166 kRequestSizeBytes; | 166 kRequestSizeBytes; |
167 } | 167 } |
168 | 168 |
169 std::string GnubbySocket::GetResponseLengthAsBytes( | 169 std::string SecurityKeySocket::GetResponseLengthAsBytes( |
170 const std::string& response) const { | 170 const std::string& response) const { |
171 std::string response_len; | 171 std::string response_len; |
172 response_len.reserve(kRequestSizeBytes); | 172 response_len.reserve(kRequestSizeBytes); |
173 int len = response.size(); | 173 int len = response.size(); |
174 | 174 |
175 response_len.push_back((len >> 24) & 255); | 175 response_len.push_back((len >> 24) & 255); |
176 response_len.push_back((len >> 16) & 255); | 176 response_len.push_back((len >> 16) & 255); |
177 response_len.push_back((len >> 8) & 255); | 177 response_len.push_back((len >> 8) & 255); |
178 response_len.push_back(len & 255); | 178 response_len.push_back(len & 255); |
179 | 179 |
180 return response_len; | 180 return response_len; |
181 } | 181 } |
182 | 182 |
183 void GnubbySocket::ResetTimer() { | 183 void SecurityKeySocket::ResetTimer() { |
184 if (timer_->IsRunning()) | 184 if (timer_->IsRunning()) |
185 timer_->Reset(); | 185 timer_->Reset(); |
186 } | 186 } |
187 | 187 |
188 } // namespace remoting | 188 } // namespace remoting |
OLD | NEW |