]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
Update base_version to 0.34~ for 0.34 development
[bup.git] / lib / bup / helpers.py
index 35c88cdb42490c8acd9ac0dfdf0baa5d4de0a771..81770339e4955543a498b134f7ad0b329a83c9b6 100644 (file)
@@ -2,28 +2,56 @@
 
 from __future__ import absolute_import, division
 from collections import namedtuple
-from contextlib import contextmanager
+from contextlib import ExitStack
 from ctypes import sizeof, c_void_p
 from math import floor
 from os import environ
 from subprocess import PIPE, Popen
-import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
-import hashlib, heapq, math, operator, time, grp, tempfile
+from tempfile import mkdtemp
+from shutil import rmtree
+import sys, os, subprocess, errno, select, mmap, stat, re, struct
+import hashlib, heapq, math, operator, time
 
 from bup import _helpers
-from bup import compat
-from bup.compat import byte_int
-from bup.io import path_msg
+from bup import io
+from bup.compat import argv_bytes, byte_int, nullcontext, pending_raise
+from bup.io import byte_stream, path_msg
 # This function should really be in helpers, not in bup.options.  But we
 # want options.py to be standalone so people can include it in other projects.
 from bup.options import _tty_width as tty_width
 
 
+buglvl = int(os.environ.get('BUP_DEBUG', 0))
+
+
 class Nonlocal:
     """Helper to deal with Python scoping issues"""
     pass
 
 
+def nullcontext_if_not(manager):
+    return manager if manager is not None else nullcontext()
+
+
+class finalized:
+    def __init__(self, enter_result=None, finalize=None):
+        assert finalize
+        self.finalize = finalize
+        self.enter_result = enter_result
+    def __enter__(self):
+        return self.enter_result
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.finalize(self.enter_result)
+
+def temp_dir(*args, **kwargs):
+    # This is preferable to tempfile.TemporaryDirectory because the
+    # latter uses @contextmanager, and so will always eventually be
+    # deleted if it's handed to an ExitStack, whenever the stack is
+    # gc'ed, even if you pop_all() (the new stack will also trigger
+    # the deletion) because
+    # https://github.com/python/cpython/issues/88458
+    return finalized(mkdtemp(*args, **kwargs), lambda x: rmtree(x))
+
 sc_page_size = os.sysconf('SC_PAGE_SIZE')
 assert(sc_page_size > 0)
 
@@ -37,26 +65,6 @@ def last(iterable):
         pass
     return result
 
-
-def atoi(s):
-    """Convert s (ascii bytes) to an integer. Return 0 if s is not a number."""
-    try:
-        return int(s or b'0')
-    except ValueError:
-        return 0
-
-
-def atof(s):
-    """Convert s (ascii bytes) to a float. Return 0 if s is not a number."""
-    try:
-        return float(s or b'0')
-    except ValueError:
-        return 0
-
-
-buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
-
-
 try:
     _fdatasync = os.fdatasync
 except AttributeError:
@@ -112,7 +120,7 @@ def lines_until_sentinel(f, sentinel, ex_type):
     # sentinel must end with \n and must contain only one \n
     while True:
         line = f.readline()
-        if not (line and line.endswith('\n')):
+        if not (line and line.endswith(b'\n')):
             raise ex_type('Hit EOF while reading line')
         if line == sentinel:
             return
@@ -165,8 +173,8 @@ def debug2(s):
         log(s)
 
 
-istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
-istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
+istty1 = os.isatty(1) or (int(os.environ.get('BUP_FORCE_TTY', 0)) & 1)
+istty2 = os.isatty(2) or (int(os.environ.get('BUP_FORCE_TTY', 0)) & 2)
 _last_progress = ''
 def progress(s):
     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
@@ -178,7 +186,7 @@ def progress(s):
 
 def qprogress(s):
     """Calls progress() only if we haven't printed progress in a while.
-    
+
     This avoids overloading the stderr buffer with excess junk.
     """
     global _last_prog
@@ -287,9 +295,11 @@ def squote(x):
 def quote(x):
     if isinstance(x, bytes):
         return bquote(x)
-    if isinstance(x, compat.str_type):
+    if isinstance(x, str):
         return squote(x)
     assert False
+    # some versions of pylint get confused
+    return None
 
 def shstr(cmd):
     """Return a shell quoted string for cmd if it's a sequence, else cmd.
@@ -300,11 +310,11 @@ def shstr(cmd):
     call() and friends.  e.g. log(shstr(cmd)); call(cmd)
 
     """
