| OLD | NEW |
| 1 # Copyright 2016 The LUCI Authors. All rights reserved. | 1 # Copyright 2016 The LUCI Authors. All rights reserved. |
| 2 # Use of this source code is governed under the Apache License, Version 2.0 | 2 # Use of this source code is governed under the Apache License, Version 2.0 |
| 3 # that can be found in the LICENSE file. | 3 # that can be found in the LICENSE file. |
| 4 | 4 |
| 5 # This is a reimplementation of RemoteClientNative but it uses (will use) | 5 # This is a reimplementation of RemoteClientNative but it uses (will use) |
| 6 # a gRPC method to communicate with a server instead of REST. | 6 # a gRPC method to communicate with a server instead of REST. |
| 7 | 7 |
| 8 import json | 8 import json |
| 9 import logging | 9 import logging |
| 10 import math |
| 11 import random |
| 12 import time |
| 10 | 13 |
| 11 import grpc | 14 import grpc |
| 12 import google.protobuf.json_format | 15 import google.protobuf.json_format |
| 13 from proto_bot import swarming_bot_pb2 | 16 from proto_bot import swarming_bot_pb2 |
| 14 from remote_client_errors import InternalError | 17 from remote_client_errors import InternalError |
| 15 from remote_client_errors import PollError | 18 from remote_client_errors import PollError |
| 19 from utils import net |
| 16 | 20 |
| 17 | 21 |
| 18 # How long to wait for a response from the server. Keeping the same as | 22 # How long to wait for a response from the server. Keeping the same as |
| 19 # the equivalent in remote_client.py for now. | 23 # the equivalent in remote_client.py for now. |
| 20 NET_CONNECTION_TIMEOUT_SEC = 5*60 | 24 NET_CONNECTION_TIMEOUT_SEC = 5*60 |
| 21 | 25 |
| 22 | 26 |
| 27 # How many times to retry a gRPC call |
| 28 MAX_GRPC_ATTEMPTS = 30 |
| 29 |
| 30 |
| 31 # Longest time to sleep between gRPC calls |
| 32 MAX_GRPC_SLEEP = 10. |
| 33 |
| 34 |
| 23 class RemoteClientGrpc(object): | 35 class RemoteClientGrpc(object): |
| 24 """RemoteClientGrpc knows how to make calls via gRPC. | 36 """RemoteClientGrpc knows how to make calls via gRPC. |
| 25 """ | 37 """ |
| 26 | 38 |
| 27 def __init__(self, server): | 39 def __init__(self, server): |
| 28 logging.info('Communicating with host %s via gRPC', server) | 40 logging.info('Communicating with host %s via gRPC', server) |
| 29 self._server = server | 41 self._server = server |
| 30 self._channel = grpc.insecure_channel(server) | 42 self._channel = grpc.insecure_channel(server) |
| 31 self._stub = swarming_bot_pb2.BotServiceStub(self._channel) | 43 self._stub = swarming_bot_pb2.BotServiceStub(self._channel) |
| 32 self._log_is_asleep = False | 44 self._log_is_asleep = False |
| (...skipping 28 matching lines...) Expand all Loading... |
| 61 request.output_chunk.data = stdout_and_chunk[0] | 73 request.output_chunk.data = stdout_and_chunk[0] |
| 62 request.output_chunk.offset = stdout_and_chunk[1] | 74 request.output_chunk.offset = stdout_and_chunk[1] |
| 63 if exit_code != None: | 75 if exit_code != None: |
| 64 request.exit_status.code = exit_code | 76 request.exit_status.code = exit_code |
| 65 | 77 |
| 66 # Insert everything else. Note that the b64-encoded strings in the dict | 78 # Insert everything else. Note that the b64-encoded strings in the dict |
| 67 # are automatically decoded by ParseDict. | 79 # are automatically decoded by ParseDict. |
| 68 google.protobuf.json_format.ParseDict(params, request) | 80 google.protobuf.json_format.ParseDict(params, request) |
| 69 | 81 |
| 70 # Perform update | 82 # Perform update |
| 71 response = self._stub.TaskUpdate(request, | 83 response = call_grpc(self._stub.TaskUpdate, request) |
| 72 timeout=NET_CONNECTION_TIMEOUT_SEC) | |
| 73 logging.debug('post_task_update() = %s', request) | 84 logging.debug('post_task_update() = %s', request) |
| 74 if response.error: | 85 if response.error: |
| 75 raise InternalError(response.error) | 86 raise InternalError(response.error) |
| 76 return not response.must_stop | 87 return not response.must_stop |
| 77 | 88 |
| 78 def post_task_error(self, task_id, bot_id, message): | 89 def post_task_error(self, task_id, bot_id, message): |
| 79 request = swarming_bot_pb2.TaskErrorRequest() | 90 request = swarming_bot_pb2.TaskErrorRequest() |
| 80 request.bot_id = bot_id | 91 request.bot_id = bot_id |
| 81 request.task_id = task_id | 92 request.task_id = task_id |
| 82 request.msg = message | 93 request.msg = message |
| 83 logging.error('post_task_error() = %s', request) | 94 logging.error('post_task_error() = %s', request) |
| 84 | 95 |
| 85 response = self._stub.TaskError(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 96 response = call_grpc(self._stub.TaskError, request) |
| 86 return response.ok | 97 return response.ok |
| 87 | 98 |
| 88 def _attributes_json_to_proto(self, json_attr, msg): | 99 def _attributes_json_to_proto(self, json_attr, msg): |
| 89 msg.version = json_attr['version'] | 100 msg.version = json_attr['version'] |
| 90 for k, values in sorted(json_attr['dimensions'].iteritems()): | 101 for k, values in sorted(json_attr['dimensions'].iteritems()): |
| 91 pair = msg.dimensions.add() | 102 pair = msg.dimensions.add() |
| 92 pair.name = k | 103 pair.name = k |
| 93 pair.values.extend(values) | 104 pair.values.extend(values) |
| 94 create_state_proto(json_attr['state'], msg.state) | 105 create_state_proto(json_attr['state'], msg.state) |
| 95 | 106 |
| 96 def do_handshake(self, attributes): | 107 def do_handshake(self, attributes): |
| 97 request = swarming_bot_pb2.HandshakeRequest() | 108 request = swarming_bot_pb2.HandshakeRequest() |
| 98 self._attributes_json_to_proto(attributes, request.attributes) | 109 self._attributes_json_to_proto(attributes, request.attributes) |
| 99 response = self._stub.Handshake(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 110 response = call_grpc(self._stub.Handshake, request) |
| 100 resp = { | 111 resp = { |
| 101 'server_version': response.server_version, | 112 'server_version': response.server_version, |
| 102 'bot_version': response.bot_version, | 113 'bot_version': response.bot_version, |
| 103 'bot_group_cfg_version': response.bot_group_cfg_version, | 114 'bot_group_cfg_version': response.bot_group_cfg_version, |
| 104 'bot_group_cfg': { | 115 'bot_group_cfg': { |
| 105 'dimensions': { | 116 'dimensions': { |
| 106 d.name: d.values for d in response.bot_group_cfg.dimensions | 117 d.name: d.values for d in response.bot_group_cfg.dimensions |
| 107 }, | 118 }, |
| 108 }, | 119 }, |
| 109 } | 120 } |
| 110 logging.info('Completed handshake: %s', resp) | 121 logging.info('Completed handshake: %s', resp) |
| 111 return resp | 122 return resp |
| 112 | 123 |
| 113 def poll(self, attributes): | 124 def poll(self, attributes): |
| 114 request = swarming_bot_pb2.PollRequest() | 125 request = swarming_bot_pb2.PollRequest() |
| 115 self._attributes_json_to_proto(attributes, request.attributes) | 126 self._attributes_json_to_proto(attributes, request.attributes) |
| 116 # TODO(aludwin): gRPC-specific exception handling (raise PollError). | 127 # TODO(aludwin): gRPC-specific exception handling (raise PollError). |
| 117 response = self._stub.Poll(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 128 response = call_grpc(self._stub.Poll, request) |
| 118 | 129 |
| 119 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: | 130 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: |
| 120 return 'update', response.version | 131 return 'update', response.version |
| 121 | 132 |
| 122 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: | 133 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: |
| 123 if not self._log_is_asleep: | 134 if not self._log_is_asleep: |
| 124 logging.info('Going to sleep') | 135 logging.info('Going to sleep') |
| 125 self._log_is_asleep = True | 136 self._log_is_asleep = True |
| 126 return 'sleep', response.sleep_time | 137 return 'sleep', response.sleep_time |
| 127 | 138 |
| (...skipping 39 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 167 self._log_is_asleep = False | 178 self._log_is_asleep = False |
| 168 return 'run', manifest | 179 return 'run', manifest |
| 169 | 180 |
| 170 raise PollError('Unknown command in response: %s' % response) | 181 raise PollError('Unknown command in response: %s' % response) |
| 171 | 182 |
| 172 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): | 183 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): |
| 173 # TODO(aludwin): exception handling, pass bot_id | 184 # TODO(aludwin): exception handling, pass bot_id |
| 174 logging.info('Updating to version: %s', bot_version) | 185 logging.info('Updating to version: %s', bot_version) |
| 175 request = swarming_bot_pb2.BotUpdateRequest() | 186 request = swarming_bot_pb2.BotUpdateRequest() |
| 176 request.bot_version = bot_version | 187 request.bot_version = bot_version |
| 177 response = self._stub.BotUpdate(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 188 response = call_grpc(self._stub.BotUpdate, request) |
| 178 with open(new_zip_fn, 'wb') as f: | 189 with open(new_zip_fn, 'wb') as f: |
| 179 f.write(response.bot_code) | 190 f.write(response.bot_code) |
| 180 | 191 |
| 181 def ping(self): | 192 def ping(self): |
| 182 pass | 193 pass |
| 183 | 194 |
| 184 | 195 |
| 185 def create_state_proto(state_dict, message): | 196 def create_state_proto(state_dict, message): |
| 186 """ Constructs a State message out of a state dict. | 197 """ Constructs a State message out of a state dict. |
| 187 | 198 |
| (...skipping 27 matching lines...) Expand all Loading... |
| 215 def insert_dict_as_submessage(message, keyname, value): | 226 def insert_dict_as_submessage(message, keyname, value): |
| 216 """Encodes a dict as a Protobuf message. | 227 """Encodes a dict as a Protobuf message. |
| 217 | 228 |
| 218 The keyname for the message field is passed in to simplify the creation | 229 The keyname for the message field is passed in to simplify the creation |
| 219 of the submessage in the first place - you need to say getattr and not | 230 of the submessage in the first place - you need to say getattr and not |
| 220 simply refer to message.keyname since the former actually creates the | 231 simply refer to message.keyname since the former actually creates the |
| 221 submessage while the latter does not. | 232 submessage while the latter does not. |
| 222 """ | 233 """ |
| 223 sub_msg = getattr(message, keyname) | 234 sub_msg = getattr(message, keyname) |
| 224 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) | 235 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) |
| 236 |
| 237 |
| 238 def call_grpc(method, request): |
| 239 """Retries a command a set number of times""" |
| 240 for attempt in range(1, MAX_GRPC_ATTEMPTS+1): |
| 241 try: |
| 242 return method(request, timeout=NET_CONNECTION_TIMEOUT_SEC) |
| 243 except grpc.RpcError as g: |
| 244 if g.code() is not grpc.StatusCode.UNAVAILABLE: |
| 245 raise |
| 246 logging.warning('call_grpc - proxy is unavailable (attempt %d/%d)', |
| 247 attempt, MAX_GRPC_ATTEMPTS) |
| 248 grpc_error = g |
| 249 time.sleep(net.calculate_sleep_before_retry(attempt, MAX_GRPC_SLEEP)) |
| 250 # If we get here, it must be because we got (and saved) an error |
| 251 assert grpc_error is not None |
| 252 raise grpc_error |
| OLD | NEW |