]> arthur.barton.de Git - bup.git/blobdiff - git.py
Start using wvtest.sh for shell-based tests in test-sh.
[bup.git] / git.py
diff --git a/git.py b/git.py
index 825ef295a94b88b88bd9ec46a94da0bc5fc4fb19..6775bb2876a7d8c10d98b65027db432a4438fe77 100644 (file)
--- a/git.py
+++ b/git.py
@@ -1,10 +1,23 @@
-import os, errno, zlib, time, sha, subprocess, struct, mmap
+import os, errno, zlib, time, sha, subprocess, struct, mmap, stat, re
 from helpers import *
 
 verbose = 0
+home_repodir = os.path.expanduser('~/.bup')
+repodir = None
 
-def repodir(sub = ''):
-    return os.path.join(os.environ.get('BUP_DIR', '.git'), sub)
+
+class GitError(Exception):
+    pass
+
+
+def repo(sub = ''):
+    global repodir
+    if not repodir:
+        raise GitError('You should call check_repo_or_die()')
+    gd = os.path.join(repodir, '.git')
+    if os.path.exists(gd):
+        repodir = gd
+    return os.path.join(repodir, sub)
 
 
 class PackIndex:
@@ -15,7 +28,8 @@ class PackIndex:
                              mmap.MAP_SHARED, mmap.PROT_READ)
         f.close()  # map will persist beyond file close
         assert(str(self.map[0:8]) == '\377tOc\0\0\0\2')
-        self.fanout = list(struct.unpack('!256I', buffer(self.map, 8, 256*4)))
+        self.fanout = list(struct.unpack('!256I',
+                                         str(buffer(self.map, 8, 256*4))))
         self.fanout.append(0)  # entry "-1"
         nsha = self.fanout[255]
         self.ofstable = buffer(self.map,
@@ -25,10 +39,11 @@ class PackIndex:
                                  8 + 256*4 + nsha*20 + nsha*4 + nsha*4)
 
     def _ofs_from_idx(self, idx):
-        ofs = struct.unpack('!I', buffer(self.ofstable, idx*4, 4))[0]
+        ofs = struct.unpack('!I', str(buffer(self.ofstable, idx*4, 4)))[0]
         if ofs & 0x80000000:
             idx64 = ofs & 0x7fffffff
-            ofs = struct.unpack('!I', buffer(self.ofs64table, idx64*8, 8))[0]
+            ofs = struct.unpack('!I',
+                                str(buffer(self.ofs64table, idx64*8, 8)))[0]
         return ofs
 
     def _idx_from_hash(self, hash):
@@ -61,11 +76,12 @@ class PackIndex:
 
 class MultiPackIndex:
     def __init__(self, dir):
-        self.packs = []
+        self.dir = dir
         self.also = {}
-        for f in os.listdir(dir):
+        self.packs = []
+        for f in os.listdir(self.dir):
             if f.endswith('.idx'):
-                self.packs.append(PackIndex(os.path.join(dir, f)))
+                self.packs.append(PackIndex(os.path.join(self.dir, f)))
 
     def exists(self, hash):
         if hash in self.also:
@@ -92,57 +108,85 @@ def calc_hash(type, content):
     return sum.digest()
 
 
+def _shalist_sort_key(ent):
+    (mode, name, id) = ent
+    if stat.S_ISDIR(int(mode, 8)):
+        return name + '/'
+    else:
+        return name
+
+
 _typemap = dict(blob=3, tree=2, commit=1, tag=8)
 class PackWriter:
-    def __init__(self):
+    def __init__(self, objcache_maker=None):
         self.count = 0
-        self.binlist = []
-        self.objcache = MultiPackIndex(repodir('objects/pack'))
+        self.outbytes = 0
         self.filename = None
         self.file = None
+        self.objcache_maker = objcache_maker
+        self.objcache = None
 
     def __del__(self):
         self.close()
 
