]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/git.py
Update base_version to 0.34~ for 0.34 development
[bup.git] / lib / bup / git.py
index 4c1a95c52a985e5eaee1bb36ef4ca61651559c78..d6a745c02d7d9d355370bf23c85a9d5c645273e1 100644 (file)
@@ -4,19 +4,19 @@ interact with the Git data structures.
 """
 
 from __future__ import absolute_import, print_function
 """
 
 from __future__ import absolute_import, print_function
-import os, sys, zlib, subprocess, struct, stat, re, tempfile, glob
+import os, sys, zlib, subprocess, struct, stat, re, glob
 from array import array
 from binascii import hexlify, unhexlify
 from collections import namedtuple
 from array import array
 from binascii import hexlify, unhexlify
 from collections import namedtuple
+from contextlib import ExitStack
 from itertools import islice
 from itertools import islice
+from shutil import rmtree
 
 from bup import _helpers, hashsplit, path, midx, bloom, xstat
 from bup.compat import (buffer,
                         byte_int, bytes_from_byte, bytes_from_uint,
                         environ,
 
 from bup import _helpers, hashsplit, path, midx, bloom, xstat
 from bup.compat import (buffer,
                         byte_int, bytes_from_byte, bytes_from_uint,
                         environ,
-                        items,
                         pending_raise,
                         pending_raise,
-                        range,
                         reraise)
 from bup.io import path_msg
 from bup.helpers import (Sha1, add_error, chunkyreader, debug1, debug2,
                         reraise)
 from bup.io import path_msg
 from bup.helpers import (Sha1, add_error, chunkyreader, debug1, debug2,
@@ -27,7 +27,9 @@ from bup.helpers import (Sha1, add_error, chunkyreader, debug1, debug2,
                          merge_dict,
                          merge_iter,
                          mmap_read, mmap_readwrite,
                          merge_dict,
                          merge_iter,
                          mmap_read, mmap_readwrite,
+                         nullcontext_if_not,
                          progress, qprogress, stat_if_exists,
                          progress, qprogress, stat_if_exists,
+                         temp_dir,
                          unlink,
                          utc_offset_str)
 
                          unlink,
                          utc_offset_str)
 
@@ -36,7 +38,7 @@ verbose = 0
 repodir = None  # The default repository, once initialized
 
 _typemap =  {b'blob': 3, b'tree': 2, b'commit': 1, b'tag': 4}
 repodir = None  # The default repository, once initialized
 
 _typemap =  {b'blob': 3, b'tree': 2, b'commit': 1, b'tag': 4}
-_typermap = {v: k for k, v in items(_typemap)}
+_typermap = {v: k for k, v in _typemap.items()}
 
 
 _total_searches = 0
 
 
 _total_searches = 0
@@ -370,10 +372,7 @@ def _decode_packobj(buf):
     return (type, zlib.decompress(buf[i+1:]))
 
 
     return (type, zlib.decompress(buf[i+1:]))
 
 
-class PackIdx:
-    def __init__(self):
-        assert(0)
-
+class PackIdx(object):
     def find_offset(self, hash):
         """Get the offset of an object inside the index file."""
         idx = self._idx_from_hash(hash)
     def find_offset(self, hash):
         """Get the offset of an object inside the index file."""
         idx = self._idx_from_hash(hash)
@@ -412,6 +411,8 @@ class PackIdx:
 class PackIdxV1(PackIdx):
     """Object representation of a Git pack index (version 1) file."""
     def __init__(self, filename, f):
 class PackIdxV1(PackIdx):
     """Object representation of a Git pack index (version 1) file."""
     def __init__(self, filename, f):
+        super(PackIdxV1, self).__init__()
+        self.closed = False
         self.name = filename
         self.idxnames = [self.name]
         self.map = mmap_read(f)
         self.name = filename
         self.idxnames = [self.name]
         self.map = mmap_read(f)
@@ -451,15 +452,21 @@ class PackIdxV1(PackIdx):
             yield self.map[ofs : ofs + 20]
 
     def close(self):
             yield self.map[ofs : ofs + 20]
 
     def close(self):
+        self.closed = True
         if self.map is not None:
             self.shatable = None
             self.map.close()
             self.map = None
 
         if self.map is not None:
             self.shatable = None
             self.map.close()
             self.map = None
 
+    def __del__(self):
+        assert self.closed
+
 
 class PackIdxV2(PackIdx):
     """Object representation of a Git pack index (version 2) file."""
     def __init__(self, filename, f):
 
 class PackIdxV2(PackIdx):
     """Object representation of a Git pack index (version 2) file."""
     def __init__(self, filename, f):
+        super(PackIdxV2, self).__init__()
+        self.closed = False
         self.name = filename
         self.idxnames = [self.name]
         self.map = mmap_read(f)
         self.name = filename
         self.idxnames = [self.name]
         self.map = mmap_read(f)
@@ -507,30 +514,62 @@ class PackIdxV2(PackIdx):
             yield self.map[ofs : ofs + 20]
 
     def close(self):
             yield self.map[ofs : ofs + 20]
 
     def close(self):
+        self.closed = True
         if self.map is not None:
             self.shatable = None
             self.map.close()
             self.map = None
 
         if self.map is not None:
             self.shatable = None
             self.map.close()
             self.map = None
 
+    def __del__(self):
+        assert self.closed
+
 
 _mpi_count = 0
 class PackIdxList:
     def __init__(self, dir, ignore_midx=False):
         global _mpi_count
 
 _mpi_count = 0
 class PackIdxList:
     def __init__(self, dir, ignore_midx=False):
         global _mpi_count
+        # Q: was this also intended to prevent opening multiple repos?
         assert(_mpi_count == 0) # these things suck tons of VM; don't waste it
         _mpi_count += 1
         assert(_mpi_count == 0) # these things suck tons of VM; don't waste it
         _mpi_count += 1
+        self.open = True
         self.dir = dir
         self.also = set()
         self.packs = []
         self.do_bloom = False
         self.bloom = None
         self.ignore_midx = ignore_midx
         self.dir = dir
         self.also = set()
         self.packs = []
         self.do_bloom = False
         self.bloom = None
         self.ignore_midx = ignore_midx
-        self.refresh()
+        try:
+            self.refresh()
+        except BaseException as ex:
+            with pending_raise(ex):
+                self.close()
 
 
-    def __del__(self):
+    def close(self):
         global _mpi_count
         global _mpi_count
+        if not self.open:
+            assert _mpi_count == 0
+            return
         _mpi_count -= 1
         _mpi_count -= 1
-        assert(_mpi_count == 0)
+        assert _mpi_count == 0
+        self.also = None
+        self.bloom, bloom = None, self.bloom
+        self.packs, packs = None, self.packs
+        self.open = False
+        with ExitStack() as stack:
+            for pack in packs:
+                stack.enter_context(pack)
+            if bloom:
+                bloom.close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type, value, traceback):
+        with pending_raise(value, rethrow=False):
+            self.close()
+
+    def __del__(self):
+        assert not self.open
 
     def __iter__(self):
         return iter(idxmerge(self.packs))
 
     def __iter__(self):
         return iter(idxmerge(self.packs))
