| OLD | NEW | 
|---|
| (Empty) |  | 
|  | 1 // Copyright 2017 The Chromium Authors. All rights reserved. | 
|  | 2 // Use of this source code is governed by a BSD-style license that can be | 
|  | 3 // found in the LICENSE file. | 
|  | 4 | 
|  | 5 #include "mojo/edk/system/channel.h" | 
|  | 6 #include "base/memory/ptr_util.h" | 
|  | 7 #include "testing/gmock/include/gmock/gmock.h" | 
|  | 8 #include "testing/gtest/include/gtest/gtest.h" | 
|  | 9 | 
|  | 10 namespace mojo { | 
|  | 11 namespace edk { | 
|  | 12 namespace { | 
|  | 13 | 
|  | 14 class TestChannel : public Channel { | 
|  | 15  public: | 
|  | 16   TestChannel(Channel::Delegate* delegate) : Channel(delegate) {} | 
|  | 17 | 
|  | 18   char* GetReadBufferTest(size_t* buffer_capacity) { | 
|  | 19     return GetReadBuffer(buffer_capacity); | 
|  | 20   } | 
|  | 21 | 
|  | 22   bool OnReadCompleteTest(size_t bytes_read, size_t* next_read_size_hint) { | 
|  | 23     return OnReadComplete(bytes_read, next_read_size_hint); | 
|  | 24   } | 
|  | 25 | 
|  | 26   MOCK_METHOD4(GetReadPlatformHandles, | 
|  | 27                bool(size_t num_handles, | 
|  | 28                     const void* extra_header, | 
|  | 29                     size_t extra_header_size, | 
|  | 30                     ScopedPlatformHandleVectorPtr* handles)); | 
|  | 31   MOCK_METHOD0(Start, void()); | 
|  | 32   MOCK_METHOD0(ShutDownImpl, void()); | 
|  | 33   MOCK_METHOD0(LeakHandle, void()); | 
|  | 34 | 
|  | 35   void Write(MessagePtr message) {} | 
|  | 36 | 
|  | 37  protected: | 
|  | 38   ~TestChannel() override {} | 
|  | 39 }; | 
|  | 40 | 
|  | 41 // Not using GMock as I don't think it supports movable types. | 
|  | 42 class MockChannelDelegate : public Channel::Delegate { | 
|  | 43  public: | 
|  | 44   MockChannelDelegate() {} | 
|  | 45 | 
|  | 46   size_t GetReceivedPayloadSize() const { return payload_size_; } | 
|  | 47 | 
|  | 48   const void* GetReceivedPayload() const { return payload_.get(); } | 
|  | 49 | 
|  | 50  protected: | 
|  | 51   void OnChannelMessage(const void* payload, | 
|  | 52                         size_t payload_size, | 
|  | 53                         ScopedPlatformHandleVectorPtr handles) override { | 
|  | 54     payload_.reset(new char[payload_size]); | 
|  | 55     memcpy(payload_.get(), payload, payload_size); | 
|  | 56     payload_size_ = payload_size; | 
|  | 57   } | 
|  | 58 | 
|  | 59   // Notify that an error has occured and the Channel will cease operation. | 
|  | 60   void OnChannelError() override {} | 
|  | 61 | 
|  | 62  private: | 
|  | 63   size_t payload_size_ = 0; | 
|  | 64   std::unique_ptr<char[]> payload_; | 
|  | 65 }; | 
|  | 66 | 
|  | 67 Channel::MessagePtr CreateDefaultMessage(bool legacy_message) { | 
|  | 68   const size_t payload_size = 100; | 
|  | 69   Channel::MessagePtr message = base::MakeUnique<Channel::Message>( | 
|  | 70       payload_size, 0, | 
|  | 71       legacy_message ? Channel::Message::MessageType::NORMAL_LEGACY | 
|  | 72                      : Channel::Message::MessageType::NORMAL); | 
|  | 73   char* payload = static_cast<char*>(message->mutable_payload()); | 
|  | 74   for (size_t i = 0; i < payload_size; i++) { | 
|  | 75     payload[i] = static_cast<char>(i); | 
|  | 76   } | 
|  | 77   return message; | 
|  | 78 } | 
|  | 79 | 
|  | 80 void TestMemoryEqual(const void* data1, | 
|  | 81                      size_t data1_size, | 
|  | 82                      const void* data2, | 
|  | 83                      size_t data2_size) { | 
|  | 84   ASSERT_EQ(data1_size, data2_size); | 
|  | 85   const unsigned char* data1_char = static_cast<const unsigned char*>(data1); | 
|  | 86   const unsigned char* data2_char = static_cast<const unsigned char*>(data2); | 
|  | 87   for (size_t i = 0; i < data1_size; i++) { | 
|  | 88     // ASSERT so we don't log tons of errors if the data is different. | 
|  | 89     ASSERT_EQ(data1_char[i], data2_char[i]); | 
|  | 90   } | 
|  | 91 } | 
|  | 92 | 
|  | 93 void TestMessagesAreEqual(Channel::Message* message1, | 
|  | 94                           Channel::Message* message2, | 
|  | 95                           bool legacy_messages) { | 
|  | 96   // If any of the message is null, this is probably not what you wanted to | 
|  | 97   // test. | 
|  | 98   ASSERT_NE(nullptr, message1); | 
|  | 99   ASSERT_NE(nullptr, message2); | 
|  | 100 | 
|  | 101   ASSERT_EQ(message1->payload_size(), message2->payload_size()); | 
|  | 102   EXPECT_EQ(message1->has_handles(), message2->has_handles()); | 
|  | 103 | 
|  | 104   TestMemoryEqual(message1->payload(), message1->payload_size(), | 
|  | 105                   message2->payload(), message2->payload_size()); | 
|  | 106 | 
|  | 107   if (legacy_messages) | 
|  | 108     return; | 
|  | 109 | 
|  | 110   ASSERT_EQ(message1->extra_header_size(), message2->extra_header_size()); | 
|  | 111   TestMemoryEqual(message1->extra_header(), message1->extra_header_size(), | 
|  | 112                   message2->extra_header(), message2->extra_header_size()); | 
|  | 113 } | 
|  | 114 | 
|  | 115 TEST(ChannelTest, LegacyMessageDeserialization) { | 
|  | 116   Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */); | 
|  | 117   Channel::MessagePtr deserialized_message = | 
|  | 118       Channel::Message::Deserialize(message->data(), message->data_num_bytes()); | 
|  | 119   TestMessagesAreEqual(message.get(), deserialized_message.get(), | 
|  | 120                        true /* legacy_message */); | 
|  | 121 } | 
|  | 122 | 
|  | 123 TEST(ChannelTest, NonLegacyMessageDeserialization) { | 
|  | 124   Channel::MessagePtr message = | 
|  | 125       CreateDefaultMessage(false /* legacy_message */); | 
|  | 126   Channel::MessagePtr deserialized_message = | 
|  | 127       Channel::Message::Deserialize(message->data(), message->data_num_bytes()); | 
|  | 128   TestMessagesAreEqual(message.get(), deserialized_message.get(), | 
|  | 129                        false /* legacy_message */); | 
|  | 130 } | 
|  | 131 | 
|  | 132 TEST(ChannelTest, OnReadLegacyMessage) { | 
|  | 133   size_t buffer_size = 100 * 1024; | 
|  | 134   Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */); | 
|  | 135 | 
|  | 136   MockChannelDelegate channel_delegate; | 
|  | 137   scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate); | 
|  | 138   char* read_buffer = channel->GetReadBufferTest(&buffer_size); | 
|  | 139   ASSERT_LT(message->data_num_bytes(), | 
|  | 140             buffer_size);  // Bad test. Increase buffer | 
|  | 141                            // size. | 
|  | 142   memcpy(read_buffer, message->data(), message->data_num_bytes()); | 
|  | 143 | 
|  | 144   size_t next_read_size_hint = 0; | 
|  | 145   EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(), | 
|  | 146                                           &next_read_size_hint)); | 
|  | 147 | 
|  | 148   TestMemoryEqual(message->payload(), message->payload_size(), | 
|  | 149                   channel_delegate.GetReceivedPayload(), | 
|  | 150                   channel_delegate.GetReceivedPayloadSize()); | 
|  | 151 } | 
|  | 152 | 
|  | 153 TEST(ChannelTest, OnReadNonLegacyMessage) { | 
|  | 154   size_t buffer_size = 100 * 1024; | 
|  | 155   Channel::MessagePtr message = | 
|  | 156       CreateDefaultMessage(false /* legacy_message */); | 
|  | 157 | 
|  | 158   MockChannelDelegate channel_delegate; | 
|  | 159   scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate); | 
|  | 160   char* read_buffer = channel->GetReadBufferTest(&buffer_size); | 
|  | 161   ASSERT_LT(message->data_num_bytes(), | 
|  | 162             buffer_size);  // Bad test. Increase buffer | 
|  | 163                            // size. | 
|  | 164   memcpy(read_buffer, message->data(), message->data_num_bytes()); | 
|  | 165 | 
|  | 166   size_t next_read_size_hint = 0; | 
|  | 167   EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(), | 
|  | 168                                           &next_read_size_hint)); | 
|  | 169 | 
|  | 170   TestMemoryEqual(message->payload(), message->payload_size(), | 
|  | 171                   channel_delegate.GetReceivedPayload(), | 
|  | 172                   channel_delegate.GetReceivedPayloadSize()); | 
|  | 173 } | 
|  | 174 | 
|  | 175 }  // namespace | 
|  | 176 }  // namespace edk | 
|  | 177 }  // namespace mojo | 
| OLD | NEW | 
|---|