]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/git.py
git.repo: don't unintentionally set global repodir
[bup.git] / lib / bup / git.py
index 3f4cbaf567d39fc27a3ff039cc36b3994768df01..be9672210ee92e1e04d686b190040a774730c2f7 100644 (file)
@@ -3,18 +3,19 @@ bup repositories are in Git format. This library allows us to
 interact with the Git data structures.
 """
 
 interact with the Git data structures.
 """
 
+from __future__ import absolute_import
 import errno, os, sys, zlib, time, subprocess, struct, stat, re, tempfile, glob
 from collections import namedtuple
 from itertools import islice
 from numbers import Integral
 
 import errno, os, sys, zlib, time, subprocess, struct, stat, re, tempfile, glob
 from collections import namedtuple
 from itertools import islice
 from numbers import Integral
 
-from bup import _helpers, hashsplit, path, midx, bloom, xstat
+from bup import _helpers, compat, hashsplit, path, midx, bloom, xstat
 from bup.helpers import (Sha1, add_error, chunkyreader, debug1, debug2,
                          fdatasync,
                          hostname, localtime, log, merge_iter,
                          mmap_read, mmap_readwrite,
                          parse_num,
 from bup.helpers import (Sha1, add_error, chunkyreader, debug1, debug2,
                          fdatasync,
                          hostname, localtime, log, merge_iter,
                          mmap_read, mmap_readwrite,
                          parse_num,
-                         progress, qprogress, stat_if_exists,
+                         progress, qprogress, shstr, stat_if_exists,
                          unlink, username, userfullname,
                          utc_offset_str)
 
                          unlink, username, userfullname,
                          utc_offset_str)
 
@@ -36,7 +37,7 @@ class GitError(Exception):
 def _git_wait(cmd, p):
     rv = p.wait()
     if rv != 0:
 def _git_wait(cmd, p):
     rv = p.wait()
     if rv != 0:
-        raise GitError('%s returned %d' % (cmd, rv))
+        raise GitError('%s returned %d' % (shstr(cmd), rv))
 
 def _git_capture(argv):
     p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn = _gitenv())
 
 def _git_capture(argv):
     p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn = _gitenv())
@@ -112,7 +113,8 @@ def parse_commit(content):
 
 def get_commit_items(id, cp):
     commit_it = cp.get(id)
 
 def get_commit_items(id, cp):
     commit_it = cp.get(id)
-    assert(next(commit_it) == 'commit')
+    _, typ, _ = next(commit_it)
+    assert(typ == 'commit')
     commit_content = ''.join(commit_it)
     return parse_commit(commit_content)
 
     commit_content = ''.join(commit_it)
     return parse_commit(commit_content)
 
@@ -132,7 +134,6 @@ def _git_date_str(epoch_sec, tz_offset_sec):
 
 def repo(sub = '', repo_dir=None):
     """Get the path to the git repository or one of its subdirectories."""
 
 def repo(sub = '', repo_dir=None):
     """Get the path to the git repository or one of its subdirectories."""
-    global repodir
     repo_dir = repo_dir or repodir
     if not repo_dir:
         raise GitError('You should call check_repo_or_die()')
     repo_dir = repo_dir or repodir
     if not repo_dir:
         raise GitError('You should call check_repo_or_die()')
@@ -140,7 +141,7 @@ def repo(sub = '', repo_dir=None):
     # If there's a .git subdirectory, then the actual repo is in there.
     gd = os.path.join(repo_dir, '.git')
     if os.path.exists(gd):
     # If there's a .git subdirectory, then the actual repo is in there.
     gd = os.path.join(repo_dir, '.git')
     if os.path.exists(gd):
-        repodir = gd
+        repo_dir = gd
 
     return os.path.join(repo_dir, sub)
 
 
     return os.path.join(repo_dir, sub)
 
@@ -551,7 +552,7 @@ class PackIdxList:
             if self.bloom is None and os.path.exists(bfull):
                 self.bloom = bloom.ShaBloom(bfull)
             self.packs = list(set(d.values()))
             if self.bloom is None and os.path.exists(bfull):
                 self.bloom = bloom.ShaBloom(bfull)
             self.packs = list(set(d.values()))
-            self.packs.sort(lambda x,y: -cmp(len(x),len(y)))
+            self.packs.sort(reverse=True, key=lambda x: len(x))
             if self.bloom and self.bloom.valid() and len(self.bloom) >= len(self):
                 self.do_bloom = True
             else:
             if self.bloom and self.bloom.valid() and len(self.bloom) >= len(self):
                 self.do_bloom = True
             else:
