OLD | NEW |
(Empty) | |
| 1 # -*- coding: utf-8 -*- |
| 2 # Copyright 2015 Google Inc. All Rights Reserved. |
| 3 # |
| 4 # Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 # you may not use this file except in compliance with the License. |
| 6 # You may obtain a copy of the License at |
| 7 # |
| 8 # http://www.apache.org/licenses/LICENSE-2.0 |
| 9 # |
| 10 # Unless required by applicable law or agreed to in writing, software |
| 11 # distributed under the License is distributed on an "AS IS" BASIS, |
| 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 # See the License for the specific language governing permissions and |
| 14 # limitations under the License. |
| 15 """Unit tests for daisy chain wrapper class.""" |
| 16 |
| 17 from __future__ import absolute_import |
| 18 |
| 19 import os |
| 20 import pkgutil |
| 21 |
| 22 import gslib.cloud_api |
| 23 from gslib.daisy_chain_wrapper import DaisyChainWrapper |
| 24 from gslib.storage_url import StorageUrlFromString |
| 25 import gslib.tests.testcase as testcase |
| 26 from gslib.util import TRANSFER_BUFFER_SIZE |
| 27 |
| 28 |
| 29 _TEST_FILE = 'test.txt' |
| 30 |
| 31 |
| 32 class TestDaisyChainWrapper(testcase.GsUtilUnitTestCase): |
| 33 """Unit tests for the DaisyChainWrapper class.""" |
| 34 |
| 35 _temp_test_file = None |
| 36 _dummy_url = StorageUrlFromString('gs://bucket/object') |
| 37 |
| 38 def setUp(self): |
| 39 super(TestDaisyChainWrapper, self).setUp() |
| 40 self.test_data_file = self._GetTestFile() |
| 41 self.test_data_file_len = os.path.getsize(self.test_data_file) |
| 42 |
| 43 def _GetTestFile(self): |
| 44 contents = pkgutil.get_data('gslib', 'tests/test_data/%s' % _TEST_FILE) |
| 45 if not self._temp_test_file: |
| 46 # Write to a temp file because pkgutil doesn't expose a stream interface. |
| 47 self._temp_test_file = self.CreateTempFile( |
| 48 file_name=_TEST_FILE, contents=contents) |
| 49 return self._temp_test_file |
| 50 |
| 51 class MockDownloadCloudApi(gslib.cloud_api.CloudApi): |
| 52 """Mock CloudApi that implements GetObjectMedia for testing.""" |
| 53 |
| 54 def __init__(self, write_values): |
| 55 """Initialize the mock that will be used by the download thread. |
| 56 |
| 57 Args: |
| 58 write_values: List of values that will be used for calls to write(), |
| 59 in order, by the download thread. An Exception class may be part of |
| 60 the list; if so, the Exception will be raised after previous |
| 61 values are consumed. |
| 62 """ |
| 63 self._write_values = write_values |
| 64 self.get_calls = 0 |
| 65 |
| 66 def GetObjectMedia(self, unused_bucket_name, unused_object_name, |
| 67 download_stream, start_byte=0, end_byte=None, |
| 68 **kwargs): |
| 69 """Writes self._write_values to the download_stream.""" |
| 70 # Writes from start_byte up to, but not including end_byte (if not None). |
| 71 # Does not slice values; |
| 72 # self._write_values must line up with start/end_byte. |
| 73 self.get_calls += 1 |
| 74 bytes_read = 0 |
| 75 for write_value in self._write_values: |
| 76 if bytes_read < start_byte: |
| 77 bytes_read += len(write_value) |
| 78 continue |
| 79 if end_byte and bytes_read >= end_byte: |
| 80 break |
| 81 if isinstance(write_value, Exception): |
| 82 raise write_value |
| 83 download_stream.write(write_value) |
| 84 bytes_read += len(write_value) |
| 85 |
| 86 def _WriteFromWrapperToFile(self, daisy_chain_wrapper, file_path): |
| 87 """Writes all contents from the DaisyChainWrapper to the named file.""" |
| 88 with open(file_path, 'wb') as upload_stream: |
| 89 while True: |
| 90 data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) |
| 91 if not data: |
| 92 break |
| 93 upload_stream.write(data) |
| 94 |
| 95 def testDownloadSingleChunk(self): |
| 96 """Tests a single call to GetObjectMedia.""" |
| 97 write_values = [] |
| 98 with open(self.test_data_file, 'rb') as stream: |
| 99 while True: |
| 100 data = stream.read(TRANSFER_BUFFER_SIZE) |
| 101 if not data: |
| 102 break |
| 103 write_values.append(data) |
| 104 upload_file = self.CreateTempFile() |
| 105 # Test for a single call even if the chunk size is larger than the data. |
| 106 for chunk_size in (self.test_data_file_len, self.test_data_file_len + 1): |
| 107 mock_api = self.MockDownloadCloudApi(write_values) |
| 108 daisy_chain_wrapper = DaisyChainWrapper( |
| 109 self._dummy_url, self.test_data_file_len, mock_api, |
| 110 download_chunk_size=chunk_size) |
| 111 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| 112 # Since the chunk size is >= the file size, only a single GetObjectMedia |
| 113 # call should be made. |
| 114 self.assertEquals(mock_api.get_calls, 1) |
| 115 with open(upload_file, 'rb') as upload_stream: |
| 116 with open(self.test_data_file, 'rb') as download_stream: |
| 117 self.assertEqual(upload_stream.read(), download_stream.read()) |
| 118 |
| 119 def testDownloadMultiChunk(self): |
| 120 """Tests multiple calls to GetObjectMedia.""" |
| 121 upload_file = self.CreateTempFile() |
| 122 write_values = [] |
| 123 with open(self.test_data_file, 'rb') as stream: |
| 124 while True: |
| 125 data = stream.read(TRANSFER_BUFFER_SIZE) |
| 126 if not data: |
| 127 break |
| 128 write_values.append(data) |
| 129 mock_api = self.MockDownloadCloudApi(write_values) |
| 130 daisy_chain_wrapper = DaisyChainWrapper( |
| 131 self._dummy_url, self.test_data_file_len, mock_api, |
| 132 download_chunk_size=TRANSFER_BUFFER_SIZE) |
| 133 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| 134 num_expected_calls = self.test_data_file_len / TRANSFER_BUFFER_SIZE |
| 135 if self.test_data_file_len % TRANSFER_BUFFER_SIZE: |
| 136 num_expected_calls += 1 |
| 137 # Since the chunk size is < the file size, multiple calls to GetObjectMedia |
| 138 # should be made. |
| 139 self.assertEqual(mock_api.get_calls, num_expected_calls) |
| 140 with open(upload_file, 'rb') as upload_stream: |
| 141 with open(self.test_data_file, 'rb') as download_stream: |
| 142 self.assertEqual(upload_stream.read(), download_stream.read()) |
| 143 |
| 144 def testDownloadWithZeroWrites(self): |
| 145 """Tests 0-byte writes to the download stream from GetObjectMedia.""" |
| 146 write_values = [] |
| 147 with open(self.test_data_file, 'rb') as stream: |
| 148 while True: |
| 149 write_values.append(b'') |
| 150 data = stream.read(TRANSFER_BUFFER_SIZE) |
| 151 write_values.append(b'') |
| 152 if not data: |
| 153 break |
| 154 write_values.append(data) |
| 155 upload_file = self.CreateTempFile() |
| 156 mock_api = self.MockDownloadCloudApi(write_values) |
| 157 daisy_chain_wrapper = DaisyChainWrapper( |
| 158 self._dummy_url, self.test_data_file_len, mock_api, |
| 159 download_chunk_size=self.test_data_file_len) |
| 160 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| 161 self.assertEquals(mock_api.get_calls, 1) |
| 162 with open(upload_file, 'rb') as upload_stream: |
| 163 with open(self.test_data_file, 'rb') as download_stream: |
| 164 self.assertEqual(upload_stream.read(), download_stream.read()) |
| 165 |
| 166 def testDownloadWithPartialWrite(self): |
| 167 """Tests unaligned writes to the download stream from GetObjectMedia.""" |
| 168 with open(self.test_data_file, 'rb') as stream: |
| 169 chunk = stream.read(TRANSFER_BUFFER_SIZE) |
| 170 one_byte = chunk[0] |
| 171 chunk_minus_one_byte = chunk[1:TRANSFER_BUFFER_SIZE] |
| 172 half_chunk = chunk[0:TRANSFER_BUFFER_SIZE/2] |
| 173 |
| 174 write_values_dict = { |
| 175 'First byte first chunk unaligned': |
| 176 (one_byte, chunk_minus_one_byte, chunk, chunk), |
| 177 'Last byte first chunk unaligned': |
| 178 (chunk_minus_one_byte, chunk, chunk), |
| 179 'First byte second chunk unaligned': |
| 180 (chunk, one_byte, chunk_minus_one_byte, chunk), |
| 181 'Last byte second chunk unaligned': |
| 182 (chunk, chunk_minus_one_byte, one_byte, chunk), |
| 183 'First byte final chunk unaligned': |
| 184 (chunk, chunk, one_byte, chunk_minus_one_byte), |
| 185 'Last byte final chunk unaligned': |
| 186 (chunk, chunk, chunk_minus_one_byte, one_byte), |
| 187 'Half chunks': |
| 188 (half_chunk, half_chunk, half_chunk), |
| 189 'Many unaligned': |
| 190 (one_byte, half_chunk, one_byte, half_chunk, chunk, |
| 191 chunk_minus_one_byte, chunk, one_byte, half_chunk, one_byte) |
| 192 } |
| 193 upload_file = self.CreateTempFile() |
| 194 for case_name, write_values in write_values_dict.iteritems(): |
| 195 expected_contents = b'' |
| 196 for write_value in write_values: |
| 197 expected_contents += write_value |
| 198 mock_api = self.MockDownloadCloudApi(write_values) |
| 199 daisy_chain_wrapper = DaisyChainWrapper( |
| 200 self._dummy_url, len(expected_contents), mock_api, |
| 201 download_chunk_size=self.test_data_file_len) |
| 202 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| 203 with open(upload_file, 'rb') as upload_stream: |
| 204 self.assertEqual(upload_stream.read(), expected_contents, |
| 205 'Uploaded file contents for case %s did not match' |
| 206 % case_name) |
| 207 |
| 208 def testSeekAndReturn(self): |
| 209 """Tests seeking to the end of the wrapper (simulates getting size).""" |
| 210 write_values = [] |
| 211 with open(self.test_data_file, 'rb') as stream: |
| 212 while True: |
| 213 data = stream.read(TRANSFER_BUFFER_SIZE) |
| 214 if not data: |
| 215 break |
| 216 write_values.append(data) |
| 217 upload_file = self.CreateTempFile() |
| 218 mock_api = self.MockDownloadCloudApi(write_values) |
| 219 daisy_chain_wrapper = DaisyChainWrapper( |
| 220 self._dummy_url, self.test_data_file_len, mock_api, |
| 221 download_chunk_size=self.test_data_file_len) |
| 222 with open(upload_file, 'wb') as upload_stream: |
| 223 current_position = 0 |
| 224 daisy_chain_wrapper.seek(0, whence=os.SEEK_END) |
| 225 daisy_chain_wrapper.seek(current_position) |
| 226 while True: |
| 227 data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) |
| 228 current_position += len(data) |
| 229 daisy_chain_wrapper.seek(0, whence=os.SEEK_END) |
| 230 daisy_chain_wrapper.seek(current_position) |
| 231 if not data: |
| 232 break |
| 233 upload_stream.write(data) |
| 234 self.assertEquals(mock_api.get_calls, 1) |
| 235 with open(upload_file, 'rb') as upload_stream: |
| 236 with open(self.test_data_file, 'rb') as download_stream: |
| 237 self.assertEqual(upload_stream.read(), download_stream.read()) |
| 238 |
| 239 def testRestartDownloadThread(self): |
| 240 """Tests seek to non-stored position; this restarts the download thread.""" |
| 241 write_values = [] |
| 242 with open(self.test_data_file, 'rb') as stream: |
| 243 while True: |
| 244 data = stream.read(TRANSFER_BUFFER_SIZE) |
| 245 if not data: |
| 246 break |
| 247 write_values.append(data) |
| 248 upload_file = self.CreateTempFile() |
| 249 mock_api = self.MockDownloadCloudApi(write_values) |
| 250 daisy_chain_wrapper = DaisyChainWrapper( |
| 251 self._dummy_url, self.test_data_file_len, mock_api, |
| 252 download_chunk_size=self.test_data_file_len) |
| 253 daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) |
| 254 daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) |
| 255 daisy_chain_wrapper.seek(0) |
| 256 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| 257 self.assertEquals(mock_api.get_calls, 2) |
| 258 with open(upload_file, 'rb') as upload_stream: |
| 259 with open(self.test_data_file, 'rb') as download_stream: |
| 260 self.assertEqual(upload_stream.read(), download_stream.read()) |
| 261 |
| 262 def testDownloadThreadException(self): |
| 263 """Tests that an exception is propagated via the upload thread.""" |
| 264 |
| 265 class DownloadException(Exception): |
| 266 pass |
| 267 |
| 268 write_values = [b'a', b'b', |
| 269 DownloadException('Download thread forces failure')] |
| 270 upload_file = self.CreateTempFile() |
| 271 mock_api = self.MockDownloadCloudApi(write_values) |
| 272 daisy_chain_wrapper = DaisyChainWrapper( |
| 273 self._dummy_url, self.test_data_file_len, mock_api, |
| 274 download_chunk_size=self.test_data_file_len) |
| 275 try: |
| 276 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| 277 self.fail('Expected exception') |
| 278 except DownloadException, e: |
| 279 self.assertIn('Download thread forces failure', str(e)) |
| 280 |
| 281 def testInvalidSeek(self): |
| 282 """Tests that seeking fails for unsupported seek arguments.""" |
| 283 daisy_chain_wrapper = DaisyChainWrapper( |
| 284 self._dummy_url, self.test_data_file_len, self.MockDownloadCloudApi([])) |
| 285 try: |
| 286 # SEEK_CUR is invalid. |
| 287 daisy_chain_wrapper.seek(0, whence=os.SEEK_CUR) |
| 288 self.fail('Expected exception') |
| 289 except IOError, e: |
| 290 self.assertIn('does not support seek mode', str(e)) |
| 291 |
| 292 try: |
| 293 # Seeking from the end with an offset is invalid. |
| 294 daisy_chain_wrapper.seek(1, whence=os.SEEK_END) |
| 295 self.fail('Expected exception') |
| 296 except IOError, e: |
| 297 self.assertIn('Invalid seek during daisy chain', str(e)) |
OLD | NEW |