]> arthur.barton.de Git - bup.git/commitdiff
Add DemuxConn and `bup mux` for client-server
authorBrandon Low <lostlogic@lostlogicx.com>
Thu, 27 Jan 2011 02:30:20 +0000 (18:30 -0800)
committerAvery Pennarun <apenwarr@gmail.com>
Tue, 1 Feb 2011 05:35:19 +0000 (21:35 -0800)
`bup mux` works with any bup command to multiplex its stdout and stderr
streams over a single stdout stream.

DemuxConn works on the client side to demultiplex stderr and data
streams from a single stream, emulating a simple connection.

For now, these are only used in the case of simple socket bup://
client-server connections, because rsh and local connections don't need
them.

Signed-off-by: Brandon Low <lostlogic@lostlogicx.com>
Documentation/bup-mux.md [new file with mode: 0644]
cmd/mux-cmd.py [new file with mode: 0755]
lib/bup/client.py
lib/bup/helpers.py
main.py

diff --git a/Documentation/bup-mux.md b/Documentation/bup-mux.md
new file mode 100644 (file)
index 0000000..1062418
--- /dev/null
@@ -0,0 +1,30 @@
+% bup-mux(1) Bup %BUP_VERSION%
+% Brandon Low <lostlogic@lostlogicx.com>
+% %BUP_DATE%
+
+# NAME
+
+bup-mux - multiplexes data and error streams over a connection
+
+# SYNOPSIS
+
+bup mux \<command\> [options...]
+
+# DESCRIPTION
+
+`bup mux` is used in the bup client-server protocol to
+send both data and debugging/error output over the single
+connection stream.
+
+`bup mux server` might be used in an inetd server setup.
+
+# OPTIONS
+
+command
+:   the subcommand to run
+options
+:   options for command
+
+# BUP
+
+Part of the `bup`(1) suite.
diff --git a/cmd/mux-cmd.py b/cmd/mux-cmd.py
new file mode 100755 (executable)
index 0000000..299dec9
--- /dev/null
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+import os, sys, subprocess, struct
+from bup import options
+from bup.helpers import *
+
+optspec = """
+bup mux command [command arguments...]
+--
+"""
+o = options.Options(optspec)
+(opt, flags, extra) = o.parse(sys.argv[1:])
+if len(extra) < 1:
+    o.fatal('command is required')
+
+cmdpath, cmdfn = os.path.split(__file__)
+subcmd = extra
+subcmd[0] = os.path.join(cmdpath, 'bup-' + subcmd[0])
+
+debug2('bup mux: starting %r\n' % (extra,))
+
+outr, outw = os.pipe()
+errr, errw = os.pipe()
+def close_fds():
+    os.close(outr)
+    os.close(errr)
+p = subprocess.Popen(subcmd, stdout=outw, stderr=errw, preexec_fn=close_fds)
+os.close(outw)
+os.close(errw)
+sys.stdout.write('BUPMUX')
+sys.stdout.flush()
+mux(p, sys.stdout.fileno(), outr, errr)
+os.close(outr)
+os.close(errr)
+prv = p.wait()
+
+if prv:
+    debug1('%s exited with code %d\n' % (extra[0], prv))
+
+debug1('bup mux: done\n')
+
+sys.exit(prv)
index 01287ef299cefc18c6ac9f9eeb14ec3e49049213..3a0bb427a1bee94b93f7cc5d7ac347ac7a4a42a2 100644 (file)
@@ -65,6 +65,7 @@ class Client:
         if is_reverse:
             self.pout = os.fdopen(3, 'rb')
             self.pin = os.fdopen(4, 'wb')
+            self.conn = Conn(self.pout, self.pin)
         else:
             if self.protocol in ('ssh', 'file'):
                 try:
@@ -72,14 +73,14 @@ class Client:
                     self.p = ssh.connect(self.host, self.port, 'server')
                     self.pout = self.p.stdout
                     self.pin = self.p.stdin