@@ -922,34 +923,52 @@ def read_ref(refname, repo_dir = None):
         return None
 
 
         return None
 
 
-def rev_list(ref, count=None, repo_dir=None):
-    """Generate a list of reachable commits in reverse chronological order.
+def rev_list_invocation(ref_or_refs, count=None, format=None):
+    if isinstance(ref_or_refs, compat.str_type):
+        refs = (ref_or_refs,)
+    else:
+        refs = ref_or_refs
+    argv = ['git', 'rev-list']
+    if isinstance(count, Integral):
+        argv.extend(['-n', str(count)])
+    elif count:
+        raise ValueError('unexpected count argument %r' % count)
+
+    if format:
+        argv.append('--pretty=format:' + format)
+    for ref in refs:
+        assert not ref.startswith('-')
+        argv.append(ref)
+    argv.append('--')
+    return argv
+
 
 
-    This generator walks through commits, from child to parent, that are
-    reachable via the specified ref and yields a series of tuples of the form
-    (date,hash).
+def rev_list(ref_or_refs, count=None, parse=None, format=None, repo_dir=None):
+    """Yield information about commits as per "git rev-list".  If a format
+    is not provided, yield one hex hash at a time.  If a format is
+    provided, pass it to rev-list and call parse(git_stdout) for each
+    commit with the stream positioned just after the rev-list "commit
+    HASH" header line.  When a format is provided yield (oidx,
+    parse(git_stdout)) for each commit.
 
 
-    If count is a non-zero integer, limit the number of commits to "count"
-    objects.
     """
     """
-    assert(not ref.startswith('-'))
-    opts = []
-    if isinstance(count, Integral):
-        opts += ['-n', str(count)]
-    else:
-        assert not count
-    argv = ['git', 'rev-list', '--pretty=format:%at'] + opts + [ref, '--']
-    p = subprocess.Popen(argv,
+    assert bool(parse) == bool(format)
+    p = subprocess.Popen(rev_list_invocation(ref_or_refs, count=count,
+                                             format=format),
                          preexec_fn = _gitenv(repo_dir),
                          stdout = subprocess.PIPE)
                          preexec_fn = _gitenv(repo_dir),
                          stdout = subprocess.PIPE)
-    commit = None
-    for row in p.stdout:
-        s = row.strip()
-        if s.startswith('commit '):
-            commit = s[7:].decode('hex')
-        else:
-            date = int(s)
-            yield (date, commit)
+    if not format:
+        for line in p.stdout:
+            yield line.strip()
+    else:
+        line = p.stdout.readline()
+        while line:
+            s = line.strip()
+            if not s.startswith('commit '):
+                raise Exception('unexpected line ' + s)
+            yield s[7:], parse(p.stdout)
+            line = p.stdout.readline()
+
     rv = p.wait()  # not fatal
     if rv:
         raise GitError, 'git rev-list returned error %d' % rv
     rv = p.wait()  # not fatal
     if rv:
         raise GitError, 'git rev-list returned error %d' % rv
@@ -1127,12 +1146,6 @@ class _AbortableIter:
         self.abort()
 
 
         self.abort()
 
 
-class MissingObject(KeyError):
-    def __init__(self, id):
-        self.id = id
-        KeyError.__init__(self, 'object %r is missing' % id.encode('hex'))
-
-
 _ver_warned = 0
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
 _ver_warned = 0
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
@@ -1161,10 +1174,9 @@ class CatPipe:
                                   bufsize = 4096,
                                   preexec_fn = _gitenv(self.repo_dir))
 
                                   bufsize = 4096,
                                   preexec_fn = _gitenv(self.repo_dir))
 
-    def get(self, id, size=False):
-        """Yield the object type, and then an iterator over the data referred
-        to by the id ref.  If size is true, yield (obj_type, obj_size)
-        instead of just the type.
+    def get(self, ref):
+        """Yield (oidx, type, size), followed by the data referred to by ref.
+        If ref does not exist, only yield (None, None, None).
 
         """
         if not self.p or self.p.poll() != None:
 
         """
         if not self.p or self.p.poll() != None:
@@ -1173,30 +1185,28 @@ class CatPipe:
         poll_result = self.p.poll()
         assert(poll_result == None)
         if self.inprogress:
         poll_result = self.p.poll()
         assert(poll_result == None)
         if self.inprogress:
-            log('get: opening %r while %r is open\n' % (id, self.inprogress))
+            log('get: opening %r while %r is open\n' % (ref, self.inprogress))
         assert(not self.inprogress)
         assert(not self.inprogress)
-        assert(id.find('\n') < 0)
-        assert(id.find('\r') < 0)
-        assert(not id.startswith('-'))
-        self.inprogress = id
-        self.p.stdin.write('%s\n' % id)
+        assert(ref.find('\n') < 0)
+        assert(ref.find('\r') < 0)
+        assert(not ref.startswith('-'))
+        self.inprogress = ref
+        self.p.stdin.write('%s\n' % ref)
         self.p.stdin.flush()
         hdr = self.p.stdout.readline()
         if hdr.endswith(' missing\n'):
             self.inprogress = None
         self.p.stdin.flush()
         hdr = self.p.stdout.readline()
         if hdr.endswith(' missing\n'):
             self.inprogress = None
-            raise MissingObject(id.decode('hex'))
-        spl = hdr.split(' ')
-        if len(spl) != 3 or len(spl[0]) != 40:
-            raise GitError('expected blob, got %r' % spl)
-        hex, typ, sz = spl
-        sz = int(sz)
-        it = _AbortableIter(chunkyreader(self.p.stdout, sz),
+            yield None, None, None
+            return
+        info = hdr.split(' ')
+        if len(info) != 3 or len(info[0]) != 40:
+            raise GitError('expected object (id, type, size), got %r' % spl)
+        oidx, typ, size = info
+        size = int(size)
+        it = _AbortableIter(chunkyreader(self.p.stdout, size),
                             onabort=self._abort)
         try:
                             onabort=self._abort)
         try:
-            if size:
-                yield typ, sz
-            else:
-                yield typ
+            yield oidx, typ, size
             for blob in it:
                 yield blob
             readline_result = self.p.stdout.readline()
             for blob in it:
                 yield blob
             readline_result = self.p.stdout.readline()
@@ -1207,23 +1217,23 @@ class CatPipe:
             raise
 
     def _join(self, it):
             raise
 
     def _join(self, it):
-        type = next(it)
-        if type == 'blob':
+        _, typ, _ = next(it)
+        if typ == 'blob':
             for blob in it:
                 yield blob
             for blob in it:
                 yield blob
-        elif type == 'tree':
+        elif typ == 'tree':
             treefile = ''.join(it)
             for (mode, name, sha) in tree_decode(treefile):
                 for blob in self.join(sha.encode('hex')):
                     yield blob
             treefile = ''.join(it)
             for (mode, name, sha) in tree_decode(treefile):
                 for blob in self.join(sha.encode('hex')):
                     yield blob
-        elif type == 'commit':
+        elif typ == 'commit':
             treeline = ''.join(it).split('\n')[0]
             assert(treeline.startswith('tree '))
             for blob in self.join(treeline[5:]):
                 yield blob
         else:
             raise GitError('invalid object type %r: expected blob/tree/commit'
             treeline = ''.join(it).split('\n')[0]
             assert(treeline.startswith('tree '))
             for blob in self.join(treeline[5:]):
                 yield blob
         else:
             raise GitError('invalid object type %r: expected blob/tree/commit'
-                           % type)
+                           % typ)
 
     def join(self, id):
         """Generate a list of the content of all blobs that can be reached
 
     def join(self, id):
         """Generate a list of the content of all blobs that can be reached
@@ -1265,7 +1275,13 @@ def tags(repo_dir = None):
     return tags
 
 
     return tags
 
 
-WalkItem = namedtuple('WalkItem', ['id', 'type', 'mode',
+class MissingObject(KeyError):
+    def __init__(self, oid):
+        self.oid = oid
+        KeyError.__init__(self, 'object %r is missing' % oid.encode('hex'))
+
+
+WalkItem = namedtuple('WalkItem', ['oid', 'type', 'mode',
                                    'path', 'chunk_path', 'data'])
 # The path is the mangled path, and if an item represents a fragment
 # of a chunked file, the chunk_path will be the chunked subtree path
                                    'path', 'chunk_path', 'data'])
 # The path is the mangled path, and if an item represents a fragment
 # of a chunked file, the chunk_path will be the chunked subtree path
