]> arthur.barton.de Git - bup.git/commitdiff
Allow the specification of a repo_dir to some VFS and git operations
authorRob Browning <rlb@defaultvalue.org>
Wed, 22 Jan 2014 18:55:16 +0000 (12:55 -0600)
committerRob Browning <rlb@defaultvalue.org>
Fri, 17 Oct 2014 19:42:34 +0000 (14:42 -0500)
Previously, these VFS and git operations would only operate on the
default repository (git.repo()).

Have vfs.cp() handle more than one repository (via an internal cache).

Signed-off-by: Rob Browning <rlb@defaultvalue.org>
Tested-by: Rob Browning <rlb@defaultvalue.org>
lib/bup/git.py
lib/bup/vfs.py

index a8f3729f05f439ef545e3ecb3be5cfa316f493b2..543d55d91af6551bccd49b73a7b65347e112d845 100644 (file)
@@ -86,18 +86,19 @@ def get_commit_items(id, cp):
     return parse_commit(commit_content)
 
 
-def repo(sub = ''):
+def repo(sub = '', repo_dir=None):
     """Get the path to the git repository or one of its subdirectories."""
     global repodir
-    if not repodir:
+    repo_dir = repo_dir or repodir
+    if not repo_dir:
         raise GitError('You should call check_repo_or_die()')
 
     # If there's a .git subdirectory, then the actual repo is in there.
-    gd = os.path.join(repodir, '.git')
+    gd = os.path.join(repo_dir, '.git')
     if os.path.exists(gd):
         repodir = gd
 
-    return os.path.join(repodir, sub)
+    return os.path.join(repo_dir, sub)
 
 
 def shorten_hash(s):
@@ -771,14 +772,16 @@ def _gitenv(repo_dir = None):
     return env
 
 
-def list_refs(refname = None):
+def list_refs(refname = None, repo_dir = None):
     """Generate a list of tuples in the form (refname,hash).
     If a ref name is specified, list only this particular ref.
     """
     argv = ['git', 'show-ref', '--']
     if refname:
         argv += [refname]
-    p = subprocess.Popen(argv, preexec_fn = _gitenv(), stdout = subprocess.PIPE)
+    p = subprocess.Popen(argv,
+                         preexec_fn = _gitenv(repo_dir),
+                         stdout = subprocess.PIPE)
     out = p.stdout.read().strip()
     rv = p.wait()  # not fatal
     if rv:
@@ -789,9 +792,9 @@ def list_refs(refname = None):
             yield (name, sha.decode('hex'))
 
 
-def read_ref(refname):
+def read_ref(refname, repo_dir = None):
     """Get the commit id of the most recent commit made on a given ref."""
-    l = list(list_refs(refname))
+    l = list(list_refs(refname, repo_dir))
     if l:
         assert(len(l) == 1)
         return l[0][1]
@@ -799,7 +802,7 @@ def read_ref(refname):
         return None
 
 
-def rev_list(ref, count=None):
+def rev_list(ref, count=None, repo_dir=None):
     """Generate a list of reachable commits in reverse chronological order.
 
     This generator walks through commits, from child to parent, that are
@@ -814,7 +817,9 @@ def rev_list(ref, count=None):
     if count:
         opts += ['-n', str(atoi(count))]
     argv = ['git', 'rev-list', '--pretty=format:%at'] + opts + [ref, '--']
-    p = subprocess.Popen(argv, preexec_fn = _gitenv(), stdout = subprocess.PIPE)
+    p = subprocess.Popen(argv,
+                         preexec_fn = _gitenv(repo_dir),
+                         stdout = subprocess.PIPE)
     commit = None
     for row in p.stdout:
         s = row.strip()
@@ -828,18 +833,18 @@ def rev_list(ref, count=None):
         raise GitError, 'git rev-list returned error %d' % rv
 
 
-def get_commit_dates(refs):
+def get_commit_dates(refs, repo_dir=None):
     """Get the dates for the specified commit refs.  For now, every unique
        string in refs must resolve to a different commit or this
        function will fail."""
     result = []
     for ref in refs:
-        commit = get_commit_items(ref, cp())
+        commit = get_commit_items(ref, cp(repo_dir))
         result.append(commit.author_sec)
     return result
 
 
-def rev_parse(committish):
+def rev_parse(committish, repo_dir=None):
     """Resolve the full hash for 'committish', if it exists.
 
     Should be roughly equivalent to 'git rev-parse'.
