]> arthur.barton.de Git - bup.git/blob - lib/bup/index.py
Merge branch 'maint'
[bup.git] / lib / bup / index.py
1 import os, stat, struct, tempfile
2 from bup.helpers import *
3
4 EMPTY_SHA = '\0'*20
5 FAKE_SHA = '\x01'*20
6 INDEX_HDR = 'BUPI\0\0\0\2'
7 INDEX_SIG = '!IIIIIQII20sHII'
8 ENTLEN = struct.calcsize(INDEX_SIG)
9 FOOTER_SIG = '!Q'
10 FOOTLEN = struct.calcsize(FOOTER_SIG)
11
12 IX_EXISTS = 0x8000        # file exists on filesystem
13 IX_HASHVALID = 0x4000     # the stored sha1 matches the filesystem
14 IX_SHAMISSING = 0x2000    # the stored sha1 object doesn't seem to exist
15
16 class Error(Exception):
17     pass
18
19
20 class Level:
21     def __init__(self, ename, parent):
22         self.parent = parent
23         self.ename = ename
24         self.list = []
25         self.count = 0
26
27     def write(self, f):
28         (ofs,n) = (f.tell(), len(self.list))
29         if self.list:
30             count = len(self.list)
31             #log('popping %r with %d entries\n' 
32             #    % (''.join(self.ename), count))
33             for e in self.list:
34                 e.write(f)
35             if self.parent:
36                 self.parent.count += count + self.count
37         return (ofs,n)
38
39
40 def _golevel(level, f, ename, newentry):
41     # close nodes back up the tree
42     assert(level)
43     while ename[:len(level.ename)] != level.ename:
44         n = BlankNewEntry(level.ename[-1])
45         n.flags |= IX_EXISTS
46         (n.children_ofs,n.children_n) = level.write(f)
47         level.parent.list.append(n)
48         level = level.parent
49
50     # create nodes down the tree
51     while len(level.ename) < len(ename):
52         level = Level(ename[:len(level.ename)+1], level)
53
54     # are we in precisely the right place?
55     assert(ename == level.ename)
56     n = newentry or BlankNewEntry(ename and level.ename[-1] or None)
57     (n.children_ofs,n.children_n) = level.write(f)
58     if level.parent:
59         level.parent.list.append(n)
60     level = level.parent
61
62     return level
63
64
65 class Entry:
66     def __init__(self, basename, name):
67         self.basename = str(basename)
68         self.name = str(name)
69         self.children_ofs = 0
70         self.children_n = 0
71
72     def __repr__(self):
73         return ("(%s,0x%04x,%d,%d,%d,%d,%d,%s/%s,0x%04x,0x%08x/%d)" 
74                 % (self.name, self.dev,
75                    self.ctime, self.mtime, self.uid, self.gid,
76                    self.size, oct(self.mode), oct(self.gitmode),
77                    self.flags, self.children_ofs, self.children_n))
78
79     def packed(self):
80         try:
81             return struct.pack(INDEX_SIG,
82                            self.dev, self.ctime, self.mtime, 
83                            self.uid, self.gid, self.size, self.mode,
84                            self.gitmode, self.sha, self.flags,
85                            self.children_ofs, self.children_n)
86         except DeprecationWarning, e:
87             log('pack error: %s (%r)\n' % (e, self))
88             raise
89
90     def from_stat(self, st, tstart):
91         old = (self.dev, self.ctime, self.mtime,
92                self.uid, self.gid, self.size, self.flags & IX_EXISTS)
93         new = (st.st_dev, int(st.st_ctime), int(st.st_mtime),
94                st.st_uid, st.st_gid, st.st_size, IX_EXISTS)
95         self.dev = st.st_dev
96         self.ctime = int(st.st_ctime)
97         self.mtime = int(st.st_mtime)
98         self.uid = st.st_uid
99         self.gid = st.st_gid
100         self.size = st.st_size
101         self.mode = st.st_mode
102         self.flags |= IX_EXISTS
103         if int(st.st_ctime) >= tstart or old != new \
104               or self.sha == EMPTY_SHA or not self.gitmode:
105             self.invalidate()
106         self._fixup()
107         
108     def _fixup(self):
109         if self.uid < 0:
110             self.uid += 0x100000000
111         if self.gid < 0:
112             self.gid += 0x100000000
113         assert(self.uid >= 0)
114         assert(self.gid >= 0)
115
116     def is_valid(self):
117         f = IX_HASHVALID|IX_EXISTS
118         return (self.flags & f) == f
119
120     def invalidate(self):
121         self.flags &= ~IX_HASHVALID
122
123     def validate(self, gitmode, sha):
124         assert(sha)
125         assert(gitmode)
126         self.gitmode = gitmode
127         self.sha = sha
128         self.flags |= IX_HASHVALID|IX_EXISTS
129
130     def exists(self):
131         return not self.is_deleted()
132
133     def sha_missing(self):
134         return (self.flags & IX_SHAMISSING) or not (self.flags & IX_HASHVALID)
135
136     def is_deleted(self):
137         return (self.flags & IX_EXISTS) == 0
138
139     def set_deleted(self):
140         if self.flags & IX_EXISTS:
141             self.flags &= ~(IX_EXISTS | IX_HASHVALID)
142
143     def is_real(self):
144         return not self.is_fake()
145
146     def is_fake(self):
147         return not self.ctime
148
149     def __cmp__(a, b):
150         return (cmp(a.name, b.name)
151                 or -cmp(a.is_valid(), b.is_valid())
152                 or -cmp(a.is_fake(), b.is_fake()))
153
154     def write(self, f):
155         f.write(self.basename + '\0' + self.packed())
156
157
158 class NewEntry(Entry):
159     def __init__(self, basename, name, dev, ctime, mtime, uid, gid,
160                  size, mode, gitmode, sha, flags, children_ofs, children_n):
161         Entry.__init__(self, basename, name)
162         (self.dev, self.ctime, self.mtime, self.uid, self.gid,
163          self.size, self.mode, self.gitmode, self.sha,
164          self.flags, self.children_ofs, self.children_n
165          ) = (dev, int(ctime), int(mtime), uid, gid,
166               size, mode, gitmode, sha, flags, children_ofs, children_n)
167         self._fixup()
168
169
170 class BlankNewEntry(NewEntry):
171     def __init__(self, basename):
172         NewEntry.__init__(self, basename, basename,
173                           0, 0, 0, 0, 0, 0, 0,
174                           0, EMPTY_SHA, 0, 0, 0)
175
176
177 class ExistingEntry(Entry):
178     def __init__(self, parent, basename, name, m, ofs):
179         Entry.__init__(self, basename, name)
180         self.parent = parent
181         self._m = m
182         self._ofs = ofs
183         (self.dev, self.ctime, self.mtime, self.uid, self.gid,
184          self.size, self.mode, self.gitmode, self.sha,
185          self.flags, self.children_ofs, self.children_n
186          ) = struct.unpack(INDEX_SIG, str(buffer(m, ofs, ENTLEN)))
187
188     # effectively, we don't bother messing with IX_SHAMISSING if
189     # not IX_HASHVALID, since it's redundant, and repacking is more
190     # expensive than not repacking.
191     # This is implemented by having sha_missing() check IX_HASHVALID too.
192     def set_sha_missing(self, val):
193         val = val and 1 or 0
194         oldval = self.sha_missing() and 1 or 0
195         if val != oldval:
196             flag = val and IX_SHAMISSING or 0
197             newflags = (self.flags & (~IX_SHAMISSING)) | flag
198             self.flags = newflags
199             self.repack()
200
201     def unset_sha_missing(self, flag):
202         if self.flags & IX_SHAMISSING:
203             self.flags &= ~IX_SHAMISSING
204             self.repack()
205
206     def repack(self):
207         self._m[self._ofs:self._ofs+ENTLEN] = self.packed()
208         if self.parent and not self.is_valid():
209             self.parent.invalidate()
210             self.parent.repack()
211
212     def iter(self, name=None, wantrecurse=None):
213         dname = name
214         if dname and not dname.endswith('/'):
215             dname += '/'
216         ofs = self.children_ofs
217         assert(ofs <= len(self._m))
218         assert(self.children_n < 1000000)
219         for i in xrange(self.children_n):
220             eon = self._m.find('\0', ofs)
221             assert(eon >= 0)
222             assert(eon >= ofs)
223             assert(eon > ofs)
224             basename = str(buffer(self._m, ofs, eon-ofs))
225             child = ExistingEntry(self, basename, self.name + basename,
226                                   self._m, eon+1)
227             if (not dname
228                  or child.name.startswith(dname)
229                  or child.name.endswith('/') and dname.startswith(child.name)):
230                 if not wantrecurse or wantrecurse(child):
231                     for e in child.iter(name=name, wantrecurse=wantrecurse):
232                         yield e
233             if not name or child.name == name or child.name.startswith(dname):
234                 yield child
235             ofs = eon + 1 + ENTLEN
236
237     def __iter__(self):
238         return self.iter()
239             
240
241 class Reader:
242     def __init__(self, filename):
243         self.filename = filename
244         self.m = ''
245         self.writable = False
246         self.count = 0
247         f = None
248         try:
249             f = open(filename, 'r+')
250         except IOError, e:
251             if e.errno == errno.ENOENT:
252                 pass
253             else:
254                 raise
255         if f:
256             b = f.read(len(INDEX_HDR))
257             if b != INDEX_HDR:
258                 log('warning: %s: header: expected %r, got %r'
259                                  % (filename, INDEX_HDR, b))
260             else:
261                 st = os.fstat(f.fileno())
262                 if st.st_size:
263                     self.m = mmap_readwrite(f)
264                     self.writable = True
265                     self.count = struct.unpack(FOOTER_SIG,
266                           str(buffer(self.m, st.st_size-FOOTLEN, FOOTLEN)))[0]
267
268     def __del__(self):
269         self.close()
270
271     def __len__(self):
272         return int(self.count)
273
274     def forward_iter(self):
275         ofs = len(INDEX_HDR)
276         while ofs+ENTLEN <= len(self.m)-FOOTLEN:
277             eon = self.m.find('\0', ofs)
278             assert(eon >= 0)
279             assert(eon >= ofs)
280             assert(eon > ofs)
281             basename = str(buffer(self.m, ofs, eon-ofs))
282             yield ExistingEntry(None, basename, basename, self.m, eon+1)
283             ofs = eon + 1 + ENTLEN
284
285     def iter(self, name=None, wantrecurse=None):
286         if len(self.m) > len(INDEX_HDR)+ENTLEN:
287             dname = name
288             if dname and not dname.endswith('/'):
289                 dname += '/'
290             root = ExistingEntry(None, '/', '/',
291                                  self.m, len(self.m)-FOOTLEN-ENTLEN)
292             for sub in root.iter(name=name, wantrecurse=wantrecurse):
293                 yield sub
294             if not dname or dname == root.name:
295                 yield root
296
297     def __iter__(self):
298         return self.iter()
299
300     def exists(self):
301         return self.m
302
303     def save(self):
304         if self.writable and self.m:
305             self.m.flush()
306
307     def close(self):
308         self.save()
309         if self.writable and self.m:
310             self.m.close()
311             self.m = None
312             self.writable = False
313
314     def filter(self, prefixes, wantrecurse=None):
315         for (rp, path) in reduce_paths(prefixes):
316             for e in self.iter(rp, wantrecurse=wantrecurse):
317                 assert(e.name.startswith(rp))
318                 name = path + e.name[len(rp):]
319                 yield (name, e)
320
321
322 # FIXME: this function isn't very generic, because it splits the filename
323 # in an odd way and depends on a terminating '/' to indicate directories.
324 def pathsplit(p):
325     """Split a path into a list of elements of the file system hierarchy."""
326     l = p.split('/')
327     l = [i+'/' for i in l[:-1]] + l[-1:]
328     if l[-1] == '':
329         l.pop()  # extra blank caused by terminating '/'
330     return l
331
332
333 class Writer:
334     def __init__(self, filename):
335         self.rootlevel = self.level = Level([], None)
336         self.f = None
337         self.count = 0
338         self.lastfile = None
339         self.filename = None
340         self.filename = filename = realpath(filename)
341         (dir,name) = os.path.split(filename)
342         (ffd,self.tmpname) = tempfile.mkstemp('.tmp', filename, dir)
343         self.f = os.fdopen(ffd, 'wb', 65536)
344         self.f.write(INDEX_HDR)
345
346     def __del__(self):
347         self.abort()
348
349     def abort(self):
350         f = self.f
351         self.f = None
352         if f:
353             f.close()
354             os.unlink(self.tmpname)
355
356     def flush(self):
357         if self.level:
358             self.level = _golevel(self.level, self.f, [], None)
359             self.count = self.rootlevel.count
360             if self.count:
361                 self.count += 1
362             self.f.write(struct.pack(FOOTER_SIG, self.count))
363             self.f.flush()
364         assert(self.level == None)
365
366     def close(self):
367         self.flush()
368         f = self.f
369         self.f = None
370         if f:
371             f.close()
372             os.rename(self.tmpname, self.filename)
373
374     def _add(self, ename, entry):
375         if self.lastfile and self.lastfile <= ename:
376             raise Error('%r must come before %r' 
377                              % (''.join(e.name), ''.join(self.lastfile)))
378             self.lastfile = e.name
379         self.level = _golevel(self.level, self.f, ename, entry)
380
381     def add(self, name, st, hashgen = None):
382         endswith = name.endswith('/')
383         ename = pathsplit(name)
384         basename = ename[-1]
385         #log('add: %r %r\n' % (basename, name))
386         flags = IX_EXISTS
387         sha = None
388         if hashgen:
389             (gitmode, sha) = hashgen(name)
390             flags |= IX_HASHVALID
391         else:
392             (gitmode, sha) = (0, EMPTY_SHA)
393         if st:
394             isdir = stat.S_ISDIR(st.st_mode)
395             assert(isdir == endswith)
396             e = NewEntry(basename, name, st.st_dev, int(st.st_ctime),
397                          int(st.st_mtime), st.st_uid, st.st_gid,
398                          st.st_size, st.st_mode, gitmode, sha, flags,
399                          0, 0)
400         else:
401             assert(endswith)
402             e = BlankNewEntry(basename)
403             e.gitmode = gitmode
404             e.sha = sha
405             e.flags = flags
406         self._add(ename, e)
407
408     def add_ixentry(self, e):
409         e.children_ofs = e.children_n = 0
410         self._add(pathsplit(e.name), e)
411
412     def new_reader(self):
413         self.flush()
414         return Reader(self.tmpname)
415
416
417 def reduce_paths(paths):
418     xpaths = []
419     for p in paths:
420         rp = realpath(p)
421         try:
422             st = os.lstat(rp)
423             if stat.S_ISDIR(st.st_mode):
424                 rp = slashappend(rp)
425                 p = slashappend(p)
426             xpaths.append((rp, p))
427         except OSError, e:
428             add_error('reduce_paths: %s' % e)
429     xpaths.sort()
430
431     paths = []
432     prev = None
433     for (rp, p) in xpaths:
434         if prev and (prev == rp 
435                      or (prev.endswith('/') and rp.startswith(prev))):
436             continue # already superceded by previous path
437         paths.append((rp, p))
438         prev = rp
439     paths.sort(reverse=True)
440     return paths
441
442
443 class MergeIter:
444     def __init__(self, iters):
445         self.iters = iters
446
447     def __len__(self):
448         # FIXME: doesn't remove duplicated entries between iters.
449         # That only happens for parent directories, but will mean the
450         # actual iteration returns fewer entries than this function counts.
451         return sum(len(it) for it in self.iters)
452
453     def __iter__(self):
454         total = len(self)
455         l = [iter(it) for it in self.iters]
456         l = [(next(it),it) for it in l]
457         l = filter(lambda x: x[0], l)
458         count = 0
459         lastname = None
460         while l:
461             if not (count % 1024):
462                 progress('bup: merging indexes (%d/%d)\r' % (count, total))
463             l.sort()
464             (e,it) = l.pop()
465             if not e:
466                 continue
467             if e.name != lastname:
468                 yield e
469                 lastname = e.name
470             n = next(it)
471             if n:
472                 l.append((n,it))
473             count += 1
474         log('bup: merging indexes (%d/%d), done.\n' % (count, total))