]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
Adjust server and client to accommodate python 3
[bup.git] / lib / bup / helpers.py
index 64d5b26bfcbdf7d5e919f247819ac313a3486174..0e9078092200d92f3c489a45f89d0a5e3ca9b4e7 100644 (file)
@@ -4,14 +4,16 @@ from __future__ import absolute_import, division
 from collections import namedtuple
 from contextlib import contextmanager
 from ctypes import sizeof, c_void_p
+from math import floor
 from os import environ
-from pipes import quote
 from subprocess import PIPE, Popen
 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
 import hashlib, heapq, math, operator, time, grp, tempfile
 
 from bup import _helpers
 from bup import compat
+from bup.compat import byte_int
+from bup.io import path_msg
 # This function should really be in helpers, not in bup.options.  But we
 # want options.py to be standalone so people can include it in other projects.
 from bup.options import _tty_width as tty_width
@@ -37,17 +39,17 @@ def last(iterable):
 
 
 def atoi(s):
-    """Convert the string 's' to an integer. Return 0 if s is not a number."""
+    """Convert s (ascii bytes) to an integer. Return 0 if s is not a number."""
     try:
-        return int(s or '0')
+        return int(s or b'0')
     except ValueError:
         return 0
 
 
 def atof(s):
-    """Convert the string 's' to a float. Return 0 if s is not a number."""
+    """Convert s (ascii bytes) to a float. Return 0 if s is not a number."""
     try:
-        return float(s or '0')
+        return float(s or b'0')
     except ValueError:
         return 0
 
@@ -110,7 +112,7 @@ def lines_until_sentinel(f, sentinel, ex_type):
     # sentinel must end with \n and must contain only one \n
     while True:
         line = f.readline()
-        if not (line and line.endswith('\n')):
+        if not (line and line.endswith(b'\n')):
             raise ex_type('Hit EOF while reading line')
         if line == sentinel:
             return
@@ -214,6 +216,13 @@ def mkdirp(d, mode=None):
             raise
 
 
+class MergeIterItem:
+    def __init__(self, entry, read_it):
+        self.entry = entry
+        self.read_it = read_it
+    def __lt__(self, x):
+        return self.entry < x.entry
+
 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
     if key:
         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
@@ -223,14 +232,14 @@ def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
     total = sum(len(it) for it in iters)
     iters = (iter(it) for it in iters)
     heap = ((next(it, None),it) for it in iters)
-    heap = [(e,it) for e,it in heap if e]
+    heap = [MergeIterItem(e, it) for e, it in heap if e]
 
     heapq.heapify(heap)
     pe = None
     while heap:
         if not count % pfreq:
             pfunc(count, total)
-        e, it = heap[0]
+        e, it = heap[0].entry, heap[0].read_it
         if not samekey(e, pe):
             pe = e
             yield e
@@ -240,7 +249,8 @@ def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
         except StopIteration:
             heapq.heappop(heap) # remove current
         else:
-            heapq.heapreplace(heap, (e, it)) # shift current to new location
+            # shift current to new location
+            heapq.heapreplace(heap, MergeIterItem(e, it))
     pfinal(count, total)
 
 
@@ -257,11 +267,47 @@ def unlink(f):
             raise
 
 
+_bq_simple_id_rx = re.compile(br'^[-_./a-zA-Z0-9]+$')
+_sq_simple_id_rx = re.compile(r'^[-_./a-zA-Z0-9]+$')
+
+def bquote(x):
+    if x == b'':
+        return b"''"
+    if _bq_simple_id_rx.match(x):
+        return x
+    return b"'%s'" % x.replace(b"'", b"'\"'\"'")
+
+def squote(x):
+    if x == '':
+        return "''"
+    if _sq_simple_id_rx.match(x):
+        return x
+    return "'%s'" % x.replace("'", "'\"'\"'")
+
+def quote(x):
+    if isinstance(x, bytes):
+        return bquote(x)
+    if isinstance(x, compat.str_type):
+        return squote(x)
+    assert False
+
 def shstr(cmd):
-    if isinstance(cmd, compat.str_type):
+    """Return a shell quoted string for cmd if it's a sequence, else cmd.
+
+    cmd must be a string, bytes, or a sequence of one or the other,
+    and the assumption is that if cmd is a string or bytes, then it's
+    already quoted (because it's what's actually being passed to
+    call() and friends.  e.g. log(shstr(cmd)); call(cmd)
+
+    """
+    if isinstance(cmd, (bytes, compat.str_type)):
         return cmd