@@ -847,12 +852,12 @@ def rev_parse(committish):
     Returns the hex value of the hash if it is found, None if 'committish' does
     not correspond to anything.
     """
-    head = read_ref(committish)
+    head = read_ref(committish, repo_dir=repo_dir)
     if head:
         debug2("resolved from ref: commit = %s\n" % head.encode('hex'))
         return head
 
-    pL = PackIdxList(repo('objects/pack'))
+    pL = PackIdxList(repo('objects/pack', repo_dir=repo_dir))
 
     if len(committish) == 40:
         try:
@@ -1121,28 +1126,29 @@ class CatPipe:
             log('booger!\n')
 
 
-_cp = (None, None)
+_cp = {}
 
-def cp():
-    """Create a CatPipe object or reuse an already existing one."""
+def cp(repo_dir=None):
+    """Create a CatPipe object or reuse the already existing one."""
     global _cp
-    cp_dir, cp = _cp
-    cur_dir = os.path.realpath(repo())
-    if cur_dir != cp_dir:
-        cp = CatPipe()
-        _cp = (cur_dir, cp)
+    if not repo_dir:
+        repo_dir = repo()
+    repo_dir = os.path.abspath(repo_dir)
+    cp = _cp.get(repo_dir)
+    if not cp:
+        cp = CatPipe(repo_dir)
+        _cp[repo_dir] = cp
     return cp
 
 
-def tags():
+def tags(repo_dir = None):
     """Return a dictionary of all tags in the form {hash: [tag_names, ...]}."""
     tags = {}
-    for (n,c) in list_refs():
+    for (n,c) in list_refs(repo_dir = repo_dir):
         if n.startswith('refs/tags/'):
             name = n[10:]
             if not c in tags:
                 tags[c] = []
 
             tags[c].append(name)  # more than one tag can point at 'c'
-
     return tags
index d76f26f2f98322a09a68314fa90251be13200a3e..3da55d7c90f7a1d3c6e63625eb839431c34904de 100644 (file)
@@ -33,44 +33,44 @@ class TooManySymlinks(NodeError):
     pass
 
 
-def _treeget(hash):
-    it = cp().get(hash.encode('hex'))
+def _treeget(hash, repo_dir=None):
+    it = cp(repo_dir).get(hash.encode('hex'))
     type = it.next()
     assert(type == 'tree')
     return git.tree_decode(''.join(it))
 
 
-def _tree_decode(hash):
+def _tree_decode(hash, repo_dir=None):
     tree = [(int(name,16),stat.S_ISDIR(mode),sha)
             for (mode,name,sha)
-            in _treeget(hash)]
+            in _treeget(hash, repo_dir)]
     assert(tree == list(sorted(tree)))
     return tree
 
 
-def _chunk_len(hash):
-    return sum(len(b) for b in cp().join(hash.encode('hex')))
+def _chunk_len(hash, repo_dir=None):
+    return sum(len(b) for b in cp(repo_dir).join(hash.encode('hex')))
 
 
-def _last_chunk_info(hash):
-    tree = _tree_decode(hash)
+def _last_chunk_info(hash, repo_dir=None):
+    tree = _tree_decode(hash, repo_dir)
     assert(tree)
     (ofs,isdir,sha) = tree[-1]
     if isdir:
-        (subofs, sublen) = _last_chunk_info(sha)
+        (subofs, sublen) = _last_chunk_info(sha, repo_dir)
         return (ofs+subofs, sublen)
     else:
         return (ofs, _chunk_len(sha))
 
 
-def _total_size(hash):
-    (lastofs, lastsize) = _last_chunk_info(hash)
+def _total_size(hash, repo_dir=None):
+    (lastofs, lastsize) = _last_chunk_info(hash, repo_dir)
     return lastofs + lastsize
 
 
-def _chunkiter(hash, startofs):
+def _chunkiter(hash, startofs, repo_dir=None):
     assert(startofs >= 0)
-    tree = _tree_decode(hash)
+    tree = _tree_decode(hash, repo_dir)
 
     # skip elements before startofs
     for i in xrange(len(tree)):
@@ -85,20 +85,20 @@ def _chunkiter(hash, startofs):
         if skipmore < 0:
             skipmore = 0
         if isdir:
-            for b in _chunkiter(sha, skipmore):
+            for b in _chunkiter(sha, skipmore, repo_dir):
                 yield b
         else:
-            yield ''.join(cp().join(sha.encode('hex')))[skipmore:]
+            yield ''.join(cp(repo_dir).join(sha.encode('hex')))[skipmore:]
 
 
 class _ChunkReader:
-    def __init__(self, hash, isdir, startofs):
+    def __init__(self, hash, isdir, startofs, repo_dir=None):
         if isdir:
-            self.it = _chunkiter(hash, startofs)
+            self.it = _chunkiter(hash, startofs, repo_dir)
             self.blob = None
         else:
             self.it = None
-            self.blob = ''.join(cp().join(hash.encode('hex')))[startofs:]
+            self.blob = ''.join(cp(repo_dir).join(hash.encode('hex')))[startofs:]
         self.ofs = startofs
 
     def next(self, size):
@@ -121,12 +121,13 @@ class _ChunkReader:
 
 
 class _FileReader(object):
-    def __init__(self, hash, size, isdir):
+    def __init__(self, hash, size, isdir, repo_dir=None):
         self.hash = hash
         self.ofs = 0
         self.size = size
         self.isdir = isdir
         self.reader = None
+        self._repo_dir = repo_dir
 
     def seek(self, ofs):
         if ofs > self.size:
@@ -143,7 +144,8 @@ class _FileReader(object):
         if count < 0:
             count = self.size - self.ofs
         if not self.reader or self.reader.ofs != self.ofs:
-            self.reader = _ChunkReader(self.hash, self.isdir, self.ofs)
+            self.reader = _ChunkReader(self.hash, self.isdir, self.ofs,
+                                       self._repo_dir)
         try:
             buf = self.reader.next(count)
         except:
@@ -158,12 +160,13 @@ class _FileReader(object):
 
 class Node(object):
     """Base class for file representation."""
-    def __init__(self, parent, name, mode, hash):
+    def __init__(self, parent, name, mode, hash, repo_dir=None):
         self.parent = parent
         self.name = name
         self.mode = mode
         self.hash = hash
         self.ctime = self.mtime = self.atime = 0
+        self._repo_dir = repo_dir
         self._subs = None
         self._metadata = None
 
@@ -309,8 +312,8 @@ class Node(object):
 
 class File(Node):
     """A normal file from bup's repository."""
