]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/helpers.py
Check that all context managed objects are properly closed
[bup.git] / lib / bup / helpers.py
index f38435557f29c1642609205e3765a1bd64d5647b..fdc683bd7c2ca739699e760a0b1eda5c3a8f156d 100644 (file)
@@ -7,23 +7,41 @@ from ctypes import sizeof, c_void_p
 from math import floor
 from os import environ
 from subprocess import PIPE, Popen
-import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
-import hashlib, heapq, math, operator, time, grp, tempfile
+import sys, os, subprocess, errno, select, mmap, stat, re, struct
+import hashlib, heapq, math, operator, time, tempfile
 
 from bup import _helpers
 from bup import compat
-from bup.compat import argv_bytes, byte_int
+from bup.compat import argv_bytes, byte_int, nullcontext, pending_raise
 from bup.io import byte_stream, path_msg
 # This function should really be in helpers, not in bup.options.  But we
 # want options.py to be standalone so people can include it in other projects.
 from bup.options import _tty_width as tty_width
 
 
+buglvl = int(os.environ.get('BUP_DEBUG', 0))
+
+
 class Nonlocal:
     """Helper to deal with Python scoping issues"""
     pass
 
 
+def nullcontext_if_not(manager):
+    return manager if manager is not None else nullcontext()
+
+
+@contextmanager
+def finalized(enter_result=None, finalize=None):
+    assert finalize
+    try:
+        yield enter_result
+    except BaseException as ex:
+        with pending_raise(ex):
+            finalize(enter_result)
+    finalize(enter_result)
+
+
 sc_page_size = os.sysconf('SC_PAGE_SIZE')
 assert(sc_page_size > 0)
 
@@ -37,26 +55,6 @@ def last(iterable):
         pass
     return result
 
-
-def atoi(s):
-    """Convert s (ascii bytes) to an integer. Return 0 if s is not a number."""
-    try:
-        return int(s or b'0')
-    except ValueError:
-        return 0
-
-
-def atof(s):
-    """Convert s (ascii bytes) to a float. Return 0 if s is not a number."""
-    try:
-        return float(s or b'0')
-    except ValueError:
-        return 0
-
-
-buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
-
-
 try:
     _fdatasync = os.fdatasync
 except AttributeError:
@@ -165,8 +163,8 @@ def debug2(s):
         log(s)
 
 
-istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
-istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
+istty1 = os.isatty(1) or (int(os.environ.get('BUP_FORCE_TTY', 0)) & 1)
+istty2 = os.isatty(2) or (int(os.environ.get('BUP_FORCE_TTY', 0)) & 2)
 _last_progress = ''
 def progress(s):
     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
@@ -178,7 +176,7 @@ def progress(s):
 
 def qprogress(s):
     """Calls progress() only if we haven't printed progress in a while.
-    
+
     This avoids overloading the stderr buffer with excess junk.
     """
     global _last_prog
@@ -290,6 +288,8 @@ def quote(x):
     if isinstance(x, compat.str_type):
         return squote(x)
     assert False