@@ -1279,39 +1295,42 @@ WalkItem = namedtuple('WalkItem', ['id', 'type', 'mode',
 #   ...
 
 
 #   ...
 
 
-def walk_object(cat_pipe, id,
+def walk_object(cat_pipe, oidx,
                 stop_at=None,
                 include_data=None):
                 stop_at=None,
                 include_data=None):
-    """Yield everything reachable from id via cat_pipe as a WalkItem,
-    stopping whenever stop_at(id) returns true.  Throw MissingObject
+    """Yield everything reachable from oidx via cat_pipe as a WalkItem,
+    stopping whenever stop_at(oidx) returns true.  Throw MissingObject
     if a hash encountered is missing from the repository, and don't
     read or return blob content in the data field unless include_data
     is set.
     """
     # Maintain the pending stack on the heap to avoid stack overflow
     if a hash encountered is missing from the repository, and don't
     read or return blob content in the data field unless include_data
     is set.
     """
     # Maintain the pending stack on the heap to avoid stack overflow
-    pending = [(id, [], [], None)]
+    pending = [(oidx, [], [], None)]
     while len(pending):
     while len(pending):
-        id, parent_path, chunk_path, mode = pending.pop()
-        if stop_at and stop_at(id):
+        oidx, parent_path, chunk_path, mode = pending.pop()
+        oid = oidx.decode('hex')
+        if stop_at and stop_at(oidx):
             continue
 
         if (not include_data) and mode and stat.S_ISREG(mode):
             # If the object is a "regular file", then it's a leaf in
             # the graph, so we can skip reading the data if the caller
             # hasn't requested it.
             continue
 
         if (not include_data) and mode and stat.S_ISREG(mode):
             # If the object is a "regular file", then it's a leaf in
             # the graph, so we can skip reading the data if the caller
             # hasn't requested it.
-            yield WalkItem(id=id, type='blob',
+            yield WalkItem(oid=oid, type='blob',
                            chunk_path=chunk_path, path=parent_path,
                            mode=mode,
                            data=None)
             continue
 
                            chunk_path=chunk_path, path=parent_path,
                            mode=mode,
                            data=None)
             continue
 
-        item_it = cat_pipe.get(id)
-        type = next(item_it)
-        if type not in ('blob', 'commit', 'tree'):
-            raise Exception('unexpected repository object type %r' % type)
+        item_it = cat_pipe.get(oidx)
+        get_oidx, typ, _ = next(item_it)
+        if not get_oidx:
+            raise MissingObject(oidx.decode('hex'))
+        if typ not in ('blob', 'commit', 'tree'):
+            raise Exception('unexpected repository object type %r' % typ)
 
         # FIXME: set the mode based on the type when the mode is None
 
         # FIXME: set the mode based on the type when the mode is None
-        if type == 'blob' and not include_data:
+        if typ == 'blob' and not include_data:
             # Dump data until we can ask cat_pipe not to fetch it
             for ignored in item_it:
                 pass
             # Dump data until we can ask cat_pipe not to fetch it
             for ignored in item_it:
                 pass
@@ -1319,18 +1338,18 @@ def walk_object(cat_pipe, id,
         else:
             data = ''.join(item_it)
 
         else:
             data = ''.join(item_it)
 
-        yield WalkItem(id=id, type=type,
+        yield WalkItem(oid=oid, type=typ,
                        chunk_path=chunk_path, path=parent_path,
                        mode=mode,
                        data=(data if include_data else None))
 
                        chunk_path=chunk_path, path=parent_path,
                        mode=mode,
                        data=(data if include_data else None))
 
-        if type == 'commit':
+        if typ == 'commit':
             commit_items = parse_commit(data)
             for pid in commit_items.parents:
                 pending.append((pid, parent_path, chunk_path, mode))
             pending.append((commit_items.tree, parent_path, chunk_path,
                             hashsplit.GIT_MODE_TREE))
             commit_items = parse_commit(data)
             for pid in commit_items.parents:
                 pending.append((pid, parent_path, chunk_path, mode))
             pending.append((commit_items.tree, parent_path, chunk_path,
                             hashsplit.GIT_MODE_TREE))
-        elif type == 'tree':
+        elif typ == 'tree':
             for mode, name, ent_id in tree_decode(data):
                 demangled, bup_type = demangle_name(name, mode)
                 if chunk_path:
             for mode, name, ent_id in tree_decode(data):
                 demangled, bup_type = demangle_name(name, mode)
                 if chunk_path: