]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/git.py
walk_object: rewrite as nonrecursive
[bup.git] / lib / bup / git.py
index 543d55d91af6551bccd49b73a7b65347e112d845..460d6b11a2d9273aa0b2dfd4f0bfcb2599df5fe4 100644 (file)
@@ -2,11 +2,19 @@
 bup repositories are in Git format. This library allows us to
 interact with the Git data structures.
 """
-import os, sys, zlib, time, subprocess, struct, stat, re, tempfile, glob
+
+import errno, os, sys, zlib, time, subprocess, struct, stat, re, tempfile, glob
 from collections import namedtuple
+from itertools import islice
+
+from bup import _helpers, hashsplit, path, midx, bloom, xstat
+from bup.helpers import (Sha1, add_error, chunkyreader, debug1, debug2,
+                         fdatasync,
+                         hostname, localtime, log, merge_iter,
+                         mmap_read, mmap_readwrite,
+                         progress, qprogress, unlink, username, userfullname,
+                         utc_offset_str)
 
-from bup.helpers import *
-from bup import _helpers, path, midx, bloom, xstat
 
 max_pack_size = 1000*1000*1000  # larger packs will slow down pruning
 max_pack_objects = 200*1000  # cache memory usage is about 83 bytes per object
@@ -86,6 +94,19 @@ def get_commit_items(id, cp):
     return parse_commit(commit_content)
 
 
+def _local_git_date_str(epoch_sec):
+    return '%d %s' % (epoch_sec, utc_offset_str(epoch_sec))
+
+
+def _git_date_str(epoch_sec, tz_offset_sec):
+    offs =  tz_offset_sec // 60
+    return '%d %s%02d%02d' \
+        % (epoch_sec,
+           '+' if offs >= 0 else '-',
+           abs(offs) // 60,
+           abs(offs) % 60)
+
+
 def repo(sub = '', repo_dir=None):
     """Get the path to the git repository or one of its subdirectories."""
     global repodir
@@ -128,7 +149,7 @@ def auto_midx(objdir):
     args = [path.exe(), 'midx', '--auto', '--dir', objdir]
     try:
         rv = subprocess.call(args, stdout=open('/dev/null', 'w'))
-    except OSError, e:
+    except OSError as e:
         # make sure 'args' gets printed to help with debugging
         add_error('%r: exception: %s' % (args, e))
         raise
@@ -138,7 +159,7 @@ def auto_midx(objdir):
     args = [path.exe(), 'bloom', '--dir', objdir]
     try:
         rv = subprocess.call(args, stdout=open('/dev/null', 'w'))
-    except OSError, e:
+    except OSError as e:
         # make sure 'args' gets printed to help with debugging
         add_error('%r: exception: %s' % (args, e))
         raise
@@ -162,7 +183,7 @@ def mangle_name(name, mode, gitmode):
 
 
 (BUP_NORMAL, BUP_CHUNKED) = (0,1)
-def demangle_name(name):
+def demangle_name(name, mode):
     """Remove name mangling from a file name, if necessary.
 
     The return value is a tuple (demangled_filename,mode), where mode is one of
@@ -177,6 +198,9 @@ def demangle_name(name):
         return (name[:-5], BUP_NORMAL)
     elif name.endswith('.bup'):
         return (name[:-4], BUP_CHUNKED)
+    elif name.endswith('.bupm'):
+        return (name[:-5],
+                BUP_CHUNKED if stat.S_ISDIR(mode) else BUP_NORMAL)
     else:
         return (name, BUP_NORMAL)
 
@@ -228,6 +252,8 @@ def tree_decode(buf):
 
 
 def _encode_packobj(type, content, compression_level=1):
+    if compression_level not in (0, 1, 2, 3, 4, 5, 6, 7, 8, 9):
+        raise ValueError('invalid compression level %s' % compression_level)
     szout = ''
     sz = len(content)
     szbits = (sz & 0x0f) | (_typemap[type]<<4)
@@ -239,10 +265,6 @@ def _encode_packobj(type, content, compression_level=1):
             break
         szbits = sz & 0x7f
         sz >>= 7
