OLD | NEW |
1 # Copyright 2016 The LUCI Authors. All rights reserved. | 1 # Copyright 2016 The LUCI Authors. All rights reserved. |
2 # Use of this source code is governed under the Apache License, Version 2.0 | 2 # Use of this source code is governed under the Apache License, Version 2.0 |
3 # that can be found in the LICENSE file. | 3 # that can be found in the LICENSE file. |
4 | 4 |
5 # This is a reimplementation of RemoteClientNative but it uses (will use) | 5 # This is a reimplementation of RemoteClientNative but it uses (will use) |
6 # a gRPC method to communicate with a server instead of REST. | 6 # a gRPC method to communicate with a server instead of REST. |
7 | 7 |
8 import json | 8 import json |
9 import logging | 9 import logging |
10 | 10 |
11 import grpc | 11 import grpc |
12 import google.protobuf.json_format | 12 import google.protobuf.json_format |
13 from proto_bot import swarming_bot_pb2 | 13 from proto_bot import swarming_bot_pb2 |
14 from remote_client_errors import InternalError | 14 from remote_client_errors import InternalError |
15 | 15 |
16 | 16 |
| 17 # How long to wait for a response from the server. Keeping the same as |
| 18 # the equivalent in remote_client.py for now. |
| 19 NET_CONNECTION_TIMEOUT_SEC = 5*60 |
| 20 |
| 21 |
17 class RemoteClientGrpc(object): | 22 class RemoteClientGrpc(object): |
18 """RemoteClientGrpc knows how to make calls via gRPC. | 23 """RemoteClientGrpc knows how to make calls via gRPC. |
19 """ | 24 """ |
20 | 25 |
21 def __init__(self, server): | 26 def __init__(self, server): |
22 logging.info('Communicating with host %s via gRPC', server) | 27 logging.info('Communicating with host %s via gRPC', server) |
23 self._server = server | 28 self._server = server |
24 self._channel = grpc.insecure_channel(server) | 29 self._channel = grpc.insecure_channel(server) |
25 self._stub = swarming_bot_pb2.BotServiceStub(self._channel) | 30 self._stub = swarming_bot_pb2.BotServiceStub(self._channel) |
26 self._log_is_asleep = False | 31 self._log_is_asleep = False |
(...skipping 28 matching lines...) Expand all Loading... |
55 request.output_chunk.data = stdout_and_chunk[0] | 60 request.output_chunk.data = stdout_and_chunk[0] |
56 request.output_chunk.offset = stdout_and_chunk[1] | 61 request.output_chunk.offset = stdout_and_chunk[1] |
57 if exit_code != None: | 62 if exit_code != None: |
58 request.exit_status.code = exit_code | 63 request.exit_status.code = exit_code |
59 | 64 |
60 # Insert everything else. Note that the b64-encoded strings in the dict | 65 # Insert everything else. Note that the b64-encoded strings in the dict |
61 # are automatically decoded by ParseDict. | 66 # are automatically decoded by ParseDict. |
62 google.protobuf.json_format.ParseDict(params, request) | 67 google.protobuf.json_format.ParseDict(params, request) |
63 | 68 |
64 # Perform update | 69 # Perform update |
65 response = self._stub.TaskUpdate(request) | 70 response = self._stub.TaskUpdate(request, |
| 71 timeout=NET_CONNECTION_TIMEOUT_SEC) |
66 logging.debug('post_task_update() = %s', request) | 72 logging.debug('post_task_update() = %s', request) |
67 if response.error: | 73 if response.error: |
68 raise InternalError(response.error) | 74 raise InternalError(response.error) |
69 return not response.must_stop | 75 return not response.must_stop |
70 | 76 |
71 def post_task_error(self, task_id, bot_id, message): | 77 def post_task_error(self, task_id, bot_id, message): |
72 request = swarming_bot_pb2.TaskErrorRequest() | 78 request = swarming_bot_pb2.TaskErrorRequest() |
73 request.bot_id = bot_id | 79 request.bot_id = bot_id |
74 request.task_id = task_id | 80 request.task_id = task_id |
75 request.msg = message | 81 request.msg = message |
76 logging.error('post_task_error() = %s', request) | 82 logging.error('post_task_error() = %s', request) |
77 | 83 |
78 response = self._stub.TaskError(request) | 84 response = self._stub.TaskError(request, timeout=NET_CONNECTION_TIMEOUT_SEC) |
79 return response.ok | 85 return response.ok |
80 | 86 |
81 def _attributes_json_to_proto(self, json_attr, msg): | 87 def _attributes_json_to_proto(self, json_attr, msg): |
82 msg.version = json_attr['version'] | 88 msg.version = json_attr['version'] |
83 for k, values in sorted(json_attr['dimensions'].iteritems()): | 89 for k, values in sorted(json_attr['dimensions'].iteritems()): |
84 pair = msg.dimensions.add() | 90 pair = msg.dimensions.add() |
85 pair.name = k | 91 pair.name = k |
86 pair.values.extend(values) | 92 pair.values.extend(values) |
87 create_state_proto(json_attr['state'], msg.state) | 93 create_state_proto(json_attr['state'], msg.state) |
88 | 94 |
89 def do_handshake(self, attributes): | 95 def do_handshake(self, attributes): |
90 request = swarming_bot_pb2.HandshakeRequest() | 96 request = swarming_bot_pb2.HandshakeRequest() |
91 self._attributes_json_to_proto(attributes, request.attributes) | 97 self._attributes_json_to_proto(attributes, request.attributes) |
92 response = self._stub.Handshake(request) | 98 response = self._stub.Handshake(request, timeout=NET_CONNECTION_TIMEOUT_SEC) |
93 resp = { | 99 resp = { |
94 'server_version': response.server_version, | 100 'server_version': response.server_version, |
95 'bot_version': response.bot_version, | 101 'bot_version': response.bot_version, |
96 'bot_group_cfg_version': response.bot_group_cfg_version, | 102 'bot_group_cfg_version': response.bot_group_cfg_version, |
97 'bot_group_cfg': { | 103 'bot_group_cfg': { |
98 'dimensions': { | 104 'dimensions': { |
99 d.name: d.values for d in response.bot_group_cfg.dimensions | 105 d.name: d.values for d in response.bot_group_cfg.dimensions |
100 }, | 106 }, |
101 }, | 107 }, |
102 } | 108 } |
103 logging.info('Completed handshake: %s', resp) | 109 logging.info('Completed handshake: %s', resp) |
104 return resp | 110 return resp |
105 | 111 |
106 def poll(self, attributes): | 112 def poll(self, attributes): |
107 request = swarming_bot_pb2.PollRequest() | 113 request = swarming_bot_pb2.PollRequest() |
108 self._attributes_json_to_proto(attributes, request.attributes) | 114 self._attributes_json_to_proto(attributes, request.attributes) |
109 # TODO(aludwin): gRPC-specific exception handling | 115 # TODO(aludwin): gRPC-specific exception handling |
110 response = self._stub.Poll(request) | 116 response = self._stub.Poll(request, timeout=NET_CONNECTION_TIMEOUT_SEC) |
111 | 117 |
112 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: | 118 if response.cmd == swarming_bot_pb2.PollResponse.UPDATE: |
113 return 'update', response.version | 119 return 'update', response.version |
114 | 120 |
115 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: | 121 if response.cmd == swarming_bot_pb2.PollResponse.SLEEP: |
116 if not self._log_is_asleep: | 122 if not self._log_is_asleep: |
117 logging.info('Going to sleep') | 123 logging.info('Going to sleep') |
118 self._log_is_asleep = True | 124 self._log_is_asleep = True |
119 return 'sleep', response.sleep_time | 125 return 'sleep', response.sleep_time |
120 | 126 |
(...skipping 39 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
160 self._log_is_asleep = False | 166 self._log_is_asleep = False |
161 return 'run', manifest | 167 return 'run', manifest |
162 | 168 |
163 raise ValueError('Unknown command in response: %s' % response) | 169 raise ValueError('Unknown command in response: %s' % response) |
164 | 170 |
165 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): | 171 def get_bot_code(self, new_zip_fn, bot_version, _bot_id): |
166 # TODO(aludwin): exception handling, pass bot_id | 172 # TODO(aludwin): exception handling, pass bot_id |
167 logging.info('Updating to version: %s', bot_version) | 173 logging.info('Updating to version: %s', bot_version) |
168 request = swarming_bot_pb2.BotUpdateRequest() | 174 request = swarming_bot_pb2.BotUpdateRequest() |
169 request.bot_version = bot_version | 175 request.bot_version = bot_version |
170 response = self._stub.BotUpdate(request) | 176 response = self._stub.BotUpdate(request, timeout=NET_CONNECTION_TIMEOUT_SEC) |
171 with open(new_zip_fn, 'wb') as f: | 177 with open(new_zip_fn, 'wb') as f: |
172 f.write(response.bot_code) | 178 f.write(response.bot_code) |
173 | 179 |
174 def ping(self): | 180 def ping(self): |
175 pass | 181 pass |
176 | 182 |
177 | 183 |
178 def create_state_proto(state_dict, message): | 184 def create_state_proto(state_dict, message): |
179 """ Constructs a State message out of a state dict. | 185 """ Constructs a State message out of a state dict. |
180 | 186 |
(...skipping 27 matching lines...) Expand all Loading... |
208 def insert_dict_as_submessage(message, keyname, value): | 214 def insert_dict_as_submessage(message, keyname, value): |
209 """Encodes a dict as a Protobuf message. | 215 """Encodes a dict as a Protobuf message. |
210 | 216 |
211 The keyname for the message field is passed in to simplify the creation | 217 The keyname for the message field is passed in to simplify the creation |
212 of the submessage in the first place - you need to say getattr and not | 218 of the submessage in the first place - you need to say getattr and not |
213 simply refer to message.keyname since the former actually creates the | 219 simply refer to message.keyname since the former actually creates the |
214 submessage while the latter does not. | 220 submessage while the latter does not. |
215 """ | 221 """ |
216 sub_msg = getattr(message, keyname) | 222 sub_msg = getattr(message, keyname) |
217 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) | 223 google.protobuf.json_format.Parse(json.dumps(value), sub_msg) |
OLD | NEW |