]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
Merge branch 'master' into config
[bup.git] / lib / bup / helpers.py
index 1223456b0f889c40516c99d5ec1e974a3994d9c6..d9d177cabf30931a14e074e828204c2b7ed13999 100644 (file)
@@ -1,6 +1,8 @@
 """Helper functions and classes for bup."""
-import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
-from bup import _version
+
+import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
+import heapq, operator, time, platform
+from bup import _version, _helpers
 import bup._helpers as _helpers
 
 # This function should really be in helpers, not in bup.options.  But we
@@ -45,10 +47,14 @@ def _hard_write(fd, buf):
         assert(sz >= 0)
         buf = buf[sz:]
 
+
+_last_prog = 0
 def log(s):
     """Print a log message to stderr."""
+    global _last_prog
     sys.stdout.flush()
     _hard_write(sys.stderr.fileno(), s)
+    _last_prog = 0
 
 
 def debug1(s):
@@ -61,6 +67,39 @@ def debug2(s):
         log(s)
 
 
+istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
+istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
+_last_progress = ''
+def progress(s):
+    """Calls log() if stderr is a TTY.  Does nothing otherwise."""
+    global _last_progress
+    if istty2:
+        log(s)
+        _last_progress = s
+
+
+def qprogress(s):
+    """Calls progress() only if we haven't printed progress in a while.
+    
+    This avoids overloading the stderr buffer with excess junk.
+    """
+    global _last_prog
+    now = time.time()
+    if now - _last_prog > 0.1:
+        progress(s)
+        _last_prog = now
+
+
+def reprogress():
+    """Calls progress() to redisplay the most recent progress message.
+
+    Useful after you've printed some other message that wipes out the
+    progress line.
+    """
+    if _last_progress and _last_progress.endswith('\r'):
+        progress(_last_progress)
+
+
 def mkdirp(d, mode=None):
     """Recursively create directories on path 'd'.
 
@@ -87,6 +126,36 @@ def next(it):
         return None
 
 
+def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
+    if key:
+        samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
+    else:
+        samekey = operator.eq
+    count = 0
+    total = sum(len(it) for it in iters)
+    iters = (iter(it) for it in iters)
+    heap = ((next(it),it) for it in iters)
+    heap = [(e,it) for e,it in heap if e]
+
+    heapq.heapify(heap)
+    pe = None
+    while heap:
+        if not count % pfreq:
+            pfunc(count, total)
+        e, it = heap[0]
+        if not samekey(e, pe):
+            pe = e
+            yield e
+        count += 1
+        try:
+            e = it.next() # Don't use next() function, it's too expensive
+        except StopIteration:
+            heapq.heappop(heap) # remove current
+        else:
+            heapq.heapreplace(heap, (e, it)) # shift current to new location
+    pfinal(count, total)
+
+
 def unlink(f):
     """Delete a file at path 'f' if it currently exists.
 
@@ -134,6 +203,14 @@ def detect_fakeroot():
     return os.getenv("FAKEROOTKEY") != None
 
 
+def is_superuser():
+    if platform.system().startswith('CYGWIN'):
+        import ctypes
+        return ctypes.cdll.shell32.IsUserAnAdmin()
+    else:
+        return os.geteuid() == 0
+
+
 _username = None
 def username():
     """Get the user's login name."""
@@ -176,24 +253,27 @@ def resource_path(subdir=''):
         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
     return os.path.join(_resource_path, subdir)
 
+
 class NotOk(Exception):
     pass
 
-class Conn:
-    """A helper class for bup's client-server protocol."""
-    def __init__(self, inp, outp):
-        self.inp = inp
+
+class BaseConn:
+    def __init__(self, outp):
         self.outp = outp
 
+    def close(self):
+        while self._read(65536): pass
+
     def read(self, size):
         """Read 'size' bytes from input stream."""
         self.outp.flush()
-        return self.inp.read(size)
+        return self._read(size)
 
     def readline(self):
         """Read from input stream until a newline is found."""
         self.outp.flush()
-        return self.inp.readline()
+        return self._readline()
 
     def write(self, data):
         """Write 'data' to output stream."""
@@ -202,12 +282,7 @@ class Conn:
 
     def has_input(self):
         """Return true if input stream is readable."""