+                    self.conn = Conn(self.pout, self.pin)
                 except OSError, e:
                     raise ClientError, 'connect: %s' % e, sys.exc_info()[2]
             elif self.protocol == 'bup':
                 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                 self.sock.connect((self.host, self.port or 1982))
-                self.pout = self.sock.makefile('rb')
-                self.pin = self.sock.makefile('wb')
-        self.conn = Conn(self.pout, self.pin)
+                self.sockw = self.sock.makefile('wb')
+                self.conn = DemuxConn(self.sock.fileno(), self.sockw)
         if self.dir:
             self.dir = re.sub(r'[\r\n]', ' ', self.dir)
             if create:
@@ -101,10 +102,14 @@ class Client:
     def close(self):
         if self.conn and not self._busy:
             self.conn.write('quit\n')
-        if self.pin and self.pout:
+        if self.pin:
             self.pin.close()
-            while self.pout.read(65536):
-                pass
+        if self.sock and self.sockw:
+            self.sockw.close()
+            self.sock.shutdown(socket.SHUT_WR)
+        if self.conn:
+            self.conn.close()
+        if self.pout:
             self.pout.close()
         if self.sock:
             self.sock.close()
index 8a4da97056caf558a3d1aa39c0898f59329b1b4f..62d5b26c32d1a7dbf82a9007ecc51c489a8b748b 100644 (file)
@@ -1,6 +1,6 @@
 """Helper functions and classes for bup."""
 
-import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
+import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
 import heapq, operator
 from bup import _version
 
@@ -205,21 +205,22 @@ def resource_path(subdir=''):
 class NotOk(Exception):
     pass
 
-class Conn:
-    """A helper class for bup's client-server protocol."""
-    def __init__(self, inp, outp):
-        self.inp = inp
+class BaseConn:
+    def __init__(self, outp):
         self.outp = outp
 
+    def close(self):
+        while self._read(65536): pass
+
     def read(self, size):
         """Read 'size' bytes from input stream."""
         self.outp.flush()
-        return self.inp.read(size)
+        return self._read(size)
 
     def readline(self):
         """Read from input stream until a newline is found."""
         self.outp.flush()
-        return self.inp.readline()
+        return self._readline()
 
     def write(self, data):
         """Write 'data' to output stream."""
@@ -228,12 +229,7 @@ class Conn:
 
     def has_input(self):
         """Return true if input stream is readable."""