+    # some versions of pylint get confused
+    return None
 
 def shstr(cmd):
     """Return a shell quoted string for cmd if it's a sequence, else cmd.
@@ -434,7 +434,7 @@ def hostname():
     """Get the FQDN of this machine."""
     global _hostname
     if not _hostname:
-        _hostname = socket.getfqdn().encode('iso-8859-1')
+        _hostname = _helpers.gethostname()
     return _hostname
 
 
@@ -445,7 +445,7 @@ def format_filesize(size):
         return "%d" % (size)
     exponent = int(math.log(size) // math.log(unit))
     size_prefix = "KMGTPE"[exponent - 1]
-    return "%.1f%s" % (size // math.pow(unit, exponent), size_prefix)
+    return "%.1f%s" % (size / math.pow(unit, exponent), size_prefix)
 
 
 class NotOk(Exception):
@@ -454,11 +454,16 @@ class NotOk(Exception):
 
 class BaseConn:
     def __init__(self, outp):
+        self._base_closed = False
         self.outp = outp
 
     def close(self):
+        self._base_closed = True
         while self._read(65536): pass
 
+    def __del__(self):
+        assert self._base_closed
+
     def _read(self, size):
         raise NotImplementedError("Subclasses must implement _read")
 
@@ -578,13 +583,19 @@ class DemuxConn(BaseConn):
         # Anything that comes through before the sync string was not
         # multiplexed and can be assumed to be debug/log before mux init.
         tail = b''
+        stderr = byte_stream(sys.stderr)
         while tail != b'BUPMUX':
+            # Make sure to write all pre-BUPMUX output to stderr
             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
             if not b:
-                raise IOError('demux: unexpected EOF during initialization')
+                ex = IOError('demux: unexpected EOF during initialization')
+                with pending_raise(ex):
+                    stderr.write(tail)
+                    stderr.flush()
             tail += b
-            byte_stream(sys.stderr).write(tail[:-6])  # pre-mux log messages
+            stderr.write(tail[:-6])
             tail = tail[-6:]
+        stderr.flush()
         self.infd = infd
         self.reader = None
         self.buf = None
@@ -601,7 +612,13 @@ class DemuxConn(BaseConn):
         assert(rl[0] == self.infd)
         ns = b''.join(checked_reader(self.infd, 5))
         n, fdw = struct.unpack('!IB', ns)
-        assert(n <= MAX_PACKET)
+        if n > MAX_PACKET:
+            # assume that something went wrong and print stuff
+            ns += os.read(self.infd, 1024)
+            stderr = byte_stream(sys.stderr)
+            stderr.write(ns)
+            stderr.flush()
+            raise Exception("Connection broken")
         if fdw == 1:
             self.reader = checked_reader(self.infd, n)
         elif fdw == 2:
@@ -747,7 +764,7 @@ def _mmap_do(f, sz, flags, prot, close):
         # string has all the same behaviour of a zero-length map, ie. it has
         # no elements :)
         return ''
-    map = mmap.mmap(f.fileno(), sz, flags, prot)
+    map = compat.mmap(f.fileno(), sz, flags, prot)
     if close:
         f.close()  # map will persist beyond file close
     return map
@@ -803,15 +820,13 @@ if _mincore:
             _set_fmincore_chunk_size()
         pages_per_chunk = _fmincore_chunk_size // sc_page_size;
         page_count = (st.st_size + sc_page_size - 1) // sc_page_size;
-        chunk_count = page_count // _fmincore_chunk_size
-        if chunk_count < 1:
-            chunk_count = 1
+        chunk_count = (st.st_size + _fmincore_chunk_size - 1) // _fmincore_chunk_size
         result = bytearray(page_count)
         for ci in compat.range(chunk_count):
             pos = _fmincore_chunk_size * ci;
             msize = min(_fmincore_chunk_size, st.st_size - pos)
             try:
-                m = mmap.mmap(fd, msize, mmap.MAP_PRIVATE, 0, 0, pos)
+                m = compat.mmap(fd, msize, mmap.MAP_PRIVATE, 0, 0, pos)
             except mmap.error as ex:
                 if ex.errno == errno.EINVAL or ex.errno == errno.ENODEV:
                     # Perhaps the file was a pipe, i.e. "... | bup split ..."
@@ -915,7 +930,7 @@ def handle_ctrl_c():
         if exctype == KeyboardInterrupt:
             log('\nInterrupted.\n')
         else:
-            return oldhook(exctype, value, traceback)
+            oldhook(exctype, value, traceback)
     sys.excepthook = newhook
 
 
@@ -1130,7 +1145,7 @@ if _localtime:
 # module, which doesn't appear willing to ignore the extra items.
 if _localtime:
     def localtime(time):
-        return bup_time(*_helpers.localtime(floor(time)))
+        return bup_time(*_helpers.localtime(int(floor(time))))
     def utc_offset_str(t):
         """Return the local offset from UTC as "+hhmm" or "-hhmm" for time t.
         If the current UTC offset does not represent an integer number