]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/git.py
git.py: allow the specification of a repo_dir to update_ref()
[bup.git] / lib / bup / git.py
index 2a84eeb6f256d9e0d0fbc9bf409a1dfbf5571521..4825a19820f813227304131ac79374d899000fb2 100644 (file)
@@ -3,16 +3,16 @@ bup repositories are in Git format. This library allows us to
 interact with the Git data structures.
 """
 import os, sys, zlib, time, subprocess, struct, stat, re, tempfile, glob
 interact with the Git data structures.
 """
 import os, sys, zlib, time, subprocess, struct, stat, re, tempfile, glob
+from collections import namedtuple
+
 from bup.helpers import *
 from bup.helpers import *
-from bup import _helpers, path, midx, bloom
+from bup import _helpers, path, midx, bloom, xstat
 
 max_pack_size = 1000*1000*1000  # larger packs will slow down pruning
 max_pack_objects = 200*1000  # cache memory usage is about 83 bytes per object
 
 max_pack_size = 1000*1000*1000  # larger packs will slow down pruning
 max_pack_objects = 200*1000  # cache memory usage is about 83 bytes per object
-SEEK_END=2  # os.SEEK_END is not defined in python 2.4
 
 verbose = 0
 ignore_midx = 0
 
 verbose = 0
 ignore_midx = 0
-home_repodir = os.path.expanduser('~/.bup')
 repodir = None
 
 _typemap =  { 'blob':3, 'tree':2, 'commit':1, 'tag':4 }
 repodir = None
 
 _typemap =  { 'blob':3, 'tree':2, 'commit':1, 'tag':4 }
@@ -26,18 +26,79 @@ class GitError(Exception):
     pass
 
 
     pass
 
 
-def repo(sub = ''):
+def parse_tz_offset(s):
+    """UTC offset in seconds."""
+    tz_off = (int(s[1:3]) * 60 * 60) + (int(s[3:5]) * 60)
+    if s[0] == '-':
+        return - tz_off
+    return tz_off
+
+
+# FIXME: derived from http://git.rsbx.net/Documents/Git_Data_Formats.txt
+# Make sure that's authoritative.
+_start_end_char = r'[^ .,:;<>"\'\0\n]'
+_content_char = r'[^\0\n<>]'
+_safe_str_rx = '(?:%s{1,2}|(?:%s%s*%s))' \
+    % (_start_end_char,
+       _start_end_char, _content_char, _start_end_char)
+_tz_rx = r'[-+]\d\d[0-5]\d'
+_parent_rx = r'(?:parent [abcdefABCDEF0123456789]{40}\n)'
+_commit_rx = re.compile(r'''tree (?P<tree>[abcdefABCDEF0123456789]{40})
+(?P<parents>%s*)author (?P<author_name>%s) <(?P<author_mail>%s)> (?P<asec>\d+) (?P<atz>%s)
+committer (?P<committer_name>%s) <(?P<committer_mail>%s)> (?P<csec>\d+) (?P<ctz>%s)
+
+(?P<message>(?:.|\n)*)''' % (_parent_rx,
+                             _safe_str_rx, _safe_str_rx, _tz_rx,
+                             _safe_str_rx, _safe_str_rx, _tz_rx))
+_parent_hash_rx = re.compile(r'\s*parent ([abcdefABCDEF0123456789]{40})\s*')
+
+
+# Note that the author_sec and committer_sec values are (UTC) epoch seconds.
+CommitInfo = namedtuple('CommitInfo', ['tree', 'parents',
+                                       'author_name', 'author_mail',
+                                       'author_sec', 'author_offset',
+                                       'committer_name', 'committer_mail',
+                                       'committer_sec', 'committer_offset',
+                                       'message'])
+
+def parse_commit(content):
+    commit_match = re.match(_commit_rx, content)
+    if not commit_match:
+        raise Exception('cannot parse commit %r' % content)
+    matches = commit_match.groupdict()
+    return CommitInfo(tree=matches['tree'],
+                      parents=re.findall(_parent_hash_rx, matches['parents']),
+                      author_name=matches['author_name'],
+                      author_mail=matches['author_mail'],
+                      author_sec=int(matches['asec']),
+                      author_offset=parse_tz_offset(matches['atz']),
+                      committer_name=matches['committer_name'],
+                      committer_mail=matches['committer_mail'],
+                      committer_sec=int(matches['csec']),
+                      committer_offset=parse_tz_offset(matches['ctz']),
+                      message=matches['message'])
+
+
+def get_commit_items(id, cp):
+    commit_it = cp.get(id)
+    assert(commit_it.next() == 'commit')
+    commit_content = ''.join(commit_it)
+    return parse_commit(commit_content)
+
+
+def repo(sub = '', repo_dir=None):
     """Get the path to the git repository or one of its subdirectories."""
     global repodir
     """Get the path to the git repository or one of its subdirectories."""
     global repodir