-        [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
-        if rl:
-            assert(rl[0] == self.inp.fileno())
-            return True
-        else:
-            return None
+        raise NotImplemented("Subclasses must implement has_input")
 
     def ok(self):
         """Indicate end of output from last sent command."""
@@ -247,7 +243,7 @@ class Conn:
     def _check_ok(self, onempty):
         self.outp.flush()
         rl = ''
-        for rl in linereader(self.inp):
+        for rl in linereader(self):
             #log('%d got line: %r\n' % (os.getpid(), rl))
             if not rl:  # empty line
                 continue
@@ -272,6 +268,139 @@ class Conn:
             raise Exception('expected "ok", got %r' % rl)
         return self._check_ok(onempty)
 
+class Conn(BaseConn):
+    def __init__(self, inp, outp):
+        BaseConn.__init__(self, outp)
+        self.inp = inp
+
+    def _read(self, size):
+        return self.inp.read(size)
+
+    def _readline(self):
+        return self.inp.readline()
+
+    def has_input(self):
+        [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
+        if rl:
+            assert(rl[0] == self.inp.fileno())
+            return True
+        else:
+            return None
+
+def checked_reader(fd, n):
+    while n > 0:
+        rl, _, _ = select.select([fd], [], [])
+        assert(rl[0] == fd)
+        buf = os.read(fd, n)
+        if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
+        yield buf
+        n -= len(buf)
+
+MAX_PACKET = 128 * 1024
+def mux(p, outfd, outr, errr):
+    try:
+        fds = [outr, errr]
+        while p.poll() is None:
+            rl, _, _ = select.select(fds, [], [])
+            for fd in rl:
+                if fd == outr:
+                    buf = os.read(outr, MAX_PACKET)
+                    if not buf: break
+                    os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
+                elif fd == errr:
+                    buf = os.read(errr, 1024)
+                    if not buf: break
+                    os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
+    finally:
+        os.write(outfd, struct.pack('!IB', 0, 3))
+
+class DemuxConn(BaseConn):
+    """A helper class for bup's client-server protocol."""
+    def __init__(self, infd, outp):
+        BaseConn.__init__(self, outp)
+        # Anything that comes through before the sync string was not
+        # multiplexed and can be assumed to be debug/log before mux init.
+        tail = ''
+        while tail != 'BUPMUX':
+            tail += os.read(infd, 1024)
+            buf = tail[:-6]
+            tail = tail[-6:]
+            sys.stderr.write(buf)
+        self.infd = infd
+        self.reader = None
+        self.buf = None
+        self.closed = False
+
+    def write(self, data):
+        self._load_buf(0)
+        BaseConn.write(self, data)
+
+    def _next_packet(self, timeout):
+        if self.closed: return False
+        rl, wl, xl = select.select([self.infd], [], [], timeout)
+        if not rl: return False
+        assert(rl[0] == self.infd)
+        ns = ''.join(checked_reader(self.infd, 5))
+        n, fdw = struct.unpack('!IB', ns)
+        assert(n<=MAX_PACKET)
+        if fdw == 1:
+            self.reader = checked_reader(self.infd, n)
+        elif fdw == 2:
+            for buf in checked_reader(self.infd, n):
+                sys.stderr.write(buf)
+        elif fdw == 3:
+            self.closed = True
+            debug2("DemuxConn: marked closed\n")
+        return True
+
+    def _load_buf(self, timeout):
+        if self.buf is not None:
+            return True
+        while not self.closed:
+            while not self.reader:
+                if not self._next_packet(timeout):
+                    return False
+            try:
+                self.buf = self.reader.next()
+                return True
+            except StopIteration:
+                self.reader = None
+        return False
+
+    def _read_parts(self, ix_fn):
+        while self._load_buf(None):
+            assert(self.buf is not None)
+            i = ix_fn(self.buf)
+            if i is None or i == len(self.buf):
+                yv = self.buf
+                self.buf = None
+            else:
+                yv = self.buf[:i]
+                self.buf = self.buf[i:]
+            yield yv
+            if i is not None:
+                break
+
+    def _readline(self):
+        def find_eol(buf):
+            try:
+                return buf.index('\n')+1
+            except ValueError:
+                return None
+        return ''.join(self._read_parts(find_eol))
+
+    def _read(self, size):
+        csize = [size]
+        def until_size(buf): # Closes on csize
+            if len(buf) < csize[0]:
+                csize[0] -= len(buf)
+                return None
+            else:
+                return csize[0]
+        return ''.join(self._read_parts(until_size))
+
+    def has_input(self):
+        return self._load_buf(0)
 
 def linereader(f):
     """Generate a list of input lines from 'f' without terminating newlines."""
diff --git a/main.py b/main.py
index 196f6937d51609e0c6b341baa951ae21fdf51e61..3d56ed49f44018adebf326d7e5432f55b9a903f4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -128,7 +128,7 @@ if not os.path.exists(subcmd[0]):
     usage()
 
 already_fixed = atoi(os.environ.get('BUP_FORCE_TTY'))
-if subcmd_name in ['ftp', 'help']:
+if subcmd_name in ['mux', 'ftp', 'help']:
     already_fixed = True
 fix_stdout = not already_fixed and os.isatty(1)
 fix_stderr = not already_fixed and os.isatty(2)