-    else:
-        return ' '.join(map(quote, cmd))
+    elif all(isinstance(x, bytes) for x in cmd):
+        return b' '.join(map(bquote, cmd))
+    elif all(isinstance(x, compat.str_type) for x in cmd):
+        return ' '.join(map(squote, cmd))
+    raise TypeError('unsupported shstr argument: ' + repr(cmd))
+
 
 exc = subprocess.check_call
 
@@ -282,7 +328,7 @@ def exo(cmd,
     out, err = p.communicate(input)
     if check and p.returncode != 0:
         raise Exception('subprocess %r failed with status %d, stderr: %r'
-                        % (' '.join(map(quote, cmd)), p.returncode, err))
+                        % (b' '.join(map(quote, cmd)), p.returncode, err))
     return out, err, p
 
 def readpipe(argv, preexec_fn=None, shell=False):
@@ -292,7 +338,7 @@ def readpipe(argv, preexec_fn=None, shell=False):
     out, err = p.communicate()
     if p.returncode != 0:
         raise Exception('subprocess %r failed with status %d'
-                        % (' '.join(argv), p.returncode))
+                        % (b' '.join(argv), p.returncode))
     return out
 
 
@@ -391,7 +437,7 @@ def hostname():
     """Get the FQDN of this machine."""
     global _hostname
     if not _hostname:
-        _hostname = socket.getfqdn()
+        _hostname = socket.getfqdn().encode('iso-8859-1')
     return _hostname
 
 
@@ -437,23 +483,23 @@ class BaseConn:
 
     def ok(self):
         """Indicate end of output from last sent command."""
-        self.write('\nok\n')
+        self.write(b'\nok\n')
 
     def error(self, s):
         """Indicate server error to the client."""
-        s = re.sub(r'\s+', ' ', str(s))
-        self.write('\nerror %s\n' % s)
+        s = re.sub(br'\s+', b' ', s)
+        self.write(b'\nerror %s\n' % s)
 
     def _check_ok(self, onempty):
         self.outp.flush()
-        rl = ''
+        rl = b''
         for rl in linereader(self):
             #log('%d got line: %r\n' % (os.getpid(), rl))
             if not rl:  # empty line
                 continue
-            elif rl == 'ok':
+            elif rl == b'ok':
                 return None
-            elif rl.startswith('error '):
+            elif rl.startswith(b'error '):
                 #log('client: error: %s\n' % rl[6:])
                 return NotOk(rl[6:])
             else:
@@ -682,8 +728,9 @@ def atomically_replaced_file(name, mode='w', buffering=-1):
 
 def slashappend(s):
     """Append "/" to 's' if it doesn't aleady end in "/"."""
-    if s and not s.endswith('/'):
-        return s + '/'
+    assert isinstance(s, bytes)
+    if s and not s.endswith(b'/'):
+        return s + b'/'
     else:
         return s
 
@@ -797,13 +844,17 @@ throw a ValueError that may contain additional information."""
 
 
 def parse_num(s):
-    """Parse data size information into a float number.
+    """Parse string or bytes as a possibly unit suffixed number.
 
-    Here are some examples of conversions:
+    For example:
         199.2k means 203981 bytes
         1GB means 1073741824 bytes
         2.1 tb means 2199023255552 bytes
     """
+    if isinstance(s, bytes):
+        # FIXME: should this raise a ValueError for UnicodeDecodeError
+        # (perhaps with the latter as the context).
+        s = s.decode('ascii')
     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
     if not g:
         raise ValueError("can't parse %r as a number" % s)
@@ -825,11 +876,6 @@ def parse_num(s):
     return int(num*mult)
 
 
-def count(l):
-    """Count the number of elements in an iterator. (consumes the iterator)"""
-    return reduce(lambda x,y: x+1, l)
-
-
 saved_errors = []
 def add_error(e):
     """Append an error message to the list of saved errors.
@@ -979,16 +1025,16 @@ def path_components(path):
     full_path_to_name).  Path must start with '/'.
     Example:
       '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
-    if not path.startswith('/'):
-        raise Exception('path must start with "/": %s' % path)
+    if not path.startswith(b'/'):
+        raise Exception('path must start with "/": %s' % path_msg(path))
     # Since we assume path startswith('/'), we can skip the first element.
-    result = [('', '/')]
+    result = [(b'', b'/')]
     norm_path = os.path.abspath(path)
-    if norm_path == '/':
+    if norm_path == b'/':
         return result
-    full_path = ''
-    for p in norm_path.split('/')[1:]:
-        full_path += '/' + p
+    full_path = b''
+    for p in norm_path.split(b'/')[1:]:
+        full_path += b'/' + p
         result.append((p, full_path))
     return result
 
@@ -1002,14 +1048,14 @@ def stripped_path_components(path, strip_prefixes):
     sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
     for bp in sorted_strip_prefixes:
         normalized_bp = os.path.abspath(bp)
-        if normalized_bp == '/':
+        if normalized_bp == b'/':
             continue
         if normalized_path.startswith(normalized_bp):
             prefix = normalized_path[:len(normalized_bp)]
             result = []
-            for p in normalized_path[len(normalized_bp):].split('/'):
+            for p in normalized_path[len(normalized_bp):].split(b'/'):
                 if p: # not root
-                    prefix += '/'
+                    prefix += b'/'
                 prefix += p
                 result.append((p, prefix))
             return result
@@ -1040,21 +1086,21 @@ def grafted_path_components(graft_points, path):
         new_prefix = os.path.normpath(new_prefix)
         if clean_path.startswith(old_prefix):
             escaped_prefix = re.escape(old_prefix)
-            grafted_path = re.sub(r'^' + escaped_prefix, new_prefix, clean_path)
+            grafted_path = re.sub(br'^' + escaped_prefix, new_prefix, clean_path)
             # Handle /foo=/ (at least) -- which produces //whatever.
-            grafted_path = '/' + grafted_path.lstrip('/')
+            grafted_path = b'/' + grafted_path.lstrip(b'/')
             clean_path_components = path_components(clean_path)
             # Count the components that were stripped.
-            strip_count = 0 if old_prefix == '/' else old_prefix.count('/')
-            new_prefix_parts = new_prefix.split('/')
-            result_prefix = grafted_path.split('/')[:new_prefix.count('/')]
+            strip_count = 0 if old_prefix == b'/' else old_prefix.count(b'/')
+            new_prefix_parts = new_prefix.split(b'/')
+            result_prefix = grafted_path.split(b'/')[:new_prefix.count(b'/')]
             result = [(p, None) for p in result_prefix] \
                 + clean_path_components[strip_count:]
             # Now set the graft point name to match the end of new_prefix.
             graft_point = len(result_prefix)
             result[graft_point] = \
                 (new_prefix_parts[-1], clean_path_components[strip_count][1])
-            if new_prefix == '/': # --graft ...=/ is a special case.
+            if new_prefix == b'/': # --graft ...=/ is a special case.
                 return result[1:]
             return result
     return path_components(clean_path)
@@ -1077,7 +1123,7 @@ if _localtime:
 # module, which doesn't appear willing to ignore the extra items.
 if _localtime:
     def localtime(time):
-        return bup_time(*_helpers.localtime(time))
+        return bup_time(*_helpers.localtime(floor(time)))
     def utc_offset_str(t):
         """Return the local offset from UTC as "+hhmm" or "-hhmm" for time t.
         If the current UTC offset does not represent an integer number
@@ -1087,7 +1133,7 @@ if _localtime:
         offmin = abs(off) // 60
         m = offmin % 60
         h = (offmin - m) // 60
-        return "%+03d%02d" % (-h if off < 0 else h, m)
+        return b'%+03d%02d' % (-h if off < 0 else h, m)
     def to_py_time(x):
         if isinstance(x, time.struct_time):
             return x
@@ -1095,26 +1141,26 @@ if _localtime:
 else:
     localtime = time.localtime
     def utc_offset_str(t):
-        return time.strftime('%z', localtime(t))
+        return time.strftime(b'%z', localtime(t))
     def to_py_time(x):
         return x
 
 
-_some_invalid_save_parts_rx = re.compile(r'[\[ ~^:?*\\]|\.\.|//|@{')
+_some_invalid_save_parts_rx = re.compile(br'[\[ ~^:?*\\]|\.\.|//|@{')
 
 def valid_save_name(name):
     # Enforce a superset of the restrictions in git-check-ref-format(1)
-    if name == '@' \
-       or name.startswith('/') or name.endswith('/') \
-       or name.endswith('.'):
+    if name == b'@' \
+       or name.startswith(b'/') or name.endswith(b'/') \
+       or name.endswith(b'.'):
         return False
     if _some_invalid_save_parts_rx.search(name):
         return False
     for c in name:
-        if ord(c) < 0x20 or ord(c) == 0x7f:
+        if byte_int(c) < 0x20 or byte_int(c) == 0x7f:
             return False
-    for part in name.split('/'):
-        if part.startswith('.') or part.endswith('.lock'):
+    for part in name.split(b'/'):
+        if part.startswith(b'.') or part.endswith(b'.lock'):
             return False
     return True