-    if compression_level > 9:
-        compression_level = 9
-    elif compression_level < 0:
-        compression_level = 0
     z = zlib.compressobj(compression_level)
     yield szout
     yield z.compress(content)
@@ -498,7 +520,7 @@ class PackIdxList:
                 if not d.get(full):
                     try:
                         ix = open_idx(full)
-                    except GitError, e:
+                    except GitError as e:
                         add_error(e)
                         continue
                     d[full] = ix
@@ -555,25 +577,44 @@ def idxmerge(idxlist, final_progress=True):
 def _make_objcache():
     return PackIdxList(repo('objects/pack'))
 
+# bup-gc assumes that it can disable all PackWriter activities
+# (bloom/midx/cache) via the constructor and close() arguments.
+
 class PackWriter:
     """Writes Git objects inside a pack file."""
-    def __init__(self, objcache_maker=_make_objcache, compression_level=1):
+    def __init__(self, objcache_maker=_make_objcache, compression_level=1,
+                 run_midx=True, on_pack_finish=None):
+        self.file = None
+        self.parentfd = None
         self.count = 0
         self.outbytes = 0
         self.filename = None
-        self.file = None
         self.idx = None
         self.objcache_maker = objcache_maker
         self.objcache = None
         self.compression_level = compression_level
+        self.run_midx=run_midx
+        self.on_pack_finish = on_pack_finish
 
     def __del__(self):
         self.close()
 
     def _open(self):
         if not self.file:
-            (fd,name) = tempfile.mkstemp(suffix='.pack', dir=repo('objects'))
-            self.file = os.fdopen(fd, 'w+b')
+            objdir = dir=repo('objects')
+            fd, name = tempfile.mkstemp(suffix='.pack', dir=objdir)
+            try:
+                self.file = os.fdopen(fd, 'w+b')
+            except:
+                os.close(fd)
+                raise
+            try:
+                self.parentfd = os.open(objdir, os.O_RDONLY)
+            except:
+                f = self.file
+                self.file = None
+                f.close()
+                raise
             assert(name.endswith('.pack'))
             self.filename = name[:-5]
             self.file.write('PACK\0\0\0\2\0\0\0\0')
@@ -590,7 +631,7 @@ class PackWriter:
         oneblob = ''.join(datalist)
         try:
             f.write(oneblob)
-        except IOError, e:
+        except IOError as e:
             raise GitError, e, sys.exc_info()[2]
         nw = len(oneblob)
         crc = zlib.crc32(oneblob) & 0xffffffff
@@ -618,7 +659,7 @@ class PackWriter:
 
     def breakpoint(self):
         """Clear byte and object counts and return the last processed id."""
-        id = self._end()
+        id = self._end(self.run_midx)
         self.outbytes = self.count = 0
         return id
 
@@ -634,11 +675,15 @@ class PackWriter:
         self._require_objcache()
         return self.objcache.exists(id, want_source=want_source)
 
+    def write(self, sha, type, content):
+        """Write an object to the pack file.  Fails if sha exists()."""
+        self._write(sha, type, content)
+
     def maybe_write(self, type, content):
         """Write an object to the pack file if not present and return its id."""
         sha = calc_hash(type, content)
         if not self.exists(sha):
-            self._write(sha, type, content)
+            self.write(sha, type, content)
             self._require_objcache()
             self.objcache.add(sha)
         return sha
@@ -652,55 +697,71 @@ class PackWriter:
         content = tree_encode(shalist)
         return self.maybe_write('tree', content)
 
-    def _new_commit(self, tree, parent, author, adate, committer, cdate, msg):
+    def new_commit(self, tree, parent,
+                   author, adate_sec, adate_tz,
+                   committer, cdate_sec, cdate_tz,
+                   msg):
+        """Create a commit object in the pack.  The date_sec values must be
+        epoch-seconds, and if a tz is None, the local timezone is assumed."""
+        if adate_tz:
+            adate_str = _git_date_str(adate_sec, adate_tz)
+        else:
+            adate_str = _local_git_date_str(adate_sec)
+        if cdate_tz:
+            cdate_str = _git_date_str(cdate_sec, cdate_tz)
+        else:
+            cdate_str = _local_git_date_str(cdate_sec)
         l = []
         if tree: l.append('tree %s' % tree.encode('hex'))
         if parent: l.append('parent %s' % parent.encode('hex'))
