]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
Fix typo in documentation for strip_base_path
[bup.git] / lib / bup / helpers.py
index 75cf09cb240a4f8aa9d12b629855a3ed87d92b0e..36b95602e9c8f2088ec9fd55d9693f47619b1d4f 100644 (file)
@@ -1,13 +1,77 @@
+"""Helper functions and classes for bup."""
+
 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
+from bup import _version
+
+# This function should really be in helpers, not in bup.options.  But we
+# want options.py to be standalone so people can include it in other projects.
+from bup.options import _tty_width
+tty_width = _tty_width
+
+
+def atoi(s):
+    """Convert the string 's' to an integer. Return 0 if s is not a number."""
+    try:
+        return int(s or '0')
+    except ValueError:
+        return 0
+
+
+def atof(s):
+    """Convert the string 's' to a float. Return 0 if s is not a number."""
+    try:
+        return float(s or '0')
+    except ValueError:
+        return 0
+
+
+buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
 
 
+# Write (blockingly) to sockets that may or may not be in blocking mode.
+# We need this because our stderr is sometimes eaten by subprocesses
+# (probably ssh) that sometimes make it nonblocking, if only temporarily,
+# leading to race conditions.  Ick.  We'll do it the hard way.
+def _hard_write(fd, buf):
+    while buf:
+        (r,w,x) = select.select([], [fd], [], None)
+        if not w:
+            raise IOError('select(fd) returned without being writable')
+        try:
+            sz = os.write(fd, buf)
+        except OSError, e:
+            if e.errno != errno.EAGAIN:
+                raise
+        assert(sz >= 0)
+        buf = buf[sz:]
+
 def log(s):
-    sys.stderr.write(s)
+    """Print a log message to stderr."""
+    sys.stdout.flush()
+    _hard_write(sys.stderr.fileno(), s)
+
+
+def debug1(s):
+    if buglvl >= 1:
+        log(s)
+
+
+def debug2(s):
+    if buglvl >= 2:
+        log(s)
 
 
-def mkdirp(d):
+def mkdirp(d, mode=None):
+    """Recursively create directories on path 'd'.
+
+    Unlike os.makedirs(), it doesn't raise an exception if the last element of
+    the path already exists.
+    """
     try:
-        os.makedirs(d)
+        if mode:
+            os.makedirs(d, mode)
+        else:
+            os.makedirs(d)
     except OSError, e:
         if e.errno == errno.EEXIST:
             pass
@@ -16,13 +80,19 @@ def mkdirp(d):
 
 
 def next(it):
+    """Get the next item from an iterator, None if we reached the end."""
     try:
         return it.next()
     except StopIteration:
         return None
-    
-    
+
+
 def unlink(f):
+    """Delete a file at path 'f' if it currently exists.
+
+    Unlike os.unlink(), does not throw an exception if the file didn't already
+    exist.
+    """
     try:
         os.unlink(f)
     except OSError, e:
@@ -31,26 +101,20 @@ def unlink(f):
 
 
 def readpipe(argv):
+    """Run a subprocess and return its output."""
     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
     r = p.stdout.read()
     p.wait()
     return r
 
 
-# FIXME: this function isn't very generic, because it splits the filename
-# in an odd way and depends on a terminating '/' to indicate directories.
-# But it's used in a couple of places, so let's put it here.
-def pathsplit(p):
-    l = p.split('/')
-    l = [i+'/' for i in l[:-1]] + l[-1:]
-    if l[-1] == '':
-        l.pop()  # extra blank caused by terminating '/'
-    return l
-
-
-# like os.path.realpath, but doesn't follow a symlink for the last element.
-# (ie. if 'p' itself is itself a symlink, this one won't follow it)
 def realpath(p):
+    """Get the absolute path of a file.
+
+    Behaves like os.path.realpath, but doesn't follow a symlink for the last
+    element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
+    will follow symlinks in p's directory)
+    """
     try:
         st = os.lstat(p)
     except OSError:
@@ -67,6 +131,7 @@ def realpath(p):
 
 _username = None
 def username():
+    """Get the user's login name."""
     global _username
     if not _username:
         uid = os.getuid()
@@ -79,6 +144,7 @@ def username():
 
 _userfullname = None
 def userfullname():
+    """Get the user's full name."""
     global _userfullname
     if not _userfullname:
         uid = os.getuid()
@@ -91,33 +157,46 @@ def userfullname():
 
 _hostname = None
 def hostname():
+    """Get the FQDN of this machine."""
     global _hostname
     if not _hostname:
         _hostname = socket.getfqdn()
     return _hostname
 
 