-        [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
-        if rl:
-            assert(rl[0] == self.inp.fileno())
-            return True
-        else:
-            return None
+        raise NotImplemented("Subclasses must implement has_input")
 
     def ok(self):
         """Indicate end of output from last sent command."""
@@ -221,7 +296,7 @@ class Conn:
     def _check_ok(self, onempty):
         self.outp.flush()
         rl = ''
-        for rl in linereader(self.inp):
+        for rl in linereader(self):
             #log('%d got line: %r\n' % (os.getpid(), rl))
             if not rl:  # empty line
                 continue
@@ -247,6 +322,146 @@ class Conn:
         return self._check_ok(onempty)
 
 
+class Conn(BaseConn):
+    def __init__(self, inp, outp):
+        BaseConn.__init__(self, outp)
+        self.inp = inp
+
+    def _read(self, size):
+        return self.inp.read(size)
+
+    def _readline(self):
+        return self.inp.readline()
+
+    def has_input(self):
+        [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
+        if rl:
+            assert(rl[0] == self.inp.fileno())
+            return True
+        else:
+            return None
+
+
+def checked_reader(fd, n):
+    while n > 0:
+        rl, _, _ = select.select([fd], [], [])
+        assert(rl[0] == fd)
+        buf = os.read(fd, n)
+        if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
+        yield buf
+        n -= len(buf)
+
+
+MAX_PACKET = 128 * 1024
+def mux(p, outfd, outr, errr):
+    try:
+        fds = [outr, errr]
+        while p.poll() is None:
+            rl, _, _ = select.select(fds, [], [])
+            for fd in rl:
+                if fd == outr:
+                    buf = os.read(outr, MAX_PACKET)
+                    if not buf: break
+                    os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
+                elif fd == errr:
+                    buf = os.read(errr, 1024)
+                    if not buf: break
+                    os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
+    finally:
+        os.write(outfd, struct.pack('!IB', 0, 3))
+
+
+class DemuxConn(BaseConn):
+    """A helper class for bup's client-server protocol."""
+    def __init__(self, infd, outp):
+        BaseConn.__init__(self, outp)
+        # Anything that comes through before the sync string was not
+        # multiplexed and can be assumed to be debug/log before mux init.
+        tail = ''
+        while tail != 'BUPMUX':
+            b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
+            if not b:
+                raise IOError('demux: unexpected EOF during initialization')
+            tail += b
+            sys.stderr.write(tail[:-6])  # pre-mux log messages
+            tail = tail[-6:]
+        self.infd = infd
+        self.reader = None
+        self.buf = None
+        self.closed = False
+
+    def write(self, data):
+        self._load_buf(0)
+        BaseConn.write(self, data)
+
+    def _next_packet(self, timeout):
+        if self.closed: return False
+        rl, wl, xl = select.select([self.infd], [], [], timeout)
+        if not rl: return False
+        assert(rl[0] == self.infd)
+        ns = ''.join(checked_reader(self.infd, 5))
+        n, fdw = struct.unpack('!IB', ns)
+        assert(n <= MAX_PACKET)
+        if fdw == 1:
+            self.reader = checked_reader(self.infd, n)
+        elif fdw == 2:
+            for buf in checked_reader(self.infd, n):
+                sys.stderr.write(buf)
+        elif fdw == 3:
+            self.closed = True
+            debug2("DemuxConn: marked closed\n")
+        return True
+
+    def _load_buf(self, timeout):
+        if self.buf is not None:
+            return True
+        while not self.closed:
+            while not self.reader:
+                if not self._next_packet(timeout):
+                    return False
+            try:
+                self.buf = self.reader.next()
+                return True
+            except StopIteration:
+                self.reader = None
+        return False
+
+    def _read_parts(self, ix_fn):
+        while self._load_buf(None):
+            assert(self.buf is not None)
+            i = ix_fn(self.buf)
+            if i is None or i == len(self.buf):
+                yv = self.buf
+                self.buf = None
+            else:
+                yv = self.buf[:i]
+                self.buf = self.buf[i:]
+            yield yv
+            if i is not None:
+                break
+
+    def _readline(self):
+        def find_eol(buf):
+            try:
+                return buf.index('\n')+1
+            except ValueError:
+                return None
+        return ''.join(self._read_parts(find_eol))
+
+    def _read(self, size):
+        csize = [size]
+        def until_size(buf): # Closes on csize
+            if len(buf) < csize[0]:
+                csize[0] -= len(buf)
+                return None
+            else:
+                return csize[0]
+        return ''.join(self._read_parts(until_size))
+
+    def has_input(self):
+        return self._load_buf(0)
+
+
 def linereader(f):
     """Generate a list of input lines from 'f' without terminating newlines."""
     while 1:
@@ -286,29 +501,44 @@ def slashappend(s):
         return s
 
 
-def _mmap_do(f, sz, flags, prot):
+def _mmap_do(f, sz, flags, prot, close):
     if not sz:
         st = os.fstat(f.fileno())
         sz = st.st_size
+    if not sz:
+        # trying to open a zero-length map gives an error, but an empty
+        # string has all the same behaviour of a zero-length map, ie. it has
+        # no elements :)
+        return ''
     map = mmap.mmap(f.fileno(), sz, flags, prot)
-    f.close()  # map will persist beyond file close
+    if close:
+        f.close()  # map will persist beyond file close
     return map
 
 
-def mmap_read(f, sz = 0):
+def mmap_read(f, sz = 0, close=True):
     """Create a read-only memory mapped region on file 'f'.
-
     If sz is 0, the region will cover the entire file.
     """
-    return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
+    return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
 
 
-def mmap_readwrite(f, sz = 0):
+def mmap_readwrite(f, sz = 0, close=True):
     """Create a read-write memory mapped region on file 'f'.
+    If sz is 0, the region will cover the entire file.
+    """
+    return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
+                    close)
 
+
+def mmap_readwrite_private(f, sz = 0, close=True):
+    """Create a read-write memory mapped region on file 'f'.
     If sz is 0, the region will cover the entire file.
+    The map is private, which means the changes are never flushed back to the
+    file.
     """
-    return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
+    return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
+                    close)
 
 
 def parse_num(s):
