Index: tools/usb_gadget/server.py |
diff --git a/tools/usb_gadget/server.py b/tools/usb_gadget/server.py |
index 9f237264213fdd5b04db69571f6bd12a5b2f48a0..e5aa8b8e5d46d71aa80e9e76133bced15233b16d 100644 |
--- a/tools/usb_gadget/server.py |
+++ b/tools/usb_gadget/server.py |
@@ -5,10 +5,16 @@ |
"""WSGI application to manage a USB gadget. |
""" |
+import datetime |
+import hashlib |
import re |
+import subprocess |
import sys |
+import time |
+import urllib2 |
from tornado import httpserver |
+from tornado import ioloop |
from tornado import web |
import default_gadget |
@@ -20,6 +26,9 @@ chip = None |
claimed_by = None |
default = default_gadget.DefaultGadget() |
gadget = None |
+hardware = None |
+interface = None |
+port = None |
def SwitchGadget(new_gadget): |
@@ -45,6 +54,68 @@ class VersionHandler(web.RequestHandler): |
self.write(version) |
+class UpdateHandler(web.RequestHandler): |
+ |
+ def post(self): |
+ fileinfo = self.request.files['file'][0] |
+ |
+ match = VERSION_PATTERN.match(fileinfo['filename']) |
+ if match is None: |
+ self.write('Filename must contain MD5 hash.') |
+ self.set_status(400) |
+ return |
+ |
+ content = fileinfo['body'] |
+ md5sum = hashlib.md5(content).hexdigest() |
+ if md5sum != match.group(1): |
+ self.write('File hash does not match.') |
+ self.set_status(400) |
+ return |
+ |
+ filename = 'usb_gadget-{}.zip'.format(md5sum) |
+ with open(filename, 'wb') as f: |
+ f.write(content) |
+ |
+ args = ['/usr/bin/python', filename, |
+ '--interface', interface, |
+ '--port', str(port), |
+ '--hardware', hardware] |
+ if claimed_by is not None: |
+ args.extend(['--start-claimed', claimed_by]) |
+ |
+ print 'Reloading with version {}...'.format(md5sum) |
+ |
+ global http_server |
+ if chip.IsConfigured(): |
+ chip.Destroy() |
+ http_server.stop() |
+ |
+ child = subprocess.Popen(args, close_fds=True) |
+ |
+ while True: |
+ child.poll() |
+ if child.returncode is not None: |
+ self.write('New package exited with error {}.' |
+ .format(child.returncode)) |
+ self.set_status(500) |
+ |
+ http_server = httpserver.HTTPServer(app) |
+ http_server.listen(port) |
+ SwitchGadget(gadget) |
+ return |
+ |
+ try: |
+ f = urllib2.urlopen('http://{}/version'.format(address)) |
+ if f.getcode() == 200: |
+ # Update complete, wait 1 second to make sure buffers are flushed. |
+ io_loop = ioloop.IOLoop.instance() |
+ io_loop.add_timeout(datetime.timedelta(seconds=1), io_loop.stop) |
+ return |
+ except urllib2.URLError: |
+ pass |
+ time.sleep(0.1) |
+ |
+ |
class ClaimHandler(web.RequestHandler): |
def post(self): |
@@ -88,6 +159,7 @@ class ReconnectHandler(web.RequestHandler): |
app = web.Application([ |
(r'/version', VersionHandler), |
+ (r'/update', UpdateHandler), |
(r'/claim', ClaimHandler), |
(r'/unclaim', UnclaimHandler), |
(r'/unconfigure', UnconfigureHandler), |