| OLD | NEW |
| 1 # Copyright 2016 The LUCI Authors. All rights reserved. | 1 # Copyright 2016 The LUCI Authors. All rights reserved. |
| 2 # Use of this source code is governed under the Apache License, Version 2.0 | 2 # Use of this source code is governed under the Apache License, Version 2.0 |
| 3 # that can be found in the LICENSE file. | 3 # that can be found in the LICENSE file. |
| 4 | 4 |
| 5 # This is a reimplementation of RemoteClientNative but it uses (will use) | 5 # This is a reimplementation of RemoteClientNative but it uses (will use) |
| 6 # a gRPC method to communicate with a server instead of REST. | 6 # a gRPC method to communicate with a server instead of REST. |
| 7 | 7 |
| 8 import copy | 8 import copy |
| 9 import json | 9 import json |
| 10 import logging | 10 import logging |
| 11 import time | 11 import time |
| 12 | 12 |
| 13 import grpc | |
| 14 import google.protobuf.json_format | 13 import google.protobuf.json_format |
| 15 from proto_bot import swarming_bot_pb2 | 14 from proto_bot import swarming_bot_pb2 |
| 16 from remote_client_errors import InternalError | 15 from remote_client_errors import InternalError |
| 17 from remote_client_errors import MintOAuthTokenError | 16 from remote_client_errors import MintOAuthTokenError |
| 18 from remote_client_errors import PollError | 17 from remote_client_errors import PollError |
| 19 from utils import net | 18 from utils import net |
| 20 | 19 from utils import grpc_proxy |
| 21 | |
| 22 # How long to wait for a response from the server. Keeping the same as | |
| 23 # the equivalent in remote_client.py for now. | |
| 24 NET_CONNECTION_TIMEOUT_SEC = 5*60 | |
| 25 | |
| 26 | |
| 27 # How many times to retry a gRPC call | |
| 28 MAX_GRPC_ATTEMPTS = 30 | |
| 29 | |
| 30 | |
| 31 # Longest time to sleep between gRPC calls | |
| 32 MAX_GRPC_SLEEP = 10. | |
| 33 | 20 |
| 34 | 21 |
| 35 class RemoteClientGrpc(object): | 22 class RemoteClientGrpc(object): |
| 36 """RemoteClientGrpc knows how to make calls via gRPC. | 23 """RemoteClientGrpc knows how to make calls via gRPC. |
| 37 | 24 |
| 38 Any non-scalar value that is returned that references values from the proto | 25 Any non-scalar value that is returned that references values from the proto |
| 39 messages should be deepcopy'd since protos make use of weak references and | 26 messages should be deepcopy'd since protos make use of weak references and |
| 40 might be garbage-collected before the values are used. | 27 might be garbage-collected before the values are used. |
| 41 """ | 28 """ |
| 42 | 29 |
| 43 def __init__(self, server): | 30 def __init__(self, server, fake_proxy=None): |
| 44 logging.info('Communicating with host %s via gRPC', server) | 31 logging.info('Communicating with host %s via gRPC', server) |
| 32 if fake_proxy: |
| 33 self._proxy = fake_proxy |
| 34 else: |
| 35 self._proxy = grpc_proxy.Proxy(server, swarming_bot_pb2.BotServiceStub) |
| 45 self._server = server | 36 self._server = server |
| 46 self._channel = grpc.insecure_channel(server) | |
| 47 self._stub = swarming_bot_pb2.BotServiceStub(self._channel) | |
| 48 self._log_is_asleep = False | 37 self._log_is_asleep = False |
| 49 | 38 |
| 50 def is_grpc(self): | 39 def is_grpc(self): |
| 51 return True | 40 return True |
| 52 | 41 |
| 53 def initialize(self, quit_bit=None): | 42 def initialize(self, quit_bit=None): |
| 54 pass | 43 pass |
| 55 | 44 |
| 56 @property | 45 @property |
| 57 def uses_auth(self): | 46 def uses_auth(self): |
| (...skipping 19 matching lines...) Expand all Loading... |
| 77 request.output_chunk.data = stdout_and_chunk[0] | 66 request.output_chunk.data = stdout_and_chunk[0] |
| 78 request.output_chunk.offset = stdout_and_chunk[1] | 67 request.output_chunk.offset = stdout_and_chunk[1] |
| 79 if exit_code != None: | 68 if exit_code != None: |
| 80 request.exit_status.code = exit_code | 69 request.exit_status.code = exit_code |
| 81 | 70 |
| 82 # Insert everything else. Note that the b64-encoded strings in the dict | 71 # Insert everything else. Note that the b64-encoded strings in the dict |
| 83 # are automatically decoded by ParseDict. | 72 # are automatically decoded by ParseDict. |
| 84 google.protobuf.json_format.ParseDict(params, request) | 73 google.protobuf.json_format.ParseDict(params, request) |
| 85 | 74 |
| 86 # Perform update | 75 # Perform update |
| 87 response = call_grpc(self._stub.TaskUpdate, request) | 76 response = self._proxy.call_unary('TaskUpdate', request) |
| 88 logging.debug('post_task_update() = %s', request) | 77 logging.debug('post_task_update() = %s', request) |
| 89 if response.error: | 78 if response.error: |
| 90 raise InternalError(response.error) | 79 raise InternalError(response.error) |
| 91 return not response.must_stop | 80 return not response.must_stop |
| 92 | 81 |
| 93 def post_task_error(self, task_id, bot_id, message): | 82 def post_task_error(self, task_id, bot_id, message): |
| 94 request = swarming_bot_pb2.TaskErrorRequest() | 83 request = swarming_bot_pb2.TaskErrorRequest() |
| 95 request.bot_id = bot_id | 84 request.bot_id = bot_id |
| 96 request.task_id = task_id | 85 request.task_id = task_id |
| 97 request.msg = message | 86 request.msg = message |
| 98 logging.error('post_task_error() = %s', request) | 87 logging.error('post_task_error() = %s', request) |
| 99 | 88 |
| 100 response = call_grpc(self._stub.TaskError, request) | 89 response = self._proxy.call_unary('TaskError', request) |
| 101 return response.ok | 90 return response.ok |
| 102 | 91 |
| 103 def _attributes_json_to_proto(self, json_attr, msg): | 92 def _attributes_json_to_proto(self, json_attr, msg): |
| 104 msg.version = json_attr['version'] | 93 msg.version = json_attr['version'] |
| 105 for k, values in sorted(json_attr['dimensions'].iteritems()): | 94 for k, values in sorted(json_attr['dimensions'].iteritems()): |
| 106 pair = msg.dimensions.add() | 95 pair = msg.dimensions.add() |
| 107 pair.name = k | 96 pair.name = k |
| 108 pair.values.extend(values) | 97 pair.values.extend(values) |
| 109 create_state_proto(json_attr['state'], msg.state) | 98 create_state_proto(json_attr['state'], msg.state) |
| 110 | 99 |
| 111 def do_handshake(self, attributes): | 100 def do_handshake(self, attributes): |
| 112 request = swarming_bot_pb2.HandshakeRequest() | 101 request = swarming_bot_pb2.HandshakeRequest() |
| 113 self._attributes_json_to_proto(attributes, request.attributes) | 102 self._attributes_json_to_proto(attributes, request.attributes) |
| 114 response = call_grpc(self._stub.Handshake, request) | 103 response = self._proxy.call_unary('Handshake', request) |
| 115 resp = { | 104 resp = { |
| 116 'server_version': response.server_version, | 105 'server_version': response.server_version, |
| 117 'bot_version': response.bot_version, | 106 'bot_version': response.bot_version, |
| 118 'bot_group_cfg_version': response.bot_group_cfg_version, | 107 'bot_group_cfg_version': response.bot_group_cfg_version, |
| 119 'bot_group_cfg': { | 108 'bot_group_cfg': { |
| 120 'dimensions': { | 109 'dimensions': { |
| 121 d.name: d.values for d in response.bot_group_cfg.dimensions | 110 d.name: d.values for d in response.bot_group_cfg.dimensions |
| 122 }, | 111 }, |
| 123 }, | 112 }, |
| 124 } | 113 } |
| 125 logging.info('Completed handshake: %s', resp) | 114 logging.info('Completed handshake: %s', resp) |
| 126 return copy.deepcopy(resp) | 115 return copy.deepcopy(resp) |
| 127 | 116 |
| 128 def poll(self, attributes): | 117 def poll(self, attributes): |
| 129 request = swarming_bot_pb2.PollRequest() | 118 request = swarming_bot_pb2.PollRequest() |
| 130 self._attributes_json_to_proto(attributes, request.attributes) | 119 self._attributes_json_to_proto(attributes, request.attributes) |
| 131 # TODO(aludwin): gRPC-specific exception handling (raise PollError). | 120 # TODO(aludwin): gRPC-specific exception handling (raise PollError). |
| 132 response = call_grpc(self._stub.Poll, request) | 121 response = self._proxy.call_unary('Poll', request) |
| 133 | 122 |
| 134 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: | 123 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: |
| 135 return 'update', response.version | 124 return 'update', response.version |
| 136 | 125 |
| 137 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: | 126 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: |
| 138 if not self._log_is_asleep: | 127 if not self._log_is_asleep: |
| 139 logging.info('Going to sleep') | 128 logging.info('Going to sleep') |
| 140 self._log_is_asleep = True | 129 self._log_is_asleep = True |
| 141 return 'sleep', response.sleep_time | 130 return 'sleep', response.sleep_time |
| 142 | 131 |
| (...skipping 39 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 182 self._log_is_asleep = False | 171 self._log_is_asleep = False |
| 183 return 'run', copy.deepcopy(manifest) | 172 return 'run', copy.deepcopy(manifest) |
| 184 | 173 |
| 185 raise PollError('Unknown command in response: %s' % response) | 174 raise PollError('Unknown command in response: %s' % response) |
| 186 | 175 |
| 187 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): | 176 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): |
| 188 # TODO(aludwin): exception handling, pass bot_id | 177 # TODO(aludwin): exception handling, pass bot_id |
| 189 logging.info('Updating to version: %s', bot_version) | 178 logging.info('Updating to version: %s', bot_version) |
| 190 request = swarming_bot_pb2.BotUpdateRequest() | 179 request = swarming_bot_pb2.BotUpdateRequest() |
| 191 request.bot_version = bot_version | 180 request.bot_version = bot_version |
| 192 response = call_grpc(self._stub.BotUpdate, request) | 181 response = self._proxy.call_unary('BotUpdate', request) |
| 193 with open(new_zip_fn, 'wb') as f: | 182 with open(new_zip_fn, 'wb') as f: |
| 194 f.write(response.bot_code) | 183 f.write(response.bot_code) |
| 195 | 184 |
| 196 def ping(self): | 185 def ping(self): |
| 197 pass | 186 pass |
| 198 | 187 |
| 199 def mint_oauth_token(self, task_id, bot_id, account_id, scopes): | 188 def mint_oauth_token(self, task_id, bot_id, account_id, scopes): |
| 200 # pylint: disable=unused-argument | 189 # pylint: disable=unused-argument |
| 201 raise MintOAuthTokenError( | 190 raise MintOAuthTokenError( |
| 202 'mint_oauth_token is not supported in grpc protocol') | 191 'mint_oauth_token is not supported in grpc protocol') |
| (...skipping 34 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 237 | 226 |
| 238 The keyname for the message field is passed in to simplify the creation | 227 The keyname for the message field is passed in to simplify the creation |
| 239 of the submessage in the first place - you need to say getattr and not | 228 of the submessage in the first place - you need to say getattr and not |
| 240 simply refer to message.keyname since the former actually creates the | 229 simply refer to message.keyname since the former actually creates the |
| 241 submessage while the latter does not. | 230 submessage while the latter does not. |
| 242 """ | 231 """ |
| 243 sub_msg = getattr(message, keyname) | 232 sub_msg = getattr(message, keyname) |
| 244 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) | 233 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) |
| 245 | 234 |
| 246 | 235 |
| 247 def call_grpc(method, request): | |
| 248 """Retries a command a set number of times""" | |
| 249 for attempt in range(1, MAX_GRPC_ATTEMPTS+1): | |
| 250 try: | |
| 251 return method(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | |
| 252 except grpc.RpcError as g: | |
| 253 if g.code() is not grpc.StatusCode.UNAVAILABLE: | |
| 254 raise | |
| 255 logging.warning('call_grpc - proxy is unavailable (attempt %d/%d)', | |
| 256 attempt, MAX_GRPC_ATTEMPTS) | |
| 257 grpc_error = g | |
| 258 time.sleep(net.calculate_sleep_before_retry(attempt, MAX_GRPC_SLEEP)) | |
| 259 # If we get here, it must be because we got (and saved) an error | |
| 260 assert grpc_error is not None | |
| 261 raise grpc_error | |
| OLD | NEW |