]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/compat.py
Detect failures to explicitly close mmaps in py3 too
[bup.git] / lib / bup / compat.py
index fbeee4a478fd3da42976d8b50a05f03558a2996d..00575036e27cc46f7fb6b3408e273831526b9daa 100644 (file)
@@ -13,7 +13,7 @@ py3 = py_maj >= 3
 if py3:
 
     # pylint: disable=unused-import
-    from contextlib import nullcontext
+    from contextlib import ExitStack, nullcontext
     from os import environb as environ
     from os import fsdecode, fsencode
     from shlex import quote
@@ -47,13 +47,17 @@ if py3:
 
         """
         def __init__(self, ex, rethrow=True):
+            self.closed = False
             self.ex = ex
             self.rethrow = rethrow
         def __enter__(self):
             return None
         def __exit__(self, exc_type, exc_value, traceback):
+            self.closed = True
             if not exc_type and self.ex and self.rethrow:
                 raise self.ex
+        def __del__(self):
+            assert self.closed
 
     def items(x):
         return x.items()
@@ -144,12 +148,14 @@ else:  # Python 2
 
         """
         def __init__(self, ex, rethrow=True):
+            self.closed = False
             self.ex = ex
             self.rethrow = rethrow
         def __enter__(self):
             if self.ex:
                 add_ex_tb(self.ex)
         def __exit__(self, exc_type, exc_value, traceback):
+            self.closed = True
             if exc_value:
                 if self.ex:
                     add_ex_tb(exc_value)
@@ -157,6 +163,8 @@ else:  # Python 2
                 return
             if self.rethrow and self.ex:
                 raise self.ex
+        def __del__(self):
+            assert self.closed
 
     def dump_traceback(ex):
         stack = [ex]
@@ -174,6 +182,31 @@ else:  # Python 2
             tb = getattr(ex, '__traceback__', None)
             print_exception(type(ex), ex, tb)
 
+    class ExitStack:
+        def __init__(self):
+            self.contexts = []
+
+        def __enter__(self):
+            return self
+
+        def __exit__(self, value_type, value, traceback):
+            init_value = value
+            for ctx in reversed(self.contexts):
+                try:
+                    ctx.__exit__(value_type, value, traceback)
+                except BaseException as ex:
+                    add_ex_tb(ex)
+                    if value:
+                        add_ex_ctx(ex, value)
+                    value_type = type(ex)
+                    value = ex
+                    traceback = ex.__traceback__
+            if value is not init_value:
+                raise value
+
+        def enter_context(self, x):
+            self.contexts.append(x)
+
     def items(x):
         return x.iteritems()