OLD | NEW |
| (Empty) |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "ipc/unix_domain_socket_util.h" | |
6 | |
7 #include <stddef.h> | |
8 #include <sys/socket.h> | |
9 | |
10 #include <memory> | |
11 | |
12 #include "base/bind.h" | |
13 #include "base/files/file_path.h" | |
14 #include "base/location.h" | |
15 #include "base/macros.h" | |
16 #include "base/path_service.h" | |
17 #include "base/posix/eintr_wrapper.h" | |
18 #include "base/single_thread_task_runner.h" | |
19 #include "base/synchronization/waitable_event.h" | |
20 #include "base/threading/thread.h" | |
21 #include "base/threading/thread_restrictions.h" | |
22 #include "testing/gtest/include/gtest/gtest.h" | |
23 | |
24 namespace { | |
25 | |
26 class SocketAcceptor : public base::MessageLoopForIO::Watcher { | |
27 public: | |
28 SocketAcceptor(int fd, base::SingleThreadTaskRunner* target_thread) | |
29 : server_fd_(-1), | |
30 target_thread_(target_thread), | |
31 started_watching_event_( | |
32 base::WaitableEvent::ResetPolicy::AUTOMATIC, | |
33 base::WaitableEvent::InitialState::NOT_SIGNALED), | |
34 stopped_watching_event_( | |
35 base::WaitableEvent::ResetPolicy::AUTOMATIC, | |
36 base::WaitableEvent::InitialState::NOT_SIGNALED), | |
37 accepted_event_(base::WaitableEvent::ResetPolicy::AUTOMATIC, | |
38 base::WaitableEvent::InitialState::NOT_SIGNALED) { | |
39 target_thread->PostTask(FROM_HERE, | |
40 base::Bind(&SocketAcceptor::StartWatching, base::Unretained(this), fd)); | |
41 } | |
42 | |
43 ~SocketAcceptor() override { | |
44 Close(); | |
45 } | |
46 | |
47 int server_fd() const { return server_fd_; } | |
48 | |
49 void WaitUntilReady() { | |
50 started_watching_event_.Wait(); | |
51 } | |
52 | |
53 void WaitForAccept() { | |
54 accepted_event_.Wait(); | |
55 } | |
56 | |
57 void Close() { | |
58 if (watcher_.get()) { | |
59 target_thread_->PostTask(FROM_HERE, | |
60 base::Bind(&SocketAcceptor::StopWatching, base::Unretained(this), | |
61 watcher_.release())); | |
62 stopped_watching_event_.Wait(); | |
63 } | |
64 } | |
65 | |
66 private: | |
67 void StartWatching(int fd) { | |
68 watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher); | |
69 base::MessageLoopForIO::current()->WatchFileDescriptor( | |
70 fd, true, base::MessageLoopForIO::WATCH_READ, watcher_.get(), this); | |
71 started_watching_event_.Signal(); | |
72 } | |
73 void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher* watcher) { | |
74 watcher->StopWatchingFileDescriptor(); | |
75 delete watcher; | |
76 stopped_watching_event_.Signal(); | |
77 } | |
78 void OnFileCanReadWithoutBlocking(int fd) override { | |
79 ASSERT_EQ(-1, server_fd_); | |
80 IPC::ServerOnConnect(fd, &server_fd_); | |
81 watcher_->StopWatchingFileDescriptor(); | |
82 accepted_event_.Signal(); | |
83 } | |
84 void OnFileCanWriteWithoutBlocking(int fd) override {} | |
85 | |
86 int server_fd_; | |
87 base::SingleThreadTaskRunner* target_thread_; | |
88 std::unique_ptr<base::MessageLoopForIO::FileDescriptorWatcher> watcher_; | |
89 base::WaitableEvent started_watching_event_; | |
90 base::WaitableEvent stopped_watching_event_; | |
91 base::WaitableEvent accepted_event_; | |
92 | |
93 DISALLOW_COPY_AND_ASSIGN(SocketAcceptor); | |
94 }; | |
95 | |
96 const base::FilePath GetChannelDir() { | |
97 base::FilePath tmp_dir; | |
98 PathService::Get(base::DIR_TEMP, &tmp_dir); | |
99 return tmp_dir; | |
100 } | |
101 | |
102 class TestUnixSocketConnection { | |
103 public: | |
104 TestUnixSocketConnection() | |
105 : worker_("WorkerThread"), | |
106 server_listen_fd_(-1), | |
107 server_fd_(-1), | |
108 client_fd_(-1) { | |
109 socket_name_ = GetChannelDir().Append("TestSocket"); | |
110 base::Thread::Options options; | |
111 options.message_loop_type = base::MessageLoop::TYPE_IO; | |
112 worker_.StartWithOptions(options); | |
113 } | |
114 | |
115 bool CreateServerSocket() { | |
116 IPC::CreateServerUnixDomainSocket(socket_name_, &server_listen_fd_); | |
117 if (server_listen_fd_ < 0) | |
118 return false; | |
119 struct stat socket_stat; | |
120 stat(socket_name_.value().c_str(), &socket_stat); | |
121 EXPECT_TRUE(S_ISSOCK(socket_stat.st_mode)); | |
122 acceptor_.reset( | |
123 new SocketAcceptor(server_listen_fd_, worker_.task_runner().get())); | |
124 acceptor_->WaitUntilReady(); | |
125 return true; | |
126 } | |
127 | |
128 bool CreateClientSocket() { | |
129 DCHECK(server_listen_fd_ >= 0); | |
130 IPC::CreateClientUnixDomainSocket(socket_name_, &client_fd_); | |
131 if (client_fd_ < 0) | |
132 return false; | |
133 acceptor_->WaitForAccept(); | |
134 server_fd_ = acceptor_->server_fd(); | |
135 return server_fd_ >= 0; | |
136 } | |
137 | |
138 virtual ~TestUnixSocketConnection() { | |
139 if (client_fd_ >= 0) | |
140 close(client_fd_); | |
141 if (server_fd_ >= 0) | |
142 close(server_fd_); | |
143 if (server_listen_fd_ >= 0) { | |
144 close(server_listen_fd_); | |
145 unlink(socket_name_.value().c_str()); | |
146 } | |
147 } | |
148 | |
149 int client_fd() const { return client_fd_; } | |
150 int server_fd() const { return server_fd_; } | |
151 | |
152 private: | |
153 base::Thread worker_; | |
154 base::FilePath socket_name_; | |
155 int server_listen_fd_; | |
156 int server_fd_; | |
157 int client_fd_; | |
158 std::unique_ptr<SocketAcceptor> acceptor_; | |
159 }; | |
160 | |
161 // Ensure that IPC::CreateServerUnixDomainSocket creates a socket that | |
162 // IPC::CreateClientUnixDomainSocket can successfully connect to. | |
163 TEST(UnixDomainSocketUtil, Connect) { | |
164 TestUnixSocketConnection connection; | |
165 ASSERT_TRUE(connection.CreateServerSocket()); | |
166 ASSERT_TRUE(connection.CreateClientSocket()); | |
167 } | |
168 | |
169 // Ensure that messages can be sent across the resulting socket. | |
170 TEST(UnixDomainSocketUtil, SendReceive) { | |
171 TestUnixSocketConnection connection; | |
172 ASSERT_TRUE(connection.CreateServerSocket()); | |
173 ASSERT_TRUE(connection.CreateClientSocket()); | |
174 | |
175 const char buffer[] = "Hello, server!"; | |
176 size_t buf_len = sizeof(buffer); | |
177 size_t sent_bytes = | |
178 HANDLE_EINTR(send(connection.client_fd(), buffer, buf_len, 0)); | |
179 ASSERT_EQ(buf_len, sent_bytes); | |
180 char recv_buf[sizeof(buffer)]; | |
181 size_t received_bytes = | |
182 HANDLE_EINTR(recv(connection.server_fd(), recv_buf, buf_len, 0)); | |
183 ASSERT_EQ(buf_len, received_bytes); | |
184 ASSERT_EQ(0, memcmp(recv_buf, buffer, buf_len)); | |
185 } | |
186 | |
187 } // namespace | |
OLD | NEW |