-    def _open(self):
-        assert(not self.file)
-        self.objcache.zap_also()
-        self.filename = repodir('objects/bup%d' % os.getpid())
-        self.file = open(self.filename + '.pack', 'w+')
-        self.file.write('PACK\0\0\0\2\0\0\0\0')
+    def _make_objcache(self):
+        if not self.objcache:
+            if self.objcache_maker:
+                self.objcache = self.objcache_maker()
+            else:
+                self.objcache = MultiPackIndex(repo('objects/pack'))
 
-    def _write(self, bin, type, content):
+    def _open(self):
         if not self.file:
-            self._open()
+            self._make_objcache()
+            self.filename = repo('objects/bup%d' % os.getpid())
+            self.file = open(self.filename + '.pack', 'w+')
+            self.file.write('PACK\0\0\0\2\0\0\0\0')
+
+    def _raw_write(self, datalist):
+        self._open()
         f = self.file
+        for d in datalist:
+            f.write(d)
+            self.outbytes += len(d)
+        self.count += 1
 
+    def _write(self, bin, type, content):
         if verbose:
             log('>')
-            
+
+        out = []
+
         sz = len(content)
         szbits = (sz & 0x0f) | (_typemap[type]<<4)
         sz >>= 4
         while 1:
             if sz: szbits |= 0x80
-            f.write(chr(szbits))
+            out.append(chr(szbits))
             if not sz:
                 break
             szbits = sz & 0x7f
             sz >>= 7
-        
+
         z = zlib.compressobj(1)
-        f.write(z.compress(content))
-        f.write(z.flush())
+        out.append(z.compress(content))
+        out.append(z.flush())
 
-        self.count += 1
-        self.binlist.append(bin)
+        self._raw_write(out)
         return bin
 
+    def breakpoint(self):
+        id = self._end()
+        self.outbytes = self.count = 0
+        return id
+
     def write(self, type, content):
         return self._write(calc_hash(type, content), type, content)
 
     def maybe_write(self, type, content):
         bin = calc_hash(type, content)
+        if not self.objcache:
+            self._make_objcache()
         if not self.objcache.exists(bin):
             self._write(bin, type, content)
             self.objcache.add(bin)
@@ -152,7 +196,7 @@ class PackWriter:
         return self.maybe_write('blob', blob)
 
     def new_tree(self, shalist):
-        shalist = sorted(shalist, key = lambda x: x[1])
+        shalist = sorted(shalist, key = _shalist_sort_key)
         l = ['%s %s\0%s' % (mode,name,bin) 
              for (mode,name,bin) in shalist]
         return self.maybe_write('tree', ''.join(l))
@@ -160,23 +204,19 @@ class PackWriter:
     def _new_commit(self, tree, parent, author, adate, committer, cdate, msg):
         l = []
         if tree: l.append('tree %s' % tree.encode('hex'))
-        if parent: l.append('parent %s' % parent)
+        if parent: l.append('parent %s' % parent.encode('hex'))
         if author: l.append('author %s %s' % (author, _git_date(adate)))
         if committer: l.append('committer %s %s' % (committer, _git_date(cdate)))
         l.append('')
         l.append(msg)
         return self.maybe_write('commit', '\n'.join(l))
 
-    def new_commit(self, ref, tree, msg):
+    def new_commit(self, parent, tree, msg):
         now = time.time()
         userline = '%s <%s@%s>' % (userfullname(), username(), hostname())