-    if not repodir:
+    repo_dir = repo_dir or repodir
+    if not repo_dir:
         raise GitError('You should call check_repo_or_die()')
 
     # If there's a .git subdirectory, then the actual repo is in there.
         raise GitError('You should call check_repo_or_die()')
 
     # If there's a .git subdirectory, then the actual repo is in there.
-    gd = os.path.join(repodir, '.git')
+    gd = os.path.join(repo_dir, '.git')
     if os.path.exists(gd):
         repodir = gd
 
     if os.path.exists(gd):
         repodir = gd
 
-    return os.path.join(repodir, sub)
+    return os.path.join(repo_dir, sub)
 
 
 def shorten_hash(s):
 
 
 def shorten_hash(s):
@@ -89,9 +150,10 @@ def mangle_name(name, mode, gitmode):
     """Mangle a file name to present an abstract name for segmented files.
     Mangled file names will have the ".bup" extension added to them. If a
     file's name already ends with ".bup", a ".bupl" extension is added to
     """Mangle a file name to present an abstract name for segmented files.
     Mangled file names will have the ".bup" extension added to them. If a
     file's name already ends with ".bup", a ".bupl" extension is added to
-    disambiguate normal files from semgmented ones.
+    disambiguate normal files from segmented ones.
     """
     if stat.S_ISREG(mode) and not stat.S_ISREG(gitmode):
     """
     if stat.S_ISREG(mode) and not stat.S_ISREG(gitmode):
+        assert(stat.S_ISDIR(gitmode))
         return name + '.bup'
     elif name.endswith('.bup') or name[:-1].endswith('.bup'):
         return name + '.bupl'
         return name + '.bup'
     elif name.endswith('.bup') or name[:-1].endswith('.bup'):
         return name + '.bupl'
@@ -107,9 +169,9 @@ def demangle_name(name):
     the following:
 
     * BUP_NORMAL  : files that should be read as-is from the repository
     the following:
 
     * BUP_NORMAL  : files that should be read as-is from the repository
-    * BUP_CHUNKED : files that were chunked and need to be assembled
+    * BUP_CHUNKED : files that were chunked and need to be reassembled
 
 
-    For more information on the name mangling algorythm, see mangle_name()
+    For more information on the name mangling algorithm, see mangle_name()
     """
     if name.endswith('.bupl'):
         return (name[:-5], BUP_NORMAL)
     """
     if name.endswith('.bupl'):
         return (name[:-5], BUP_NORMAL)
@@ -127,7 +189,7 @@ def calc_hash(type, content):
     return sum.digest()
 
 
     return sum.digest()
 
 
-def _shalist_sort_key(ent):
+def shalist_item_sort_key(ent):
     (mode, name, id) = ent
     assert(mode+0 == mode)
     if stat.S_ISDIR(mode):
     (mode, name, id) = ent
     assert(mode+0 == mode)
     if stat.S_ISDIR(mode):
@@ -138,7 +200,7 @@ def _shalist_sort_key(ent):
 
 def tree_encode(shalist):
     """Generate a git tree object from (mode,name,hash) tuples."""
 
 def tree_encode(shalist):
     """Generate a git tree object from (mode,name,hash) tuples."""
-    shalist = sorted(shalist, key = _shalist_sort_key)
+    shalist = sorted(shalist, key = shalist_item_sort_key)
     l = []
     for (mode,name,bin) in shalist:
         assert(mode)
     l = []
     for (mode,name,bin) in shalist:
         assert(mode)
@@ -155,13 +217,13 @@ def tree_decode(buf):
     """Generate a list of (mode,name,hash) from the git tree object in buf."""
     ofs = 0
     while ofs < len(buf):
     """Generate a list of (mode,name,hash) from the git tree object in buf."""
     ofs = 0
     while ofs < len(buf):
-        z = buf[ofs:].find('\0')
-        assert(z > 0)
-        spl = buf[ofs:ofs+z].split(' ', 1)
+        z = buf.find('\0', ofs)
+        assert(z > ofs)
+        spl = buf[ofs:z].split(' ', 1)
         assert(len(spl) == 2)
         mode,name = spl
         assert(len(spl) == 2)
         mode,name = spl
-        sha = buf[ofs+z+1:ofs+z+1+20]
-        ofs += z+1+20
+        sha = buf[z+1:z+1+20]
+        ofs = z+1+20
         yield (int(mode, 8), name, sha)
 
 
         yield (int(mode, 8), name, sha)
 
 
@@ -177,6 +239,10 @@ def _encode_packobj(type, content, compression_level=1):
             break
         szbits = sz & 0x7f
         sz >>= 7
             break
         szbits = sz & 0x7f
         sz >>= 7
+    if compression_level > 9:
+        compression_level = 9
+    elif compression_level < 0:
+        compression_level = 0
     z = zlib.compressobj(compression_level)
     yield szout
     yield z.compress(content)
     z = zlib.compressobj(compression_level)
     yield szout
     yield z.compress(content)
@@ -404,12 +470,13 @@ class PackIdxList:
                                     '  used by %s\n') % (n, mxf))
                                 broken = True
                         if broken:
                                     '  used by %s\n') % (n, mxf))
                                 broken = True
                         if broken:
+                            mx.close()
                             del mx
                             unlink(full)
                         else:
                             midxl.append(mx)
                 midxl.sort(key=lambda ix:
                             del mx
                             unlink(full)
                         else:
                             midxl.append(mx)
                 midxl.sort(key=lambda ix:
-                           (-len(ix), -os.stat(ix.name).st_mtime))
+                           (-len(ix), -xstat.stat(ix.name).st_mtime))
                 for ix in midxl:
                     any_needed = False
                     for sub in ix.idxnames:
                 for ix in midxl:
                     any_needed = False
                     for sub in ix.idxnames:
@@ -425,6 +492,7 @@ class PackIdxList:
                     elif not ix.force_keep:
                         debug1('midx: removing redundant: %s\n'
                                % os.path.basename(ix.name))
                     elif not ix.force_keep:
                         debug1('midx: removing redundant: %s\n'
                                % os.path.basename(ix.name))
+                        ix.close()
                         unlink(ix.name)
             for full in glob.glob(os.path.join(self.dir,'*.idx')):
                 if not d.get(full):
                         unlink(ix.name)
             for full in glob.glob(os.path.join(self.dir,'*.idx')):
                 if not d.get(full):
@@ -651,54 +719,69 @@ class PackWriter:
         return self._end(run_midx=run_midx)
 
     def _write_pack_idx_v2(self, filename, idx, packbin):
         return self._end(run_midx=run_midx)
 
     def _write_pack_idx_v2(self, filename, idx, packbin):
+        ofs64_count = 0
+        for section in idx:
+            for entry in section:
+                if entry[2] >= 2**31:
+                    ofs64_count += 1
+
+        # Length: header + fan-out + shas-and-crcs + overflow-offsets
+        index_len = 8 + (4 * 256) + (28 * self.count) + (8 * ofs64_count)
+        idx_map = None
         idx_f = open(filename, 'w+b')
         idx_f = open(filename, 'w+b')
-        idx_f.write('\377tOc\0\0\0\2')
-
-        ofs64_ofs = 8 + 4*256 + 28*self.count
-        idx_f.truncate(ofs64_ofs)
-        idx_f.seek(0)
-        idx_map = mmap_readwrite(idx_f, close=False)
-        idx_f.seek(0, SEEK_END)
-        count = _helpers.write_idx(idx_f, idx_map, idx, self.count)
-        assert(count == self.count)
-        idx_map.close()
-        idx_f.write(packbin)
-
-        idx_f.seek(0)
-        idx_sum = Sha1()
-        b = idx_f.read(8 + 4*256)
-        idx_sum.update(b)
-
-        obj_list_sum = Sha1()
-        for b in chunkyreader(idx_f, 20*self.count):
+        try:
+            idx_f.truncate(index_len)
+            idx_map = mmap_readwrite(idx_f, close=False)
+            count = _helpers.write_idx(filename, idx_map, idx, self.count)
+            assert(count == self.count)
+        finally:
+            if idx_map: idx_map.close()
+            idx_f.close()
+
+        idx_f = open(filename, 'a+b')
+        try:
+            idx_f.write(packbin)
+            idx_f.seek(0)
+            idx_sum = Sha1()
+            b = idx_f.read(8 + 4*256)
             idx_sum.update(b)
             idx_sum.update(b)
-            obj_list_sum.update(b)
-        namebase = obj_list_sum.hexdigest()
 
 
-        for b in chunkyreader(idx_f):
-            idx_sum.update(b)
-        idx_f.write(idx_sum.digest())
-        idx_f.close()
+            obj_list_sum = Sha1()
+            for b in chunkyreader(idx_f, 20*self.count):
+                idx_sum.update(b)
+                obj_list_sum.update(b)
+            namebase = obj_list_sum.hexdigest()
 
 
-        return namebase
+            for b in chunkyreader(idx_f):
+                idx_sum.update(b)
+            idx_f.write(idx_sum.digest())
+            return namebase
+        finally:
+            idx_f.close()
 
 
 def _git_date(date):
     return '%d %s' % (date, time.strftime('%z', time.localtime(date)))
 
 
 
 
 def _git_date(date):
     return '%d %s' % (date, time.strftime('%z', time.localtime(date)))
 
 
-def _gitenv():
-    os.environ['GIT_DIR'] = os.path.abspath(repo())
+def _gitenv(repo_dir = None):
+    if not repo_dir:
+        repo_dir = repo()
+    def env():
+        os.environ['GIT_DIR'] = os.path.abspath(repo_dir)
+    return env
 
 
 
 
-def list_refs(refname = None):
+def list_refs(refname = None, repo_dir = None):
     """Generate a list of tuples in the form (refname,hash).
     If a ref name is specified, list only this particular ref.
     """
     argv = ['git', 'show-ref', '--']
     if refname:
         argv += [refname]
     """Generate a list of tuples in the form (refname,hash).
     If a ref name is specified, list only this particular ref.
     """
     argv = ['git', 'show-ref', '--']
     if refname:
         argv += [refname]
-    p = subprocess.Popen(argv, preexec_fn = _gitenv, stdout = subprocess.PIPE)
+    p = subprocess.Popen(argv,
+                         preexec_fn = _gitenv(repo_dir),
+                         stdout = subprocess.PIPE)
     out = p.stdout.read().strip()
     rv = p.wait()  # not fatal
     if rv:
     out = p.stdout.read().strip()
     rv = p.wait()  # not fatal
     if rv:
@@ -709,9 +792,9 @@ def list_refs(refname = None):
             yield (name, sha.decode('hex'))
 
 
             yield (name, sha.decode('hex'))
 
 
-def read_ref(refname):
+def read_ref(refname, repo_dir = None):
     """Get the commit id of the most recent commit made on a given ref."""
     """Get the commit id of the most recent commit made on a given ref."""
-    l = list(list_refs(refname))
+    l = list(list_refs(refname, repo_dir))
     if l:
         assert(len(l) == 1)
         return l[0][1]
     if l:
         assert(len(l) == 1)
         return l[0][1]
@@ -719,7 +802,7 @@ def read_ref(refname):
         return None
 
 
         return None
 
 
-def rev_list(ref, count=None):
+def rev_list(ref, count=None, repo_dir=None):
     """Generate a list of reachable commits in reverse chronological order.
 
     This generator walks through commits, from child to parent, that are
     """Generate a list of reachable commits in reverse chronological order.
 
     This generator walks through commits, from child to parent, that are
@@ -733,8 +816,10 @@ def rev_list(ref, count=None):
     opts = []
     if count:
         opts += ['-n', str(atoi(count))]
     opts = []
     if count:
         opts += ['-n', str(atoi(count))]
-    argv = ['git', 'rev-list', '--pretty=format:%ct'] + opts + [ref, '--']
-    p = subprocess.Popen(argv, preexec_fn = _gitenv, stdout = subprocess.PIPE)
+    argv = ['git', 'rev-list', '--pretty=format:%at'] + opts + [ref, '--']
+    p = subprocess.Popen(argv,
+                         preexec_fn = _gitenv(repo_dir),
+                         stdout = subprocess.PIPE)
     commit = None
     for row in p.stdout:
         s = row.strip()
     commit = None
     for row in p.stdout:
         s = row.strip()
@@ -748,14 +833,18 @@ def rev_list(ref, count=None):
         raise GitError, 'git rev-list returned error %d' % rv
 
 
         raise GitError, 'git rev-list returned error %d' % rv
 
 
-def rev_get_date(ref):
-    """Get the date of the latest commit on the specified ref."""
-    for (date, commit) in rev_list(ref, count=1):
-        return date
-    raise GitError, 'no such commit %r' % ref
+def get_commit_dates(refs, repo_dir=None):
+    """Get the dates for the specified commit refs.  For now, every unique
+       string in refs must resolve to a different commit or this
+       function will fail."""
+    result = []
+    for ref in refs:
+        commit = get_commit_items(ref, cp(repo_dir))
+        result.append(commit.author_sec)
+    return result
 
 
 
 
-def rev_parse(committish):
+def rev_parse(committish, repo_dir=None):
     """Resolve the full hash for 'committish', if it exists.
 
     Should be roughly equivalent to 'git rev-parse'.
     """Resolve the full hash for 'committish', if it exists.
 
     Should be roughly equivalent to 'git rev-parse'.
@@ -763,12 +852,12 @@ def rev_parse(committish):
     Returns the hex value of the hash if it is found, None if 'committish' does
     not correspond to anything.
     """
     Returns the hex value of the hash if it is found, None if 'committish' does
     not correspond to anything.
     """
-    head = read_ref(committish)
+    head = read_ref(committish, repo_dir=repo_dir)
     if head:
         debug2("resolved from ref: commit = %s\n" % head.encode('hex'))
         return head
 
     if head:
         debug2("resolved from ref: commit = %s\n" % head.encode('hex'))
         return head
 
-    pL = PackIdxList(repo('objects/pack'))
+    pL = PackIdxList(repo('objects/pack', repo_dir=repo_dir))
 
     if len(committish) == 40:
         try:
 
     if len(committish) == 40:
         try:
