]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/git.py
Make PackWriter a "with" context manager
[bup.git] / lib / bup / git.py
index 950a2d4ebb41e0399851d8b71ee49d6e9c301d5b..35bf838fe42ca875c8d707013771e74c8b29d946 100644 (file)
@@ -3,17 +3,19 @@ bup repositories are in Git format. This library allows us to
 interact with the Git data structures.
 """
 
+from __future__ import absolute_import
 import errno, os, sys, zlib, time, subprocess, struct, stat, re, tempfile, glob
 from collections import namedtuple
 from itertools import islice
+from numbers import Integral
 
-from bup import _helpers, hashsplit, path, midx, bloom, xstat
+from bup import _helpers, compat, hashsplit, path, midx, bloom, xstat
 from bup.helpers import (Sha1, add_error, chunkyreader, debug1, debug2,
                          fdatasync,
                          hostname, localtime, log, merge_iter,
                          mmap_read, mmap_readwrite,
                          parse_num,
-                         progress, qprogress, stat_if_exists,
+                         progress, qprogress, shstr, stat_if_exists,
                          unlink, username, userfullname,
                          utc_offset_str)
 
@@ -35,7 +37,7 @@ class GitError(Exception):
 def _git_wait(cmd, p):
     rv = p.wait()
     if rv != 0:
-        raise GitError('%s returned %d' % (cmd, rv))
+        raise GitError('%s returned %d' % (shstr(cmd), rv))
 
 def _git_capture(argv):
     p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn = _gitenv())
@@ -111,7 +113,8 @@ def parse_commit(content):
 
 def get_commit_items(id, cp):
     commit_it = cp.get(id)
-    assert(commit_it.next() == 'commit')
+    _, typ, _ = next(commit_it)
+    assert(typ == 'commit')
     commit_content = ''.join(commit_it)
     return parse_commit(commit_content)
 
@@ -131,7 +134,6 @@ def _git_date_str(epoch_sec, tz_offset_sec):
 
 def repo(sub = '', repo_dir=None):
     """Get the path to the git repository or one of its subdirectories."""
-    global repodir
     repo_dir = repo_dir or repodir
     if not repo_dir:
         raise GitError('You should call check_repo_or_die()')
@@ -139,7 +141,7 @@ def repo(sub = '', repo_dir=None):
     # If there's a .git subdirectory, then the actual repo is in there.
     gd = os.path.join(repo_dir, '.git')
     if os.path.exists(gd):
-        repodir = gd
+        repo_dir = gd
 
     return os.path.join(repo_dir, sub)
 
@@ -550,7 +552,7 @@ class PackIdxList:
             if self.bloom is None and os.path.exists(bfull):
                 self.bloom = bloom.ShaBloom(bfull)
             self.packs = list(set(d.values()))
-            self.packs.sort(lambda x,y: -cmp(len(x),len(y)))
+            self.packs.sort(reverse=True, key=lambda x: len(x))
             if self.bloom and self.bloom.valid() and len(self.bloom) >= len(self):
                 self.do_bloom = True
             else:
@@ -606,8 +608,8 @@ class PackWriter:
     """Writes Git objects inside a pack file."""
     def __init__(self, objcache_maker=_make_objcache, compression_level=1,
                  run_midx=True, on_pack_finish=None,
-                 max_pack_size=None, max_pack_objects=None):
-        self.repo_dir = repo()
+                 max_pack_size=None, max_pack_objects=None, repo_dir=None):
+        self.repo_dir = repo_dir or repo()
         self.file = None
         self.parentfd = None
         self.count = 0
@@ -635,6 +637,12 @@ class PackWriter:
     def __del__(self):
         self.close()
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type, value, traceback):
+        self.close()
+
     def _open(self):
         if not self.file:
             objdir = dir = os.path.join(self.repo_dir, 'objects')
@@ -880,13 +888,13 @@ def _gitenv(repo_dir = None):
     return env
 
 
-def list_refs(refnames=None, repo_dir=None,
+def list_refs(patterns=None, repo_dir=None,
               limit_to_heads=False, limit_to_tags=False):
     """Yield (refname, hash) tuples for all repository refs unless
