Index: mojo/services/network/udp_socket_impl.cc |
diff --git a/mojo/services/network/udp_socket_impl.cc b/mojo/services/network/udp_socket_impl.cc |
index b4438bec1f07d25b2a45cafd5c8fd2d3b91180e2..524136a32744b49fd542a1c1a468485cc605f2bf 100644 |
--- a/mojo/services/network/udp_socket_impl.cc |
+++ b/mojo/services/network/udp_socket_impl.cc |
@@ -16,6 +16,8 @@ |
#include "mojo/services/network/net_address_type_converters.h" |
#include "net/base/io_buffer.h" |
#include "net/base/net_errors.h" |
+#include "net/base/rand_callback.h" |
+#include "net/udp/datagram_socket.h" |
namespace mojo { |
@@ -33,8 +35,10 @@ UDPSocketImpl::PendingSendRequest::PendingSendRequest() {} |
UDPSocketImpl::PendingSendRequest::~PendingSendRequest() {} |
UDPSocketImpl::UDPSocketImpl() |
- : socket_(nullptr, net::NetLog::Source()), |
- bound_(false), |
+ : socket_(net::DatagramSocket::DEFAULT_BIND, net::RandIntCallback(), |
+ nullptr, net::NetLog::Source()), |
+ state_(NOT_BOUND_OR_CONNECTED), |
+ allow_address_reuse_(false), |
remaining_recv_slots_(0), |
max_pending_send_requests_(kDefaultMaxPendingSendRequests) { |
} |
@@ -45,60 +49,123 @@ UDPSocketImpl::~UDPSocketImpl() { |
void UDPSocketImpl::AllowAddressReuse( |
const Callback<void(NetworkErrorPtr)>& callback) { |
- if (bound_) { |
+ if (IsBoundOrConnected()) { |
callback.Run(MakeNetworkError(net::ERR_FAILED)); |
return; |
} |
- socket_.AllowAddressReuse(); |
+ allow_address_reuse_ = true; |
callback.Run(MakeNetworkError(net::OK)); |
} |
void UDPSocketImpl::Bind( |
NetAddressPtr addr, |
const Callback<void(NetworkErrorPtr, NetAddressPtr)>& callback) { |
- if (bound_) { |
- callback.Run(MakeNetworkError(net::ERR_FAILED), NetAddressPtr()); |
- return; |
- } |
+ int net_result = net::OK; |
+ bool opened = false; |
- net::IPEndPoint ip_end_point = addr.To<net::IPEndPoint>(); |
- if (ip_end_point.GetFamily() == net::ADDRESS_FAMILY_UNSPECIFIED) { |
- callback.Run(MakeNetworkError(net::ERR_ADDRESS_INVALID), NetAddressPtr()); |
- return; |
- } |
+ do { |
+ if (IsBoundOrConnected()) { |
+ net_result = net::ERR_FAILED; |
+ break; |
+ } |
- int net_result = socket_.Listen(ip_end_point); |
- if (net_result != net::OK) { |
- callback.Run(MakeNetworkError(net_result), NetAddressPtr()); |
- return; |
- } |
+ net::IPEndPoint ip_end_point = addr.To<net::IPEndPoint>(); |
+ if (ip_end_point.GetFamily() == net::ADDRESS_FAMILY_UNSPECIFIED) { |
+ net_result = net::ERR_ADDRESS_INVALID; |
+ break; |
+ } |
- net::IPEndPoint bound_ip_end_point; |
- NetAddressPtr bound_addr; |
- net_result = socket_.GetLocalAddress(&bound_ip_end_point); |
- if (net_result == net::OK) |
- bound_addr = NetAddress::From(bound_ip_end_point); |
+ net_result = socket_.Open(ip_end_point.GetFamily()); |
+ if (net_result != net::OK) |
+ break; |
+ opened = true; |
- bound_ = true; |
- callback.Run(MakeNetworkError(net::OK), bound_addr.Pass()); |
+ if (allow_address_reuse_) { |
+ net_result = socket_.AllowAddressReuse(); |
+ if (net_result != net::OK) |
+ break; |
+ } |
- if (remaining_recv_slots_ > 0) { |
- DCHECK(!recvfrom_buffer_.get()); |
- DoRecvFrom(); |
- } |
+ net_result = socket_.Bind(ip_end_point); |
+ if (net_result != net::OK) |
+ break; |
+ |
+ net::IPEndPoint bound_ip_end_point; |
+ net_result = socket_.GetLocalAddress(&bound_ip_end_point); |
+ if (net_result != net::OK) |
+ break; |
+ |
+ state_ = BOUND; |
+ callback.Run(MakeNetworkError(net_result), |
+ NetAddress::From(bound_ip_end_point)); |
+ |
+ if (remaining_recv_slots_ > 0) { |
+ DCHECK(!recvfrom_buffer_.get()); |
+ DoRecvFrom(); |
+ } |
+ return; |
+ } while (false); |
+ |
+ DCHECK(net_result != net::OK); |
+ if (opened) |
+ socket_.Close(); |
+ callback.Run(MakeNetworkError(net_result), nullptr); |
} |
void UDPSocketImpl::Connect( |
NetAddressPtr remote_addr, |
const Callback<void(NetworkErrorPtr, NetAddressPtr)>& callback) { |
- // TODO(yzshen): Implement it. |
+ int net_result = net::OK; |
+ bool opened = false; |
+ |
+ do { |
+ if (IsBoundOrConnected()) { |
+ net_result = net::ERR_FAILED; |
+ break; |
+ } |
+ |
+ net::IPEndPoint ip_end_point = remote_addr.To<net::IPEndPoint>(); |
+ if (ip_end_point.GetFamily() == net::ADDRESS_FAMILY_UNSPECIFIED) { |
+ net_result = net::ERR_ADDRESS_INVALID; |
+ break; |
+ } |
+ |
+ net_result = socket_.Open(ip_end_point.GetFamily()); |
+ if (net_result != net::OK) |
+ break; |
+ opened = true; |
+ |
+ net_result = socket_.Connect(ip_end_point); |
+ if (net_result != net::OK) |
+ break; |
+ |
+ net::IPEndPoint local_ip_end_point; |
+ net_result = socket_.GetLocalAddress(&local_ip_end_point); |
+ if (net_result != net::OK) |
+ break; |
+ |
+ state_ = CONNECTED; |
+ callback.Run(MakeNetworkError(net_result), |
+ NetAddress::From(local_ip_end_point)); |
+ |
+ if (remaining_recv_slots_ > 0) { |
+ DCHECK(!recvfrom_buffer_.get()); |
+ DoRecvFrom(); |
+ } |
+ return; |
+ } while (false); |
+ |
+ DCHECK(net_result != net::OK); |
+ if (opened) |
+ socket_.Close(); |
+ callback.Run(MakeNetworkError(net_result), nullptr); |
} |
void UDPSocketImpl::SetSendBufferSize( |
uint32_t size, |
const Callback<void(NetworkErrorPtr)>& callback) { |
- if (!bound_) { |
+ if (!IsBoundOrConnected()) { |
callback.Run(MakeNetworkError(net::ERR_FAILED)); |
return; |
} |
@@ -113,7 +180,7 @@ void UDPSocketImpl::SetSendBufferSize( |
void UDPSocketImpl::SetReceiveBufferSize( |
uint32_t size, |
const Callback<void(NetworkErrorPtr)>& callback) { |
- if (!bound_) { |
+ if (!IsBoundOrConnected()) { |
callback.Run(MakeNetworkError(net::ERR_FAILED)); |
return; |
} |
@@ -158,7 +225,7 @@ void UDPSocketImpl::ReceiveMore(uint32_t datagram_number) { |
remaining_recv_slots_ += datagram_number; |
- if (bound_ && !recvfrom_buffer_.get()) { |
+ if (IsBoundOrConnected() && !recvfrom_buffer_.get()) { |
DCHECK_EQ(datagram_number, remaining_recv_slots_); |
DoRecvFrom(); |
} |
@@ -167,10 +234,14 @@ void UDPSocketImpl::ReceiveMore(uint32_t datagram_number) { |
void UDPSocketImpl::SendTo(NetAddressPtr dest_addr, |
Array<uint8_t> data, |
const Callback<void(NetworkErrorPtr)>& callback) { |
- if (!bound_) { |
+ if (!IsBoundOrConnected()) { |
callback.Run(MakeNetworkError(net::ERR_FAILED)); |
return; |
} |
+ if (state_ == BOUND && !dest_addr) { |
+ callback.Run(MakeNetworkError(net::ERR_INVALID_ARGUMENT)); |
+ return; |
+ } |
if (sendto_buffer_.get()) { |
if (pending_send_requests_.size() >= max_pending_send_requests_) { |
@@ -192,7 +263,7 @@ void UDPSocketImpl::SendTo(NetAddressPtr dest_addr, |
} |
void UDPSocketImpl::DoRecvFrom() { |
- DCHECK(bound_); |
+ DCHECK(IsBoundOrConnected()); |
DCHECK(!recvfrom_buffer_.get()); |
DCHECK_GT(remaining_recv_slots_, 0u); |
@@ -204,7 +275,7 @@ void UDPSocketImpl::DoRecvFrom() { |
int net_result = socket_.RecvFrom( |
recvfrom_buffer_.get(), |
kMaxReadSize, |
- &recvfrom_address_, |
+ state_ == BOUND ? &recvfrom_address_ : nullptr, |
base::Bind(&UDPSocketImpl::OnRecvFromCompleted, base::Unretained(this))); |
if (net_result != net::ERR_IO_PENDING) |
OnRecvFromCompleted(net_result); |
@@ -213,15 +284,9 @@ void UDPSocketImpl::DoRecvFrom() { |
void UDPSocketImpl::DoSendTo(NetAddressPtr addr, |
Array<uint8_t> data, |
const Callback<void(NetworkErrorPtr)>& callback) { |
- DCHECK(bound_); |
+ DCHECK(IsBoundOrConnected()); |
DCHECK(!sendto_buffer_.get()); |
- net::IPEndPoint ip_end_point = addr.To<net::IPEndPoint>(); |
- if (ip_end_point.GetFamily() == net::ADDRESS_FAMILY_UNSPECIFIED) { |
- callback.Run(MakeNetworkError(net::ERR_ADDRESS_INVALID)); |
- return; |
- } |
- |
if (data.size() > kMaxWriteSize) { |
callback.Run(MakeNetworkError(net::ERR_INVALID_ARGUMENT)); |
return; |
@@ -230,13 +295,27 @@ void UDPSocketImpl::DoSendTo(NetAddressPtr addr, |
if (data.size() > 0) |
memcpy(sendto_buffer_->data(), &data.storage()[0], data.size()); |
- // It is safe to use base::Unretained(this) because |socket_| is owned by this |
- // object. If this object gets destroyed (and so does |socket_|), the callback |
- // won't be called. |
- int net_result = socket_.SendTo(sendto_buffer_.get(), sendto_buffer_->size(), |
- ip_end_point, |
- base::Bind(&UDPSocketImpl::OnSendToCompleted, |
- base::Unretained(this), callback)); |
+ int net_result = net::OK; |
+ if (addr) { |
+ net::IPEndPoint ip_end_point = addr.To<net::IPEndPoint>(); |
+ if (ip_end_point.GetFamily() == net::ADDRESS_FAMILY_UNSPECIFIED) { |
+ callback.Run(MakeNetworkError(net::ERR_ADDRESS_INVALID)); |
+ return; |
+ } |
+ |
+ // It is safe to use base::Unretained(this) because |socket_| is owned by |
+ // this object. If this object gets destroyed (and so does |socket_|), the |
+ // callback won't be called. |
+ net_result = socket_.SendTo(sendto_buffer_.get(), sendto_buffer_->size(), |
+ ip_end_point, |
+ base::Bind(&UDPSocketImpl::OnSendToCompleted, |
+ base::Unretained(this), callback)); |
+ } else { |
+ DCHECK(state_ == CONNECTED); |
+ net_result = socket_.Write(sendto_buffer_.get(), sendto_buffer_->size(), |
+ base::Bind(&UDPSocketImpl::OnSendToCompleted, |
+ base::Unretained(this), callback)); |
+ } |
if (net_result != net::ERR_IO_PENDING) |
OnSendToCompleted(callback, net_result); |
} |
@@ -247,7 +326,9 @@ void UDPSocketImpl::OnRecvFromCompleted(int net_result) { |
NetAddressPtr net_address; |
Array<uint8_t> array; |
if (net_result >= 0) { |
- net_address = NetAddress::From(recvfrom_address_); |
+ if (state_ == BOUND) |
+ net_address = NetAddress::From(recvfrom_address_); |
+ |
std::vector<uint8_t> data(net_result); |
if (net_result > 0) |
memcpy(&data[0], recvfrom_buffer_->data(), net_result); |