OLD | NEW |
1 # Copyright 2016 The LUCI Authors. All rights reserved. | 1 # Copyright 2016 The LUCI Authors. All rights reserved. |
2 # Use of this source code is governed under the Apache License, Version 2.0 | 2 # Use of this source code is governed under the Apache License, Version 2.0 |
3 # that can be found in the LICENSE file. | 3 # that can be found in the LICENSE file. |
4 | 4 |
5 # This is a reimplementation of RemoteClientNative but it uses (will use) | 5 # This is a reimplementation of RemoteClientNative but it uses (will use) |
6 # a gRPC method to communicate with a server instead of REST. | 6 # a gRPC method to communicate with a server instead of REST. |
7 | 7 |
8 import json | 8 import json |
9 import logging | 9 import logging |
| 10 import math |
| 11 import random |
| 12 import time |
10 | 13 |
11 import grpc | 14 import grpc |
12 import google.protobuf.json_format | 15 import google.protobuf.json_format |
13 from proto_bot import swarming_bot_pb2 | 16 from proto_bot import swarming_bot_pb2 |
14 from remote_client_errors import InternalError | 17 from remote_client_errors import InternalError |
15 from remote_client_errors import PollError | 18 from remote_client_errors import PollError |
| 19 from utils import net |
16 | 20 |
17 | 21 |
18 # How long to wait for a response from the server. Keeping the same as | 22 # How long to wait for a response from the server. Keeping the same as |
19 # the equivalent in remote_client.py for now. | 23 # the equivalent in remote_client.py for now. |
20 NET_CONNECTION_TIMEOUT_SEC = 5*60 | 24 NET_CONNECTION_TIMEOUT_SEC = 5*60 |
21 | 25 |
22 | 26 |
| 27 # How many times to retry a gRPC call |
| 28 MAX_GRPC_ATTEMPTS = 30 |
| 29 |
| 30 |
| 31 # Longest time to sleep between gRPC calls |
| 32 MAX_GRPC_SLEEP = 10. |
| 33 |
| 34 |
23 class RemoteClientGrpc(object): | 35 class RemoteClientGrpc(object): |
24 """RemoteClientGrpc knows how to make calls via gRPC. | 36 """RemoteClientGrpc knows how to make calls via gRPC. |
25 """ | 37 """ |
26 | 38 |
27 def __init__(self, server): | 39 def __init__(self, server): |
28 logging.info('Communicating with host %s via gRPC', server) | 40 logging.info('Communicating with host %s via gRPC', server) |
29 self._server = server | 41 self._server = server |
30 self._channel = grpc.insecure_channel(server) | 42 self._channel = grpc.insecure_channel(server) |
31 self._stub = swarming_bot_pb2.BotServiceStub(self._channel) | 43 self._stub = swarming_bot_pb2.BotServiceStub(self._channel) |
32 self._log_is_asleep = False | 44 self._log_is_asleep = False |
(...skipping 28 matching lines...) Expand all Loading... |
61 request.output_chunk.data = stdout_and_chunk[0] | 73 request.output_chunk.data = stdout_and_chunk[0] |
62 request.output_chunk.offset = stdout_and_chunk[1] | 74 request.output_chunk.offset = stdout_and_chunk[1] |
63 if exit_code != None: | 75 if exit_code != None: |
64 request.exit_status.code = exit_code | 76 request.exit_status.code = exit_code |
65 | 77 |
66 # Insert everything else. Note that the b64-encoded strings in the dict | 78 # Insert everything else. Note that the b64-encoded strings in the dict |
67 # are automatically decoded by ParseDict. | 79 # are automatically decoded by ParseDict. |
68 google.protobuf.json_format.ParseDict(params, request) | 80 google.protobuf.json_format.ParseDict(params, request) |
69 | 81 |
70 # Perform update | 82 # Perform update |
71 response = self._stub.TaskUpdate(request, | 83 response = call_grpc(self._stub.TaskUpdate, request) |
72 timeout=NET_CONNECTION_TIMEOUT_SEC) | |
73 logging.debug('post_task_update() = %s', request) | 84 logging.debug('post_task_update() = %s', request) |
74 if response.error: | 85 if response.error: |
75 raise InternalError(response.error) | 86 raise InternalError(response.error) |
76 return not response.must_stop | 87 return not response.must_stop |
77 | 88 |
78 def post_task_error(self, task_id, bot_id, message): | 89 def post_task_error(self, task_id, bot_id, message): |
79 request = swarming_bot_pb2.TaskErrorRequest() | 90 request = swarming_bot_pb2.TaskErrorRequest() |
80 request.bot_id = bot_id | 91 request.bot_id = bot_id |
81 request.task_id = task_id | 92 request.task_id = task_id |
82 request.msg = message | 93 request.msg = message |
83 logging.error('post_task_error() = %s', request) | 94 logging.error('post_task_error() = %s', request) |
84 | 95 |
85 response = self._stub.TaskError(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 96 response = call_grpc(self._stub.TaskError, request) |
86 return response.ok | 97 return response.ok |
87 | 98 |
88 def _attributes_json_to_proto(self, json_attr, msg): | 99 def _attributes_json_to_proto(self, json_attr, msg): |
89 msg.version = json_attr['version'] | 100 msg.version = json_attr['version'] |
90 for k, values in sorted(json_attr['dimensions'].iteritems()): | 101 for k, values in sorted(json_attr['dimensions'].iteritems()): |
91 pair = msg.dimensions.add() | 102 pair = msg.dimensions.add() |
92 pair.name = k | 103 pair.name = k |
93 pair.values.extend(values) | 104 pair.values.extend(values) |
94 create_state_proto(json_attr['state'], msg.state) | 105 create_state_proto(json_attr['state'], msg.state) |
95 | 106 |
96 def do_handshake(self, attributes): | 107 def do_handshake(self, attributes): |
97 request = swarming_bot_pb2.HandshakeRequest() | 108 request = swarming_bot_pb2.HandshakeRequest() |
98 self._attributes_json_to_proto(attributes, request.attributes) | 109 self._attributes_json_to_proto(attributes, request.attributes) |
99 response = self._stub.Handshake(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 110 response = call_grpc(self._stub.Handshake, request) |
100 resp = { | 111 resp = { |
101 'server_version': response.server_version, | 112 'server_version': response.server_version, |
102 'bot_version': response.bot_version, | 113 'bot_version': response.bot_version, |
103 'bot_group_cfg_version': response.bot_group_cfg_version, | 114 'bot_group_cfg_version': response.bot_group_cfg_version, |
104 'bot_group_cfg': { | 115 'bot_group_cfg': { |
105 'dimensions': { | 116 'dimensions': { |
106 d.name: d.values for d in response.bot_group_cfg.dimensions | 117 d.name: d.values for d in response.bot_group_cfg.dimensions |
107 }, | 118 }, |
108 }, | 119 }, |
109 } | 120 } |
110 logging.info('Completed handshake: %s', resp) | 121 logging.info('Completed handshake: %s', resp) |
111 return resp | 122 return resp |
112 | 123 |
113 def poll(self, attributes): | 124 def poll(self, attributes): |
114 request = swarming_bot_pb2.PollRequest() | 125 request = swarming_bot_pb2.PollRequest() |
115 self._attributes_json_to_proto(attributes, request.attributes) | 126 self._attributes_json_to_proto(attributes, request.attributes) |
116 # TODO(aludwin): gRPC-specific exception handling (raise PollError). | 127 # TODO(aludwin): gRPC-specific exception handling (raise PollError). |
117 response = self._stub.Poll(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 128 response = call_grpc(self._stub.Poll, request) |
118 | 129 |
119 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: | 130 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: |
120 return 'update', response.version | 131 return 'update', response.version |
121 | 132 |
122 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: | 133 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: |
123 if not self._log_is_asleep: | 134 if not self._log_is_asleep: |
124 logging.info('Going to sleep') | 135 logging.info('Going to sleep') |
125 self._log_is_asleep = True | 136 self._log_is_asleep = True |
126 return 'sleep', response.sleep_time | 137 return 'sleep', response.sleep_time |
127 | 138 |
(...skipping 39 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
167 self._log_is_asleep = False | 178 self._log_is_asleep = False |
168 return 'run', manifest | 179 return 'run', manifest |
169 | 180 |
170 raise PollError('Unknown command in response: %s' % response) | 181 raise PollError('Unknown command in response: %s' % response) |
171 | 182 |
172 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): | 183 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): |
173 # TODO(aludwin): exception handling, pass bot_id | 184 # TODO(aludwin): exception handling, pass bot_id |
174 logging.info('Updating to version: %s', bot_version) | 185 logging.info('Updating to version: %s', bot_version) |
175 request = swarming_bot_pb2.BotUpdateRequest() | 186 request = swarming_bot_pb2.BotUpdateRequest() |
176 request.bot_version = bot_version | 187 request.bot_version = bot_version |
177 response = self._stub.BotUpdate(request, timeout=NET_CONNECTION_TIMEOUT_SEC) | 188 response = call_grpc(self._stub.BotUpdate, request) |
178 with open(new_zip_fn, 'wb') as f: | 189 with open(new_zip_fn, 'wb') as f: |
179 f.write(response.bot_code) | 190 f.write(response.bot_code) |
180 | 191 |
181 def ping(self): | 192 def ping(self): |
182 pass | 193 pass |
183 | 194 |
184 | 195 |
185 def create_state_proto(state_dict, message): | 196 def create_state_proto(state_dict, message): |
186 """ Constructs a State message out of a state dict. | 197 """ Constructs a State message out of a state dict. |
187 | 198 |
(...skipping 27 matching lines...) Expand all Loading... |
215 def insert_dict_as_submessage(message, keyname, value): | 226 def insert_dict_as_submessage(message, keyname, value): |
216 """Encodes a dict as a Protobuf message. | 227 """Encodes a dict as a Protobuf message. |
217 | 228 |
218 The keyname for the message field is passed in to simplify the creation | 229 The keyname for the message field is passed in to simplify the creation |
219 of the submessage in the first place - you need to say getattr and not | 230 of the submessage in the first place - you need to say getattr and not |
220 simply refer to message.keyname since the former actually creates the | 231 simply refer to message.keyname since the former actually creates the |
221 submessage while the latter does not. | 232 submessage while the latter does not. |
222 """ | 233 """ |
223 sub_msg = getattr(message, keyname) | 234 sub_msg = getattr(message, keyname) |
224 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) | 235 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) |
| 236 |
| 237 |
| 238 def call_grpc(method, request): |
| 239 """Retries a command a set number of times""" |
| 240 for attempt in range(1, MAX_GRPC_ATTEMPTS+1): |
| 241 try: |
| 242 return method(request, timeout=NET_CONNECTION_TIMEOUT_SEC) |
| 243 except grpc.RpcError as g: |
| 244 if g.code() is not grpc.StatusCode.UNAVAILABLE: |
| 245 raise |
| 246 logging.warning('call_grpc - proxy is unavailable (attempt %d/%d)', |
| 247 attempt, MAX_GRPC_ATTEMPTS) |
| 248 grpc_error = g |
| 249 time.sleep(net.calculate_sleep_before_retry(attempt, MAX_GRPC_SLEEP)) |
| 250 # If we get here, it must be because we got (and saved) an error |
| 251 assert grpc_error is not None |
| 252 raise grpc_error |
OLD | NEW |