-    def __init__(self, parent, name, mode, hash, bupmode):
-        Node.__init__(self, parent, name, mode, hash)
+    def __init__(self, parent, name, mode, hash, bupmode, repo_dir=None):
+        Node.__init__(self, parent, name, mode, hash, repo_dir)
         self.bupmode = bupmode
         self._cached_size = None
         self._filereader = None
@@ -323,7 +326,8 @@ class File(Node):
         # object here so we're not constantly re-seeking.
         if not self._filereader:
             self._filereader = _FileReader(self.hash, self.size(),
-                                           self.bupmode == git.BUP_CHUNKED)
+                                           self.bupmode == git.BUP_CHUNKED,
+                                           repo_dir = self._repo_dir)
         self._filereader.seek(0)
         return self._filereader
 
@@ -332,9 +336,11 @@ class File(Node):
         if self._cached_size == None:
             debug1('<<<<File.size() is calculating (for %r)...\n' % self.name)
             if self.bupmode == git.BUP_CHUNKED:
-                self._cached_size = _total_size(self.hash)
+                self._cached_size = _total_size(self.hash,
+                                                repo_dir = self._repo_dir)
             else:
-                self._cached_size = _chunk_len(self.hash)
+                self._cached_size = _chunk_len(self.hash,
+                                               repo_dir = self._repo_dir)
             debug1('<<<<File.size() done.\n')
         return self._cached_size
 
@@ -342,8 +348,9 @@ class File(Node):
 _symrefs = 0
 class Symlink(File):
     """A symbolic link from bup's repository."""
-    def __init__(self, parent, name, hash, bupmode):
-        File.__init__(self, parent, name, 0120000, hash, bupmode)
+    def __init__(self, parent, name, hash, bupmode, repo_dir=None):
+        File.__init__(self, parent, name, 0120000, hash, bupmode,
+                      repo_dir = repo_dir)
 
     def size(self):
         """Get the file size of the file at which this link points."""
@@ -351,7 +358,7 @@ class Symlink(File):
 
     def readlink(self):
         """Get the path that this link points at."""
