]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
save: make --strip-path=/ a no-op
[bup.git] / lib / bup / helpers.py
index da27edb44b29330e6a9998e7862a9cb1747ba248..c15b357733b378105a160ec8e829f92bdfe0bdaa 100644 (file)
@@ -1,8 +1,14 @@
 """Helper functions and classes for bup."""
 
+from ctypes import sizeof, c_void_p
+from os import environ
+from contextlib import contextmanager
 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
-import heapq, operator
-from bup import _version
+import hashlib, heapq, operator, time, grp, tempfile
+
+from bup import _helpers
+import bup._helpers as _helpers
+import math
 
 # This function should really be in helpers, not in bup.options.  But we
 # want options.py to be standalone so people can include it in other projects.
@@ -29,6 +35,13 @@ def atof(s):
 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
 
 
+# If the platform doesn't have fdatasync (OS X), fall back to fsync.
+try:
+    fdatasync = os.fdatasync
+except AttributeError:
+    fdatasync = os.fsync
+
+
 # Write (blockingly) to sockets that may or may not be in blocking mode.
 # We need this because our stderr is sometimes eaten by subprocesses
 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
@@ -46,10 +59,14 @@ def _hard_write(fd, buf):
         assert(sz >= 0)
         buf = buf[sz:]
 
+
+_last_prog = 0
 def log(s):
     """Print a log message to stderr."""
+    global _last_prog
     sys.stdout.flush()
     _hard_write(sys.stderr.fileno(), s)
+    _last_prog = 0
 
 
 def debug1(s):
@@ -62,6 +79,39 @@ def debug2(s):
         log(s)
 
 
+istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
+istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
+_last_progress = ''
+def progress(s):
+    """Calls log() if stderr is a TTY.  Does nothing otherwise."""
+    global _last_progress
+    if istty2:
+        log(s)
+        _last_progress = s
+
+
+def qprogress(s):
+    """Calls progress() only if we haven't printed progress in a while.
+    
+    This avoids overloading the stderr buffer with excess junk.
+    """
+    global _last_prog
+    now = time.time()
+    if now - _last_prog > 0.1:
+        progress(s)
+        _last_prog = now
+
+
+def reprogress():
+    """Calls progress() to redisplay the most recent progress message.
+
+    Useful after you've printed some other message that wipes out the
+    progress line.
+    """
+    if _last_progress and _last_progress.endswith('\r'):
+        progress(_last_progress)
+
+
 def mkdirp(d, mode=None):
     """Recursively create directories on path 'd'.
 
@@ -80,12 +130,23 @@ def mkdirp(d, mode=None):
             raise
 
 
-def next(it):
-    """Get the next item from an iterator, None if we reached the end."""
-    try:
+_unspecified_next_default = object()
+
+def _fallback_next(it, default=_unspecified_next_default):
+    """Retrieve the next item from the iterator by calling its
+    next() method. If default is given, it is returned if the
+    iterator is exhausted, otherwise StopIteration is raised."""
+
+    if default is _unspecified_next_default:
         return it.next()
-    except StopIteration:
-        return None
+    else:
+        try:
+            return it.next()
+        except StopIteration:
+            return default
+
+if sys.version_info < (2, 6):
+    next =  _fallback_next
 
 
 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
@@ -96,7 +157,7 @@ def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
     count = 0
     total = sum(len(it) for it in iters)
     iters = (iter(it) for it in iters)
-    heap = ((next(it),it) for it in iters)
+    heap = ((next(it, None),it) for it in iters)
     heap = [(e,it) for e,it in heap if e]
 
     heapq.heapify(heap)
@@ -127,16 +188,53 @@ def unlink(f):
     try:
         os.unlink(f)
     except OSError, e:
-        if e.errno == errno.ENOENT:
-            pass  # it doesn't exist, that's what you asked for
+        if e.errno != errno.ENOENT:
+            raise
 
 
-def readpipe(argv):
+def readpipe(argv, preexec_fn=None):
     """Run a subprocess and return its output."""
-    p = subprocess.Popen(argv, stdout=subprocess.PIPE)
-    r = p.stdout.read()
-    p.wait()
-    return r
+    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn=preexec_fn)
+    out, err = p.communicate()
+    if p.returncode != 0:
+        raise Exception('subprocess %r failed with status %d'
+                        % (' '.join(argv), p.returncode))
+    return out
+
+
+def _argmax_base(command):
+    base_size = 2048
+    for c in command:
+        base_size += len(command) + 1
+    for k, v in environ.iteritems():
+        base_size += len(k) + len(v) + 2 + sizeof(c_void_p)
+    return base_size
+
+
+def _argmax_args_size(args):
+    return sum(len(x) + 1 + sizeof(c_void_p) for x in args)
+
+
+def batchpipe(command, args, preexec_fn=None, arg_max=_helpers.SC_ARG_MAX):
+    """If args is not empty, yield the output produced by calling the
+command list with args as a sequence of strings (It may be necessary
+to return multiple strings in order to respect ARG_MAX)."""
+    # The optional arg_max arg is a workaround for an issue with the
+    # current wvtest behavior.
+    base_size = _argmax_base(command)
+    while args:
+        room = arg_max - base_size
+        i = 0
+        while i < len(args):
+            next_size = _argmax_args_size(args[i:i+1])
+            if room - next_size < 0:
+                break
+            room -= next_size
+            i += 1
+        sub_args = args[:i]
+        args = args[i:]
+        assert(len(sub_args))
+        yield readpipe(command + sub_args, preexec_fn=preexec_fn)
 
 
 def realpath(p):
@@ -160,16 +258,95 @@ def realpath(p):
     return out
 
 
+def detect_fakeroot():
+    "Return True if we appear to be running under fakeroot."
+    return os.getenv("FAKEROOTKEY") != None
+
+
+def is_superuser():
+    if sys.platform.startswith('cygwin'):
+        import ctypes
+        return ctypes.cdll.shell32.IsUserAnAdmin()
+    else:
+        return os.geteuid() == 0
+
+
+def _cache_key_value(get_value, key, cache):
+    """Return (value, was_cached).  If there is a value in the cache
+    for key, use that, otherwise, call get_value(key) which should
+    throw a KeyError if there is no value -- in which case the cached
+    and returned value will be None.
+    """
+    try: # Do we already have it (or know there wasn't one)?
+        value = cache[key]
+        return value, True
+    except KeyError:
+        pass
+    value = None
+    try:
+        cache[key] = value = get_value(key)
+    except KeyError:
+        cache[key] = None
+    return value, False
+
+
+_uid_to_pwd_cache = {}
+_name_to_pwd_cache = {}
+
+def pwd_from_uid(uid):
+    """Return password database entry for uid (may be a cached value).
+    Return None if no entry is found.
+    """
+    global _uid_to_pwd_cache, _name_to_pwd_cache
+    entry, cached = _cache_key_value(pwd.getpwuid, uid, _uid_to_pwd_cache)
+    if entry and not cached:
+        _name_to_pwd_cache[entry.pw_name] = entry
+    return entry
+
+
+def pwd_from_name(name):
+    """Return password database entry for name (may be a cached value).
+    Return None if no entry is found.
+    """
+    global _uid_to_pwd_cache, _name_to_pwd_cache
+    entry, cached = _cache_key_value(pwd.getpwnam, name, _name_to_pwd_cache)
+    if entry and not cached:
+        _uid_to_pwd_cache[entry.pw_uid] = entry
+    return entry
+
+
+_gid_to_grp_cache = {}
+_name_to_grp_cache = {}
+
+def grp_from_gid(gid):
+    """Return password database entry for gid (may be a cached value).
+    Return None if no entry is found.
+    """
+    global _gid_to_grp_cache, _name_to_grp_cache
+    entry, cached = _cache_key_value(grp.getgrgid, gid, _gid_to_grp_cache)
+    if entry and not cached:
+        _name_to_grp_cache[entry.gr_name] = entry
+    return entry
+
+
+def grp_from_name(name):
+    """Return password database entry for name (may be a cached value).
+    Return None if no entry is found.
+    """
+    global _gid_to_grp_cache, _name_to_grp_cache
+    entry, cached = _cache_key_value(grp.getgrnam, name, _name_to_grp_cache)
+    if entry and not cached:
+        _gid_to_grp_cache[entry.gr_gid] = entry
+    return entry
+
+
 _username = None
 def username():
     """Get the user's login name."""
     global _username
     if not _username:
         uid = os.getuid()
-        try:
-            _username = pwd.getpwuid(uid)[0]
-        except KeyError:
-            _username = 'user%d' % uid
+        _username = pwd_from_uid(uid)[0] or 'user%d' % uid
     return _username
 
 
@@ -179,9 +356,10 @@ def userfullname():
     global _userfullname
     if not _userfullname:
         uid = os.getuid()
-        try:
-            _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
-        except KeyError:
+        entry = pwd_from_uid(uid)
+        if entry:
+            _userfullname = entry[4].split(',')[0] or entry[0]
+        if not _userfullname:
             _userfullname = 'user%d' % uid
     return _userfullname
 
@@ -202,6 +380,15 @@ def resource_path(subdir=''):
         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
     return os.path.join(_resource_path, subdir)
 
+def format_filesize(size):
+    unit = 1024.0
+    size = float(size)
+    if size < unit:
+        return "%d" % (size)
+    exponent = int(math.log(size) / math.log(unit))
+    size_prefix = "KMGTPE"[exponent - 1]
+    return "%.1f%s" % (size / math.pow(unit, exponent), size_prefix)
+
 
 class NotOk(Exception):
     pass
@@ -442,6 +629,42 @@ def chunkyreader(f, count = None):
             yield b
 
 
+@contextmanager
+def atomically_replaced_file(name, mode='w', buffering=-1):
+    """Yield a file that will be atomically renamed name when leaving the block.
+
+    This contextmanager yields an open file object that is backed by a
+    temporary file which will be renamed (atomically) to the target
+    name if everything succeeds.
+
+    The mode and buffering arguments are handled exactly as with open,
+    and the yielded file will have very restrictive permissions, as
+    per mkstemp.
+
+    E.g.::
+
+        with atomically_replaced_file('foo.txt', 'w') as f:
+            f.write('hello jack.')
+
+    """
+
+    (ffd, tempname) = tempfile.mkstemp(dir=os.path.dirname(name),
+                                       text=('b' not in mode))
+    try:
+        try:
+            f = os.fdopen(ffd, mode, buffering)
+        except:
+            os.close(ffd)
+            raise
+        try:
+            yield f
+        finally:
+            f.close()
+        os.rename(tempname, name)
+    finally:
+        unlink(tempname)  # nonexistant file is ignored
+
+
 def slashappend(s):
     """Append "/" to 's' if it doesn't aleady end in "/"."""
     if s and not s.endswith('/'):
@@ -490,6 +713,26 @@ def mmap_readwrite_private(f, sz = 0, close=True):
                     close)
 
 
+def parse_timestamp(epoch_str):
+    """Return the number of nanoseconds since the epoch that are described
+by epoch_str (100ms, 100ns, ...); when epoch_str cannot be parsed,
+throw a ValueError that may contain additional information."""
+    ns_per = {'s' :  1000000000,
+              'ms' : 1000000,
+              'us' : 1000,
+              'ns' : 1}
+    match = re.match(r'^((?:[-+]?[0-9]+)?)(s|ms|us|ns)$', epoch_str)
+    if not match:
+        if re.match(r'^([-+]?[0-9]+)$', epoch_str):
+            raise ValueError('must include units, i.e. 100ns, 100ms, ...')
+        raise ValueError()
+    (n, units) = match.group(1, 2)
+    if not n:
+        n = 1
+    n = int(n)
+    return n * ns_per[units]
+
+
 def parse_num(s):
     """Parse data size information into a float number.
 
@@ -535,11 +778,9 @@ def add_error(e):
     log('%-70s\n' % e)
 
 
-istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
-def progress(s):
-    """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
-    if istty:
-        log(s)
+def clear_errors():
+    global saved_errors
+    saved_errors = []
 
 
 def handle_ctrl_c():
@@ -551,7 +792,7 @@ def handle_ctrl_c():
     oldhook = sys.excepthook
     def newhook(exctype, value, traceback):
         if exctype == KeyboardInterrupt:
-            log('Interrupted.\n')
+            log('\nInterrupted.\n')
         else:
             return oldhook(exctype, value, traceback)
     sys.excepthook = newhook
@@ -594,88 +835,156 @@ def parse_date_or_fatal(str, fatal):
         return date
 
 
-def strip_path(prefix, path):
-    """Strips a given prefix from a path.
-
-    First both paths are normalized.
+def parse_excludes(options, fatal):
+    """Traverse the options and extract all excludes, or call Option.fatal()."""
+    excluded_paths = []
 
-    Raises an Exception if no prefix is given.
-    """
-    if prefix == None:
-        raise Exception('no path given')
-
-    normalized_prefix = os.path.realpath(prefix)
-    debug2("normalized_prefix: %s\n" % normalized_prefix)
-    normalized_path = os.path.realpath(path)
-    debug2("normalized_path: %s\n" % normalized_path)
-    if normalized_path.startswith(normalized_prefix):
-        return normalized_path[len(normalized_prefix):]
-    else:
-        return path
-
-
-def strip_base_path(path, base_paths):
-    """Strips the base path from a given path.
-
-
-    Determines the base path for the given string and then strips it
-    using strip_path().
-    Iterates over all base_paths from long to short, to prevent that
-    a too short base_path is removed.
-    """
-    normalized_path = os.path.realpath(path)
-    sorted_base_paths = sorted(base_paths, key=len, reverse=True)
-    for bp in sorted_base_paths:
-        if normalized_path.startswith(os.path.realpath(bp)):
-            return strip_path(bp, normalized_path)
-    return path
-
-
-def graft_path(graft_points, path):
-    normalized_path = os.path.realpath(path)
+    for flag in options:
+        (option, parameter) = flag
+        if option == '--exclude':
+            excluded_paths.append(realpath(parameter))
+        elif option == '--exclude-from':
+            try:
+                f = open(realpath(parameter))
+            except IOError, e:
+                raise fatal("couldn't read %s" % parameter)
+            for exclude_path in f.readlines():
+                # FIXME: perhaps this should be rstrip('\n')
+                exclude_path = realpath(exclude_path.strip())
+                if exclude_path:
+                    excluded_paths.append(exclude_path)
+    return sorted(frozenset(excluded_paths))
+
+
+def parse_rx_excludes(options, fatal):
+    """Traverse the options and extract all rx excludes, or call
+    Option.fatal()."""
+    excluded_patterns = []
+
+    for flag in options:
+        (option, parameter) = flag
+        if option == '--exclude-rx':
+            try:
+                excluded_patterns.append(re.compile(parameter))
+            except re.error, ex:
+                fatal('invalid --exclude-rx pattern (%s): %s' % (parameter, ex))
+        elif option == '--exclude-rx-from':
+            try:
+                f = open(realpath(parameter))
+            except IOError, e:
+                raise fatal("couldn't read %s" % parameter)
+            for pattern in f.readlines():
+                spattern = pattern.rstrip('\n')
+                if not spattern:
+                    continue
+                try:
+                    excluded_patterns.append(re.compile(spattern))
+                except re.error, ex:
+                    fatal('invalid --exclude-rx pattern (%s): %s' % (spattern, ex))
+    return excluded_patterns
+
+
+def should_rx_exclude_path(path, exclude_rxs):
+    """Return True if path matches a regular expression in exclude_rxs."""
+    for rx in exclude_rxs:
+        if rx.search(path):
+            debug1('Skipping %r: excluded by rx pattern %r.\n'
+                   % (path, rx.pattern))
+            return True
+    return False
+
+
+# FIXME: Carefully consider the use of functions (os.path.*, etc.)
+# that resolve against the current filesystem in the strip/graft
+# functions for example, but elsewhere as well.  I suspect bup's not
+# always being careful about that.  For some cases, the contents of
+# the current filesystem should be irrelevant, and consulting it might
+# produce the wrong result, perhaps via unintended symlink resolution,
+# for example.
+
+def path_components(path):
+    """Break path into a list of pairs of the form (name,
+    full_path_to_name).  Path must start with '/'.
+    Example:
+      '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
+    if not path.startswith('/'):
+        raise Exception, 'path must start with "/": %s' % path
+    # Since we assume path startswith('/'), we can skip the first element.
+    result = [('', '/')]
+    norm_path = os.path.abspath(path)
+    if norm_path == '/':
+        return result
+    full_path = ''
+    for p in norm_path.split('/')[1:]:
+        full_path += '/' + p
+        result.append((p, full_path))
+    return result
+
+
+def stripped_path_components(path, strip_prefixes):
+    """Strip any prefix in strip_prefixes from path and return a list
+    of path components where each component is (name,
+    none_or_full_fs_path_to_name).  Assume path startswith('/').
+    See thelpers.py for examples."""
+    normalized_path = os.path.abspath(path)
+    sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
+    for bp in sorted_strip_prefixes:
+        normalized_bp = os.path.abspath(bp)
+        if normalized_bp == '/':
+            continue
+        if normalized_path.startswith(normalized_bp):
+            prefix = normalized_path[:len(normalized_bp)]
+            result = []
+            for p in normalized_path[len(normalized_bp):].split('/'):
+                if p: # not root
+                    prefix += '/'
+                prefix += p
+                result.append((p, prefix))
+            return result
+    # Nothing to strip.
+    return path_components(path)
+
+
+def grafted_path_components(graft_points, path):
+    # Create a result that consists of some number of faked graft
+    # directories before the graft point, followed by all of the real
+    # directories from path that are after the graft point.  Arrange
+    # for the directory at the graft point in the result to correspond
+    # to the "orig" directory in --graft orig=new.  See t/thelpers.py
+    # for some examples.
+
+    # Note that given --graft orig=new, orig and new have *nothing* to
+    # do with each other, even if some of their component names
+    # match. i.e. --graft /foo/bar/baz=/foo/bar/bax is semantically
+    # equivalent to --graft /foo/bar/baz=/x/y/z, or even
+    # /foo/bar/baz=/x.
+
+    # FIXME: This can't be the best solution...
+    clean_path = os.path.abspath(path)
     for graft_point in graft_points:
         old_prefix, new_prefix = graft_point
-        if normalized_path.startswith(old_prefix):
-            return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
-    return normalized_path
-
-
-# hashlib is only available in python 2.5 or higher, but the 'sha' module
-# produces a DeprecationWarning in python 2.6 or higher.  We want to support
-# python 2.4 and above without any stupid warnings, so let's try using hashlib
-# first, and downgrade if it fails.
-try:
-    import hashlib
-except ImportError:
-    import sha
-    Sha1 = sha.sha
-else:
-    Sha1 = hashlib.sha1
-
-
-def version_date():
-    """Format bup's version date string for output."""
-    return _version.DATE.split(' ')[0]
-
-
-def version_commit():
-    """Get the commit hash of bup's current version."""
-    return _version.COMMIT
-
-
-def version_tag():
-    """Format bup's version tag (the official version number).
-
-    When generated from a commit other than one pointed to with a tag, the
-    returned string will be "unknown-" followed by the first seven positions of
-    the commit hash.
-    """
-    names = _version.NAMES.strip()
-    assert(names[0] == '(')
-    assert(names[-1] == ')')
-    names = names[1:-1]
-    l = [n.strip() for n in names.split(',')]
-    for n in l:
-        if n.startswith('tag: bup-'):
-            return n[9:]
-    return 'unknown-%s' % _version.COMMIT[:7]
+        # Expand prefixes iff not absolute paths.
+        old_prefix = os.path.normpath(old_prefix)
+        new_prefix = os.path.normpath(new_prefix)
+        if clean_path.startswith(old_prefix):
+            escaped_prefix = re.escape(old_prefix)
+            grafted_path = re.sub(r'^' + escaped_prefix, new_prefix, clean_path)
+            # Handle /foo=/ (at least) -- which produces //whatever.
+            grafted_path = '/' + grafted_path.lstrip('/')
+            clean_path_components = path_components(clean_path)
+            # Count the components that were stripped.
+            strip_count = 0 if old_prefix == '/' else old_prefix.count('/')
+            new_prefix_parts = new_prefix.split('/')
+            result_prefix = grafted_path.split('/')[:new_prefix.count('/')]
+            result = [(p, None) for p in result_prefix] \
+                + clean_path_components[strip_count:]
+            # Now set the graft point name to match the end of new_prefix.
+            graft_point = len(result_prefix)
+            result[graft_point] = \
+                (new_prefix_parts[-1], clean_path_components[strip_count][1])
+            if new_prefix == '/': # --graft ...=/ is a special case.
+                return result[1:]
+            return result
+    return path_components(clean_path)
+
+Sha1 = hashlib.sha1