]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
save: make --strip-path=/ a no-op
[bup.git] / lib / bup / helpers.py
index c090bc0ab09e0f13a68019b32b88ba3cf5a957bf..c15b357733b378105a160ec8e829f92bdfe0bdaa 100644 (file)
@@ -1,8 +1,12 @@
 """Helper functions and classes for bup."""
 
+from ctypes import sizeof, c_void_p
+from os import environ
+from contextlib import contextmanager
 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
-import hashlib, heapq, operator, time, grp
-from bup import _version, _helpers
+import hashlib, heapq, operator, time, grp, tempfile
+
+from bup import _helpers
 import bup._helpers as _helpers
 import math
 
@@ -126,12 +130,23 @@ def mkdirp(d, mode=None):
             raise
 
 
-def next(it):
-    """Get the next item from an iterator, None if we reached the end."""
-    try:
+_unspecified_next_default = object()
+
+def _fallback_next(it, default=_unspecified_next_default):
+    """Retrieve the next item from the iterator by calling its
+    next() method. If default is given, it is returned if the
+    iterator is exhausted, otherwise StopIteration is raised."""
+
+    if default is _unspecified_next_default:
         return it.next()
-    except StopIteration:
-        return None
+    else:
+        try:
+            return it.next()
+        except StopIteration:
+            return default
+
+if sys.version_info < (2, 6):
+    next =  _fallback_next
 
 
 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
@@ -142,7 +157,7 @@ def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
     count = 0
     total = sum(len(it) for it in iters)
     iters = (iter(it) for it in iters)
-    heap = ((next(it),it) for it in iters)
+    heap = ((next(it, None),it) for it in iters)
     heap = [(e,it) for e,it in heap if e]
 
     heapq.heapify(heap)
@@ -173,16 +188,53 @@ def unlink(f):
     try:
         os.unlink(f)
     except OSError, e:
-        if e.errno == errno.ENOENT:
-            pass  # it doesn't exist, that's what you asked for
+        if e.errno != errno.ENOENT:
+            raise
 
 
-def readpipe(argv):
+def readpipe(argv, preexec_fn=None):
     """Run a subprocess and return its output."""
-    p = subprocess.Popen(argv, stdout=subprocess.PIPE)
-    r = p.stdout.read()
-    p.wait()
-    return r
+    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn=preexec_fn)
+    out, err = p.communicate()
+    if p.returncode != 0:
+        raise Exception('subprocess %r failed with status %d'
+                        % (' '.join(argv), p.returncode))
+    return out
+
+
+def _argmax_base(command):
+    base_size = 2048
+    for c in command:
+        base_size += len(command) + 1
+    for k, v in environ.iteritems():
+        base_size += len(k) + len(v) + 2 + sizeof(c_void_p)
+    return base_size
+
+
+def _argmax_args_size(args):
+    return sum(len(x) + 1 + sizeof(c_void_p) for x in args)
+
+
+def batchpipe(command, args, preexec_fn=None, arg_max=_helpers.SC_ARG_MAX):
+    """If args is not empty, yield the output produced by calling the
+command list with args as a sequence of strings (It may be necessary
+to return multiple strings in order to respect ARG_MAX)."""
+    # The optional arg_max arg is a workaround for an issue with the
+    # current wvtest behavior.
+    base_size = _argmax_base(command)
+    while args:
+        room = arg_max - base_size
+        i = 0
+        while i < len(args):
+            next_size = _argmax_args_size(args[i:i+1])
+            if room - next_size < 0:
+                break
+            room -= next_size
+            i += 1
+        sub_args = args[:i]
+        args = args[i:]
+        assert(len(sub_args))
+        yield readpipe(command + sub_args, preexec_fn=preexec_fn)
 
 
 def realpath(p):
@@ -577,6 +629,42 @@ def chunkyreader(f, count = None):
             yield b
 
 