-        return ''.join(cp().join(self.hash.encode('hex')))
+        return ''.join(cp(self._repo_dir).join(self.hash.encode('hex')))
 
     def dereference(self):
         """Get the node that this link points at.
@@ -381,8 +388,9 @@ class Symlink(File):
 
 class FakeSymlink(Symlink):
     """A symlink that is not stored in the bup repository."""
-    def __init__(self, parent, name, toname):
-        Symlink.__init__(self, parent, name, EMPTY_SHA, git.BUP_NORMAL)
+    def __init__(self, parent, name, toname, repo_dir=None):
+        Symlink.__init__(self, parent, name, EMPTY_SHA, git.BUP_NORMAL,
+                         repo_dir = repo_dir)
         self.toname = toname
 
     def readlink(self):
@@ -413,11 +421,11 @@ class Dir(Node):
 
     def _mksubs(self):
         self._subs = {}
-        it = cp().get(self.hash.encode('hex'))
+        it = cp(self._repo_dir).get(self.hash.encode('hex'))
         type = it.next()
         if type == 'commit':
             del it
-            it = cp().get(self.hash.encode('hex') + ':')
+            it = cp(self._repo_dir).get(self.hash.encode('hex') + ':')
             type = it.next()
         assert(type == 'tree')
         for (mode,mangled_name,sha) in git.tree_decode(''.join(it)):
@@ -431,11 +439,13 @@ class Dir(Node):
             if bupmode == git.BUP_CHUNKED:
                 mode = GIT_MODE_FILE
             if stat.S_ISDIR(mode):
-                self._subs[name] = Dir(self, name, mode, sha)
+                self._subs[name] = Dir(self, name, mode, sha, self._repo_dir)
             elif stat.S_ISLNK(mode):
-                self._subs[name] = Symlink(self, name, sha, bupmode)
+                self._subs[name] = Symlink(self, name, sha, bupmode,
+                                           self._repo_dir)
             else:
-                self._subs[name] = File(self, name, mode, sha, bupmode)
+                self._subs[name] = File(self, name, mode, sha, bupmode,
+                                        self._repo_dir)
 
     def metadata(self):
         """Return this Dir's Metadata() object, if any."""
@@ -464,15 +474,15 @@ class CommitDir(Node):
     separation helps us avoid having too much directories on the same level as
     the number of commits grows big.
     """
-    def __init__(self, parent, name):
-        Node.__init__(self, parent, name, GIT_MODE_TREE, EMPTY_SHA)
+    def __init__(self, parent, name, repo_dir=None):
+        Node.__init__(self, parent, name, GIT_MODE_TREE, EMPTY_SHA, repo_dir)
 
     def _mksubs(self):
         self._subs = {}
-        refs = git.list_refs()
+        refs = git.list_refs(repo_dir = self._repo_dir)
         for ref in refs:
             #debug2('ref name: %s\n' % ref[0])
-            revs = git.rev_list(ref[1].encode('hex'))
+            revs = git.rev_list(ref[1].encode('hex'), repo_dir = self._repo_dir)
             for (date, commit) in revs:
                 #debug2('commit: %s  date: %s\n' % (commit.encode('hex'), date))
                 commithex = commit.encode('hex')
@@ -480,7 +490,7 @@ class CommitDir(Node):
                 dirname = commithex[2:]
                 n1 = self._subs.get(containername)
                 if not n1:
-                    n1 = CommitList(self, containername)
+                    n1 = CommitList(self, containername, self._repo_dir)
                     self._subs[containername] = n1
 
                 if n1.commits.get(dirname):
@@ -492,32 +502,33 @@ class CommitDir(Node):
 
 class CommitList(Node):
     """A list of commits with hashes that start with the current node's name."""
-    def __init__(self, parent, name):
-        Node.__init__(self, parent, name, GIT_MODE_TREE, EMPTY_SHA)
+    def __init__(self, parent, name, repo_dir=None):
+        Node.__init__(self, parent, name, GIT_MODE_TREE, EMPTY_SHA, repo_dir)
         self.commits = {}
 
     def _mksubs(self):
         self._subs = {}
         for (name, (hash, date)) in self.commits.items():
-            n1 = Dir(self, name, GIT_MODE_TREE, hash)
+            n1 = Dir(self, name, GIT_MODE_TREE, hash, self._repo_dir)
             n1.ctime = n1.mtime = date
             self._subs[name] = n1
 
 
 class TagDir(Node):
     """A directory that contains all tags in the repository."""
-    def __init__(self, parent, name):
-        Node.__init__(self, parent, name, GIT_MODE_TREE, EMPTY_SHA)
+    def __init__(self, parent, name, repo_dir = None):
+        Node.__init__(self, parent, name, GIT_MODE_TREE, EMPTY_SHA, repo_dir)
 
     def _mksubs(self):
         self._subs = {}