@@ -611,7 +650,6 @@ class PackIdxList:
                                 broken = True
                         if broken:
                             mx.close()
                                 broken = True
                         if broken:
                             mx.close()
-                            del mx
                             unlink(full)
                         else:
                             midxl.append(mx)
                             unlink(full)
                         else:
                             midxl.append(mx)
@@ -643,14 +681,27 @@ class PackIdxList:
                         continue
                     d[full] = ix
             bfull = os.path.join(self.dir, b'bup.bloom')
                         continue
                     d[full] = ix
             bfull = os.path.join(self.dir, b'bup.bloom')
+            new_packs = set(d.values())
+            for p in self.packs:
+                if not p in new_packs:
+                    p.close()
+            new_packs = list(new_packs)
+            new_packs.sort(reverse=True, key=lambda x: len(x))
+            self.packs = new_packs
             if self.bloom is None and os.path.exists(bfull):
                 self.bloom = bloom.ShaBloom(bfull)
             if self.bloom is None and os.path.exists(bfull):
                 self.bloom = bloom.ShaBloom(bfull)
-            self.packs = list(set(d.values()))
-            self.packs.sort(reverse=True, key=lambda x: len(x))
-            if self.bloom and self.bloom.valid() and len(self.bloom) >= len(self):
-                self.do_bloom = True
-            else:
-                self.bloom = None
+            try:
+                if self.bloom and self.bloom.valid() and len(self.bloom) >= len(self):
+                    self.do_bloom = True
+                else:
+                    if self.bloom:
+                        self.bloom, bloom_tmp = None, self.bloom
+                        bloom_tmp.close()
+            except BaseException as ex:
+                with pending_raise(ex):
+                    if self.bloom:
+                        self.bloom.close()
+
         debug1('PackIdxList: using %d index%s.\n'
             % (len(self.packs), len(self.packs)!=1 and 'es' or ''))
 
         debug1('PackIdxList: using %d index%s.\n'
             % (len(self.packs), len(self.packs)!=1 and 'es' or ''))
 
