]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
Move pwd grp functions to pwdgrp module; require binary fields
[bup.git] / lib / bup / helpers.py
index 5bae5dee943d477a97f6dd65c73cc154eec089f8..f6b71cd1c7f643ba786f01ad1e94493a62e22d76 100644 (file)
@@ -1,13 +1,26 @@
 """Helper functions and classes for bup."""
 
+from __future__ import absolute_import, division
 from collections import namedtuple
+from contextlib import contextmanager
 from ctypes import sizeof, c_void_p
 from os import environ
-from contextlib import contextmanager
+from pipes import quote
+from subprocess import PIPE, Popen
 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
 import hashlib, heapq, math, operator, time, grp, tempfile
 
 from bup import _helpers
+from bup import compat
+# This function should really be in helpers, not in bup.options.  But we
+# want options.py to be standalone so people can include it in other projects.
+from bup.options import _tty_width as tty_width
+
+
+class Nonlocal:
+    """Helper to deal with Python scoping issues"""
+    pass
+
 
 sc_page_size = os.sysconf('SC_PAGE_SIZE')
 assert(sc_page_size > 0)
@@ -16,10 +29,11 @@ sc_arg_max = os.sysconf('SC_ARG_MAX')
 if sc_arg_max == -1:  # "no definite limit" - let's choose 2M
     sc_arg_max = 2 * 1024 * 1024
 
-# This function should really be in helpers, not in bup.options.  But we
-# want options.py to be standalone so people can include it in other projects.
-from bup.options import _tty_width
-tty_width = _tty_width
+def last(iterable):
+    result = None
+    for result in iterable:
+        pass
+    return result
 
 
 def atoi(s):
@@ -41,11 +55,75 @@ def atof(s):
 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
 
 
-# If the platform doesn't have fdatasync (OS X), fall back to fsync.
 try:
-    fdatasync = os.fdatasync
+    _fdatasync = os.fdatasync
 except AttributeError:
-    fdatasync = os.fsync
+    _fdatasync = os.fsync
+
+if sys.platform.startswith('darwin'):
+    # Apparently os.fsync on OS X doesn't guarantee to sync all the way down
+    import fcntl
+    def fdatasync(fd):
+        try:
+            return fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
+        except IOError as e:
+            # Fallback for file systems (SMB) that do not support F_FULLFSYNC
+            if e.errno == errno.ENOTSUP:
+                return _fdatasync(fd)
+            else:
+                raise
+else:
+    fdatasync = _fdatasync
+
+
+def partition(predicate, stream):
+    """Returns (leading_matches_it, rest_it), where leading_matches_it
+    must be completely exhausted before traversing rest_it.
+
+    """
+    stream = iter(stream)
+    ns = Nonlocal()
+    ns.first_nonmatch = None
+    def leading_matches():
+        for x in stream:
+            if predicate(x):
+                yield x
+            else:
+                ns.first_nonmatch = (x,)
+                break
+    def rest():
+        if ns.first_nonmatch:
+            yield ns.first_nonmatch[0]
+            for x in stream:
+                yield x
+    return (leading_matches(), rest())
+
+
+def merge_dict(*xs):
+    result = {}
+    for x in xs:
+        result.update(x)
+    return result
+
+
+def lines_until_sentinel(f, sentinel, ex_type):
+    # sentinel must end with \n and must contain only one \n
+    while True:
+        line = f.readline()
+        if not (line and line.endswith('\n')):
+            raise ex_type('Hit EOF while reading line')
+        if line == sentinel:
+            return
+        yield line
+
+
+def stat_if_exists(path):
+    try:
+        return os.stat(path)
+    except OSError as e:
+        if e.errno != errno.ENOENT:
+            raise
+    return None
 
 
 # Write (blockingly) to sockets that may or may not be in blocking mode.
@@ -136,25 +214,6 @@ def mkdirp(d, mode=None):
             raise
 
 
-_unspecified_next_default = object()
-
-def _fallback_next(it, default=_unspecified_next_default):
-    """Retrieve the next item from the iterator by calling its
-    next() method. If default is given, it is returned if the
-    iterator is exhausted, otherwise StopIteration is raised."""
-
-    if default is _unspecified_next_default:
-        return it.next()
-    else:
-        try:
-            return it.next()
-        except StopIteration:
-            return default
-
-if sys.version_info < (2, 6):
-    next =  _fallback_next
-
-
 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
     if key:
         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