-    if isinstance(cmd, (bytes, compat.str_type)):
+    if isinstance(cmd, (bytes, str)):
         return cmd
     elif all(isinstance(x, bytes) for x in cmd):
         return b' '.join(map(bquote, cmd))
-    elif all(isinstance(x, compat.str_type) for x in cmd):
+    elif all(isinstance(x, str) for x in cmd):
         return ' '.join(map(squote, cmd))
     raise TypeError('unsupported shstr argument: ' + repr(cmd))
 
@@ -317,36 +327,33 @@ def exo(cmd,
         stderr=None,
         shell=False,
         check=True,
-        preexec_fn=None):
+        preexec_fn=None,
+        close_fds=True):
     if input:
         assert stdin in (None, PIPE)
         stdin = PIPE
     p = Popen(cmd,
               stdin=stdin, stdout=PIPE, stderr=stderr,
               shell=shell,
-              preexec_fn=preexec_fn)
+              preexec_fn=preexec_fn,
+              close_fds=close_fds)
     out, err = p.communicate(input)
     if check and p.returncode != 0:
-        raise Exception('subprocess %r failed with status %d, stderr: %r'
-                        % (b' '.join(map(quote, cmd)), p.returncode, err))
+        raise Exception('subprocess %r failed with status %d%s'
+                        % (b' '.join(map(quote, cmd)), p.returncode,
+                           ', stderr: %r' % err if err else ''))
     return out, err, p
 
 def readpipe(argv, preexec_fn=None, shell=False):
     """Run a subprocess and return its output."""
-    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn=preexec_fn,
-                         shell=shell)
-    out, err = p.communicate()
-    if p.returncode != 0:
-        raise Exception('subprocess %r failed with status %d'
-                        % (b' '.join(argv), p.returncode))
-    return out
+    return exo(argv, preexec_fn=preexec_fn, shell=shell)[0]
 
 
 def _argmax_base(command):
     base_size = 2048
     for c in command:
         base_size += len(command) + 1
-    for k, v in compat.items(environ):
+    for k, v in environ.items():
         base_size += len(k) + len(v) + 2 + sizeof(c_void_p)
     return base_size
 
@@ -437,7 +444,7 @@ def hostname():
     """Get the FQDN of this machine."""
     global _hostname
     if not _hostname:
-        _hostname = socket.getfqdn().encode('iso-8859-1')
+        _hostname = _helpers.gethostname()
     return _hostname
 
 