-    refnames are specified.  In that case, only include tuples for
-    those refs.  The limits restrict the result items to refs/heads or
-    refs/tags.  If both limits are specified, items from both sources
-    will be included.
+    patterns are specified.  In that case, only include tuples for
+    refs matching those patterns (cf. git-show-ref(1)).  The limits
+    restrict the result items to refs/heads or refs/tags.  If both
+    limits are specified, items from both sources will be included.
 
     """
     argv = ['git', 'show-ref']
@@ -895,8 +903,8 @@ def list_refs(refnames=None, repo_dir=None,
     if limit_to_tags:
         argv.append('--tags')
     argv.append('--')
-    if refnames:
-        argv += refnames
+    if patterns:
+        argv.extend(patterns)
     p = subprocess.Popen(argv,
                          preexec_fn = _gitenv(repo_dir),
                          stdout = subprocess.PIPE)
@@ -912,7 +920,7 @@ def list_refs(refnames=None, repo_dir=None,
 
 def read_ref(refname, repo_dir = None):
     """Get the commit id of the most recent commit made on a given ref."""
-    refs = list_refs(refnames=[refname], repo_dir=repo_dir, limit_to_heads=True)
+    refs = list_refs(patterns=[refname], repo_dir=repo_dir, limit_to_heads=True)
     l = tuple(islice(refs, 2))
     if l:
         assert(len(l) == 1)
@@ -921,32 +929,52 @@ def read_ref(refname, repo_dir = None):
         return None
 
 
-def rev_list(ref, count=None, repo_dir=None):
-    """Generate a list of reachable commits in reverse chronological order.
+def rev_list_invocation(ref_or_refs, count=None, format=None):
+    if isinstance(ref_or_refs, compat.str_type):
+        refs = (ref_or_refs,)
+    else:
+        refs = ref_or_refs
+    argv = ['git', 'rev-list']
+    if isinstance(count, Integral):
+        argv.extend(['-n', str(count)])
+    elif count:
+        raise ValueError('unexpected count argument %r' % count)
+
+    if format:
+        argv.append('--pretty=format:' + format)
+    for ref in refs:
+        assert not ref.startswith('-')
+        argv.append(ref)
+    argv.append('--')
+    return argv
+
 
-    This generator walks through commits, from child to parent, that are
-    reachable via the specified ref and yields a series of tuples of the form
-    (date,hash).
+def rev_list(ref_or_refs, count=None, parse=None, format=None, repo_dir=None):
+    """Yield information about commits as per "git rev-list".  If a format
+    is not provided, yield one hex hash at a time.  If a format is
+    provided, pass it to rev-list and call parse(git_stdout) for each
+    commit with the stream positioned just after the rev-list "commit
+    HASH" header line.  When a format is provided yield (oidx,
+    parse(git_stdout)) for each commit.
 
-    If count is a non-zero integer, limit the number of commits to "count"
-    objects.
     """
-    assert(not ref.startswith('-'))
-    opts = []
-    if count:
-        opts += ['-n', str(atoi(count))]
-    argv = ['git', 'rev-list', '--pretty=format:%at'] + opts + [ref, '--']
-    p = subprocess.Popen(argv,
+    assert bool(parse) == bool(format)
+    p = subprocess.Popen(rev_list_invocation(ref_or_refs, count=count,
+                                             format=format),
                          preexec_fn = _gitenv(repo_dir),
                          stdout = subprocess.PIPE)
-    commit = None
-    for row in p.stdout:
-        s = row.strip()
-        if s.startswith('commit '):
-            commit = s[7:].decode('hex')
-        else:
-            date = int(s)
-            yield (date, commit)
+    if not format:
+        for line in p.stdout:
+            yield line.strip()
+    else:
+        line = p.stdout.readline()
+        while line:
+            s = line.strip()
+            if not s.startswith('commit '):
+                raise Exception('unexpected line ' + s)
+            yield s[7:], parse(p.stdout)
+            line = p.stdout.readline()
+
     rv = p.wait()  # not fatal
     if rv:
         raise GitError, 'git rev-list returned error %d' % rv
@@ -1105,7 +1133,7 @@ class _AbortableIter:
 
     def next(self):
         try:
-            return self.it.next()
+            return next(self.it)
         except StopIteration as e:
             self.done = True
             raise
@@ -1124,12 +1152,6 @@ class _AbortableIter:
         self.abort()
 
 
-class MissingObject(KeyError):
-    def __init__(self, id):
-        self.id = id
-        KeyError.__init__(self, 'object %r is missing' % id.encode('hex'))
-
-
 _ver_warned = 0
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
@@ -1138,14 +1160,9 @@ class CatPipe:
         self.repo_dir = repo_dir
         wanted = ('1','5','6')
         if ver() < wanted:
-            if not _ver_warned:
-                log('warning: git version < %s; bup will be slow.\n'
-                    % '.'.join(wanted))
-                _ver_warned = 1
-            self.get = self._slow_get
-        else:
-            self.p = self.inprogress = None
-            self.get = self._fast_get
+            log('error: git version must be at least 1.5.6\n')
+            sys.exit(1)
+        self.p = self.inprogress = None
 
     def _abort(self):
         if self.p:
@@ -1163,35 +1180,39 @@ class CatPipe:
                                   bufsize = 4096,
                                   preexec_fn = _gitenv(self.repo_dir))
 