@@ -355,11 +585,10 @@ def add_error(e):
     saved_errors.append(e)
     log('%-70s\n' % e)
 
-istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
-def progress(s):
-    """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
-    if istty:
-        log(s)
+
+def clear_errors():
+    global saved_errors
+    saved_errors = []
 
 
 def handle_ctrl_c():
@@ -402,6 +631,7 @@ def columnate(l, prefix):
         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
     return out
 
+
 def parse_date_or_fatal(str, fatal):
     """Parses the given date or calls Option.fatal().
     For now we expect a string that contains a float."""
@@ -413,137 +643,50 @@ def parse_date_or_fatal(str, fatal):
         return date
 
 
-class FSTime():
-    # Class to represent filesystem timestamps.  Use integer
-    # nanoseconds on platforms where we have the higher resolution
-    # lstat.  Use the native python stat representation (floating
-    # point seconds) otherwise.
-
-    def __cmp__(self, x):
-        return self._value.__cmp__(x._value)
-
-    def to_timespec(self):
-        """Return (s, ns) where ns is always non-negative
-        and t = s + ns / 10e8""" # metadata record rep (and libc rep)
-        s_ns = self.secs_nsecs()
-        if s_ns[0] > 0 or s_ns[1] >= 0:
-            return s_ns
-        return (s_ns[0] - 1, 10**9 + s_ns[1]) # ns is negative
-
-    if _helpers.lstat: # Use integer nanoseconds.
-
-        @staticmethod
-        def from_secs(secs):
-            ts = FSTime()
-            ts._value = int(secs * 10**9)
-            return ts
-
-        @staticmethod
-        def from_timespec(timespec):
-            ts = FSTime()
-            ts._value = timespec[0] * 10**9 + timespec[1]
-            return ts
-
-        @staticmethod
-        def from_stat_time(stat_time):
-            return FSTime.from_timespec(stat_time)
-
-        def approx_secs(self):
-            return self._value / 10e8;
-
-        def secs_nsecs(self):
-            "Return a (s, ns) pair: -1.5s -> (-1, -10**9 / 2)."
-            if self._value >= 0:
-                return (self._value / 10**9, self._value % 10**9)
-            abs_val = -self._value
-            return (- (abs_val / 10**9), - (abs_val % 10**9))
-
-    else: # Use python default floating-point seconds.
-
-        @staticmethod
-        def from_secs(secs):
-            ts = FSTime()
-            ts._value = secs
-            return ts
-
-        @staticmethod
-        def from_timespec(timespec):
-            ts = FSTime()
-            ts._value = timespec[0] + (timespec[1] / 10e8)
-            return ts
-
-        @staticmethod
-        def from_stat_time(stat_time):
-            ts = FSTime()
-            ts._value = stat_time
-            return ts
-
-        def approx_secs(self):
-            return self._value
-
-        def secs_nsecs(self):
-            "Return a (s, ns) pair: -1.5s -> (-1, -5**9)."
-            x = math.modf(self._value)
-            return (x[1], x[0] * 10**9)
-
-
-def lutime(path, times):
-    if _helpers.utimensat:
-        atime = times[0].to_timespec()
-        mtime = times[1].to_timespec()
-        return _helpers.utimensat(_helpers.AT_FDCWD, path, (atime, mtime),
-                                  _helpers.AT_SYMLINK_NOFOLLOW)
-    else:
-        return None
+def strip_path(prefix, path):
+    """Strips a given prefix from a path.
 
