]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
Add DemuxConn and `bup mux` for client-server
[bup.git] / lib / bup / helpers.py
index 8a4da97056caf558a3d1aa39c0898f59329b1b4f..62d5b26c32d1a7dbf82a9007ecc51c489a8b748b 100644 (file)
@@ -1,6 +1,6 @@
 """Helper functions and classes for bup."""
 
-import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
+import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
 import heapq, operator
 from bup import _version
 
@@ -205,21 +205,22 @@ def resource_path(subdir=''):
 class NotOk(Exception):
     pass
 
-class Conn:
-    """A helper class for bup's client-server protocol."""
-    def __init__(self, inp, outp):
-        self.inp = inp
+class BaseConn:
+    def __init__(self, outp):
         self.outp = outp
 
+    def close(self):
+        while self._read(65536): pass
+
     def read(self, size):
         """Read 'size' bytes from input stream."""
         self.outp.flush()
-        return self.inp.read(size)
+        return self._read(size)
 
     def readline(self):
         """Read from input stream until a newline is found."""
         self.outp.flush()
-        return self.inp.readline()
+        return self._readline()
 
     def write(self, data):
         """Write 'data' to output stream."""
@@ -228,12 +229,7 @@ class Conn:
 
     def has_input(self):
         """Return true if input stream is readable."""
-        [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
-        if rl:
-            assert(rl[0] == self.inp.fileno())
-            return True
-        else:
-            return None
+        raise NotImplemented("Subclasses must implement has_input")
 
     def ok(self):
         """Indicate end of output from last sent command."""
@@ -247,7 +243,7 @@ class Conn:
     def _check_ok(self, onempty):
         self.outp.flush()
         rl = ''
-        for rl in linereader(self.inp):
+        for rl in linereader(self):
             #log('%d got line: %r\n' % (os.getpid(), rl))
             if not rl:  # empty line
                 continue
@@ -272,6 +268,139 @@ class Conn:
             raise Exception('expected "ok", got %r' % rl)
         return self._check_ok(onempty)
 
+class Conn(BaseConn):
+    def __init__(self, inp, outp):
+        BaseConn.__init__(self, outp)
+        self.inp = inp
+
+    def _read(self, size):
+        return self.inp.read(size)
+
+    def _readline(self):
+        return self.inp.readline()
+
+    def has_input(self):
+        [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
+        if rl:
+            assert(rl[0] == self.inp.fileno())
+            return True
+        else:
+            return None
+
+def checked_reader(fd, n):
+    while n > 0:
+        rl, _, _ = select.select([fd], [], [])
+        assert(rl[0] == fd)
+        buf = os.read(fd, n)
+        if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
+        yield buf
+        n -= len(buf)
+
+MAX_PACKET = 128 * 1024
+def mux(p, outfd, outr, errr):
+    try:
+        fds = [outr, errr]
+        while p.poll() is None:
+            rl, _, _ = select.select(fds, [], [])
+            for fd in rl:
+                if fd == outr:
+                    buf = os.read(outr, MAX_PACKET)
+                    if not buf: break
+                    os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
+                elif fd == errr:
+                    buf = os.read(errr, 1024)
+                    if not buf: break
+                    os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
+    finally:
+        os.write(outfd, struct.pack('!IB', 0, 3))
+
+class DemuxConn(BaseConn):
+    """A helper class for bup's client-server protocol."""
+    def __init__(self, infd, outp):
+        BaseConn.__init__(self, outp)
+        # Anything that comes through before the sync string was not
+        # multiplexed and can be assumed to be debug/log before mux init.
+        tail = ''
+        while tail != 'BUPMUX':
+            tail += os.read(infd, 1024)
+            buf = tail[:-6]
+            tail = tail[-6:]
+            sys.stderr.write(buf)
+        self.infd = infd
+        self.reader = None
+        self.buf = None
+        self.closed = False
+
+    def write(self, data):
+        self._load_buf(0)
+        BaseConn.write(self, data)
+
+    def _next_packet(self, timeout):
+        if self.closed: return False
+        rl, wl, xl = select.select([self.infd], [], [], timeout)
+        if not rl: return False
+        assert(rl[0] == self.infd)
+        ns = ''.join(checked_reader(self.infd, 5))
+        n, fdw = struct.unpack('!IB', ns)
+        assert(n<=MAX_PACKET)
+        if fdw == 1:
+            self.reader = checked_reader(self.infd, n)
+        elif fdw == 2:
+            for buf in checked_reader(self.infd, n):
+                sys.stderr.write(buf)
+        elif fdw == 3:
+            self.closed = True
+            debug2("DemuxConn: marked closed\n")
+        return True
+
+    def _load_buf(self, timeout):
+        if self.buf is not None:
+            return True
+        while not self.closed:
+            while not self.reader:
+                if not self._next_packet(timeout):
+                    return False
+            try:
+                self.buf = self.reader.next()
+                return True
+            except StopIteration:
+                self.reader = None
+        return False
+
+    def _read_parts(self, ix_fn):
+        while self._load_buf(None):
+            assert(self.buf is not None)
+            i = ix_fn(self.buf)
+            if i is None or i == len(self.buf):
+                yv = self.buf
+                self.buf = None
+            else:
+                yv = self.buf[:i]
+                self.buf = self.buf[i:]
+            yield yv
+            if i is not None:
+                break
+
+    def _readline(self):
+        def find_eol(buf):
+            try:
+                return buf.index('\n')+1
+            except ValueError:
+                return None
+        return ''.join(self._read_parts(find_eol))
+
+    def _read(self, size):
+        csize = [size]
+        def until_size(buf): # Closes on csize
+            if len(buf) < csize[0]:
+                csize[0] -= len(buf)
+                return None
+            else:
+                return csize[0]
+        return ''.join(self._read_parts(until_size))
+
+    def has_input(self):
+        return self._load_buf(0)
 
 def linereader(f):
     """Generate a list of input lines from 'f' without terminating newlines."""