@@ -448,7 +455,7 @@ def format_filesize(size):
         return "%d" % (size)
     exponent = int(math.log(size) // math.log(unit))
     size_prefix = "KMGTPE"[exponent - 1]
-    return "%.1f%s" % (size // math.pow(unit, exponent), size_prefix)
+    return "%.1f%s" % (size / math.pow(unit, exponent), size_prefix)
 
 
 class NotOk(Exception):
@@ -457,16 +464,33 @@ class NotOk(Exception):
 
 class BaseConn:
     def __init__(self, outp):
+        self._base_closed = False
         self.outp = outp
 
     def close(self):
-        while self._read(65536): pass
+        self._base_closed = True
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, tb):
+        with pending_raise(exc_value, rethrow=False):
+            self.close()
+
+    def __del__(self):
+        assert self._base_closed
+
+    def _read(self, size):
+        raise NotImplementedError("Subclasses must implement _read")
 
     def read(self, size):
         """Read 'size' bytes from input stream."""
         self.outp.flush()
         return self._read(size)
 
+    def _readline(self, size):
+        raise NotImplementedError("Subclasses must implement _readline")
+
     def readline(self):
         """Read from input stream until a newline is found."""
         self.outp.flush()
@@ -479,27 +503,27 @@ class BaseConn:
 
     def has_input(self):
         """Return true if input stream is readable."""
-        raise NotImplemented("Subclasses must implement has_input")
+        raise NotImplementedError("Subclasses must implement has_input")
 
     def ok(self):
         """Indicate end of output from last sent command."""
-        self.write('\nok\n')
+        self.write(b'\nok\n')
 
     def error(self, s):
         """Indicate server error to the client."""
-        s = re.sub(r'\s+', ' ', str(s))
-        self.write('\nerror %s\n' % s)
+        s = re.sub(br'\s+', b' ', s)
+        self.write(b'\nerror %s\n' % s)
 
     def _check_ok(self, onempty):
         self.outp.flush()
-        rl = ''
+        rl = b''
         for rl in linereader(self):
             #log('%d got line: %r\n' % (os.getpid(), rl))
             if not rl:  # empty line
                 continue
-            elif rl == 'ok':
+            elif rl == b'ok':
                 return None
-            elif rl.startswith('error '):
+            elif rl.startswith(b'error '):
                 #log('client: error: %s\n' % rl[6:])
                 return NotOk(rl[6:])
             else:
@@ -574,14 +598,20 @@ class DemuxConn(BaseConn):
         BaseConn.__init__(self, outp)
         # Anything that comes through before the sync string was not
         # multiplexed and can be assumed to be debug/log before mux init.
-        tail = ''
-        while tail != 'BUPMUX':
+        tail = b''
+        stderr = byte_stream(sys.stderr)
+        while tail != b'BUPMUX':
+            # Make sure to write all pre-BUPMUX output to stderr
             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
             if not b:
-                raise IOError('demux: unexpected EOF during initialization')
+                ex = IOError('demux: unexpected EOF during initialization')
+                with pending_raise(ex):
+                    stderr.write(tail)
+                    stderr.flush()
             tail += b
-            sys.stderr.write(tail[:-6])  # pre-mux log messages
+            stderr.write(tail[:-6])
             tail = tail[-6:]
+        stderr.flush()
         self.infd = infd
         self.reader = None
         self.buf = None
@@ -596,14 +626,20 @@ class DemuxConn(BaseConn):
         rl, wl, xl = select.select([self.infd], [], [], timeout)
         if not rl: return False
         assert(rl[0] == self.infd)
-        ns = ''.join(checked_reader(self.infd, 5))
+        ns = b''.join(checked_reader(self.infd, 5))
         n, fdw = struct.unpack('!IB', ns)
-        assert(n <= MAX_PACKET)
+        if n > MAX_PACKET:
+            # assume that something went wrong and print stuff
+            ns += os.read(self.infd, 1024)
+            stderr = byte_stream(sys.stderr)
+            stderr.write(ns)
+            stderr.flush()
+            raise Exception("Connection broken")
         if fdw == 1:
             self.reader = checked_reader(self.infd, n)
         elif fdw == 2:
             for buf in checked_reader(self.infd, n):
-                sys.stderr.write(buf)
+                byte_stream(sys.stderr).write(buf)
         elif fdw == 3:
             self.closed = True
             debug2("DemuxConn: marked closed\n")
@@ -640,10 +676,10 @@ class DemuxConn(BaseConn):
     def _readline(self):
         def find_eol(buf):
             try:
-                return buf.index('\n')+1
+                return buf.index(b'\n')+1
             except ValueError:
                 return None
-        return ''.join(self._read_parts(find_eol))
+        return b''.join(self._read_parts(find_eol))
 
     def _read(self, size):
         csize = [size]
@@ -653,7 +689,7 @@ class DemuxConn(BaseConn):
                 return None
             else:
                 return csize[0]
-        return ''.join(self._read_parts(until_size))
+        return b''.join(self._read_parts(until_size))
 
     def has_input(self):
         return self._load_buf(0)
@@ -690,40 +726,52 @@ def chunkyreader(f, count = None):
             yield b
 
 
-@contextmanager
-def atomically_replaced_file(name, mode='w', buffering=-1):
-    """Yield a file that will be atomically renamed name when leaving the block.
-
-    This contextmanager yields an open file object that is backed by a
-    temporary file which will be renamed (atomically) to the target
-    name if everything succeeds.
-
-    The mode and buffering arguments are handled exactly as with open,
-    and the yielded file will have very restrictive permissions, as
-    per mkstemp.
-
-    E.g.::
-
-        with atomically_replaced_file('foo.txt', 'w') as f:
-            f.write('hello jack.')
-
-    """
-
-    (ffd, tempname) = tempfile.mkstemp(dir=os.path.dirname(name),
-                                       text=('b' not in mode))
-    try:
-        try:
-            f = os.fdopen(ffd, mode, buffering)
-        except:
-            os.close(ffd)
-            raise
-        try:
-            yield f
-        finally:
-            f.close()
-        os.rename(tempname, name)
-    finally:
-        unlink(tempname)  # nonexistant file is ignored
+class atomically_replaced_file:
+    def __init__(self, path, mode='w', buffering=-1):
+        """Return a context manager supporting the atomic replacement of a file.
+
+        The context manager yields an open file object that has been
+        created in a mkdtemp-style temporary directory in the same
+        directory as the path.  The temporary file will be renamed to
+        the target path (atomically if the platform allows it) if
+        there are no exceptions, and the temporary directory will
+        always be removed.  Calling cancel() will prevent the
+        replacement.
+
+        The file object will have a name attribute containing the
+        file's path, and the mode and buffering arguments will be
+        handled exactly as with open().  The resulting permissions
+        will also match those produced by open().
+
+        E.g.::
+
+          with atomically_replaced_file('foo.txt', 'w') as f:
+              f.write('hello jack.')
+
+        """
+        assert 'w' in mode
+        self.path = path
+        self.mode = mode
+        self.buffering = buffering
+        self.canceled = False
+        self.tmp_path = None
+        self.cleanup = ExitStack()
+    def __enter__(self):
+        with self.cleanup:
+            parent, name = os.path.split(self.path)
+            tmpdir = self.cleanup.enter_context(temp_dir(dir=parent,
+                                                         prefix=name + b'-'))
+            self.tmp_path = tmpdir + b'/pending'
+            f = open(self.tmp_path, mode=self.mode, buffering=self.buffering)
+            f = self.cleanup.enter_context(f)
+            self.cleanup = self.cleanup.pop_all()
+            return f
+    def __exit__(self, exc_type, exc_value, traceback):
+        with self.cleanup:
+            if not (self.canceled or exc_type):
+                os.rename(self.tmp_path, self.path)
+    def cancel(self):
+        self.canceled = True
 
 
 def slashappend(s):
@@ -744,7 +792,7 @@ def _mmap_do(f, sz, flags, prot, close):
         # string has all the same behaviour of a zero-length map, ie. it has
         # no elements :)
         return ''
-    map = mmap.mmap(f.fileno(), sz, flags, prot)
+    map = io.mmap(f.fileno(), sz, flags, prot)
     if close:
         f.close()  # map will persist beyond file close
     return map
@@ -800,26 +848,25 @@ if _mincore:
             _set_fmincore_chunk_size()
         pages_per_chunk = _fmincore_chunk_size // sc_page_size;
         page_count = (st.st_size + sc_page_size - 1) // sc_page_size;
-        chunk_count = page_count // _fmincore_chunk_size
-        if chunk_count < 1:
-            chunk_count = 1
+        chunk_count = (st.st_size + _fmincore_chunk_size - 1) // _fmincore_chunk_size
         result = bytearray(page_count)
-        for ci in compat.range(chunk_count):
+        for ci in range(chunk_count):
             pos = _fmincore_chunk_size * ci;
             msize = min(_fmincore_chunk_size, st.st_size - pos)
             try:
-                m = mmap.mmap(fd, msize, mmap.MAP_PRIVATE, 0, 0, pos)
+                m = io.mmap(fd, msize, mmap.MAP_PRIVATE, 0, 0, pos)
             except mmap.error as ex:
-                if ex.errno == errno.EINVAL or ex.errno == errno.ENODEV:
+                if ex.errno in (errno.EINVAL, errno.ENODEV):
                     # Perhaps the file was a pipe, i.e. "... | bup split ..."
                     return None
                 raise ex
-            try:
-                _mincore(m, msize, 0, result, ci * pages_per_chunk)
-            except OSError as ex:
-                if ex.errno == errno.ENOSYS:
-                    return None
-                raise
+            with m:
+                try:
+                    _mincore(m, msize, 0, result, ci * pages_per_chunk)
+                except OSError as ex:
+                    if ex.errno == errno.ENOSYS:
+                        return None
+                    raise
         return result
 
 
@@ -912,7 +959,7 @@ def handle_ctrl_c():
         if exctype == KeyboardInterrupt:
             log('\nInterrupted.\n')
         else:
-            return oldhook(exctype, value, traceback)
+            oldhook(exctype, value, traceback)
     sys.excepthook = newhook
 
 
@@ -922,8 +969,11 @@ def columnate(l, prefix):
     The number of columns is determined automatically based on the string
     lengths.
     """
+    binary = isinstance(prefix, bytes)
+    nothing = b'' if binary else ''
+    nl = b'\n' if binary else '\n'
     if not l:
-        return ""
+        return nothing
     l = l[:]
     clen = max(len(s) for s in l)
     ncols = (tty_width() - len(prefix)) // (clen + 2)
@@ -932,13 +982,14 @@ def columnate(l, prefix):
         clen = 0
     cols = []
     while len(l) % ncols:
-        l.append('')
+        l.append(nothing)
     rows = len(l) // ncols
-    for s in compat.range(0, len(l), rows):
+    for s in range(0, len(l), rows):
         cols.append(l[s:s+rows])
-    out = ''
+    out = nothing
+    fmt = b'%-*s' if binary else '%-*s'
     for row in zip(*cols):
-        out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
+        out += prefix + nothing.join((fmt % (clen+2, s)) for s in row) + nl
     return out
 
 
@@ -960,12 +1011,12 @@ def parse_excludes(options, fatal):
     for flag in options:
         (option, parameter) = flag
         if option == '--exclude':
-            excluded_paths.append(resolve_parent(parameter))
+            excluded_paths.append(resolve_parent(argv_bytes(parameter)))
         elif option == '--exclude-from':
             try:
-                f = open(resolve_parent(parameter))
+                f = open(resolve_parent(argv_bytes(parameter)), 'rb')
             except IOError as e:
-                raise fatal("couldn't read %s" % parameter)
+                raise fatal("couldn't read %r" % parameter)
             for exclude_path in f.readlines():
                 # FIXME: perhaps this should be rstrip('\n')
                 exclude_path = resolve_parent(exclude_path.strip())
@@ -983,22 +1034,22 @@ def parse_rx_excludes(options, fatal):
         (option, parameter) = flag
         if option == '--exclude-rx':
             try:
-                excluded_patterns.append(re.compile(parameter))
+                excluded_patterns.append(re.compile(argv_bytes(parameter)))
             except re.error as ex:
-                fatal('invalid --exclude-rx pattern (%s): %s' % (parameter, ex))
+                fatal('invalid --exclude-rx pattern (%r): %s' % (parameter, ex))
         elif option == '--exclude-rx-from':
             try:
-                f = open(resolve_parent(parameter))
+                f = open(resolve_parent(parameter), 'rb')
             except IOError as e:
-                raise fatal("couldn't read %s" % parameter)
+                raise fatal("couldn't read %r" % parameter)
             for pattern in f.readlines():
-                spattern = pattern.rstrip('\n')
+                spattern = pattern.rstrip(b'\n')
                 if not spattern:
                     continue
                 try:
                     excluded_patterns.append(re.compile(spattern))
                 except re.error as ex:
-                    fatal('invalid --exclude-rx pattern (%s): %s' % (spattern, ex))
+                    fatal('invalid --exclude-rx pattern (%r): %s' % (spattern, ex))
     return excluded_patterns
 
 
@@ -1123,7 +1174,7 @@ if _localtime:
 # module, which doesn't appear willing to ignore the extra items.
 if _localtime:
     def localtime(time):
-        return bup_time(*_helpers.localtime(floor(time)))
+        return bup_time(*_helpers.localtime(int(floor(time))))
     def utc_offset_str(t):
         """Return the local offset from UTC as "+hhmm" or "-hhmm" for time t.
         If the current UTC offset does not represent an integer number
@@ -1165,20 +1216,20 @@ def valid_save_name(name):
     return True
 
 
-_period_rx = re.compile(r'^([0-9]+)(s|min|h|d|w|m|y)$')
+_period_rx = re.compile(br'^([0-9]+)(s|min|h|d|w|m|y)$')
 
 def period_as_secs(s):
-    if s == 'forever':
+    if s == b'forever':
         return float('inf')
     match = _period_rx.match(s)
     if not match:
         return None
     mag = int(match.group(1))
     scale = match.group(2)
-    return mag * {'s': 1,
-                  'min': 60,
-                  'h': 60 * 60,
-                  'd': 60 * 60 * 24,
-                  'w': 60 * 60 * 24 * 7,
-                  'm': 60 * 60 * 24 * 31,
-                  'y': 60 * 60 * 24 * 366}[scale]
+    return mag * {b's': 1,
+                  b'min': 60,
+                  b'h': 60 * 60,
+                  b'd': 60 * 60 * 24,
+                  b'w': 60 * 60 * 24 * 7,
+                  b'm': 60 * 60 * 24 * 31,
+                  b'y': 60 * 60 * 24 * 366}[scale]