-        if author: l.append('author %s %s' % (author, _git_date(adate)))
-        if committer: l.append('committer %s %s' % (committer, _git_date(cdate)))
+        if author: l.append('author %s %s' % (author, adate_str))
+        if committer: l.append('committer %s %s' % (committer, cdate_str))
         l.append('')
         l.append(msg)
         return self.maybe_write('commit', '\n'.join(l))
 
-    def new_commit(self, parent, tree, date, msg):
-        """Create a commit object in the pack."""
-        userline = '%s <%s@%s>' % (userfullname(), username(), hostname())
-        commit = self._new_commit(tree, parent,
-                                  userline, date, userline, date,
-                                  msg)
-        return commit
-
     def abort(self):
         """Remove the pack file from disk."""
         f = self.file
         if f:
-            self.idx = None
+            pfd = self.parentfd
             self.file = None
-            f.close()
-            os.unlink(self.filename + '.pack')
+            self.parentfd = None
+            self.idx = None
+            try:
+                try:
+                    os.unlink(self.filename + '.pack')
+                finally:
+                    f.close()
+            finally:
+                if pfd is not None:
+                    os.close(pfd)
 
     def _end(self, run_midx=True):
         f = self.file
         if not f: return None
         self.file = None
-        self.objcache = None
-        idx = self.idx
-        self.idx = None
+        try:
+            self.objcache = None
+            idx = self.idx
+            self.idx = None
 
-        # update object count
-        f.seek(8)
-        cp = struct.pack('!i', self.count)
-        assert(len(cp) == 4)
-        f.write(cp)
-
-        # calculate the pack sha1sum
-        f.seek(0)
-        sum = Sha1()
-        for b in chunkyreader(f):
-            sum.update(b)
-        packbin = sum.digest()
-        f.write(packbin)
-        f.close()
+            # update object count
+            f.seek(8)
+            cp = struct.pack('!i', self.count)
+            assert(len(cp) == 4)
+            f.write(cp)
+
+            # calculate the pack sha1sum
+            f.seek(0)
+            sum = Sha1()
+            for b in chunkyreader(f):
+                sum.update(b)
+            packbin = sum.digest()
+            f.write(packbin)
+            fdatasync(f.fileno())
+        finally:
+            f.close()
 
         obj_list_sha = self._write_pack_idx_v2(self.filename + '.idx', idx, packbin)
 
@@ -709,9 +770,17 @@ class PackWriter:
             os.unlink(self.filename + '.map')
         os.rename(self.filename + '.pack', nameprefix + '.pack')
         os.rename(self.filename + '.idx', nameprefix + '.idx')
+        try:
+            os.fsync(self.parentfd)
+        finally:
+            os.close(self.parentfd)
 
         if run_midx:
             auto_midx(repo('objects/pack'))
+
+        if self.on_pack_finish:
+            self.on_pack_finish(nameprefix)
+
         return nameprefix
 
     def close(self, run_midx=True):
@@ -731,11 +800,15 @@ class PackWriter:
         idx_f = open(filename, 'w+b')
         try:
             idx_f.truncate(index_len)
+            fdatasync(idx_f.fileno())
             idx_map = mmap_readwrite(idx_f, close=False)
-            count = _helpers.write_idx(filename, idx_map, idx, self.count)
-            assert(count == self.count)
+            try:
+                count = _helpers.write_idx(filename, idx_map, idx, self.count)
+                assert(count == self.count)
+                idx_map.flush()
+            finally:
+                idx_map.close()
         finally:
-            if idx_map: idx_map.close()
             idx_f.close()
 
         idx_f = open(filename, 'a+b')
@@ -755,15 +828,12 @@ class PackWriter:
             for b in chunkyreader(idx_f):
                 idx_sum.update(b)
             idx_f.write(idx_sum.digest())
