OLD | NEW |
| (Empty) |
1 # -*- coding: utf-8 -*- | |
2 # Copyright 2015 Google Inc. All Rights Reserved. | |
3 # | |
4 # Licensed under the Apache License, Version 2.0 (the "License"); | |
5 # you may not use this file except in compliance with the License. | |
6 # You may obtain a copy of the License at | |
7 # | |
8 # http://www.apache.org/licenses/LICENSE-2.0 | |
9 # | |
10 # Unless required by applicable law or agreed to in writing, software | |
11 # distributed under the License is distributed on an "AS IS" BASIS, | |
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
13 # See the License for the specific language governing permissions and | |
14 # limitations under the License. | |
15 """Unit tests for daisy chain wrapper class.""" | |
16 | |
17 from __future__ import absolute_import | |
18 | |
19 import os | |
20 import pkgutil | |
21 | |
22 import gslib.cloud_api | |
23 from gslib.daisy_chain_wrapper import DaisyChainWrapper | |
24 from gslib.storage_url import StorageUrlFromString | |
25 import gslib.tests.testcase as testcase | |
26 from gslib.util import TRANSFER_BUFFER_SIZE | |
27 | |
28 | |
29 _TEST_FILE = 'test.txt' | |
30 | |
31 | |
32 class TestDaisyChainWrapper(testcase.GsUtilUnitTestCase): | |
33 """Unit tests for the DaisyChainWrapper class.""" | |
34 | |
35 _temp_test_file = None | |
36 _dummy_url = StorageUrlFromString('gs://bucket/object') | |
37 | |
38 def setUp(self): | |
39 super(TestDaisyChainWrapper, self).setUp() | |
40 self.test_data_file = self._GetTestFile() | |
41 self.test_data_file_len = os.path.getsize(self.test_data_file) | |
42 | |
43 def _GetTestFile(self): | |
44 contents = pkgutil.get_data('gslib', 'tests/test_data/%s' % _TEST_FILE) | |
45 if not self._temp_test_file: | |
46 # Write to a temp file because pkgutil doesn't expose a stream interface. | |
47 self._temp_test_file = self.CreateTempFile( | |
48 file_name=_TEST_FILE, contents=contents) | |
49 return self._temp_test_file | |
50 | |
51 class MockDownloadCloudApi(gslib.cloud_api.CloudApi): | |
52 """Mock CloudApi that implements GetObjectMedia for testing.""" | |
53 | |
54 def __init__(self, write_values): | |
55 """Initialize the mock that will be used by the download thread. | |
56 | |
57 Args: | |
58 write_values: List of values that will be used for calls to write(), | |
59 in order, by the download thread. An Exception class may be part of | |
60 the list; if so, the Exception will be raised after previous | |
61 values are consumed. | |
62 """ | |
63 self._write_values = write_values | |
64 self.get_calls = 0 | |
65 | |
66 def GetObjectMedia(self, unused_bucket_name, unused_object_name, | |
67 download_stream, start_byte=0, end_byte=None, | |
68 **kwargs): | |
69 """Writes self._write_values to the download_stream.""" | |
70 # Writes from start_byte up to, but not including end_byte (if not None). | |
71 # Does not slice values; | |
72 # self._write_values must line up with start/end_byte. | |
73 self.get_calls += 1 | |
74 bytes_read = 0 | |
75 for write_value in self._write_values: | |
76 if bytes_read < start_byte: | |
77 bytes_read += len(write_value) | |
78 continue | |
79 if end_byte and bytes_read >= end_byte: | |
80 break | |
81 if isinstance(write_value, Exception): | |
82 raise write_value | |
83 download_stream.write(write_value) | |
84 bytes_read += len(write_value) | |
85 | |
86 def _WriteFromWrapperToFile(self, daisy_chain_wrapper, file_path): | |
87 """Writes all contents from the DaisyChainWrapper to the named file.""" | |
88 with open(file_path, 'wb') as upload_stream: | |
89 while True: | |
90 data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) | |
91 if not data: | |
92 break | |
93 upload_stream.write(data) | |
94 | |
95 def testDownloadSingleChunk(self): | |
96 """Tests a single call to GetObjectMedia.""" | |
97 write_values = [] | |
98 with open(self.test_data_file, 'rb') as stream: | |
99 while True: | |
100 data = stream.read(TRANSFER_BUFFER_SIZE) | |
101 if not data: | |
102 break | |
103 write_values.append(data) | |
104 upload_file = self.CreateTempFile() | |
105 # Test for a single call even if the chunk size is larger than the data. | |
106 for chunk_size in (self.test_data_file_len, self.test_data_file_len + 1): | |
107 mock_api = self.MockDownloadCloudApi(write_values) | |
108 daisy_chain_wrapper = DaisyChainWrapper( | |
109 self._dummy_url, self.test_data_file_len, mock_api, | |
110 download_chunk_size=chunk_size) | |
111 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) | |
112 # Since the chunk size is >= the file size, only a single GetObjectMedia | |
113 # call should be made. | |
114 self.assertEquals(mock_api.get_calls, 1) | |
115 with open(upload_file, 'rb') as upload_stream: | |
116 with open(self.test_data_file, 'rb') as download_stream: | |
117 self.assertEqual(upload_stream.read(), download_stream.read()) | |
118 | |
119 def testDownloadMultiChunk(self): | |
120 """Tests multiple calls to GetObjectMedia.""" | |
121 upload_file = self.CreateTempFile() | |
122 write_values = [] | |
123 with open(self.test_data_file, 'rb') as stream: | |
124 while True: | |
125 data = stream.read(TRANSFER_BUFFER_SIZE) | |
126 if not data: | |
127 break | |
128 write_values.append(data) | |
129 mock_api = self.MockDownloadCloudApi(write_values) | |
130 daisy_chain_wrapper = DaisyChainWrapper( | |
131 self._dummy_url, self.test_data_file_len, mock_api, | |
132 download_chunk_size=TRANSFER_BUFFER_SIZE) | |
133 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) | |
134 num_expected_calls = self.test_data_file_len / TRANSFER_BUFFER_SIZE | |
135 if self.test_data_file_len % TRANSFER_BUFFER_SIZE: | |
136 num_expected_calls += 1 | |
137 # Since the chunk size is < the file size, multiple calls to GetObjectMedia | |
138 # should be made. | |
139 self.assertEqual(mock_api.get_calls, num_expected_calls) | |
140 with open(upload_file, 'rb') as upload_stream: | |
141 with open(self.test_data_file, 'rb') as download_stream: | |
142 self.assertEqual(upload_stream.read(), download_stream.read()) | |
143 | |
144 def testDownloadWithZeroWrites(self): | |
145 """Tests 0-byte writes to the download stream from GetObjectMedia.""" | |
146 write_values = [] | |
147 with open(self.test_data_file, 'rb') as stream: | |
148 while True: | |
149 write_values.append(b'') | |
150 data = stream.read(TRANSFER_BUFFER_SIZE) | |
151 write_values.append(b'') | |
152 if not data: | |
153 break | |
154 write_values.append(data) | |
155 upload_file = self.CreateTempFile() | |
156 mock_api = self.MockDownloadCloudApi(write_values) | |
157 daisy_chain_wrapper = DaisyChainWrapper( | |
158 self._dummy_url, self.test_data_file_len, mock_api, | |
159 download_chunk_size=self.test_data_file_len) | |
160 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) | |
161 self.assertEquals(mock_api.get_calls, 1) | |
162 with open(upload_file, 'rb') as upload_stream: | |
163 with open(self.test_data_file, 'rb') as download_stream: | |
164 self.assertEqual(upload_stream.read(), download_stream.read()) | |
165 | |
166 def testDownloadWithPartialWrite(self): | |
167 """Tests unaligned writes to the download stream from GetObjectMedia.""" | |
168 with open(self.test_data_file, 'rb') as stream: | |
169 chunk = stream.read(TRANSFER_BUFFER_SIZE) | |
170 one_byte = chunk[0] | |
171 chunk_minus_one_byte = chunk[1:TRANSFER_BUFFER_SIZE] | |
172 half_chunk = chunk[0:TRANSFER_BUFFER_SIZE/2] | |
173 | |
174 write_values_dict = { | |
175 'First byte first chunk unaligned': | |
176 (one_byte, chunk_minus_one_byte, chunk, chunk), | |
177 'Last byte first chunk unaligned': | |
178 (chunk_minus_one_byte, chunk, chunk), | |
179 'First byte second chunk unaligned': | |
180 (chunk, one_byte, chunk_minus_one_byte, chunk), | |
181 'Last byte second chunk unaligned': | |
182 (chunk, chunk_minus_one_byte, one_byte, chunk), | |
183 'First byte final chunk unaligned': | |
184 (chunk, chunk, one_byte, chunk_minus_one_byte), | |
185 'Last byte final chunk unaligned': | |
186 (chunk, chunk, chunk_minus_one_byte, one_byte), | |
187 'Half chunks': | |
188 (half_chunk, half_chunk, half_chunk), | |
189 'Many unaligned': | |
190 (one_byte, half_chunk, one_byte, half_chunk, chunk, | |
191 chunk_minus_one_byte, chunk, one_byte, half_chunk, one_byte) | |
192 } | |
193 upload_file = self.CreateTempFile() | |
194 for case_name, write_values in write_values_dict.iteritems(): | |
195 expected_contents = b'' | |
196 for write_value in write_values: | |
197 expected_contents += write_value | |
198 mock_api = self.MockDownloadCloudApi(write_values) | |
199 daisy_chain_wrapper = DaisyChainWrapper( | |
200 self._dummy_url, len(expected_contents), mock_api, | |
201 download_chunk_size=self.test_data_file_len) | |
202 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) | |
203 with open(upload_file, 'rb') as upload_stream: | |
204 self.assertEqual(upload_stream.read(), expected_contents, | |
205 'Uploaded file contents for case %s did not match' | |
206 % case_name) | |
207 | |
208 def testSeekAndReturn(self): | |
209 """Tests seeking to the end of the wrapper (simulates getting size).""" | |
210 write_values = [] | |
211 with open(self.test_data_file, 'rb') as stream: | |
212 while True: | |
213 data = stream.read(TRANSFER_BUFFER_SIZE) | |
214 if not data: | |
215 break | |
216 write_values.append(data) | |
217 upload_file = self.CreateTempFile() | |
218 mock_api = self.MockDownloadCloudApi(write_values) | |
219 daisy_chain_wrapper = DaisyChainWrapper( | |
220 self._dummy_url, self.test_data_file_len, mock_api, | |
221 download_chunk_size=self.test_data_file_len) | |
222 with open(upload_file, 'wb') as upload_stream: | |
223 current_position = 0 | |
224 daisy_chain_wrapper.seek(0, whence=os.SEEK_END) | |
225 daisy_chain_wrapper.seek(current_position) | |
226 while True: | |
227 data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) | |
228 current_position += len(data) | |
229 daisy_chain_wrapper.seek(0, whence=os.SEEK_END) | |
230 daisy_chain_wrapper.seek(current_position) | |
231 if not data: | |
232 break | |
233 upload_stream.write(data) | |
234 self.assertEquals(mock_api.get_calls, 1) | |
235 with open(upload_file, 'rb') as upload_stream: | |
236 with open(self.test_data_file, 'rb') as download_stream: | |
237 self.assertEqual(upload_stream.read(), download_stream.read()) | |
238 | |
239 def testRestartDownloadThread(self): | |
240 """Tests seek to non-stored position; this restarts the download thread.""" | |
241 write_values = [] | |
242 with open(self.test_data_file, 'rb') as stream: | |
243 while True: | |
244 data = stream.read(TRANSFER_BUFFER_SIZE) | |
245 if not data: | |
246 break | |
247 write_values.append(data) | |
248 upload_file = self.CreateTempFile() | |
249 mock_api = self.MockDownloadCloudApi(write_values) | |
250 daisy_chain_wrapper = DaisyChainWrapper( | |
251 self._dummy_url, self.test_data_file_len, mock_api, | |
252 download_chunk_size=self.test_data_file_len) | |
253 daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) | |
254 daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) | |
255 daisy_chain_wrapper.seek(0) | |
256 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) | |
257 self.assertEquals(mock_api.get_calls, 2) | |
258 with open(upload_file, 'rb') as upload_stream: | |
259 with open(self.test_data_file, 'rb') as download_stream: | |
260 self.assertEqual(upload_stream.read(), download_stream.read()) | |
261 | |
262 def testDownloadThreadException(self): | |
263 """Tests that an exception is propagated via the upload thread.""" | |
264 | |
265 class DownloadException(Exception): | |
266 pass | |
267 | |
268 write_values = [b'a', b'b', | |
269 DownloadException('Download thread forces failure')] | |
270 upload_file = self.CreateTempFile() | |
271 mock_api = self.MockDownloadCloudApi(write_values) | |
272 daisy_chain_wrapper = DaisyChainWrapper( | |
273 self._dummy_url, self.test_data_file_len, mock_api, | |
274 download_chunk_size=self.test_data_file_len) | |
275 try: | |
276 self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) | |
277 self.fail('Expected exception') | |
278 except DownloadException, e: | |
279 self.assertIn('Download thread forces failure', str(e)) | |
280 | |
281 def testInvalidSeek(self): | |
282 """Tests that seeking fails for unsupported seek arguments.""" | |
283 daisy_chain_wrapper = DaisyChainWrapper( | |
284 self._dummy_url, self.test_data_file_len, self.MockDownloadCloudApi([])) | |
285 try: | |
286 # SEEK_CUR is invalid. | |
287 daisy_chain_wrapper.seek(0, whence=os.SEEK_CUR) | |
288 self.fail('Expected exception') | |
289 except IOError, e: | |
290 self.assertIn('does not support seek mode', str(e)) | |
291 | |
292 try: | |
293 # Seeking from the end with an offset is invalid. | |
294 daisy_chain_wrapper.seek(1, whence=os.SEEK_END) | |
295 self.fail('Expected exception') | |
296 except IOError, e: | |
297 self.assertIn('Invalid seek during daisy chain', str(e)) | |
OLD | NEW |