-    def _fast_get(self, id):
+    def get(self, ref):
+        """Yield (oidx, type, size), followed by the data referred to by ref.
+        If ref does not exist, only yield (None, None, None).
+
+        """
         if not self.p or self.p.poll() != None:
             self.restart()
         assert(self.p)
         poll_result = self.p.poll()
         assert(poll_result == None)
         if self.inprogress:
-            log('_fast_get: opening %r while %r is open\n'
-                % (id, self.inprogress))
+            log('get: opening %r while %r is open\n' % (ref, self.inprogress))
         assert(not self.inprogress)
-        assert(id.find('\n') < 0)
-        assert(id.find('\r') < 0)
-        assert(not id.startswith('-'))
-        self.inprogress = id
-        self.p.stdin.write('%s\n' % id)
+        assert(ref.find('\n') < 0)
+        assert(ref.find('\r') < 0)
+        assert(not ref.startswith('-'))
+        self.inprogress = ref
+        self.p.stdin.write('%s\n' % ref)
         self.p.stdin.flush()
         hdr = self.p.stdout.readline()
         if hdr.endswith(' missing\n'):
             self.inprogress = None
-            raise MissingObject(id.decode('hex'))
-        spl = hdr.split(' ')
-        if len(spl) != 3 or len(spl[0]) != 40:
-            raise GitError('expected blob, got %r' % spl)
-        (hex, type, size) = spl
-
-        it = _AbortableIter(chunkyreader(self.p.stdout, int(spl[2])),
-                           onabort = self._abort)
+            yield None, None, None
+            return
+        info = hdr.split(' ')
+        if len(info) != 3 or len(info[0]) != 40:
+            raise GitError('expected object (id, type, size), got %r' % spl)
+        oidx, typ, size = info
+        size = int(size)
+        it = _AbortableIter(chunkyreader(self.p.stdout, size),
+                            onabort=self._abort)
         try:
-            yield type
+            yield oidx, typ, size
             for blob in it:
                 yield blob
             readline_result = self.p.stdout.readline()
@@ -1201,38 +1222,24 @@ class CatPipe:
             it.abort()
             raise
 
-    def _slow_get(self, id):
-        assert(id.find('\n') < 0)
-        assert(id.find('\r') < 0)
-        assert(id[0] != '-')
-        type = _git_capture(['git', 'cat-file', '-t', id]).strip()
-        yield type
-
-        p = subprocess.Popen(['git', 'cat-file', type, id],
-                             stdout=subprocess.PIPE,
-                             preexec_fn = _gitenv(self.repo_dir))
-        for blob in chunkyreader(p.stdout):
-            yield blob
-        _git_wait('git cat-file', p)
-
     def _join(self, it):
-        type = it.next()
-        if type == 'blob':
+        _, typ, _ = next(it)
+        if typ == 'blob':
             for blob in it:
                 yield blob
-        elif type == 'tree':
+        elif typ == 'tree':
             treefile = ''.join(it)
             for (mode, name, sha) in tree_decode(treefile):
                 for blob in self.join(sha.encode('hex')):
                     yield blob
