]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/hlinkdb.py
hlinkdb: respect umask/sgid/etc. when creating new db
[bup.git] / lib / bup / hlinkdb.py
index 5c1ec0099621b46130217c4622937b7617473e7c..f7e5d72153b6d72e7e5934296eb9169bfe4a12d8 100644 (file)
@@ -1,92 +1,80 @@
-import cPickle, errno, os, tempfile
 
-from bup import compat
+from contextlib import ExitStack
+import os, pickle
+
+from bup.helpers import atomically_replaced_file, unlink
+
+
+def pickle_load(filename):
+    try:
+        f = open(filename, 'rb')
+    except FileNotFoundError:
+        return None
+    with f:
+        return pickle.load(f, encoding='bytes')
+
 
 class Error(Exception):
     pass
 
 class HLinkDB:
     def __init__(self, filename):
+        self.closed = False
+        self._cleanup = ExitStack()
+        self._filename = filename
+        self._pending_save = None
         # Map a "dev:ino" node to a list of paths associated with that node.
-        self._node_paths = {}
-        # Map a path to a "dev:ino" node.
+        self._node_paths = pickle_load(filename) or {}
+        # Map a path to a "dev:ino" node (a reverse hard link index).
         self._path_node = {}
-        self._filename = filename
-        self._save_prepared = None
-        self._tmpname = None
-        f = None
-        try:
-            f = open(filename, 'r')
-        except IOError as e:
-            if e.errno == errno.ENOENT:
-                pass
-            else:
-                raise
-        if f:
-            try:
-                self._node_paths = cPickle.load(f)
-            finally:
-                f.close()
-                f = None
-        # Set up the reverse hard link index.
-        for node, paths in compat.items(self._node_paths):
+        for node, paths in self._node_paths.items():
             for path in paths:
                 self._path_node[path] = node
 
     def prepare_save(self):
         """ Commit all of the relevant data to disk.  Do as much work
         as possible without actually making the changes visible."""
-        if self._save_prepared:
+        if self._pending_save:
             raise Error('save of %r already in progress' % self._filename)
-        if self._node_paths:
-            (dir, name) = os.path.split(self._filename)
-            (ffd, self._tmpname) = tempfile.mkstemp('.tmp', name, dir)
-            try:
-                try:
-                    f = os.fdopen(ffd, 'wb', 65536)
-                except:
-                    os.close(ffd)
-                    raise
-                try:
-                    cPickle.dump(self._node_paths, f, 2)
-                finally:
-                    f.close()
-                    f = None
-            except:
-                tmpname = self._tmpname
-                self._tmpname = None
-                os.unlink(tmpname)
-                raise
-        self._save_prepared = True
+        with self._cleanup:
+            if self._node_paths:
+                dir, name = os.path.split(self._filename)
+                self._pending_save = atomically_replaced_file(self._filename,
+                                                              mode='wb',
+                                                              buffering=65536)
+                with self._cleanup.enter_context(self._pending_save) as f:
+                    pickle.dump(self._node_paths, f, 2)
+            else: # No data
+                self._cleanup.callback(lambda: unlink(self._filename))
+            self._cleanup = self._cleanup.pop_all()
 
     def commit_save(self):
-        if not self._save_prepared:
+        self.closed = True
+        if self._node_paths and not self._pending_save:
             raise Error('cannot commit save of %r; no save prepared'
                         % self._filename)
-        if self._tmpname:
-            os.rename(self._tmpname, self._filename)
-            self._tmpname = None
-        else: # No data -- delete _filename if it exists.
-            try:
-                os.unlink(self._filename)
-            except OSError as e:
-                if e.errno == errno.ENOENT:
-                    pass
-                else:
-                    raise
-        self._save_prepared = None
+        self._cleanup.close()
+        self._pending_save = None
 
     def abort_save(self):
-        if self._tmpname:
-            os.unlink(self._tmpname)
-            self._tmpname = None
+        self.closed = True
+        with self._cleanup:
+            if self._pending_save:
+                self._pending_save.cancel()
+        self._pending_save = None
 
-    def __del__(self):
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type, value, traceback):
         self.abort_save()
 
+    def __del__(self):
+        assert self.closed
+
     def add_path(self, path, dev, ino):
         # Assume path is new.
-        node = '%s:%s' % (dev, ino)
+        node = b'%d:%d' % (dev, ino)
         self._path_node[path] = node
         link_paths = self._node_paths.get(node)
         if link_paths and path not in link_paths:
@@ -114,5 +102,5 @@ class HLinkDB:
             del self._path_node[path]
 
     def node_paths(self, dev, ino):
-        node = '%s:%s' % (dev, ino)
+        node = b'%d:%d' % (dev, ino)
         return self._node_paths[node]