Index: tools/telemetry/telemetry/core/forwarders/do_nothing_forwarder.py |
diff --git a/tools/telemetry/telemetry/core/forwarders/do_nothing_forwarder.py b/tools/telemetry/telemetry/core/forwarders/do_nothing_forwarder.py |
index cf94ae9ab73569d7ddeeac1b93eb12c62f091895..0542df5cb952208f07ef79484502076df1b98c61 100644 |
--- a/tools/telemetry/telemetry/core/forwarders/do_nothing_forwarder.py |
+++ b/tools/telemetry/telemetry/core/forwarders/do_nothing_forwarder.py |
@@ -2,6 +2,7 @@ |
# Use of this source code is governed by a BSD-style license that can be |
# found in the LICENSE file. |
+import contextlib |
import logging |
import socket |
@@ -9,6 +10,21 @@ from telemetry.core import forwarders |
from telemetry.core import util |
+class Error(Exception): |
+ """Base class for exceptions in this module.""" |
+ pass |
+ |
+ |
+class PortsMismatchError(Error): |
+ """Raised when local and remote ports are not equal.""" |
+ pass |
+ |
+ |
+class ConnectionError(Error): |
+ """Raised when unable to connect to local TCP ports.""" |
+ pass |
+ |
+ |
class DoNothingForwarderFactory(forwarders.ForwarderFactory): |
def Create(self, port_pairs): |
@@ -16,16 +32,44 @@ class DoNothingForwarderFactory(forwarders.ForwarderFactory): |
class DoNothingForwarder(forwarders.Forwarder): |
+ """Check that no forwarding is needed for the given port pairs. |
+ |
+ The local and remote ports must be equal. Otherwise, the "do nothing" |
+ forwarder does not make sense. (Raises PortsMismatchError.) |
+ |
+ Also, check that all TCP ports support connections. (Raises ConnectionError.) |
+ """ |
def __init__(self, port_pairs): |
super(DoNothingForwarder, self).__init__(port_pairs) |
+ self._CheckPortPairs() |
- for port_pair in port_pairs: |
+ def _CheckPortPairs(self): |
+ # namedtuple._asdict() is a public method. The method starts with an |
+ # underscore to avoid conflicts with attribute names. pylint: disable=W0212 |
+ for protocol, port_pair in self._port_pairs._asdict().items(): |
if not port_pair: |
continue |
local_port, remote_port = port_pair |
- assert local_port == remote_port, 'Local port forwarding is not supported' |
- def IsStarted(): |
- return not socket.socket().connect_ex((self.host_ip, self.host_port)) |
- util.WaitFor(IsStarted, 10) |
- logging.debug('Server started on %s:%d' % (self.host_ip, self.host_port)) |
+ if local_port != remote_port: |
+ raise PortsMismatchError('Local port forwarding is not supported') |
+ if protocol == 'dns': |
+ logging.debug('Connection test SKIPPED for DNS: %s:%d', |
+ self.host_ip, local_port) |
+ continue |
+ try: |
+ self._WaitForConnectionEstablished( |
+ (self.host_ip, local_port), timeout=10) |
+ logging.debug( |
+ 'Connection test succeeded for %s: %s:%d', |
+ protocol.upper(), self.host_ip, local_port) |
+ except util.TimeoutException: |
+ raise ConnectionError( |
+ 'Unable to connect to %s address: %s:%d', |
+ protocol.upper(), self.host_ip, local_port) |
+ |
+ def _WaitForConnectionEstablished(self, address, timeout): |
+ def CanConnect(): |
+ with contextlib.closing(socket.socket()) as s: |
+ return s.connect_ex(address) == 0 |
+ util.WaitFor(CanConnect, timeout) |