]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/bloom.py
Check that all context managed objects are properly closed
[bup.git] / lib / bup / bloom.py
index 05ebd714b696da13ee1172a24df85c1aae1517ae..b686127f294026911f442a8ff4ac7cbcd62a011b 100644 (file)
@@ -84,6 +84,7 @@ from __future__ import absolute_import
 import os, math, struct
 
 from bup import _helpers
+from bup.compat import pending_raise
 from bup.helpers import (debug1, debug2, log, mmap_read, mmap_readwrite,
                          mmap_readwrite_private, unlink)
 
@@ -106,13 +107,15 @@ bloom_add = _helpers.bloom_add
 class ShaBloom:
     """Wrapper which contains data from multiple index files. """
     def __init__(self, filename, f=None, readwrite=False, expected=-1):
+        self.closed = False
         self.name = filename
-        self.rwfile = None
+        self.readwrite = readwrite
+        self.file = None
         self.map = None
         assert(filename.endswith(b'.bloom'))
         if readwrite:
             assert(expected > 0)
-            self.rwfile = f = f or open(filename, 'r+b')
+            self.file = f = f or open(filename, 'r+b')
             f.seek(0)
 
             # Decide if we want to mmap() the pages as writable ('immediate'
@@ -135,13 +138,12 @@ class ShaBloom:
             self.delaywrite = expected > pages
             debug1('bloom: delaywrite=%r\n' % self.delaywrite)
             if self.delaywrite:
-                self.map = mmap_readwrite_private(self.rwfile, close=False)
+                self.map = mmap_readwrite_private(self.file, close=False)
             else:
-                self.map = mmap_readwrite(self.rwfile, close=False)
+                self.map = mmap_readwrite(self.file, close=False)
         else:
-            self.rwfile = None
-            f = f or open(filename, 'rb')
-            self.map = mmap_read(f)
+            self.file = f or open(filename, 'rb')
+            self.map = mmap_read(self.file)
         got = self.map[0:4]
         if got != b'BLOM':
             log('Warning: invalid BLOM header (%r) in %r\n' % (got, filename))
@@ -167,33 +169,46 @@ class ShaBloom:
             self.idxnames = []
 
     def _init_failed(self):
-        if self.map:
-            self.map = None
-        if self.rwfile:
-            self.rwfile.close()
-            self.rwfile = None
         self.idxnames = []
         self.bits = self.entries = 0
+        self.map, tmp_map = None, self.map
+        self.file, tmp_file = None, self.file
+        try:
+            if tmp_map:
+                tmp_map.close()
+        finally:  # This won't handle pending exceptions correctly in py2
+            if self.file:
+                tmp_file.close()
 
     def valid(self):
         return self.map and self.bits
 
+    def close(self):
+        self.closed = True
+        try:
+            if self.map and self.readwrite:
+                debug2("bloom: closing with %d entries\n" % self.entries)
+                self.map[12:16] = struct.pack('!I', self.entries)
+                if self.delaywrite:
+                    self.file.seek(0)
+                    self.file.write(self.map)
+                else:
+                    self.map.flush()
+                self.file.seek(16 + 2**self.bits)
+                if self.idxnames:
+                    self.file.write(b'\0'.join(self.idxnames))
+        finally:  # This won't handle pending exceptions correctly in py2
+            self._init_failed()
+
     def __del__(self):
-        self.close()
+        assert self.closed
 
-    def close(self):
-        if self.map and self.rwfile:
-            debug2("bloom: closing with %d entries\n" % self.entries)
-            self.map[12:16] = struct.pack('!I', self.entries)
-            if self.delaywrite:
-                self.rwfile.seek(0)
-                self.rwfile.write(self.map)
-            else:
-                self.map.flush()
-            self.rwfile.seek(16 + 2**self.bits)
-            if self.idxnames:
-                self.rwfile.write(b'\0'.join(self.idxnames))
-        self._init_failed()
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type, value, traceback):
+        with pending_raise(value, rethrow=False):
+            self.close()
 
     def pfalse_positive(self, additional=0):
         n = self.entries + additional