@@ -714,24 +765,24 @@ def create_commit_blob(tree, parent,
     l.append(msg)
     return b'\n'.join(l)
 
     l.append(msg)
     return b'\n'.join(l)
 
-
 def _make_objcache():
     return PackIdxList(repo(b'objects/pack'))
 
 # bup-gc assumes that it can disable all PackWriter activities
 # (bloom/midx/cache) via the constructor and close() arguments.
 
 def _make_objcache():
     return PackIdxList(repo(b'objects/pack'))
 
 # bup-gc assumes that it can disable all PackWriter activities
 # (bloom/midx/cache) via the constructor and close() arguments.
 
-class PackWriter:
+class PackWriter(object):
     """Writes Git objects inside a pack file."""
     def __init__(self, objcache_maker=_make_objcache, compression_level=1,
                  run_midx=True, on_pack_finish=None,
                  max_pack_size=None, max_pack_objects=None, repo_dir=None):
     """Writes Git objects inside a pack file."""
     def __init__(self, objcache_maker=_make_objcache, compression_level=1,
                  run_midx=True, on_pack_finish=None,
                  max_pack_size=None, max_pack_objects=None, repo_dir=None):
+        self.closed = False
         self.repo_dir = repo_dir or repo()
         self.file = None
         self.parentfd = None
         self.count = 0
         self.outbytes = 0
         self.repo_dir = repo_dir or repo()
         self.file = None
         self.parentfd = None
         self.count = 0
         self.outbytes = 0
-        self.filename = None
+        self.tmpdir = None
         self.idx = None
         self.objcache_maker = objcache_maker
         self.objcache = None
         self.idx = None
         self.objcache_maker = objcache_maker
         self.objcache = None
@@ -759,24 +810,15 @@ class PackWriter:
 
     def _open(self):
         if not self.file:
 
     def _open(self):
         if not self.file:
-            objdir = dir = os.path.join(self.repo_dir, b'objects')
-            fd, name = tempfile.mkstemp(suffix=b'.pack', dir=objdir)
-            try:
-                self.file = os.fdopen(fd, 'w+b')
-            except:
-                os.close(fd)
-                raise
-            try:
-                self.parentfd = os.open(objdir, os.O_RDONLY)
-            except:
-                f = self.file
-                self.file = None
-                f.close()
-                raise
-            assert name.endswith(b'.pack')
-            self.filename = name[:-5]
-            self.file.write(b'PACK\0\0\0\2\0\0\0\0')
-            self.idx = PackIdxV2Writer()
+            with ExitStack() as err_stack:
+                objdir = dir = os.path.join(self.repo_dir, b'objects')
+                self.tmpdir = err_stack.enter_context(temp_dir(dir=objdir, prefix=b'pack-tmp-'))
+                self.file = err_stack.enter_context(open(self.tmpdir + b'/pack', 'w+b'))
+                self.parentfd = err_stack.enter_context(finalized(os.open(objdir, os.O_RDONLY),
+                                                                  lambda x: os.close(x)))
+                self.file.write(b'PACK\0\0\0\2\0\0\0\0')
+                self.idx = PackIdxV2Writer()
+                err_stack.pop_all()
 
     def _raw_write(self, datalist, sha):
         self._open()
 
     def _raw_write(self, datalist, sha):
         self._open()
@@ -806,8 +848,7 @@ class PackWriter:
     def _write(self, sha, type, content):
         if verbose:
             log('>')
     def _write(self, sha, type, content):
         if verbose:
             log('>')
-        if not sha:
-            sha = calc_hash(type, content)
+        assert sha
         size, crc = self._raw_write(_encode_packobj(type, content,
                                                     self.compression_level),
                                     sha=sha)
         size, crc = self._raw_write(_encode_packobj(type, content,
                                                     self.compression_level),
                                     sha=sha)
@@ -866,53 +907,54 @@ class PackWriter:
 
     def _end(self, run_midx=True, abort=False):
         # Ignores run_midx during abort
 
     def _end(self, run_midx=True, abort=False):
         # Ignores run_midx during abort
-        if not self.file:
-            return None
+        self.tmpdir, tmpdir = None, self.tmpdir
+        self.parentfd, pfd, = None, self.parentfd
         self.file, f = None, self.file
         self.idx, idx = None, self.idx
         self.file, f = None, self.file
         self.idx, idx = None, self.idx
-        self.parentfd, pfd, = None, self.parentfd
-        self.objcache = None
-
-        with finalized(pfd, lambda x: x is not None and os.close(x)), \
-             f:
-
-            if abort:
-                os.unlink(self.filename + b'.pack')
-                return None
+        try:
+            with nullcontext_if_not(self.objcache), \
+                 finalized(pfd, lambda x: x is not None and os.close(x)), \
+                 nullcontext_if_not(f):
+                if abort or not f:
+                    return None
+
+                # update object count
+                f.seek(8)
+                cp = struct.pack('!i', self.count)
+                assert len(cp) == 4
+                f.write(cp)
+
+                # calculate the pack sha1sum
+                f.seek(0)
+                sum = Sha1()
+                for b in chunkyreader(f):
+                    sum.update(b)
+                packbin = sum.digest()
+                f.write(packbin)
+                f.flush()
+                fdatasync(f.fileno())
+                f.close()
 
 
-            # update object count
-            f.seek(8)
-            cp = struct.pack('!i', self.count)
-            assert len(cp) == 4
-            f.write(cp)
-
-            # calculate the pack sha1sum
-            f.seek(0)
-            sum = Sha1()
-            for b in chunkyreader(f):
-                sum.update(b)
-            packbin = sum.digest()
-            f.write(packbin)
-            f.flush()
-            fdatasync(f.fileno())
-            f.close()
-
-            idx.write(self.filename + b'.idx', packbin)
-            nameprefix = os.path.join(self.repo_dir,
-                                      b'objects/pack/pack-' +  hexlify(packbin))
-            if os.path.exists(self.filename + b'.map'):
-                os.unlink(self.filename + b'.map')
-            os.rename(self.filename + b'.pack', nameprefix + b'.pack')
-            os.rename(self.filename + b'.idx', nameprefix + b'.idx')
-            os.fsync(pfd)
-            if run_midx:
-                auto_midx(os.path.join(self.repo_dir, b'objects/pack'))
-            if self.on_pack_finish:
-                self.on_pack_finish(nameprefix)
-            return nameprefix
+                idx.write(tmpdir + b'/idx', packbin)
+                nameprefix = os.path.join(self.repo_dir,
+                                          b'objects/pack/pack-' +  hexlify(packbin))
+                os.rename(tmpdir + b'/pack', nameprefix + b'.pack')
+                os.rename(tmpdir + b'/idx', nameprefix + b'.idx')
+                os.fsync(pfd)
+                if run_midx:
+                    auto_midx(os.path.join(self.repo_dir, b'objects/pack'))
+                if self.on_pack_finish:
+                    self.on_pack_finish(nameprefix)
+                return nameprefix
+        finally:
+            if tmpdir:
+                rmtree(tmpdir)
+            # Must be last -- some of the code above depends on it
+            self.objcache = None
 
     def abort(self):
         """Remove the pack file from disk."""
 
     def abort(self):
         """Remove the pack file from disk."""
+        self.closed = True
         self._end(abort=True)
 
     def breakpoint(self):
         self._end(abort=True)
 
     def breakpoint(self):
@@ -923,8 +965,12 @@ class PackWriter:
 
     def close(self, run_midx=True):
         """Close the pack file and move it to its definitive path."""
 
     def close(self, run_midx=True):
         """Close the pack file and move it to its definitive path."""
+        self.closed = True
         return self._end(run_midx=run_midx)
 
         return self._end(run_midx=run_midx)
 
+    def __del__(self):
+        assert self.closed
+
 
 class PackIdxV2Writer:
     def __init__(self):
 
 class PackIdxV2Writer:
     def __init__(self):
@@ -1083,28 +1129,37 @@ def rev_parse(committish, repo_dir=None):
         debug2("resolved from ref: commit = %s\n" % hexlify(head))
         return head
 
         debug2("resolved from ref: commit = %s\n" % hexlify(head))
         return head
 
-    pL = PackIdxList(repo(b'objects/pack', repo_dir=repo_dir))
-
     if len(committish) == 40:
         try:
             hash = unhexlify(committish)
         except TypeError:
             return None
 
     if len(committish) == 40:
         try:
             hash = unhexlify(committish)
         except TypeError:
             return None
 
-        if pL.exists(hash):
-            return hash
+        with PackIdxList(repo(b'objects/pack', repo_dir=repo_dir)) as pL:
+            if pL.exists(hash):
+                return hash
 
     return None
 
 
 
     return None
 
 
-def update_ref(refname, newval, oldval, repo_dir=None):
-    """Update a repository reference."""
-    if not oldval:
-        oldval = b''
+def update_ref(refname, newval, oldval, repo_dir=None, force=False):
+    """Update a repository reference.
+
+    With force=True, don't care about the previous ref (oldval);
+    with force=False oldval must be either a sha1 or None (for an
+    entirely new branch)
+    """
+    if force:
+        assert oldval is None
+        oldarg = []
+    elif not oldval:
+        oldarg = [b'']
+    else:
+        oldarg = [hexlify(oldval)]
     assert refname.startswith(b'refs/heads/') \
         or refname.startswith(b'refs/tags/')
     p = subprocess.Popen([b'git', b'update-ref', refname,
     assert refname.startswith(b'refs/heads/') \
         or refname.startswith(b'refs/tags/')
     p = subprocess.Popen([b'git', b'update-ref', refname,
-                          hexlify(newval), hexlify(oldval)],
+                          hexlify(newval)] + oldarg,
                          env=_gitenv(repo_dir),
                          close_fds=True)
     _git_wait(b'git update-ref', p)
                          env=_gitenv(repo_dir),
                          close_fds=True)
     _git_wait(b'git update-ref', p)
@@ -1120,25 +1175,24 @@ def delete_ref(refname, oldvalue=None):
     _git_wait('git update-ref', p)
 
 
     _git_wait('git update-ref', p)
 
 
-def guess_repo(path=None):
-    """Set the path value in the global variable "repodir".
-    This makes bup look for an existing bup repository, but not fail if a
-    repository doesn't exist. Usually, if you are interacting with a bup
-    repository, you would not be calling this function but using
-    check_repo_or_die().
+def guess_repo():
+    """Return the global repodir or BUP_DIR when either is set, or ~/.bup.
+    Usually, if you are interacting with a bup repository, you would
+    not be calling this function but using check_repo_or_die().
+
     """
     """
-    global repodir
-    if path:
-        repodir = path
-    if not repodir:
-        repodir = environ.get(b'BUP_DIR')
-        if not repodir:
-            repodir = os.path.expanduser(b'~/.bup')
+    if repodir:
+        return repodir
+    repo = environ.get(b'BUP_DIR')
+    if not repo:
+        repo = os.path.expanduser(b'~/.bup')
+    return repo
 
 
 def init_repo(path=None):
     """Create the Git bare repository for bup in a given path."""
 
 
 def init_repo(path=None):
     """Create the Git bare repository for bup in a given path."""
-    guess_repo(path)
+    global repodir
+    repodir = path or guess_repo()
     d = repo()  # appends a / to the path
     parent = os.path.dirname(os.path.dirname(d))
     if parent and not os.path.exists(parent):
     d = repo()  # appends a / to the path
     parent = os.path.dirname(os.path.dirname(d))
     if parent and not os.path.exists(parent):
@@ -1163,7 +1217,8 @@ def init_repo(path=None):
 
 def check_repo_or_die(path=None):
     """Check to see if a bup repository probably exists, and abort if not."""
 
 def check_repo_or_die(path=None):
     """Check to see if a bup repository probably exists, and abort if not."""
-    guess_repo(path)
+    global repodir
+    repodir = path or guess_repo()
     top = repo()
     pst = stat_if_exists(top + b'/objects/pack')
     if pst and stat.S_ISDIR(pst.st_mode):
     top = repo()
     pst = stat_if_exists(top + b'/objects/pack')
     if pst and stat.S_ISDIR(pst.st_mode):
@@ -1228,38 +1283,6 @@ def require_suitable_git(ver_str=None):
     assert False
 
 
     assert False
 
 
-class _AbortableIter:
-    def __init__(self, it, onabort = None):
-        self.it = it
-        self.onabort = onabort
-        self.done = None
-
-    def __iter__(self):
-        return self
-
-    def __next__(self):
-        try:
-            return next(self.it)
-        except StopIteration as e:
-            self.done = True
-            raise
-        except:
-            self.abort()
-            raise
-
-    next = __next__
-
-    def abort(self):
-        """Abort iteration and call the abortion callback, if needed."""
-        if not self.done:
-            self.done = True
-            if self.onabort:
-                self.onabort()
-
-    def __del__(self):
-        self.abort()
-
-
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
     def __init__(self, repo_dir = None):
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
     def __init__(self, repo_dir = None):
@@ -1268,12 +1291,15 @@ class CatPipe:
         self.p = self.inprogress = None
 
     def close(self, wait=False):
         self.p = self.inprogress = None
 
     def close(self, wait=False):
-        p = self.p
-        if p:
-            p.stdout.close()
-            p.stdin.close()
-        self.p = None
+        self.p, p = None, self.p
         self.inprogress = None
         self.inprogress = None
+        if p:
+            try:
+                p.stdout.close()
+            finally:
+                # This will handle pending exceptions correctly once
+                # we drop py2
+                p.stdin.close()
         if wait:
             p.wait()
             return p.returncode
         if wait:
             p.wait()
             return p.returncode
@@ -1320,18 +1346,17 @@ class CatPipe:
             raise GitError('expected object (id, type, size), got %r' % info)
         oidx, typ, size = info
         size = int(size)
             raise GitError('expected object (id, type, size), got %r' % info)
         oidx, typ, size = info
         size = int(size)
-        it = _AbortableIter(chunkyreader(self.p.stdout, size),
-                            onabort=self.close)
         try:
         try:
+            it = chunkyreader(self.p.stdout, size)
             yield oidx, typ, size
             yield oidx, typ, size
-            for blob in it:
+            for blob in chunkyreader(self.p.stdout, size):
                 yield blob
             readline_result = self.p.stdout.readline()
             assert readline_result == b'\n'
             self.inprogress = None
                 yield blob
             readline_result = self.p.stdout.readline()
             assert readline_result == b'\n'
             self.inprogress = None
-        except Exception as e:
-            it.abort()
-            raise
+        except Exception as ex:
+            with pending_raise(ex):
+                self.close()
 
     def _join(self, it):
         _, typ, _ = next(it)
 
     def _join(self, it):
         _, typ, _ = next(it)