]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/hlinkdb.py
hlinkdb: respect umask/sgid/etc. when creating new db
[bup.git] / lib / bup / hlinkdb.py
index 102ee3bd26096d150d4b18312e186e014bc710e9..f7e5d72153b6d72e7e5934296eb9169bfe4a12d8 100644 (file)
@@ -1,11 +1,17 @@
 
-import errno, os, pickle, tempfile
+from contextlib import ExitStack
+import os, pickle
 
-from bup.compat import pending_raise
+from bup.helpers import atomically_replaced_file, unlink
 
 
-def pickle_load(f):
-    return pickle.load(f, encoding='bytes')
+def pickle_load(filename):
+    try:
+        f = open(filename, 'rb')
+    except FileNotFoundError:
+        return None
+    with f:
+        return pickle.load(f, encoding='bytes')
 
 
 class Error(Exception):
@@ -14,28 +20,13 @@ class Error(Exception):
 class HLinkDB:
     def __init__(self, filename):
         self.closed = False
+        self._cleanup = ExitStack()
+        self._filename = filename
+        self._pending_save = None
         # Map a "dev:ino" node to a list of paths associated with that node.
-        self._node_paths = {}
-        # Map a path to a "dev:ino" node.
+        self._node_paths = pickle_load(filename) or {}
+        # Map a path to a "dev:ino" node (a reverse hard link index).
         self._path_node = {}
-        self._filename = filename
-        self._save_prepared = None
-        self._tmpname = None
-        f = None
-        try:
-            f = open(filename, 'rb')
-        except IOError as e:
-            if e.errno == errno.ENOENT:
-                pass
-            else:
-                raise
-        if f:
-            try:
-                self._node_paths = pickle_load(f)
-            finally:
-                f.close()
-                f = None
-        # Set up the reverse hard link index.
         for node, paths in self._node_paths.items():
             for path in paths:
                 self._path_node[path] = node
@@ -43,59 +34,40 @@ class HLinkDB:
     def prepare_save(self):
         """ Commit all of the relevant data to disk.  Do as much work
         as possible without actually making the changes visible."""
-        if self._save_prepared:
+        if self._pending_save:
             raise Error('save of %r already in progress' % self._filename)
-        if self._node_paths:
-            (dir, name) = os.path.split(self._filename)
-            (ffd, self._tmpname) = tempfile.mkstemp(b'.tmp', name, dir)
-            try:
-                try:
-                    f = os.fdopen(ffd, 'wb', 65536)
-                except:
-                    os.close(ffd)
-                    raise
-                try:
+        with self._cleanup:
+            if self._node_paths:
+                dir, name = os.path.split(self._filename)
+                self._pending_save = atomically_replaced_file(self._filename,
+                                                              mode='wb',
+                                                              buffering=65536)
+                with self._cleanup.enter_context(self._pending_save) as f:
                     pickle.dump(self._node_paths, f, 2)
-                finally:
-                    f.close()
-                    f = None
-            except:
-                tmpname = self._tmpname
-                self._tmpname = None
-                os.unlink(tmpname)
-                raise
-        self._save_prepared = True
+            else: # No data
+                self._cleanup.callback(lambda: unlink(self._filename))
+            self._cleanup = self._cleanup.pop_all()
 
     def commit_save(self):
         self.closed = True
-        if not self._save_prepared:
+        if self._node_paths and not self._pending_save:
             raise Error('cannot commit save of %r; no save prepared'
                         % self._filename)
-        if self._tmpname:
-            os.rename(self._tmpname, self._filename)
-            self._tmpname = None
-        else: # No data -- delete _filename if it exists.
-            try:
-                os.unlink(self._filename)
-            except OSError as e:
-                if e.errno == errno.ENOENT:
-                    pass
-                else:
-                    raise
-        self._save_prepared = None
+        self._cleanup.close()
+        self._pending_save = None
 
     def abort_save(self):
         self.closed = True
-        if self._tmpname:
-            os.unlink(self._tmpname)
-            self._tmpname = None
+        with self._cleanup:
+            if self._pending_save:
+                self._pending_save.cancel()
+        self._pending_save = None
 
     def __enter__(self):
         return self
 
     def __exit__(self, type, value, traceback):
-        with pending_raise(value, rethrow=True):
-            self.abort_save()
+        self.abort_save()
 
     def __del__(self):
         assert self.closed