-        elif type == 'commit':
+        elif typ == 'commit':
             treeline = ''.join(it).split('\n')[0]
             assert(treeline.startswith('tree '))
             for blob in self.join(treeline[5:]):
                 yield blob
         else:
             raise GitError('invalid object type %r: expected blob/tree/commit'
-                           % type)
+                           % typ)
 
     def join(self, id):
         """Generate a list of the content of all blobs that can be reached
@@ -1274,7 +1281,13 @@ def tags(repo_dir = None):
     return tags
 
 
-WalkItem = namedtuple('WalkItem', ['id', 'type', 'mode',
+class MissingObject(KeyError):
+    def __init__(self, oid):
+        self.oid = oid
+        KeyError.__init__(self, 'object %r is missing' % oid.encode('hex'))
+
+
+WalkItem = namedtuple('WalkItem', ['oid', 'type', 'mode',
                                    'path', 'chunk_path', 'data'])
 # The path is the mangled path, and if an item represents a fragment
 # of a chunked file, the chunk_path will be the chunked subtree path
@@ -1288,39 +1301,42 @@ WalkItem = namedtuple('WalkItem', ['id', 'type', 'mode',
 #   ...
 
 
-def walk_object(cat_pipe, id,
+def walk_object(cat_pipe, oidx,
                 stop_at=None,
                 include_data=None):
-    """Yield everything reachable from id via cat_pipe as a WalkItem,
-    stopping whenever stop_at(id) returns true.  Throw MissingObject
+    """Yield everything reachable from oidx via cat_pipe as a WalkItem,
+    stopping whenever stop_at(oidx) returns true.  Throw MissingObject
     if a hash encountered is missing from the repository, and don't
     read or return blob content in the data field unless include_data
     is set.
     """
     # Maintain the pending stack on the heap to avoid stack overflow
-    pending = [(id, [], [], None)]
+    pending = [(oidx, [], [], None)]
     while len(pending):
-        id, parent_path, chunk_path, mode = pending.pop()
-        if stop_at and stop_at(id):
+        oidx, parent_path, chunk_path, mode = pending.pop()
+        oid = oidx.decode('hex')
+        if stop_at and stop_at(oidx):
             continue
 
         if (not include_data) and mode and stat.S_ISREG(mode):
             # If the object is a "regular file", then it's a leaf in
             # the graph, so we can skip reading the data if the caller
             # hasn't requested it.
-            yield WalkItem(id=id, type='blob',
+            yield WalkItem(oid=oid, type='blob',
                            chunk_path=chunk_path, path=parent_path,
                            mode=mode,
                            data=None)
             continue
 
-        item_it = cat_pipe.get(id)
-        type = item_it.next()
-        if type not in ('blob', 'commit', 'tree'):
-            raise Exception('unexpected repository object type %r' % type)
+        item_it = cat_pipe.get(oidx)
+        get_oidx, typ, _ = next(item_it)
+        if not get_oidx:
+            raise MissingObject(oidx.decode('hex'))
+        if typ not in ('blob', 'commit', 'tree'):
+            raise Exception('unexpected repository object type %r' % typ)
 
         # FIXME: set the mode based on the type when the mode is None
-        if type == 'blob' and not include_data:
+        if typ == 'blob' and not include_data:
             # Dump data until we can ask cat_pipe not to fetch it
             for ignored in item_it:
                 pass
@@ -1328,18 +1344,18 @@ def walk_object(cat_pipe, id,
         else:
             data = ''.join(item_it)
 
-        yield WalkItem(id=id, type=type,
+        yield WalkItem(oid=oid, type=typ,
                        chunk_path=chunk_path, path=parent_path,
                        mode=mode,
                        data=(data if include_data else None))
 
-        if type == 'commit':
+        if typ == 'commit':
             commit_items = parse_commit(data)
             for pid in commit_items.parents:
                 pending.append((pid, parent_path, chunk_path, mode))
             pending.append((commit_items.tree, parent_path, chunk_path,
                             hashsplit.GIT_MODE_TREE))
-        elif type == 'tree':
+        elif typ == 'tree':
             for mode, name, ent_id in tree_decode(data):
                 demangled, bup_type = demangle_name(name, mode)
                 if chunk_path: