]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/git.py
git.py: allow the specification of a repo_dir to update_ref()
[bup.git] / lib / bup / git.py
index 82cd787c8ee7f148601747b673d4088bc5f6935e..4825a19820f813227304131ac79374d899000fb2 100644 (file)
@@ -86,18 +86,19 @@ def get_commit_items(id, cp):
     return parse_commit(commit_content)
 
 
     return parse_commit(commit_content)
 
 
-def repo(sub = ''):
+def repo(sub = '', repo_dir=None):
     """Get the path to the git repository or one of its subdirectories."""
     global repodir
     """Get the path to the git repository or one of its subdirectories."""
     global repodir
-    if not repodir:
+    repo_dir = repo_dir or repodir
+    if not repo_dir:
         raise GitError('You should call check_repo_or_die()')
 
     # If there's a .git subdirectory, then the actual repo is in there.
         raise GitError('You should call check_repo_or_die()')
 
     # If there's a .git subdirectory, then the actual repo is in there.
-    gd = os.path.join(repodir, '.git')
+    gd = os.path.join(repo_dir, '.git')
     if os.path.exists(gd):
         repodir = gd
 
     if os.path.exists(gd):
         repodir = gd
 
-    return os.path.join(repodir, sub)
+    return os.path.join(repo_dir, sub)
 
 
 def shorten_hash(s):
 
 
 def shorten_hash(s):
@@ -152,6 +153,7 @@ def mangle_name(name, mode, gitmode):
     disambiguate normal files from segmented ones.
     """
     if stat.S_ISREG(mode) and not stat.S_ISREG(gitmode):
     disambiguate normal files from segmented ones.
     """
     if stat.S_ISREG(mode) and not stat.S_ISREG(gitmode):
+        assert(stat.S_ISDIR(gitmode))
         return name + '.bup'
     elif name.endswith('.bup') or name[:-1].endswith('.bup'):
         return name + '.bupl'
         return name + '.bup'
     elif name.endswith('.bup') or name[:-1].endswith('.bup'):
         return name + '.bupl'
@@ -762,18 +764,24 @@ def _git_date(date):
     return '%d %s' % (date, time.strftime('%z', time.localtime(date)))
 
 
     return '%d %s' % (date, time.strftime('%z', time.localtime(date)))
 
 
-def _gitenv():
-    os.environ['GIT_DIR'] = os.path.abspath(repo())
+def _gitenv(repo_dir = None):
+    if not repo_dir:
+        repo_dir = repo()
+    def env():
+        os.environ['GIT_DIR'] = os.path.abspath(repo_dir)
+    return env
 
 
 
 
-def list_refs(refname = None):
+def list_refs(refname = None, repo_dir = None):
     """Generate a list of tuples in the form (refname,hash).
     If a ref name is specified, list only this particular ref.
     """
     argv = ['git', 'show-ref', '--']
     if refname:
         argv += [refname]
     """Generate a list of tuples in the form (refname,hash).
     If a ref name is specified, list only this particular ref.
     """
     argv = ['git', 'show-ref', '--']
     if refname:
         argv += [refname]
-    p = subprocess.Popen(argv, preexec_fn = _gitenv, stdout = subprocess.PIPE)
+    p = subprocess.Popen(argv,
+                         preexec_fn = _gitenv(repo_dir),
+                         stdout = subprocess.PIPE)
     out = p.stdout.read().strip()
     rv = p.wait()  # not fatal
     if rv:
     out = p.stdout.read().strip()
     rv = p.wait()  # not fatal
     if rv:
@@ -784,9 +792,9 @@ def list_refs(refname = None):
             yield (name, sha.decode('hex'))
 
 
             yield (name, sha.decode('hex'))
 
 
-def read_ref(refname):
+def read_ref(refname, repo_dir = None):
     """Get the commit id of the most recent commit made on a given ref."""
     """Get the commit id of the most recent commit made on a given ref."""
-    l = list(list_refs(refname))
+    l = list(list_refs(refname, repo_dir))
     if l:
         assert(len(l) == 1)
         return l[0][1]
     if l:
         assert(len(l) == 1)
         return l[0][1]
@@ -794,7 +802,7 @@ def read_ref(refname):
         return None
 
 
         return None
 
 
-def rev_list(ref, count=None):
+def rev_list(ref, count=None, repo_dir=None):
     """Generate a list of reachable commits in reverse chronological order.
 
     This generator walks through commits, from child to parent, that are
     """Generate a list of reachable commits in reverse chronological order.
 
     This generator walks through commits, from child to parent, that are
@@ -809,7 +817,9 @@ def rev_list(ref, count=None):
     if count:
         opts += ['-n', str(atoi(count))]
     argv = ['git', 'rev-list', '--pretty=format:%at'] + opts + [ref, '--']
     if count:
         opts += ['-n', str(atoi(count))]
     argv = ['git', 'rev-list', '--pretty=format:%at'] + opts + [ref, '--']
-    p = subprocess.Popen(argv, preexec_fn = _gitenv, stdout = subprocess.PIPE)
+    p = subprocess.Popen(argv,
+                         preexec_fn = _gitenv(repo_dir),
+                         stdout = subprocess.PIPE)
     commit = None
     for row in p.stdout:
         s = row.strip()
     commit = None
     for row in p.stdout:
         s = row.strip()
@@ -823,18 +833,18 @@ def rev_list(ref, count=None):
         raise GitError, 'git rev-list returned error %d' % rv
 
 
         raise GitError, 'git rev-list returned error %d' % rv
 
 
-def get_commit_dates(refs):
+def get_commit_dates(refs, repo_dir=None):
     """Get the dates for the specified commit refs.  For now, every unique
        string in refs must resolve to a different commit or this
        function will fail."""
     result = []
     for ref in refs:
     """Get the dates for the specified commit refs.  For now, every unique
        string in refs must resolve to a different commit or this
        function will fail."""
     result = []
     for ref in refs:
-        commit = get_commit_items(ref, cp())
+        commit = get_commit_items(ref, cp(repo_dir))
         result.append(commit.author_sec)
     return result
 
 
         result.append(commit.author_sec)
     return result
 
 
-def rev_parse(committish):
+def rev_parse(committish, repo_dir=None):
     """Resolve the full hash for 'committish', if it exists.
 
     Should be roughly equivalent to 'git rev-parse'.
     """Resolve the full hash for 'committish', if it exists.
 
     Should be roughly equivalent to 'git rev-parse'.
@@ -842,12 +852,12 @@ def rev_parse(committish):
     Returns the hex value of the hash if it is found, None if 'committish' does
     not correspond to anything.
     """
     Returns the hex value of the hash if it is found, None if 'committish' does
     not correspond to anything.
     """
-    head = read_ref(committish)
+    head = read_ref(committish, repo_dir=repo_dir)
     if head:
         debug2("resolved from ref: commit = %s\n" % head.encode('hex'))
         return head
 
     if head:
         debug2("resolved from ref: commit = %s\n" % head.encode('hex'))
         return head
 
-    pL = PackIdxList(repo('objects/pack'))
+    pL = PackIdxList(repo('objects/pack', repo_dir=repo_dir))
 
     if len(committish) == 40:
         try:
 
     if len(committish) == 40:
         try:
@@ -861,14 +871,14 @@ def rev_parse(committish):
     return None
 
 
     return None
 
 
-def update_ref(refname, newval, oldval):
+def update_ref(refname, newval, oldval, repo_dir=None):
     """Change the commit pointed to by a branch."""
     if not oldval:
         oldval = ''
     assert(refname.startswith('refs/heads/'))
     p = subprocess.Popen(['git', 'update-ref', refname,
                           newval.encode('hex'), oldval.encode('hex')],
     """Change the commit pointed to by a branch."""
     if not oldval:
         oldval = ''
     assert(refname.startswith('refs/heads/'))
     p = subprocess.Popen(['git', 'update-ref', refname,
                           newval.encode('hex'), oldval.encode('hex')],
-                         preexec_fn = _gitenv)
+                         preexec_fn = _gitenv(repo_dir))
     _git_wait('git update-ref', p)
 
 
     _git_wait('git update-ref', p)
 
 
@@ -898,16 +908,16 @@ def init_repo(path=None):
     if os.path.exists(d) and not os.path.isdir(os.path.join(d, '.')):
         raise GitError('"%s" exists but is not a directory\n' % d)
     p = subprocess.Popen(['git', '--bare', 'init'], stdout=sys.stderr,
     if os.path.exists(d) and not os.path.isdir(os.path.join(d, '.')):
         raise GitError('"%s" exists but is not a directory\n' % d)
     p = subprocess.Popen(['git', '--bare', 'init'], stdout=sys.stderr,
-                         preexec_fn = _gitenv)
+                         preexec_fn = _gitenv())
     _git_wait('git init', p)
     # Force the index version configuration in order to ensure bup works
     # regardless of the version of the installed Git binary.
     p = subprocess.Popen(['git', 'config', 'pack.indexVersion', '2'],
     _git_wait('git init', p)
     # Force the index version configuration in order to ensure bup works
     # regardless of the version of the installed Git binary.
     p = subprocess.Popen(['git', 'config', 'pack.indexVersion', '2'],
-                         stdout=sys.stderr, preexec_fn = _gitenv)
+                         stdout=sys.stderr, preexec_fn = _gitenv())
     _git_wait('git config', p)
     # Enable the reflog
     p = subprocess.Popen(['git', 'config', 'core.logAllRefUpdates', 'true'],
     _git_wait('git config', p)
     # Enable the reflog
     p = subprocess.Popen(['git', 'config', 'core.logAllRefUpdates', 'true'],
-                         stdout=sys.stderr, preexec_fn = _gitenv)
+                         stdout=sys.stderr, preexec_fn = _gitenv())
     _git_wait('git config', p)
 
 
     _git_wait('git config', p)
 
 
@@ -963,7 +973,7 @@ def _git_wait(cmd, p):
 
 
 def _git_capture(argv):
 
 
 def _git_capture(argv):
-    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn = _gitenv)
+    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn = _gitenv())
     r = p.stdout.read()
     _git_wait(repr(argv), p)
     return r
     r = p.stdout.read()
     _git_wait(repr(argv), p)
     return r
@@ -1002,8 +1012,9 @@ class _AbortableIter:
 _ver_warned = 0
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
 _ver_warned = 0
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
-    def __init__(self):
+    def __init__(self, repo_dir = None):
         global _ver_warned
         global _ver_warned
+        self.repo_dir = repo_dir
         wanted = ('1','5','6')
         if ver() < wanted:
             if not _ver_warned:
         wanted = ('1','5','6')
         if ver() < wanted:
             if not _ver_warned:
@@ -1029,7 +1040,7 @@ class CatPipe:
                                   stdout=subprocess.PIPE,
                                   close_fds = True,
                                   bufsize = 4096,
                                   stdout=subprocess.PIPE,
                                   close_fds = True,
                                   bufsize = 4096,
-                                  preexec_fn = _gitenv)
+                                  preexec_fn = _gitenv(self.repo_dir))
 
     def _fast_get(self, id):
         if not self.p or self.p.poll() != None:
 
     def _fast_get(self, id):
         if not self.p or self.p.poll() != None:
@@ -1078,7 +1089,7 @@ class CatPipe:
 
         p = subprocess.Popen(['git', 'cat-file', type, id],
                              stdout=subprocess.PIPE,
 
         p = subprocess.Popen(['git', 'cat-file', type, id],
                              stdout=subprocess.PIPE,
-                             preexec_fn = _gitenv)
+                             preexec_fn = _gitenv(self.repo_dir))
         for blob in chunkyreader(p.stdout):
             yield blob
         _git_wait('git cat-file', p)
         for blob in chunkyreader(p.stdout):
             yield blob
         _git_wait('git cat-file', p)
@@ -1115,28 +1126,29 @@ class CatPipe:
             log('booger!\n')
 
 
             log('booger!\n')
 
 
-_cp = (None, None)
+_cp = {}
 
 
-def cp():
-    """Create a CatPipe object or reuse an already existing one."""
+def cp(repo_dir=None):
+    """Create a CatPipe object or reuse the already existing one."""
     global _cp
     global _cp
-    cp_dir, cp = _cp
-    cur_dir = os.path.realpath(repo())
-    if cur_dir != cp_dir:
-        cp = CatPipe()
-        _cp = (cur_dir, cp)
+    if not repo_dir:
+        repo_dir = repo()
+    repo_dir = os.path.abspath(repo_dir)
+    cp = _cp.get(repo_dir)
+    if not cp:
+        cp = CatPipe(repo_dir)
+        _cp[repo_dir] = cp
     return cp
 
 
     return cp
 
 
-def tags():
+def tags(repo_dir = None):
     """Return a dictionary of all tags in the form {hash: [tag_names, ...]}."""
     tags = {}
     """Return a dictionary of all tags in the form {hash: [tag_names, ...]}."""
     tags = {}
-    for (n,c) in list_refs():
+    for (n,c) in list_refs(repo_dir = repo_dir):
         if n.startswith('refs/tags/'):
             name = n[10:]
             if not c in tags:
                 tags[c] = []
 
             tags[c].append(name)  # more than one tag can point at 'c'
         if n.startswith('refs/tags/'):
             name = n[10:]
             if not c in tags:
                 tags[c] = []
 
             tags[c].append(name)  # more than one tag can point at 'c'
-
     return tags
     return tags