+_resource_path = None
+def resource_path(subdir=''):
+    global _resource_path
+    if not _resource_path:
+        _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
+    return os.path.join(_resource_path, subdir)
+
 class NotOk(Exception):
     pass
 
 class Conn:
+    """A helper class for bup's client-server protocol."""
     def __init__(self, inp, outp):
         self.inp = inp
         self.outp = outp
 
     def read(self, size):
+        """Read 'size' bytes from input stream."""
         self.outp.flush()
         return self.inp.read(size)
 
     def readline(self):
+        """Read from input stream until a newline is found."""
         self.outp.flush()
         return self.inp.readline()
 
     def write(self, data):
+        """Write 'data' to output stream."""
         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
         self.outp.write(data)
 
     def has_input(self):
+        """Return true if input stream is readable."""
         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
         if rl:
             assert(rl[0] == self.inp.fileno())
@@ -126,9 +205,11 @@ class Conn:
             return None
 
     def ok(self):
+        """Indicate end of output from last sent command."""
         self.write('\nok\n')
 
     def error(self, s):
+        """Indicate server error to the client."""
         s = re.sub(r'\s+', ' ', str(s))
         self.write('\nerror %s\n' % s)
 
@@ -149,17 +230,20 @@ class Conn:
         raise Exception('server exited unexpectedly; see errors above')
 
     def drain_and_check_ok(self):
+        """Remove all data for the current command from input stream."""
         def onempty(rl):
             pass
         return self._check_ok(onempty)
 
     def check_ok(self):
+        """Verify that server action completed successfully."""
         def onempty(rl):
             raise Exception('expected "ok", got %r' % rl)
         return self._check_ok(onempty)
 
 
 def linereader(f):
+    """Generate a list of input lines from 'f' without terminating newlines."""
     while 1:
         line = f.readline()
         if not line:
@@ -168,6 +252,13 @@ def linereader(f):
 
 
 def chunkyreader(f, count = None):
+    """Generate a list of chunks of data read from 'f'.
+
+    If count is None, read until EOF is reached.
+
+    If count is a positive integer, read 'count' bytes from 'f'. If EOF is
+    reached while reading, raise IOError.
+    """
     if count != None:
         while count > 0:
             b = f.read(min(count, 65536))
@@ -182,49 +273,52 @@ def chunkyreader(f, count = None):
             yield b
 
 
-class AutoFlushIter:
-    def __init__(self, it, ondone = None):
-        self.it = it
-        self.ondone = ondone
-
-    def __iter__(self):
-        return self
-        
-    def next(self):
-        return self.it.next()
-        
-    def __del__(self):
-        for i in self.it:
-            pass
-        if self.ondone:
-            self.ondone()
-
-
 def slashappend(s):
+    """Append "/" to 's' if it doesn't aleady end in "/"."""
     if s and not s.endswith('/'):
         return s + '/'
     else:
         return s
 
 
-def _mmap_do(f, len, flags, prot):
-    if not len:
+def _mmap_do(f, sz, flags, prot):
+    if not sz:
         st = os.fstat(f.fileno())
-        len = st.st_size
-    map = mmap.mmap(f.fileno(), len, flags, prot)
+        sz = st.st_size
+    if not sz:
+        # trying to open a zero-length map gives an error, but an empty
+        # string has all the same behaviour of a zero-length map, ie. it has
+        # no elements :)
+        return ''
+    map = mmap.mmap(f.fileno(), sz, flags, prot)
     f.close()  # map will persist beyond file close
     return map
 
 