+    First both paths are normalized.
 
-def utime(path, times):
-    if _helpers.utimensat:
-        atime = times[0].to_timespec()
-        mtime = times[1].to_timespec()
-        return _helpers.utimensat(_helpers.AT_FDCWD, path, (atime, mtime), 0)
+    Raises an Exception if no prefix is given.
+    """
+    if prefix == None:
+        raise Exception('no path given')
+
+    normalized_prefix = os.path.realpath(prefix)
+    debug2("normalized_prefix: %s\n" % normalized_prefix)
+    normalized_path = os.path.realpath(path)
+    debug2("normalized_path: %s\n" % normalized_path)
+    if normalized_path.startswith(normalized_prefix):
+        return normalized_path[len(normalized_prefix):]
     else:
-        atime = times[0].approx_secs()
-        mtime = times[1].approx_secs()
-        os.utime(path, (atime, mtime))
+        return path
 
 
-class stat_result():
-    pass
+def strip_base_path(path, base_paths):
+    """Strips the base path from a given path.
 
 
-def lstat(path):
-    result = stat_result()
-    if _helpers.lstat:
-        st = _helpers.lstat(path)
-        (result.st_mode,
-         result.st_ino,
-         result.st_dev,
-         result.st_nlink,
-         result.st_uid,
-         result.st_gid,
-         result.st_rdev,
-         result.st_size,
-         atime,
-         mtime,
-         ctime) = st
-    else:
-        st = os.lstat(path)
-        result.st_mode = st.st_mode
-        result.st_ino = st.st_ino
-        result.st_dev = st.st_dev
-        result.st_nlink = st.st_nlink
-        result.st_uid = st.st_uid
-        result.st_gid = st.st_gid
-        result.st_rdev = st.st_rdev
-        result.st_size = st.st_size
-        atime = FSTime.from_stat_time(st.st_atime)
-        mtime = FSTime.from_stat_time(st.st_mtime)
-        ctime = FSTime.from_stat_time(st.st_ctime)
-    result.st_atime = FSTime.from_stat_time(atime)
-    result.st_mtime = FSTime.from_stat_time(mtime)
-    result.st_ctime = FSTime.from_stat_time(ctime)
-    return result
+    Determines the base path for the given string and then strips it
+    using strip_path().
+    Iterates over all base_paths from long to short, to prevent that
+    a too short base_path is removed.
+    """
+    normalized_path = os.path.realpath(path)
+    sorted_base_paths = sorted(base_paths, key=len, reverse=True)
+    for bp in sorted_base_paths:
+        if normalized_path.startswith(os.path.realpath(bp)):
+            return strip_path(bp, normalized_path)
+    return path
+
+
+def graft_path(graft_points, path):
+    normalized_path = os.path.realpath(path)
+    for graft_point in graft_points:
+        old_prefix, new_prefix = graft_point
+        if normalized_path.startswith(old_prefix):
+            return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
+    return normalized_path
 
 
 # hashlib is only available in python 2.5 or higher, but the 'sha' module
@@ -563,10 +706,12 @@ def version_date():
     """Format bup's version date string for output."""
     return _version.DATE.split(' ')[0]
 
+
 def version_commit():
     """Get the commit hash of bup's current version."""
     return _version.COMMIT
 
+
 def version_tag():
     """Format bup's version tag (the official version number).