@@ -177,7 +236,7 @@ def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
             yield e
         count += 1
         try:
-            e = it.next() # Don't use next() function, it's too expensive
+            e = next(it)
         except StopIteration:
             heapq.heappop(heap) # remove current
         else:
@@ -198,9 +257,38 @@ def unlink(f):
             raise
 
 
-def readpipe(argv, preexec_fn=None):
+def shstr(cmd):
+    if isinstance(cmd, compat.str_type):
+        return cmd
+    else:
+        return ' '.join(map(quote, cmd))
+
+exc = subprocess.check_call
+
+def exo(cmd,
+        input=None,
+        stdin=None,
+        stderr=None,
+        shell=False,
+        check=True,
+        preexec_fn=None):
+    if input:
+        assert stdin in (None, PIPE)
+        stdin = PIPE
+    p = Popen(cmd,
+              stdin=stdin, stdout=PIPE, stderr=stderr,
+              shell=shell,
+              preexec_fn=preexec_fn)
+    out, err = p.communicate(input)
+    if check and p.returncode != 0:
+        raise Exception('subprocess %r failed with status %d, stderr: %r'
+                        % (' '.join(map(quote, cmd)), p.returncode, err))
+    return out, err, p
+
+def readpipe(argv, preexec_fn=None, shell=False):
     """Run a subprocess and return its output."""
-    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn=preexec_fn)
+    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn=preexec_fn,
+                         shell=shell)
     out, err = p.communicate()
     if p.returncode != 0:
         raise Exception('subprocess %r failed with status %d'
@@ -212,7 +300,7 @@ def _argmax_base(command):
     base_size = 2048
     for c in command:
         base_size += len(command) + 1
-    for k, v in environ.iteritems():
+    for k, v in compat.items(environ):
         base_size += len(k) + len(v) + 2 + sizeof(c_void_p)
     return base_size
 
@@ -243,8 +331,8 @@ to return multiple strings in order to respect ARG_MAX)."""
         yield readpipe(command + sub_args, preexec_fn=preexec_fn)
 
 
-def realpath(p):
-    """Get the absolute path of a file.
+def resolve_parent(p):
+    """Return the absolute path of a file without following any final symlink.
 
     Behaves like os.path.realpath, but doesn't follow a symlink for the last
     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
@@ -269,15 +357,17 @@ def detect_fakeroot():
     return os.getenv("FAKEROOTKEY") != None
 
 
-def is_superuser():
-    if sys.platform.startswith('cygwin'):
-        import ctypes
-        return ctypes.cdll.shell32.IsUserAnAdmin()
-    else:
+if sys.platform.startswith('cygwin'):
+    def is_superuser():
+        # https://cygwin.com/ml/cygwin/2015-02/msg00057.html
+        groups = os.getgroups()
+        return 544 in groups or 0 in groups
+else:
+    def is_superuser():
         return os.geteuid() == 0
 
 
-def _cache_key_value(get_value, key, cache):
+def cache_key_value(get_value, key, cache):
     """Return (value, was_cached).  If there is a value in the cache
     for key, use that, otherwise, call get_value(key) which should
     throw a KeyError if there is no value -- in which case the cached
@@ -296,80 +386,6 @@ def _cache_key_value(get_value, key, cache):
     return value, False
 
 
