]> arthur.barton.de Git - bup.git/blob - index.py
cmd-save: don't recurse into already-valid subdirs.
[bup.git] / index.py
1 import os, stat, time, struct, tempfile
2 from helpers import *
3
4 EMPTY_SHA = '\0'*20
5 FAKE_SHA = '\x01'*20
6 INDEX_HDR = 'BUPI\0\0\0\2'
7 INDEX_SIG = '!IIIIIQII20sHII'
8 ENTLEN = struct.calcsize(INDEX_SIG)
9 FOOTER_SIG = '!Q'
10 FOOTLEN = struct.calcsize(FOOTER_SIG)
11
12 IX_EXISTS = 0x8000
13 IX_HASHVALID = 0x4000
14
15 class Error(Exception):
16     pass
17
18
19 class Level:
20     def __init__(self, ename, parent):
21         self.parent = parent
22         self.ename = ename
23         self.list = []
24         self.count = 0
25
26     def write(self, f):
27         (ofs,n) = (f.tell(), len(self.list))
28         if self.list:
29             count = len(self.list)
30             #log('popping %r with %d entries\n' 
31             #    % (''.join(self.ename), count))
32             for e in self.list:
33                 e.write(f)
34             if self.parent:
35                 self.parent.count += count + self.count
36         return (ofs,n)
37
38
39 def _golevel(level, f, ename, newentry):
40     # close nodes back up the tree
41     assert(level)
42     while ename[:len(level.ename)] != level.ename:
43         n = BlankNewEntry(level.ename[-1])
44         (n.children_ofs,n.children_n) = level.write(f)
45         level.parent.list.append(n)
46         level = level.parent
47
48     # create nodes down the tree
49     while len(level.ename) < len(ename):
50         level = Level(ename[:len(level.ename)+1], level)
51
52     # are we in precisely the right place?
53     assert(ename == level.ename)
54     n = newentry or BlankNewEntry(ename and level.ename[-1] or None)
55     (n.children_ofs,n.children_n) = level.write(f)
56     if level.parent:
57         level.parent.list.append(n)
58     level = level.parent
59
60     return level
61
62
63 class Entry:
64     def __init__(self, basename, name):
65         self.basename = str(basename)
66         self.name = str(name)
67         self.children_ofs = 0
68         self.children_n = 0
69
70     def __repr__(self):
71         return ("(%s,0x%04x,%d,%d,%d,%d,%d,0x%04x,0x%08x/%d)" 
72                 % (self.name, self.dev,
73                    self.ctime, self.mtime, self.uid, self.gid,
74                    self.size, self.flags, self.children_ofs, self.children_n))
75
76     def packed(self):
77         return struct.pack(INDEX_SIG,
78                            self.dev, self.ctime, self.mtime, 
79                            self.uid, self.gid, self.size, self.mode,
80                            self.gitmode, self.sha, self.flags,
81                            self.children_ofs, self.children_n)
82
83     def from_stat(self, st, tstart):
84         old = (self.dev, self.ctime, self.mtime,
85                self.uid, self.gid, self.size, self.flags & IX_EXISTS)
86         new = (st.st_dev, int(st.st_ctime), int(st.st_mtime),
87                st.st_uid, st.st_gid, st.st_size, IX_EXISTS)
88         self.dev = st.st_dev
89         self.ctime = int(st.st_ctime)
90         self.mtime = int(st.st_mtime)
91         self.uid = st.st_uid
92         self.gid = st.st_gid
93         self.size = st.st_size
94         self.mode = st.st_mode
95         self.flags |= IX_EXISTS
96         if int(st.st_ctime) >= tstart or old != new:
97             self.invalidate()
98
99     def is_valid(self):
100         f = IX_HASHVALID|IX_EXISTS
101         return (self.flags & f) == f
102
103     def invalidate(self):
104         self.flags &= ~IX_HASHVALID
105         self.set_dirty()
106
107     def validate(self, gitmode, sha):
108         assert(sha)
109         self.gitmode = gitmode
110         self.sha = sha
111         self.flags |= IX_HASHVALID|IX_EXISTS
112
113     def is_deleted(self):
114         return (self.flags & IX_EXISTS) == 0
115
116     def set_deleted(self):
117         if self.flags & IX_EXISTS:
118             self.flags &= ~(IX_EXISTS | IX_HASHVALID)
119             self.set_dirty()
120
121     def set_dirty(self):
122         pass # FIXME
123
124     def is_real(self):
125         return not self.is_fake()
126
127     def is_fake(self):
128         return not self.ctime
129
130     def __cmp__(a, b):
131         return (cmp(a.name, b.name)
132                 or -cmp(a.is_valid(), b.is_valid())
133                 or -cmp(a.is_fake(), b.is_fake()))
134
135     def write(self, f):
136         f.write(self.basename + '\0' + self.packed())
137
138
139 class NewEntry(Entry):
140     def __init__(self, basename, name, dev, ctime, mtime, uid, gid,
141                  size, mode, gitmode, sha, flags, children_ofs, children_n):
142         Entry.__init__(self, basename, name)
143         (self.dev, self.ctime, self.mtime, self.uid, self.gid,
144          self.size, self.mode, self.gitmode, self.sha,
145          self.flags, self.children_ofs, self.children_n
146          ) = (dev, int(ctime), int(mtime), uid, gid,
147               size, mode, gitmode, sha, flags, children_ofs, children_n)
148
149
150 class BlankNewEntry(NewEntry):
151     def __init__(self, basename):
152         NewEntry.__init__(self, basename, basename,
153                           0, 0, 0, 0, 0, 0, 0,
154                           0, EMPTY_SHA, 0, 0, 0)
155
156
157 class ExistingEntry(Entry):
158     def __init__(self, parent, basename, name, m, ofs):
159         Entry.__init__(self, basename, name)
160         self.parent = parent
161         self._m = m
162         self._ofs = ofs
163         (self.dev, self.ctime, self.mtime, self.uid, self.gid,
164          self.size, self.mode, self.gitmode, self.sha,
165          self.flags, self.children_ofs, self.children_n
166          ) = struct.unpack(INDEX_SIG, str(buffer(m, ofs, ENTLEN)))
167
168     def repack(self):
169         self._m[self._ofs:self._ofs+ENTLEN] = self.packed()
170         if self.parent and not self.is_valid():
171             self.parent.invalidate()
172             self.parent.repack()
173
174     def iter(self, name=None, wantrecurse=None):
175         dname = name
176         if dname and not dname.endswith('/'):
177             dname += '/'
178         ofs = self.children_ofs
179         assert(ofs <= len(self._m))
180         assert(self.children_n < 1000000)
181         for i in xrange(self.children_n):
182             eon = self._m.find('\0', ofs)
183             assert(eon >= 0)
184             assert(eon >= ofs)
185             assert(eon > ofs)
186             basename = str(buffer(self._m, ofs, eon-ofs))
187             child = ExistingEntry(self, basename, self.name + basename,
188                                   self._m, eon+1)
189             if (not dname
190                  or child.name.startswith(dname)
191                  or child.name.endswith('/') and dname.startswith(child.name)):
192                 if not wantrecurse or wantrecurse(child):
193                     for e in child.iter(name=name, wantrecurse=wantrecurse):
194                         yield e
195             if not name or child.name == name or child.name.startswith(dname):
196                 yield child
197             ofs = eon + 1 + ENTLEN
198
199     def __iter__(self):
200         return self.iter()
201             
202
203 class Reader:
204     def __init__(self, filename):
205         self.filename = filename
206         self.m = ''
207         self.writable = False
208         self.count = 0
209         f = None
210         try:
211             f = open(filename, 'r+')
212         except IOError, e:
213             if e.errno == errno.ENOENT:
214                 pass
215             else:
216                 raise
217         if f:
218             b = f.read(len(INDEX_HDR))
219             if b != INDEX_HDR:
220                 log('warning: %s: header: expected %r, got %r'
221                                  % (filename, INDEX_HDR, b))
222             else:
223                 st = os.fstat(f.fileno())
224                 if st.st_size:
225                     self.m = mmap_readwrite(f)
226                     self.writable = True
227                     self.count = struct.unpack(FOOTER_SIG,
228                           str(buffer(self.m, st.st_size-FOOTLEN, FOOTLEN)))[0]
229
230     def __del__(self):
231         self.close()
232
233     def __len__(self):
234         return int(self.count)
235
236     def forward_iter(self):
237         ofs = len(INDEX_HDR)
238         while ofs+ENTLEN <= len(self.m)-FOOTLEN:
239             eon = self.m.find('\0', ofs)
240             assert(eon >= 0)
241             assert(eon >= ofs)
242             assert(eon > ofs)
243             basename = str(buffer(self.m, ofs, eon-ofs))
244             yield ExistingEntry(None, basename, basename, self.m, eon+1)
245             ofs = eon + 1 + ENTLEN
246
247     def iter(self, name=None, wantrecurse=None):
248         if len(self.m) > len(INDEX_HDR)+ENTLEN:
249             dname = name
250             if dname and not dname.endswith('/'):
251                 dname += '/'
252             root = ExistingEntry(None, '/', '/',
253                                  self.m, len(self.m)-FOOTLEN-ENTLEN)
254             for sub in root.iter(name=name, wantrecurse=wantrecurse):
255                 yield sub
256             if not dname or dname == root.name:
257                 yield root
258
259     def __iter__(self):
260         return self.iter()
261
262     def exists(self):
263         return self.m
264
265     def save(self):
266         if self.writable and self.m:
267             self.m.flush()
268
269     def close(self):
270         self.save()
271         if self.writable and self.m:
272             self.m = None
273             self.writable = False
274
275     def filter(self, prefixes, wantrecurse=None):
276         for (rp, path) in reduce_paths(prefixes):
277             for e in self.iter(rp, wantrecurse=wantrecurse):
278                 assert(e.name.startswith(rp))
279                 name = path + e.name[len(rp):]
280                 yield (name, e)
281
282
283 class Writer:
284     def __init__(self, filename):
285         self.rootlevel = self.level = Level([], None)
286         self.f = None
287         self.count = 0
288         self.lastfile = None
289         self.filename = None
290         self.filename = filename = realpath(filename)
291         (dir,name) = os.path.split(filename)
292         (ffd,self.tmpname) = tempfile.mkstemp('.tmp', filename, dir)
293         self.f = os.fdopen(ffd, 'wb', 65536)
294         self.f.write(INDEX_HDR)
295
296     def __del__(self):
297         self.abort()
298
299     def abort(self):
300         f = self.f
301         self.f = None
302         if f:
303             f.close()
304             os.unlink(self.tmpname)
305
306     def flush(self):
307         if self.level:
308             self.level = _golevel(self.level, self.f, [], None)
309             self.count = self.rootlevel.count
310             if self.count:
311                 self.count += 1
312             self.f.write(struct.pack(FOOTER_SIG, self.count))
313             self.f.flush()
314         assert(self.level == None)
315
316     def close(self):
317         self.flush()
318         f = self.f
319         self.f = None
320         if f:
321             f.close()
322             os.rename(self.tmpname, self.filename)
323
324     def _add(self, ename, entry):
325         if self.lastfile and self.lastfile <= ename:
326             raise Error('%r must come before %r' 
327                              % (''.join(e.name), ''.join(self.lastfile)))
328             self.lastfile = e.name
329         self.level = _golevel(self.level, self.f, ename, entry)
330
331     def add(self, name, st, hashgen = None):
332         endswith = name.endswith('/')
333         ename = pathsplit(name)
334         basename = ename[-1]
335         #log('add: %r %r\n' % (basename, name))
336         flags = IX_EXISTS
337         sha = None
338         if hashgen:
339             (gitmode, sha) = hashgen(name)
340             flags |= IX_HASHVALID
341         else:
342             (gitmode, sha) = (0, EMPTY_SHA)
343         if st:
344             isdir = stat.S_ISDIR(st.st_mode)
345             assert(isdir == endswith)
346             e = NewEntry(basename, name, st.st_dev, int(st.st_ctime),
347                          int(st.st_mtime), st.st_uid, st.st_gid,
348                          st.st_size, st.st_mode, gitmode, sha, flags,
349                          0, 0)
350         else:
351             assert(endswith)
352             e = BlankNewEntry(basename)
353             e.gitmode = gitmode
354             e.sha = sha
355             e.flags = flags
356         self._add(ename, e)
357
358     def add_ixentry(self, e):
359         e.children_ofs = e.children_n = 0
360         self._add(pathsplit(e.name), e)
361
362     def new_reader(self):
363         self.flush()
364         return Reader(self.tmpname)
365
366
367 def reduce_paths(paths):
368     xpaths = []
369     for p in paths:
370         rp = realpath(p)
371         try:
372             st = os.lstat(rp)
373             if stat.S_ISDIR(st.st_mode):
374                 rp = slashappend(rp)
375                 p = slashappend(p)
376         except OSError, e:
377             if e.errno != errno.ENOENT:
378                 raise
379         xpaths.append((rp, p))
380     xpaths.sort()
381
382     paths = []
383     prev = None
384     for (rp, p) in xpaths:
385         if prev and (prev == rp 
386                      or (prev.endswith('/') and rp.startswith(prev))):
387             continue # already superceded by previous path
388         paths.append((rp, p))
389         prev = rp
390     paths.sort(reverse=True)
391     return paths
392
393
394 class MergeIter:
395     def __init__(self, iters):
396         self.iters = iters
397
398     def __len__(self):
399         # FIXME: doesn't remove duplicated entries between iters.
400         # That only happens for parent directories, but will mean the
401         # actual iteration returns fewer entries than this function counts.
402         return sum(len(it) for it in self.iters)
403
404     def __iter__(self):
405         total = len(self)
406         l = [iter(it) for it in self.iters]
407         l = [(next(it),it) for it in l]
408         l = filter(lambda x: x[0], l)
409         count = 0
410         lastname = None
411         while l:
412             if not (count % 1024):
413                 progress('bup: merging indexes (%d/%d)\r' % (count, total))
414             l.sort()
415             (e,it) = l.pop()
416             if not e:
417                 continue
418             if e.name != lastname:
419                 yield e
420                 lastname = e.name
421             n = next(it)
422             if n:
423                 l.append((n,it))
424             count += 1
425         log('bup: merging indexes (%d/%d), done.\n' % (count, total))