+@contextmanager
+def atomically_replaced_file(name, mode='w', buffering=-1):
+    """Yield a file that will be atomically renamed name when leaving the block.
+
+    This contextmanager yields an open file object that is backed by a
+    temporary file which will be renamed (atomically) to the target
+    name if everything succeeds.
+
+    The mode and buffering arguments are handled exactly as with open,
+    and the yielded file will have very restrictive permissions, as
+    per mkstemp.
+
+    E.g.::
+
+        with atomically_replaced_file('foo.txt', 'w') as f:
+            f.write('hello jack.')
+
+    """
+
+    (ffd, tempname) = tempfile.mkstemp(dir=os.path.dirname(name),
+                                       text=('b' not in mode))
+    try:
+        try:
+            f = os.fdopen(ffd, mode, buffering)
+        except:
+            os.close(ffd)
+            raise
+        try:
+            yield f
+        finally:
+            f.close()
+        os.rename(tempname, name)
+    finally:
+        unlink(tempname)  # nonexistant file is ignored
+
+
 def slashappend(s):
     """Append "/" to 's' if it doesn't aleady end in "/"."""
     if s and not s.endswith('/'):
@@ -625,6 +713,26 @@ def mmap_readwrite_private(f, sz = 0, close=True):
                     close)
 
 
+def parse_timestamp(epoch_str):
+    """Return the number of nanoseconds since the epoch that are described
+by epoch_str (100ms, 100ns, ...); when epoch_str cannot be parsed,
+throw a ValueError that may contain additional information."""
+    ns_per = {'s' :  1000000000,
+              'ms' : 1000000,
+              'us' : 1000,
+              'ns' : 1}
+    match = re.match(r'^((?:[-+]?[0-9]+)?)(s|ms|us|ns)$', epoch_str)
+    if not match:
+        if re.match(r'^([-+]?[0-9]+)$', epoch_str):
+            raise ValueError('must include units, i.e. 100ns, 100ms, ...')
+        raise ValueError()
+    (n, units) = match.group(1, 2)
+    if not n:
+        n = 1
+    n = int(n)
+    return n * ns_per[units]
+
+
 def parse_num(s):
     """Parse data size information into a float number.
 
@@ -741,8 +849,11 @@ def parse_excludes(options, fatal):
             except IOError, e:
                 raise fatal("couldn't read %s" % parameter)
             for exclude_path in f.readlines():
-                excluded_paths.append(realpath(exclude_path.strip()))
-    return excluded_paths
+                # FIXME: perhaps this should be rstrip('\n')
+                exclude_path = realpath(exclude_path.strip())
+                if exclude_path:
+                    excluded_paths.append(exclude_path)
+    return sorted(frozenset(excluded_paths))
 
 
 def parse_rx_excludes(options, fatal):
@@ -764,6 +875,8 @@ def parse_rx_excludes(options, fatal):
                 raise fatal("couldn't read %s" % parameter)
             for pattern in f.readlines():
                 spattern = pattern.rstrip('\n')
+                if not spattern:
+                    continue
                 try:
                     excluded_patterns.append(re.compile(spattern))
                 except re.error, ex:
@@ -817,6 +930,8 @@ def stripped_path_components(path, strip_prefixes):
     sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
     for bp in sorted_strip_prefixes:
         normalized_bp = os.path.abspath(bp)
+        if normalized_bp == '/':
+            continue
         if normalized_path.startswith(normalized_bp):
             prefix = normalized_path[:len(normalized_bp)]
             result = []
@@ -873,30 +988,3 @@ def grafted_path_components(graft_points, path):
     return path_components(clean_path)
 
 Sha1 = hashlib.sha1
-
-def version_date():
-    """Format bup's version date string for output."""
-    return _version.DATE.split(' ')[0]
-
-
-def version_commit():
-    """Get the commit hash of bup's current version."""
-    return _version.COMMIT
-
-
-def version_tag():
-    """Format bup's version tag (the official version number).
-
-    When generated from a commit other than one pointed to with a tag, the
-    returned string will be "unknown-" followed by the first seven positions of
-    the commit hash.
-    """
-    names = _version.NAMES.strip()
-    assert(names[0] == '(')
-    assert(names[-1] == ')')
-    names = names[1:-1]
-    l = [n.strip() for n in names.split(',')]
-    for n in l:
-        if n.startswith('tag: bup-'):
-            return n[9:]
-    return 'unknown-%s' % _version.COMMIT[:7]