-def mmap_read(f, len = 0):
-    return _mmap_do(f, len, mmap.MAP_PRIVATE, mmap.PROT_READ)
+def mmap_read(f, sz = 0):
+    """Create a read-only memory mapped region on file 'f'.
+
+    If sz is 0, the region will cover the entire file.
+    """
+    return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
 
 
-def mmap_readwrite(f, len = 0):
-    return _mmap_do(f, len, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
+def mmap_readwrite(f, sz = 0):
+    """Create a read-write memory mapped region on file 'f'.
+
+    If sz is 0, the region will cover the entire file.
+    """
+    return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
 
 
 def parse_num(s):
+    """Parse data size information into a float number.
+
+    Here are some examples of conversions:
+        199.2k means 203981 bytes
+        1GB means 1073741824 bytes
+        2.1 tb means 2199023255552 bytes
+    """
     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
     if not g:
         raise ValueError("can't parse %r as a number" % s)
@@ -246,24 +340,156 @@ def parse_num(s):
     return int(num*mult)
 
 
-# count the number of elements in an iterator (consumes the iterator)
 def count(l):
+    """Count the number of elements in an iterator. (consumes the iterator)"""
     return reduce(lambda x,y: x+1, l)
 
 
-def atoi(s):
-    try:
-        return int(s or '0')
-    except ValueError:
-        return 0
-
-
 saved_errors = []
 def add_error(e):
+    """Append an error message to the list of saved errors.
+
+    Once processing is able to stop and output the errors, the saved errors are
+    accessible in the module variable helpers.saved_errors.
+    """
     saved_errors.append(e)
     log('%-70s\n' % e)
 
 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
 def progress(s):
+    """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
     if istty:
         log(s)
+
+
+def handle_ctrl_c():
+    """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
+
+    The new exception handler will make sure that bup will exit without an ugly
+    stacktrace when Ctrl-C is hit.
+    """
+    oldhook = sys.excepthook
+    def newhook(exctype, value, traceback):
+        if exctype == KeyboardInterrupt:
+            log('Interrupted.\n')
+        else:
+            return oldhook(exctype, value, traceback)
+    sys.excepthook = newhook
+
+
+def columnate(l, prefix):
+    """Format elements of 'l' in columns with 'prefix' leading each line.
+
+    The number of columns is determined automatically based on the string
+    lengths.
+    """
+    if not l:
+        return ""
+    l = l[:]
+    clen = max(len(s) for s in l)
+    ncols = (tty_width() - len(prefix)) / (clen + 2)
+    if ncols <= 1:
+        ncols = 1
+        clen = 0
+    cols = []
+    while len(l) % ncols:
+        l.append('')
+    rows = len(l)/ncols
+    for s in range(0, len(l), rows):
+        cols.append(l[s:s+rows])
+    out = ''
+    for row in zip(*cols):
+        out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
+    return out
+
+def parse_date_or_fatal(str, fatal):
+    """Parses the given date or calls Option.fatal().
+    For now we expect a string that contains a float."""
+    try:
+        date = atof(str)
+    except ValueError, e:
+        raise fatal('invalid date format (should be a float): %r' % e)
+    else:
+        return date
+
+def strip_path(prefix, path):
+    """Strips a given prefix from a path.
+
+    First both paths are normalized.
+
+    Raises an Exception if no prefix is given.
+    """
+    if prefix == None:
+        raise Exception('no path given')
+
+    normalized_prefix = os.path.realpath(prefix)
+    debug2("normalized_prefix: %s\n" % normalized_prefix)
+    normalized_path = os.path.realpath(path)
+    debug2("normalized_path: %s\n" % normalized_path)
+    if normalized_path.startswith(normalized_prefix):
+        return normalized_path[len(normalized_prefix):]
+    else:
+        return path
+
+def strip_base_path(path, base_paths):
+    """Strips the base path from a given path.
+
+
+    Determines the base path for the given string and then strips it
+    using strip_path().
+    Iterates over all base_paths from long to short, to prevent that
+    a too short base_path is removed.
+    """
+    normalized_path = os.path.realpath(path)
+    sorted_base_paths = sorted(base_paths, key=len, reverse=True)
+    for bp in sorted_base_paths:
+        if normalized_path.startswith(os.path.realpath(bp)):
+            return strip_path(bp, normalized_path)
+    return path
+
+def graft_path(graft_points, path):
+    normalized_path = os.path.realpath(path)
+    for graft_point in graft_points:
+        old_prefix, new_prefix = graft_point
+        if normalized_path.startswith(old_prefix):
+            return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
+    return normalized_path
+
+
+# hashlib is only available in python 2.5 or higher, but the 'sha' module
+# produces a DeprecationWarning in python 2.6 or higher.  We want to support
+# python 2.4 and above without any stupid warnings, so let's try using hashlib
+# first, and downgrade if it fails.
+try:
+    import hashlib
+except ImportError:
+    import sha
+    Sha1 = sha.sha
+else:
+    Sha1 = hashlib.sha1
+
+
+def version_date():
+    """Format bup's version date string for output."""
+    return _version.DATE.split(' ')[0]
+
+def version_commit():
+    """Get the commit hash of bup's current version."""
+    return _version.COMMIT
+
+def version_tag():
+    """Format bup's version tag (the official version number).
+
+    When generated from a commit other than one pointed to with a tag, the
+    returned string will be "unknown-" followed by the first seven positions of
+    the commit hash.
+    """
+    names = _version.NAMES.strip()
+    assert(names[0] == '(')
+    assert(names[-1] == ')')
+    names = names[1:-1]
+    l = [n.strip() for n in names.split(',')]
+    for n in l:
+        if n.startswith('tag: bup-'):
+            return n[9:]
+    return 'unknown-%s' % _version.COMMIT[:7]