-_uid_to_pwd_cache = {}
-_name_to_pwd_cache = {}
-
-def pwd_from_uid(uid):
-    """Return password database entry for uid (may be a cached value).
-    Return None if no entry is found.
-    """
-    global _uid_to_pwd_cache, _name_to_pwd_cache
-    entry, cached = _cache_key_value(pwd.getpwuid, uid, _uid_to_pwd_cache)
-    if entry and not cached:
-        _name_to_pwd_cache[entry.pw_name] = entry
-    return entry
-
-
-def pwd_from_name(name):
-    """Return password database entry for name (may be a cached value).
-    Return None if no entry is found.
-    """
-    global _uid_to_pwd_cache, _name_to_pwd_cache
-    entry, cached = _cache_key_value(pwd.getpwnam, name, _name_to_pwd_cache)
-    if entry and not cached:
-        _uid_to_pwd_cache[entry.pw_uid] = entry
-    return entry
-
-
-_gid_to_grp_cache = {}
-_name_to_grp_cache = {}
-
-def grp_from_gid(gid):
-    """Return password database entry for gid (may be a cached value).
-    Return None if no entry is found.
-    """
-    global _gid_to_grp_cache, _name_to_grp_cache
-    entry, cached = _cache_key_value(grp.getgrgid, gid, _gid_to_grp_cache)
-    if entry and not cached:
-        _name_to_grp_cache[entry.gr_name] = entry
-    return entry
-
-
-def grp_from_name(name):
-    """Return password database entry for name (may be a cached value).
-    Return None if no entry is found.
-    """
-    global _gid_to_grp_cache, _name_to_grp_cache
-    entry, cached = _cache_key_value(grp.getgrnam, name, _name_to_grp_cache)
-    if entry and not cached:
-        _gid_to_grp_cache[entry.gr_gid] = entry
-    return entry
-
-
-_username = None
-def username():
-    """Get the user's login name."""
-    global _username
-    if not _username:
-        uid = os.getuid()
-        _username = pwd_from_uid(uid)[0] or 'user%d' % uid
-    return _username
-
-
-_userfullname = None
-def userfullname():
-    """Get the user's full name."""
-    global _userfullname
-    if not _userfullname:
-        uid = os.getuid()
-        entry = pwd_from_uid(uid)
-        if entry:
-            _userfullname = entry[4].split(',')[0] or entry[0]
-        if not _userfullname:
-            _userfullname = 'user%d' % uid
-    return _userfullname
-
-
 _hostname = None
 def hostname():
     """Get the FQDN of this machine."""
@@ -391,9 +407,9 @@ def format_filesize(size):
     size = float(size)
     if size < unit:
         return "%d" % (size)