-        oldref = ref and _read_ref(ref) or None
-        commit = self._new_commit(tree, oldref,
+        commit = self._new_commit(tree, parent,
                                   userline, now, userline, now,
                                   msg)
-        self.close()  # UGLY: needed so _update_ref can see the new objects
-        if ref:
-            _update_ref(ref, commit.encode('hex'), oldref)
         return commit
 
     def abort(self):
@@ -186,7 +226,7 @@ class PackWriter:
             f.close()
             os.unlink(self.filename + '.pack')
 
-    def close(self):
+    def _end(self):
         f = self.file
         if not f: return None
         self.file = None
@@ -207,53 +247,207 @@ class PackWriter:
         f.write(sum.digest())
         
         f.close()
+        self.objcache = None
 
         p = subprocess.Popen(['git', 'index-pack', '-v',
+                              '--index-version=2',
                               self.filename + '.pack'],
                              preexec_fn = _gitenv,
                              stdout = subprocess.PIPE)
         out = p.stdout.read().strip()
-        if p.wait() or not out:
-            raise Exception('git index-pack returned an error')
-        nameprefix = repodir('objects/pack/%s' % out)
+        _git_wait('git index-pack', p)
+        if not out:
+            raise GitError('git index-pack produced no output')
+        nameprefix = repo('objects/pack/%s' % out)
         os.rename(self.filename + '.pack', nameprefix + '.pack')
         os.rename(self.filename + '.idx', nameprefix + '.idx')
         return nameprefix
 
+    def close(self):
+        return self._end()
+
 
 def _git_date(date):
     return time.strftime('%s %z', time.localtime(date))
 
 
 def _gitenv():
-    os.environ['GIT_DIR'] = os.path.abspath(repodir())
+    os.environ['GIT_DIR'] = os.path.abspath(repo())
 
 
-def _read_ref(refname):
+def read_ref(refname):
     p = subprocess.Popen(['git', 'show-ref', '--', refname],
                          preexec_fn = _gitenv,
                          stdout = subprocess.PIPE)
     out = p.stdout.read().strip()
-    p.wait()
+    rv = p.wait()  # not fatal
+    if rv:
+        assert(not out)
     if out:
-        return out.split()[0]
+        return out.split()[0].decode('hex')
     else:
         return None
 
 
-def _update_ref(refname, newval, oldval):
+def update_ref(refname, newval, oldval):
     if not oldval:
         oldval = ''
-    p = subprocess.Popen(['git', 'update-ref', '--', refname, newval, oldval],
+    p = subprocess.Popen(['git', 'update-ref', '--', refname,
+                          newval.encode('hex'), oldval.encode('hex')],
                          preexec_fn = _gitenv)
-    p.wait()
-    return newval
+    _git_wait('git update-ref', p)
+
+
+def guess_repo(path=None):
+    global repodir
+    if path:
+        repodir = path
+    if not repodir:
+        repodir = os.environ.get('BUP_DIR')
+        if not repodir:
+            repodir = os.path.expanduser('~/.bup')
 
 
-def init_repo():
-    d = repodir()
+def init_repo(path=None):
+    guess_repo(path)
+    d = repo()
     if os.path.exists(d) and not os.path.isdir(os.path.join(d, '.')):
-        raise Exception('"%d" exists but is not a directory\n' % d)
-    p = subprocess.Popen(['git', 'init', '--bare'],
+        raise GitError('"%d" exists but is not a directory\n' % d)
+    p = subprocess.Popen(['git', '--bare', 'init'], stdout=sys.stderr,
                          preexec_fn = _gitenv)
-    return p.wait()
+    _git_wait('git init', p)
+    p = subprocess.Popen(['git', 'config', 'pack.indexVersion', '2'],
+                         stdout=sys.stderr, preexec_fn = _gitenv)
+    _git_wait('git config', p)
+
+
+def check_repo_or_die(path=None):
+    guess_repo(path)
+    if not os.path.isdir(repo('objects/pack/.')):
+        if repodir == home_repodir:
+            init_repo()
+        else:
+            log('error: %r is not a bup/git repository\n' % repo())
+            exit(15)
+
+
+def _treeparse(buf):
+    ofs = 0
+    while ofs < len(buf):
+        z = buf[ofs:].find('\0')
+        assert(z > 0)
+        spl = buf[ofs:ofs+z].split(' ', 1)
+        assert(len(spl) == 2)
+        sha = buf[ofs+z+1:ofs+z+1+20]
+        ofs += z+1+20
+        yield (spl[0], spl[1], sha)
+
+_ver = None
+def ver():
+    global _ver
+    if not _ver:
+        p = subprocess.Popen(['git', '--version'],
+                             stdout=subprocess.PIPE)
+        gvs = p.stdout.read()
+        _git_wait('git --version', p)
+        m = re.match(r'git version (\S+.\S+)', gvs)
+        if not m:
+            raise GitError('git --version weird output: %r' % gvs)
+        _ver = tuple(m.group(1).split('.'))
+    needed = ('1','5','4')
+    if _ver < needed:
+        raise GitError('git version %s or higher is required; you have %s'
+                       % ('.'.join(needed), '.'.join(_ver)))
+    return _ver
+
+
+def _git_wait(cmd, p):
+    rv = p.wait()
+    if rv != 0:
+        raise GitError('%s returned %d' % (cmd, rv))
+
+
+def _git_capture(argv):
+    p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn = _gitenv)
+    r = p.stdout.read()
+    _git_wait(repr(argv), p)
+    return r
+
+
+_ver_warned = 0
+class CatPipe:
+    def __init__(self):
+        global _ver_warned
+        wanted = ('1','5','6')
+        if ver() < wanted:
+            if not _ver_warned:
+                log('warning: git version < %s; bup will be slow.\n'
+                    % '.'.join(wanted))
+                _ver_warned = 1
+            self.get = self._slow_get
+        else:
+            self.p = subprocess.Popen(['git', 'cat-file', '--batch'],
+                                      stdin=subprocess.PIPE, 
+                                      stdout=subprocess.PIPE,
+                                      preexec_fn = _gitenv)
+            self.get = self._fast_get
+
+    def _fast_get(self, id):
+        assert(id.find('\n') < 0)
+        assert(id.find('\r') < 0)
+        assert(id[0] != '-')
+        self.p.stdin.write('%s\n' % id)
+        hdr = self.p.stdout.readline()
+        if hdr.endswith(' missing\n'):
+            raise GitError('blob %r is missing' % id)
+        spl = hdr.split(' ')
+        if len(spl) != 3 or len(spl[0]) != 40:
+            raise GitError('expected blob, got %r' % spl)
+        (hex, type, size) = spl
+        yield type
+        for blob in chunkyreader(self.p.stdout, int(spl[2])):
+            yield blob
+        assert(self.p.stdout.readline() == '\n')
+
+    def _slow_get(self, id):
+        assert(id.find('\n') < 0)
+        assert(id.find('\r') < 0)
+        assert(id[0] != '-')
+        type = _git_capture(['git', 'cat-file', '-t', id]).strip()
+        yield type
+
+        p = subprocess.Popen(['git', 'cat-file', type, id],
+                             stdout=subprocess.PIPE,
+                             preexec_fn = _gitenv)
+        for blob in chunkyreader(p.stdout):
+            yield blob
+        _git_wait('git cat-file', p)
+
+    def _join(self, it):
+        type = it.next()
+        if type == 'blob':
+            for blob in it:
+                yield blob
+        elif type == 'tree':
+            treefile = ''.join(it)
+            for (mode, name, sha) in _treeparse(treefile):
+                for blob in self.join(sha.encode('hex')):
+                    yield blob
+        elif type == 'commit':
+            treeline = ''.join(it).split('\n')[0]
+            assert(treeline.startswith('tree '))
+            for blob in self.join(treeline[5:]):
+                yield blob
+        else:
+            raise GitError('invalid object type %r: expected blob/tree/commit'
+                           % type)
+
+    def join(self, id):
+        for d in self._join(self.get(id)):
+            yield d
+        
+
+def cat(id):
+    c = CatPipe()
+    for d in c.join(id):
+        yield d