-        for (name, sha) in git.list_refs():
+        for (name, sha) in git.list_refs(repo_dir = self._repo_dir):
             if name.startswith('refs/tags/'):
                 name = name[10:]
-                date = git.get_commit_dates([sha.encode('hex')])[0]
+                date = git.get_commit_dates([sha.encode('hex')],
+                                            repo_dir=self._repo_dir)[0]
                 commithex = sha.encode('hex')
                 target = '../.commit/%s/%s' % (commithex[:2], commithex[2:])
-                tag1 = FakeSymlink(self, name, target)
+                tag1 = FakeSymlink(self, name, target, repo_dir, self._repo_dir)
                 tag1.ctime = tag1.mtime = date
                 self._subs[name] = tag1
 
@@ -528,34 +539,35 @@ class BranchList(Node):
     Represents each commit as a symlink that points to the commit directory in
     /.commit/??/ . The symlink is named after the commit date.
     """
-    def __init__(self, parent, name, hash):
-        Node.__init__(self, parent, name, GIT_MODE_TREE, hash)
+    def __init__(self, parent, name, hash, repo_dir=None):
+        Node.__init__(self, parent, name, GIT_MODE_TREE, hash, repo_dir)
 
     def _mksubs(self):
         self._subs = {}
 
-        tags = git.tags()
+        tags = git.tags(repo_dir = self._repo_dir)
 
-        revs = list(git.rev_list(self.hash.encode('hex')))
+        revs = list(git.rev_list(self.hash.encode('hex'),
+                                 repo_dir=self._repo_dir))
         latest = revs[0]
         for (date, commit) in revs:
             l = time.localtime(date)
             ls = time.strftime('%Y-%m-%d-%H%M%S', l)
             commithex = commit.encode('hex')
             target = '../.commit/%s/%s' % (commithex[:2], commithex[2:])
-            n1 = FakeSymlink(self, ls, target)
+            n1 = FakeSymlink(self, ls, target, self._repo_dir)
             n1.ctime = n1.mtime = date
             self._subs[ls] = n1
 
             for tag in tags.get(commit, []):
-                t1 = FakeSymlink(self, tag, target)
+                t1 = FakeSymlink(self, tag, target, self._repo_dir)
                 t1.ctime = t1.mtime = date
                 self._subs[tag] = t1
 
         (date, commit) = latest
         commithex = commit.encode('hex')
         target = '../.commit/%s/%s' % (commithex[:2], commithex[2:])
-        n1 = FakeSymlink(self, 'latest', target)
+        n1 = FakeSymlink(self, 'latest', target, self._repo_dir)
         n1.ctime = n1.mtime = date
         self._subs['latest'] = n1
 
@@ -569,25 +581,25 @@ class RefList(Node):
     Also, a special sub-node named '.commit' contains all commit directories
     that are reachable via a ref (e.g. a branch).  See CommitDir for details.
     """
-    def __init__(self, parent):
-        Node.__init__(self, parent, '/', GIT_MODE_TREE, EMPTY_SHA)
+    def __init__(self, parent, repo_dir=None):
+        Node.__init__(self, parent, '/', GIT_MODE_TREE, EMPTY_SHA, repo_dir)
 
     def _mksubs(self):
         self._subs = {}
 
-        commit_dir = CommitDir(self, '.commit')
+        commit_dir = CommitDir(self, '.commit', self._repo_dir)
         self._subs['.commit'] = commit_dir
 
-        tag_dir = TagDir(self, '.tag')
+        tag_dir = TagDir(self, '.tag', self._repo_dir)
         self._subs['.tag'] = tag_dir
 
-        refs_info = [(name[11:], sha) for (name,sha) in git.list_refs() \
+        refs_info = [(name[11:], sha) for (name,sha)
+                     in git.list_refs(repo_dir=self._repo_dir)
                      if name.startswith('refs/heads/')]
-
         dates = git.get_commit_dates([sha.encode('hex')
-                                      for (name, sha) in refs_info])
-
+                                      for (name, sha) in refs_info],
+                                     repo_dir=self._repo_dir)
         for (name, sha), date in zip(refs_info, dates):
-            n1 = BranchList(self, name, sha)
+            n1 = BranchList(self, name, sha, self._repo_dir)
             n1.ctime = n1.mtime = date
             self._subs[name] = n1