-    exponent = int(math.log(size) / math.log(unit))
+    exponent = int(math.log(size) // math.log(unit))
     size_prefix = "KMGTPE"[exponent - 1]
-    return "%.1f%s" % (size / math.pow(unit, exponent), size_prefix)
+    return "%.1f%s" % (size // math.pow(unit, exponent), size_prefix)
 
 
 class NotOk(Exception):
@@ -562,7 +578,7 @@ class DemuxConn(BaseConn):
                 if not self._next_packet(timeout):
                     return False
             try:
-                self.buf = self.reader.next()
+                self.buf = next(self.reader)
                 return True
             except StopIteration:
                 self.reader = None
@@ -730,7 +746,7 @@ if _mincore:
         pref_chunk_size = 64 * 1024 * 1024
         chunk_size = sc_page_size
         if (sc_page_size < pref_chunk_size):
-            chunk_size = sc_page_size * (pref_chunk_size / sc_page_size)
+            chunk_size = sc_page_size * (pref_chunk_size // sc_page_size)
         _fmincore_chunk_size = chunk_size
 
     def fmincore(fd):
@@ -742,13 +758,13 @@ if _mincore:
             return bytearray(0)
         if not _fmincore_chunk_size:
             _set_fmincore_chunk_size()
-        pages_per_chunk = _fmincore_chunk_size / sc_page_size;
-        page_count = (st.st_size + sc_page_size - 1) / sc_page_size;
-        chunk_count = page_count / _fmincore_chunk_size
+        pages_per_chunk = _fmincore_chunk_size // sc_page_size;
+        page_count = (st.st_size + sc_page_size - 1) // sc_page_size;
+        chunk_count = page_count // _fmincore_chunk_size
         if chunk_count < 1:
             chunk_count = 1
         result = bytearray(page_count)
-        for ci in xrange(chunk_count):
+        for ci in compat.range(chunk_count):
             pos = _fmincore_chunk_size * ci;
             msize = min(_fmincore_chunk_size, st.st_size - pos)
             try:
@@ -758,7 +774,12 @@ if _mincore:
                     # Perhaps the file was a pipe, i.e. "... | bup split ..."
                     return None
                 raise ex
-            _mincore(m, msize, 0, result, ci * pages_per_chunk);
+            try:
+                _mincore(m, msize, 0, result, ci * pages_per_chunk)
+            except OSError as ex:
+                if ex.errno == errno.ENOSYS:
+                    return None
+                raise
         return result
 
 
@@ -832,6 +853,15 @@ def clear_errors():
     saved_errors = []
 
 
+def die_if_errors(msg=None, status=1):
+    global saved_errors
+    if saved_errors:
+        if not msg:
+            msg = 'warning: %d errors encountered\n' % len(saved_errors)
+        log(msg)
+        sys.exit(status)
+
+
 def handle_ctrl_c():
     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
 
@@ -857,15 +887,15 @@ def columnate(l, prefix):
         return ""
     l = l[:]
     clen = max(len(s) for s in l)
-    ncols = (tty_width() - len(prefix)) / (clen + 2)
+    ncols = (tty_width() - len(prefix)) // (clen + 2)
     if ncols <= 1:
         ncols = 1
         clen = 0
     cols = []
     while len(l) % ncols:
         l.append('')
-    rows = len(l)/ncols
-    for s in range(0, len(l), rows):
+    rows = len(l) // ncols
+    for s in compat.range(0, len(l), rows):
         cols.append(l[s:s+rows])
     out = ''
     for row in zip(*cols):
@@ -891,15 +921,15 @@ def parse_excludes(options, fatal):
     for flag in options:
         (option, parameter) = flag
         if option == '--exclude':
-            excluded_paths.append(realpath(parameter))
+            excluded_paths.append(resolve_parent(parameter))
         elif option == '--exclude-from':
             try:
-                f = open(realpath(parameter))
+                f = open(resolve_parent(parameter))
             except IOError as e:
                 raise fatal("couldn't read %s" % parameter)
             for exclude_path in f.readlines():
                 # FIXME: perhaps this should be rstrip('\n')
-                exclude_path = realpath(exclude_path.strip())
+                exclude_path = resolve_parent(exclude_path.strip())
                 if exclude_path:
                     excluded_paths.append(exclude_path)
     return sorted(frozenset(excluded_paths))
@@ -919,7 +949,7 @@ def parse_rx_excludes(options, fatal):
                 fatal('invalid --exclude-rx pattern (%s): %s' % (parameter, ex))
         elif option == '--exclude-rx-from':
             try:
-                f = open(realpath(parameter))
+                f = open(resolve_parent(parameter))
             except IOError as e:
                 raise fatal("couldn't read %s" % parameter)
             for pattern in f.readlines():
@@ -957,7 +987,7 @@ def path_components(path):
     Example:
       '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
     if not path.startswith('/'):
-        raise Exception, 'path must start with "/": %s' % path
+        raise Exception('path must start with "/": %s' % path)
     # Since we assume path startswith('/'), we can skip the first element.
     result = [('', '/')]
     norm_path = os.path.abspath(path)
@@ -1075,3 +1105,41 @@ else:
         return time.strftime('%z', localtime(t))
     def to_py_time(x):
         return x
+
+
+_some_invalid_save_parts_rx = re.compile(r'[\[ ~^:?*\\]|\.\.|//|@{')
+
+def valid_save_name(name):
+    # Enforce a superset of the restrictions in git-check-ref-format(1)
+    if name == '@' \
+       or name.startswith('/') or name.endswith('/') \
+       or name.endswith('.'):
+        return False
+    if _some_invalid_save_parts_rx.search(name):
+        return False
+    for c in name:
+        if ord(c) < 0x20 or ord(c) == 0x7f:
+            return False
+    for part in name.split('/'):
+        if part.startswith('.') or part.endswith('.lock'):
+            return False
+    return True
+
+
+_period_rx = re.compile(r'^([0-9]+)(s|min|h|d|w|m|y)$')
+
+def period_as_secs(s):
+    if s == 'forever':
+        return float('inf')
+    match = _period_rx.match(s)
+    if not match:
+        return None
+    mag = int(match.group(1))
+    scale = match.group(2)
+    return mag * {'s': 1,
+                  'min': 60,
+                  'h': 60 * 60,
+                  'd': 60 * 60 * 24,
+                  'w': 60 * 60 * 24 * 7,
+                  'm': 60 * 60 * 24 * 31,
+                  'y': 60 * 60 * 24 * 366}[scale]