Index: chrome/test/chromedriver/test/webserver.py |
diff --git a/chrome/test/chromedriver/test/webserver.py b/chrome/test/chromedriver/test/webserver.py |
index 716e0f4a9b47d96cf4775667e6e9ff4df580952f..9de03395b9f744cee3fbaf7f4f743f00fb1735fc 100644 |
--- a/chrome/test/chromedriver/test/webserver.py |
+++ b/chrome/test/chromedriver/test/webserver.py |
@@ -48,6 +48,9 @@ class Request(object): |
def GetPath(self): |
return self._handler.path |
+ def GetHeader(self, name): |
+ return self._handler.headers.getheader(name) |
+ |
class _BaseServer(BaseHTTPServer.HTTPServer): |
"""Internal server that throws if timed out waiting for a request.""" |
@@ -120,19 +123,28 @@ class WebServer(object): |
self._thread.daemon = True |
self._thread.start() |
self._path_data_map = {} |
- self._path_data_lock = threading.Lock() |
+ self._path_callback_map = {} |
+ self._path_maps_lock = threading.Lock() |
def _OnRequest(self, request, responder): |
path = request.GetPath().split('?')[0] |
- # Serve from path -> data map. |
- self._path_data_lock.acquire() |
+ # Serve from path -> callback and data maps. |
+ self._path_maps_lock.acquire() |
try: |
+ if path in self._path_callback_map: |
+ body = self._path_callback_map[path](request) |
+ if body: |
+ responder.SendResponse(body) |
+ else: |
+ responder.SendError(503) |
+ return |
+ |
if path in self._path_data_map: |
responder.SendResponse(self._path_data_map[path]) |
return |
finally: |
- self._path_data_lock.release() |
+ self._path_maps_lock.release() |
# Serve from file. |
path = os.path.normpath( |
@@ -146,11 +158,19 @@ class WebServer(object): |
responder.SendResponseFromFile(path) |
def SetDataForPath(self, path, data): |
- self._path_data_lock.acquire() |
+ self._path_maps_lock.acquire() |
try: |
self._path_data_map[path] = data |
finally: |
- self._path_data_lock.release() |
+ self._path_maps_lock.release() |
+ |
+ def SetCallbackForPath(self, path, func): |
+ self._path_maps_lock.acquire() |
+ try: |
+ self._path_callback_map[path] = func |
+ finally: |
+ self._path_maps_lock.release() |
+ |
def GetUrl(self): |
"""Returns the base URL of the server.""" |