+            fdatasync(idx_f.fileno())
             return namebase
         finally:
             idx_f.close()
 
 
-def _git_date(date):
-    return '%d %s' % (date, time.strftime('%z', time.localtime(date)))
-
-
 def _gitenv(repo_dir = None):
     if not repo_dir:
         repo_dir = repo()
@@ -772,11 +842,21 @@ def _gitenv(repo_dir = None):
     return env
 
 
-def list_refs(refname = None, repo_dir = None):
-    """Generate a list of tuples in the form (refname,hash).
-    If a ref name is specified, list only this particular ref.
+def list_refs(refname=None, repo_dir=None,
+              limit_to_heads=False, limit_to_tags=False):
+    """Yield (refname, hash) tuples for all repository refs unless a ref
+    name is specified.  Given a ref name, only include tuples for that
+    particular ref.  The limits restrict the result items to
+    refs/heads or refs/tags.  If both limits are specified, items from
+    both sources will be included.
+
     """
-    argv = ['git', 'show-ref', '--']
+    argv = ['git', 'show-ref']
+    if limit_to_heads:
+        argv.append('--heads')
+    if limit_to_tags:
+        argv.append('--tags')
+    argv.append('--')
     if refname:
         argv += [refname]
     p = subprocess.Popen(argv,
@@ -794,7 +874,8 @@ def list_refs(refname = None, repo_dir = None):
 
 def read_ref(refname, repo_dir = None):
     """Get the commit id of the most recent commit made on a given ref."""
-    l = list(list_refs(refname, repo_dir))
+    refs = list_refs(refname, repo_dir=repo_dir, limit_to_heads=True)
+    l = tuple(islice(refs, 2))
     if l:
         assert(len(l) == 1)
         return l[0][1]
@@ -871,13 +952,23 @@ def rev_parse(committish, repo_dir=None):
     return None
 
 
-def update_ref(refname, newval, oldval):
-    """Change the commit pointed to by a branch."""
+def update_ref(refname, newval, oldval, repo_dir=None):
+    """Update a repository reference."""
     if not oldval:
         oldval = ''
-    assert(refname.startswith('refs/heads/'))
+    assert(refname.startswith('refs/heads/') \
+           or refname.startswith('refs/tags/'))
     p = subprocess.Popen(['git', 'update-ref', refname,
                           newval.encode('hex'), oldval.encode('hex')],
+                         preexec_fn = _gitenv(repo_dir))
+    _git_wait('git update-ref', p)
+
+
+def delete_ref(refname, oldvalue=None):
+    """Delete a repository reference (see git update-ref(1))."""
+    assert(refname.startswith('refs/'))
+    oldvalue = [] if not oldvalue else [oldvalue]
+    p = subprocess.Popen(['git', 'update-ref', '-d', refname] + oldvalue,
                          preexec_fn = _gitenv())
     _git_wait('git update-ref', p)
 
@@ -929,7 +1020,7 @@ def check_repo_or_die(path=None):
     guess_repo(path)
     try:
         os.stat(repo('objects/pack/.'))
-    except OSError, e:
+    except OSError as e:
         if e.errno == errno.ENOENT:
             log('error: %r is not a bup repository; run "bup init"\n'
                 % repo())
@@ -991,7 +1082,7 @@ class _AbortableIter:
     def next(self):
         try:
             return self.it.next()
-        except StopIteration, e:
+        except StopIteration as e:
             self.done = True
             raise
         except:
@@ -1009,6 +1100,12 @@ class _AbortableIter:
         self.abort()
 
 
+class MissingObject(KeyError):
+    def __init__(self, id):
+        self.id = id
+        KeyError.__init__(self, 'object %r is missing' % id.encode('hex'))
+
+
 _ver_warned = 0
 class CatPipe:
     """Link to 'git cat-file' that is used to retrieve blob data."""
@@ -1061,7 +1158,7 @@ class CatPipe:
         hdr = self.p.stdout.readline()
         if hdr.endswith(' missing\n'):
             self.inprogress = None
-            raise KeyError('blob %r is missing' % id)
+            raise MissingObject(id.decode('hex'))
         spl = hdr.split(' ')
         if len(spl) != 3 or len(spl[0]) != 40:
             raise GitError('expected blob, got %r' % spl)
@@ -1076,7 +1173,7 @@ class CatPipe:
             readline_result = self.p.stdout.readline()
             assert(readline_result == '\n')
             self.inprogress = None
-        except Exception, e:
+        except Exception as e:
             it.abort()
             raise
 
@@ -1144,11 +1241,80 @@ def cp(repo_dir=None):
 def tags(repo_dir = None):
     """Return a dictionary of all tags in the form {hash: [tag_names, ...]}."""
     tags = {}
-    for (n,c) in list_refs(repo_dir = repo_dir):
-        if n.startswith('refs/tags/'):
-            name = n[10:]
-            if not c in tags:
-                tags[c] = []
-
-            tags[c].append(name)  # more than one tag can point at 'c'
+    for n, c in list_refs(repo_dir = repo_dir, limit_to_tags=True):
+        assert(n.startswith('refs/tags/'))
+        name = n[10:]
+        if not c in tags:
+            tags[c] = []
+        tags[c].append(name)  # more than one tag can point at 'c'
     return tags
+
+
+WalkItem = namedtuple('WalkItem', ['id', 'type', 'mode',
+                                   'path', 'chunk_path', 'data'])
+# The path is the mangled path, and if an item represents a fragment
+# of a chunked file, the chunk_path will be the chunked subtree path
+# for the chunk, i.e. ['', '2d3115e', ...].  The top-level path for a
+# chunked file will have a chunk_path of [''].  So some chunk subtree
+# of the file '/foo/bar/baz' might look like this:
+#
+#   item.path = ['foo', 'bar', 'baz.bup']
+#   item.chunk_path = ['', '2d3115e', '016b097']
+#   item.type = 'tree'
+#   ...
+
+
+def walk_object(cat_pipe, id,
+                stop_at=None,
+                include_data=None):
+    """Yield everything reachable from id via cat_pipe as a WalkItem,
+    stopping whenever stop_at(id) returns true.  Throw MissingObject
+    if a hash encountered is missing from the repository.
+
+    """
+    # Maintain the pending stack on the heap to avoid stack overflow
+    pending = [(id, [], [], None)]
+    while len(pending):
+        id, parent_path, chunk_path, mode = pending.pop()
+        if stop_at and stop_at(id):
+            continue
+
+        item_it = cat_pipe.get(id)  # FIXME: use include_data
+        type = item_it.next()
+        if type not in ('blob', 'commit', 'tree'):
+            raise Exception('unexpected repository object type %r' % type)
+
+        # FIXME: set the mode based on the type when the mode is None
+        if type == 'blob' and not include_data:
+            # Dump data until we can ask cat_pipe not to fetch it
+            for ignored in item_it:
+                pass
+            data = None
+        else:
+            data = ''.join(item_it)
+
+        yield WalkItem(id=id, type=type,
+                       chunk_path=chunk_path, path=parent_path,
+                       mode=mode,
+                       data=(data if include_data else None))
+
+        if type == 'commit':
+            commit_items = parse_commit(data)
+            for pid in commit_items.parents:
+                pending.append((pid, parent_path, chunk_path, mode))
+            pending.append((commit_items.tree, parent_path, chunk_path,
+                            hashsplit.GIT_MODE_TREE))
+        elif type == 'tree':
+            for mode, name, ent_id in tree_decode(data):
+                demangled, bup_type = demangle_name(name, mode)
+                if chunk_path:
+                    sub_path = parent_path
+                    sub_chunk_path = chunk_path + [name]
+                else:
+                    sub_path = parent_path + [name]
+                    if bup_type == BUP_CHUNKED:
+                        sub_chunk_path = ['']
+                    else:
+                        sub_chunk_path = chunk_path
+                pending.append((ent_id.encode('hex'), sub_path, sub_chunk_path,
+                                mode))