@@ -782,14 +871,14 @@ def rev_parse(committish):
     return None
 
 
     return None
 
 
-def update_ref(refname, newval, oldval):
+def update_ref(refname, newval, oldval, repo_dir=None):
     """Change the commit pointed to by a branch."""
     if not oldval:
         oldval = ''
     assert(refname.startswith('refs/heads/'))
     p = subprocess.Popen(['git', 'update-ref', refname,
                           newval.encode('hex'), oldval.encode('hex')],
     """Change the commit pointed to by a branch."""
     if not oldval:
         oldval = ''
     assert(refname.startswith('refs/heads/'))
     p = subprocess.Popen(['git', 'update-ref', refname,
                           newval.encode('hex'), oldval.encode('hex')],
-                         preexec_fn = _gitenv)
+                         preexec_fn = _gitenv(repo_dir))
     _git_wait('git update-ref', p)
 
 
     _git_wait('git update-ref', p)
 
 
@@ -817,14 +906,18 @@ def init_repo(path=None):
     if parent and not os.path.exists(parent):
         raise GitError('parent directory "%s" does not exist\n' % parent)
     if os.path.exists(d) and not os.path.isdir(os.path.join(d, '.')):
     if parent and not os.path.exists(parent):
         raise GitError('parent directory "%s" does not exist\n' % parent)
     if os.path.exists(d) and not os.path.isdir(os.path.join(d, '.')):
-        raise GitError('"%d" exists but is not a directory\n' % d)
+        raise GitError('"%s" exists but is not a directory\n' % d)
     p = subprocess.Popen(['git', '--bare', 'init'], stdout=sys.stderr,
     p = subprocess.Popen(['git', '--bare', 'init'], stdout=sys.stderr,
-                         preexec_fn = _gitenv)
+                         preexec_fn = _gitenv())
     _git_wait('git init', p)
     # Force the index version configuration in order to ensure bup works
     # regardless of the version of the installed Git binary.
     p = subprocess.Popen(['git', 'config', 'pack.indexVersion', '2'],
     _git_wait('git init', p)
     # Force the index version configuration in order to ensure bup works
     # regardless of the version of the installed Git binary.
     p = subprocess.Popen(['git', 'config', 'pack.indexVersion', '2'],
-                         stdout=sys.stderr, preexec_fn = _gitenv)
+                         stdout=sys.stderr, preexec_fn = _gitenv())
+    _git_wait('git config', p)
+    # Enable the reflog
+    p = subprocess.Popen(['git', 'config', 'core.logAllRefUpdates', 'true'],
+                         stdout=sys.stderr, preexec_fn = _gitenv())
     _git_wait('git config', p)
 
 
     _git_wait('git config', p)
 
 
@@ -838,11 +931,9 @@ def check_repo_or_die(path=None):
         os.stat(repo('objects/pack/.'))
     except OSError, e:
         if e.errno == errno.ENOENT:
         os.stat(repo('objects/pack/.'))
     except OSError, e:
         if e.errno == errno.ENOENT:
-            if repodir != home_repodir:
-                log('error: %r is not a bup/git repository\n' % repo())
-                sys.exit(15)
-            else:
-                init_repo()
+            log('error: %r is not a bup repository; run "bup init"\n'
+                % repo())
+            sys.exit(15)
         else:
             log('error: %s\n' % e)
             sys.exit(14)
         else:
             log('error: %s\n' % e)
             sys.exit(14)
@@ -882,7 +973,7 @@ def _git_wait(cmd, p):
 
 
 def _git_capture(argv):
 
 
 def _git_capture(argv):
-    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn = _gitenv)
+    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn = _gitenv())
     r = p.stdout.read()
     _git_wait(repr(argv), p)
     return r
     r = p.stdout.read()
     _git_wait(repr(argv), p)
     return r
@@ -921,8 +1012,9 @@ class _AbortableIter:
 _ver_warned = 0
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
 _ver_warned = 0
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
-    def __init__(self):
+    def __init__(self, repo_dir = None):
         global _ver_warned
         global _ver_warned
+        self.repo_dir = repo_dir
         wanted = ('1','5','6')
         if ver() < wanted:
             if not _ver_warned:
         wanted = ('1','5','6')
         if ver() < wanted:
             if not _ver_warned:
@@ -948,15 +1040,16 @@ class CatPipe:
                                   stdout=subprocess.PIPE,
                                   close_fds = True,
                                   bufsize = 4096,
                                   stdout=subprocess.PIPE,
                                   close_fds = True,
                                   bufsize = 4096,
-                                  preexec_fn = _gitenv)
+                                  preexec_fn = _gitenv(self.repo_dir))
 
     def _fast_get(self, id):
         if not self.p or self.p.poll() != None:
             self._restart()
         assert(self.p)
 
     def _fast_get(self, id):
         if not self.p or self.p.poll() != None:
             self._restart()
         assert(self.p)
-        assert(self.p.poll() == None)
+        poll_result = self.p.poll()
+        assert(poll_result == None)
         if self.inprogress:
         if self.inprogress:
-            log('_fast_get: opening %r while %r is open'
+            log('_fast_get: opening %r while %r is open\n'
                 % (id, self.inprogress))
         assert(not self.inprogress)
         assert(id.find('\n') < 0)
                 % (id, self.inprogress))
         assert(not self.inprogress)
         assert(id.find('\n') < 0)
@@ -980,7 +1073,8 @@ class CatPipe:
             yield type
             for blob in it:
                 yield blob
             yield type
             for blob in it:
                 yield blob
-            assert(self.p.stdout.readline() == '\n')
+            readline_result = self.p.stdout.readline()
+            assert(readline_result == '\n')
             self.inprogress = None
         except Exception, e:
             it.abort()
             self.inprogress = None
         except Exception, e:
             it.abort()
@@ -995,7 +1089,7 @@ class CatPipe:
 
         p = subprocess.Popen(['git', 'cat-file', type, id],
                              stdout=subprocess.PIPE,
 
         p = subprocess.Popen(['git', 'cat-file', type, id],
                              stdout=subprocess.PIPE,
-                             preexec_fn = _gitenv)
+                             preexec_fn = _gitenv(self.repo_dir))
         for blob in chunkyreader(p.stdout):
             yield blob
         _git_wait('git cat-file', p)
         for blob in chunkyreader(p.stdout):
             yield blob
         _git_wait('git cat-file', p)
@@ -1031,15 +1125,30 @@ class CatPipe:
         except StopIteration:
             log('booger!\n')
 
         except StopIteration:
             log('booger!\n')
 
-def tags():
+
+_cp = {}
+
+def cp(repo_dir=None):
+    """Create a CatPipe object or reuse the already existing one."""
+    global _cp
+    if not repo_dir:
+        repo_dir = repo()
+    repo_dir = os.path.abspath(repo_dir)
+    cp = _cp.get(repo_dir)
+    if not cp:
+        cp = CatPipe(repo_dir)
+        _cp[repo_dir] = cp
+    return cp
+
+
+def tags(repo_dir = None):
     """Return a dictionary of all tags in the form {hash: [tag_names, ...]}."""
     tags = {}
     """Return a dictionary of all tags in the form {hash: [tag_names, ...]}."""
     tags = {}
-    for (n,c) in list_refs():
+    for (n,c) in list_refs(repo_dir = repo_dir):
         if n.startswith('refs/tags/'):
             name = n[10:]
             if not c in tags:
                 tags[c] = []
 
             tags[c].append(name)  # more than one tag can point at 'c'
         if n.startswith('refs/tags/'):
             name = n[10:]
             if not c in tags:
                 tags[c] = []
 
             tags[c].append(name)  # more than one tag can point at 'c'
-
     return tags
     return tags