]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
Adjust server and client to accommodate python 3
[bup.git] / lib / bup / helpers.py
index 84b1b978682f375c6bbeaeaa4b7fbde53380f51e..0e9078092200d92f3c489a45f89d0a5e3ca9b4e7 100644 (file)
@@ -6,7 +6,6 @@ from contextlib import contextmanager
 from ctypes import sizeof, c_void_p
 from math import floor
 from os import environ
-from pipes import quote
 from subprocess import PIPE, Popen
 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
 import hashlib, heapq, math, operator, time, grp, tempfile
@@ -113,7 +112,7 @@ def lines_until_sentinel(f, sentinel, ex_type):
     # sentinel must end with \n and must contain only one \n
     while True:
         line = f.readline()
-        if not (line and line.endswith('\n')):
+        if not (line and line.endswith(b'\n')):
             raise ex_type('Hit EOF while reading line')
         if line == sentinel:
             return
@@ -268,11 +267,47 @@ def unlink(f):
             raise
 
 
+_bq_simple_id_rx = re.compile(br'^[-_./a-zA-Z0-9]+$')
+_sq_simple_id_rx = re.compile(r'^[-_./a-zA-Z0-9]+$')
+
+def bquote(x):
+    if x == b'':
+        return b"''"
+    if _bq_simple_id_rx.match(x):
+        return x
+    return b"'%s'" % x.replace(b"'", b"'\"'\"'")
+
+def squote(x):
+    if x == '':
+        return "''"
+    if _sq_simple_id_rx.match(x):
+        return x
+    return "'%s'" % x.replace("'", "'\"'\"'")
+
+def quote(x):
+    if isinstance(x, bytes):
+        return bquote(x)
+    if isinstance(x, compat.str_type):
+        return squote(x)
+    assert False
+
 def shstr(cmd):
-    if isinstance(cmd, compat.str_type):
+    """Return a shell quoted string for cmd if it's a sequence, else cmd.
+
+    cmd must be a string, bytes, or a sequence of one or the other,
+    and the assumption is that if cmd is a string or bytes, then it's
+    already quoted (because it's what's actually being passed to
+    call() and friends.  e.g. log(shstr(cmd)); call(cmd)
+
+    """
+    if isinstance(cmd, (bytes, compat.str_type)):
         return cmd
-    else:
-        return ' '.join(map(quote, cmd))
+    elif all(isinstance(x, bytes) for x in cmd):
+        return b' '.join(map(bquote, cmd))
+    elif all(isinstance(x, compat.str_type) for x in cmd):
+        return ' '.join(map(squote, cmd))
+    raise TypeError('unsupported shstr argument: ' + repr(cmd))
+
 
 exc = subprocess.check_call
 
@@ -448,23 +483,23 @@ class BaseConn:
 
     def ok(self):
         """Indicate end of output from last sent command."""
-        self.write('\nok\n')
+        self.write(b'\nok\n')
 
     def error(self, s):
         """Indicate server error to the client."""
-        s = re.sub(r'\s+', ' ', str(s))
-        self.write('\nerror %s\n' % s)
+        s = re.sub(br'\s+', b' ', s)
+        self.write(b'\nerror %s\n' % s)
 
     def _check_ok(self, onempty):
         self.outp.flush()
-        rl = ''
+        rl = b''
         for rl in linereader(self):
             #log('%d got line: %r\n' % (os.getpid(), rl))
             if not rl:  # empty line
                 continue
-            elif rl == 'ok':
+            elif rl == b'ok':
                 return None
-            elif rl.startswith('error '):
+            elif rl.startswith(b'error '):
                 #log('client: error: %s\n' % rl[6:])
                 return NotOk(rl[6:])
             else:
@@ -841,11 +876,6 @@ def parse_num(s):
     return int(num*mult)
 
 
-def count(l):
-    """Count the number of elements in an iterator. (consumes the iterator)"""
-    return reduce(lambda x,y: x+1, l)
-
-
 saved_errors = []
 def add_error(e):
     """Append an error message to the list of saved errors.