Chromium Code Reviews| OLD | NEW |
|---|---|
| 1 # Copyright 2016 The LUCI Authors. All rights reserved. | 1 # Copyright 2016 The LUCI Authors. All rights reserved. |
| 2 # Use of this source code is governed under the Apache License, Version 2.0 | 2 # Use of this source code is governed under the Apache License, Version 2.0 |
| 3 # that can be found in the LICENSE file. | 3 # that can be found in the LICENSE file. |
| 4 | 4 |
| 5 # This is a reimplementation of RemoteClientNative but it uses (will use) | 5 # This is a reimplementation of RemoteClientNative but it uses (will use) |
| 6 # a gRPC method to communicate with a server instead of REST. | 6 # a gRPC method to communicate with a server instead of REST. |
| 7 | 7 |
| 8 import json | 8 import json |
| 9 import logging | 9 import logging |
| 10 import time | |
| 10 | 11 |
| 11 import grpc | 12 import grpc |
| 12 import google.protobuf.json_format | 13 import google.protobuf.json_format |
| 13 from proto_bot import swarming_bot_pb2 | 14 from proto_bot import swarming_bot_pb2 |
| 14 from remote_client_errors import InternalError | 15 from remote_client_errors import InternalError |
| 15 | 16 |
| 16 | 17 |
| 17 # How long to wait for a response from the server. Keeping the same as | 18 # How long to wait for a response from the server. Keeping the same as |
| 18 # the equivalent in remote_client.py for now. | 19 # the equivalent in remote_client.py for now. |
| 19 NET_CONNECTION_TIMEOUT_SEC = 5*60 | 20 NET_CONNECTION_TIMEOUT_SEC = 5*60 |
| 20 | 21 |
| 21 | 22 |
| 23 # How many times to retry a gRPC call | |
| 24 MAX_GRPC_ATTEMPTS = 30 | |
| 25 | |
| 26 | |
| 22 class RemoteClientGrpc(object): | 27 class RemoteClientGrpc(object): |
| 23 """RemoteClientGrpc knows how to make calls via gRPC. | 28 """RemoteClientGrpc knows how to make calls via gRPC. |
| 24 """ | 29 """ |
| 25 | 30 |
| 26 def __init__(self, server): | 31 def __init__(self, server): |
| 27 logging.info('Communicating with host %s via gRPC', server) | 32 logging.info('Communicating with host %s via gRPC', server) |
| 28 self._server = server | 33 self._server = server |
| 29 self._channel = grpc.insecure_channel(server) | 34 self._channel = grpc.insecure_channel(server) |
| 30 self._stub = swarming_bot_pb2.BotServiceStub(self._channel) | 35 self._stub = swarming_bot_pb2.BotServiceStub(self._channel) |
| 31 self._log_is_asleep = False | 36 self._log_is_asleep = False |
| (...skipping 28 matching lines...) Expand all Loading... | |
| 60 request.output_chunk.data = stdout_and_chunk[0] | 65 request.output_chunk.data = stdout_and_chunk[0] |
| 61 request.output_chunk.offset = stdout_and_chunk[1] | 66 request.output_chunk.offset = stdout_and_chunk[1] |
| 62 if exit_code != None: | 67 if exit_code != None: |
| 63 request.exit_status.code = exit_code | 68 request.exit_status.code = exit_code |
| 64 | 69 |
| 65 # Insert everything else. Note that the b64-encoded strings in the dict | 70 # Insert everything else. Note that the b64-encoded strings in the dict |
| 66 # are automatically decoded by ParseDict. | 71 # are automatically decoded by ParseDict. |
| 67 google.protobuf.json_format.ParseDict(params, request) | 72 google.protobuf.json_format.ParseDict(params, request) |
| 68 | 73 |
| 69 # Perform update | 74 # Perform update |
| 70 response = self._stub.TaskUpdate(request, | 75 response = call_grpc(self._stub.TaskUpdate, request) |
| 71 timeout=NET_CONNECTION_TIMEOUT_SEC) | |
| 72 logging.debug('post_task_update() = %s', request) | 76 logging.debug('post_task_update() = %s', request) |
| 73 if response.error: | 77 if response.error: |
| 74 raise InternalError(response.error) | 78 raise InternalError(response.error) |
| 75 return not response.must_stop | 79 return not response.must_stop |
| 76 | 80 |
| 77 def post_task_error(self, task_id, bot_id, message): | 81 def post_task_error(self, task_id, bot_id, message): |
| 78 request = swarming_bot_pb2.TaskErrorRequest() | 82 request = swarming_bot_pb2.TaskErrorRequest() |
| 79 request.bot_id = bot_id | 83 request.bot_id = bot_id |
| 80 request.task_id = task_id | 84 request.task_id = task_id |
| 81 request.msg = message | 85 request.msg = message |
| 82 logging.error('post_task_error() = %s', request) | 86 logging.error('post_task_error() = %s', request) |
| 83 | 87 |
| 84 response = self._stub.TaskError(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 88 response = call_grpc(self._stub.TaskError, request) |
| 85 return response.ok | 89 return response.ok |
| 86 | 90 |
| 87 def _attributes_json_to_proto(self, json_attr, msg): | 91 def _attributes_json_to_proto(self, json_attr, msg): |
| 88 msg.version = json_attr['version'] | 92 msg.version = json_attr['version'] |
| 89 for k, values in sorted(json_attr['dimensions'].iteritems()): | 93 for k, values in sorted(json_attr['dimensions'].iteritems()): |
| 90 pair = msg.dimensions.add() | 94 pair = msg.dimensions.add() |
| 91 pair.name = k | 95 pair.name = k |
| 92 pair.values.extend(values) | 96 pair.values.extend(values) |
| 93 create_state_proto(json_attr['state'], msg.state) | 97 create_state_proto(json_attr['state'], msg.state) |
| 94 | 98 |
| 95 def do_handshake(self, attributes): | 99 def do_handshake(self, attributes): |
| 96 request = swarming_bot_pb2.HandshakeRequest() | 100 request = swarming_bot_pb2.HandshakeRequest() |
| 97 self._attributes_json_to_proto(attributes, request.attributes) | 101 self._attributes_json_to_proto(attributes, request.attributes) |
| 98 response = self._stub.Handshake(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 102 response = call_grpc(self._stub.Handshake, request) |
| 99 resp = { | 103 resp = { |
| 100 'server_version': response.server_version, | 104 'server_version': response.server_version, |
| 101 'bot_version': response.bot_version, | 105 'bot_version': response.bot_version, |
| 102 'bot_group_cfg_version': response.bot_group_cfg_version, | 106 'bot_group_cfg_version': response.bot_group_cfg_version, |
| 103 'bot_group_cfg': { | 107 'bot_group_cfg': { |
| 104 'dimensions': { | 108 'dimensions': { |
| 105 d.name: d.values for d in response.bot_group_cfg.dimensions | 109 d.name: d.values for d in response.bot_group_cfg.dimensions |
| 106 }, | 110 }, |
| 107 }, | 111 }, |
| 108 } | 112 } |
| 109 logging.info('Completed handshake: %s', resp) | 113 logging.info('Completed handshake: %s', resp) |
| 110 return resp | 114 return resp |
| 111 | 115 |
| 112 def poll(self, attributes): | 116 def poll(self, attributes): |
| 113 request = swarming_bot_pb2.PollRequest() | 117 request = swarming_bot_pb2.PollRequest() |
| 114 self._attributes_json_to_proto(attributes, request.attributes) | 118 self._attributes_json_to_proto(attributes, request.attributes) |
| 115 # TODO(aludwin): gRPC-specific exception handling | 119 # TODO(aludwin): gRPC-specific exception handling |
| 116 response = self._stub.Poll(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 120 response = call_grpc(self._stub.Poll, request) |
| 117 | 121 |
| 118 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: | 122 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: |
| 119 return 'update', response.version | 123 return 'update', response.version |
| 120 | 124 |
| 121 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: | 125 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: |
| 122 if not self._log_is_asleep: | 126 if not self._log_is_asleep: |
| 123 logging.info('Going to sleep') | 127 logging.info('Going to sleep') |
| 124 self._log_is_asleep = True | 128 self._log_is_asleep = True |
| 125 return 'sleep', response.sleep_time | 129 return 'sleep', response.sleep_time |
| 126 | 130 |
| (...skipping 39 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
| 166 self._log_is_asleep = False | 170 self._log_is_asleep = False |
| 167 return 'run', manifest | 171 return 'run', manifest |
| 168 | 172 |
| 169 raise ValueError('Unknown command in response: %s' % response) | 173 raise ValueError('Unknown command in response: %s' % response) |
| 170 | 174 |
| 171 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): | 175 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): |
| 172 # TODO(aludwin): exception handling, pass bot_id | 176 # TODO(aludwin): exception handling, pass bot_id |
| 173 logging.info('Updating to version: %s', bot_version) | 177 logging.info('Updating to version: %s', bot_version) |
| 174 request = swarming_bot_pb2.BotUpdateRequest() | 178 request = swarming_bot_pb2.BotUpdateRequest() |
| 175 request.bot_version = bot_version | 179 request.bot_version = bot_version |
| 176 response = self._stub.BotUpdate(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 180 response = call_grpc(self._stub.BotUpdate, request) |
| 177 with open(new_zip_fn, 'wb') as f: | 181 with open(new_zip_fn, 'wb') as f: |
| 178 f.write(response.bot_code) | 182 f.write(response.bot_code) |
| 179 | 183 |
| 180 def ping(self): | 184 def ping(self): |
| 181 pass | 185 pass |
| 182 | 186 |
| 183 | 187 |
| 184 def create_state_proto(state_dict, message): | 188 def create_state_proto(state_dict, message): |
| 185 """ Constructs a State message out of a state dict. | 189 """ Constructs a State message out of a state dict. |
| 186 | 190 |
| (...skipping 27 matching lines...) Expand all Loading... | |
| 214 def insert_dict_as_submessage(message, keyname, value): | 218 def insert_dict_as_submessage(message, keyname, value): |
| 215 """Encodes a dict as a Protobuf message. | 219 """Encodes a dict as a Protobuf message. |
| 216 | 220 |
| 217 The keyname for the message field is passed in to simplify the creation | 221 The keyname for the message field is passed in to simplify the creation |
| 218 of the submessage in the first place - you need to say getattr and not | 222 of the submessage in the first place - you need to say getattr and not |
| 219 simply refer to message.keyname since the former actually creates the | 223 simply refer to message.keyname since the former actually creates the |
| 220 submessage while the latter does not. | 224 submessage while the latter does not. |
| 221 """ | 225 """ |
| 222 sub_msg = getattr(message, keyname) | 226 sub_msg = getattr(message, keyname) |
| 223 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) | 227 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) |
| 228 | |
| 229 | |
| 230 def call_grpc(method, request): | |
| 231 """Retries a command a set number of times""" | |
| 232 num_attempts = 0 | |
|
M-A Ruel
2016/12/20 15:05:59
for num_attempts in xrange(MAX_GRPC_ATTEMPTS):
aludwin
2016/12/20 19:51:44
Hmm, the reason I made it a for loop in the first
M-A Ruel
2016/12/20 19:55:36
I don't mind much.
| |
| 233 while True: | |
| 234 try: | |
| 235 num_attempts += 1 | |
| 236 response = method(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | |
| 237 if num_attempts > 1: | |
| 238 logging.warning('call_grpc succeeded after %d attempts', num_attempts) | |
|
M-A Ruel
2016/12/20 15:05:59
IMHO it'd not needed since we can infer from the o
aludwin
2016/12/20 16:42:35
I thought it was nice to see but I agree it's tech
| |
| 239 return response | |
| 240 except grpc.RpcError as rpc_error: | |
| 241 logging.warning('call_grpc - gRPC error: %s', str(rpc_error)) | |
|
M-A Ruel
2016/12/20 15:05:59
str() is not needed
aludwin
2016/12/20 16:42:35
Done.
| |
| 242 if (rpc_error.code() is grpc.StatusCode.UNAVAILABLE | |
| 243 and num_attempts < MAX_GRPC_ATTEMPTS): | |
| 244 logging.warning('Swallowing UNAVAILABLE error (attempt %d/%d)', | |
| 245 num_attempts, MAX_GRPC_ATTEMPTS) | |
| 246 time.sleep(1) | |
|
M-A Ruel
2016/12/20 15:05:59
You need exponential backoff.
e.g. https://en.wiki
aludwin
2016/12/20 16:42:35
What would you recommend as the initial, maximum a
M-A Ruel
2016/12/20 17:01:45
Here's an idea:
https://github.com/luci/luci-py/bl
| |
| 247 else: | |
| 248 logging.error('Cannot recover from gRPC error; propagating') | |
| 249 raise | |
| 250 except Exception as e: | |
|
M-A Ruel
2016/12/20 15:05:59
I don't think this is needed.
aludwin
2016/12/20 16:42:35
Which part? I just want to catch it so I can log i
M-A Ruel
2016/12/20 17:01:45
Whatever catches the unrelated exception will like
| |
| 251 logging.error('call_grpc - non-gRPC error: %s', str(e)) | |
| 252 raise | |
| OLD | NEW |