]> arthur.barton.de Git - bup.git/blob - index.py
test.sh: don't try non-quick fsck on damaged repositories.
[bup.git] / index.py
1 import os, stat, time, struct, tempfile
2 from helpers import *
3
4 EMPTY_SHA = '\0'*20
5 FAKE_SHA = '\x01'*20
6 INDEX_HDR = 'BUPI\0\0\0\2'
7 INDEX_SIG = '!IIIIIQII20sHII'
8 ENTLEN = struct.calcsize(INDEX_SIG)
9 FOOTER_SIG = '!Q'
10 FOOTLEN = struct.calcsize(FOOTER_SIG)
11
12 IX_EXISTS = 0x8000
13 IX_HASHVALID = 0x4000
14
15 class Error(Exception):
16     pass
17
18
19 class Level:
20     def __init__(self, ename, parent):
21         self.parent = parent
22         self.ename = ename
23         self.list = []
24         self.count = 0
25
26     def write(self, f):
27         (ofs,n) = (f.tell(), len(self.list))
28         if self.list:
29             count = len(self.list)
30             #log('popping %r with %d entries\n' 
31             #    % (''.join(self.ename), count))
32             for e in self.list:
33                 e.write(f)
34             if self.parent:
35                 self.parent.count += count + self.count
36         return (ofs,n)
37
38
39 def _golevel(level, f, ename, newentry):
40     # close nodes back up the tree
41     assert(level)
42     while ename[:len(level.ename)] != level.ename:
43         n = BlankNewEntry(level.ename[-1])
44         (n.children_ofs,n.children_n) = level.write(f)
45         level.parent.list.append(n)
46         level = level.parent
47
48     # create nodes down the tree
49     while len(level.ename) < len(ename):
50         level = Level(ename[:len(level.ename)+1], level)
51
52     # are we in precisely the right place?
53     assert(ename == level.ename)
54     n = newentry or BlankNewEntry(ename and level.ename[-1] or None)
55     (n.children_ofs,n.children_n) = level.write(f)
56     if level.parent:
57         level.parent.list.append(n)
58     level = level.parent
59
60     return level
61
62
63 class Entry:
64     def __init__(self, basename, name):
65         self.basename = str(basename)
66         self.name = str(name)
67         self.children_ofs = 0
68         self.children_n = 0
69
70     def __repr__(self):
71         return ("(%s,0x%04x,%d,%d,%d,%d,%d,0x%04x,0x%08x/%d)" 
72                 % (self.name, self.dev,
73                    self.ctime, self.mtime, self.uid, self.gid,
74                    self.size, self.flags, self.children_ofs, self.children_n))
75
76     def packed(self):
77         return struct.pack(INDEX_SIG,
78                            self.dev, self.ctime, self.mtime, 
79                            self.uid, self.gid, self.size, self.mode,
80                            self.gitmode, self.sha, self.flags,
81                            self.children_ofs, self.children_n)
82
83     def from_stat(self, st, tstart):
84         old = (self.dev, self.ctime, self.mtime,
85                self.uid, self.gid, self.size, self.flags & IX_EXISTS)
86         new = (st.st_dev, int(st.st_ctime), int(st.st_mtime),
87                st.st_uid, st.st_gid, st.st_size, IX_EXISTS)
88         self.dev = st.st_dev
89         self.ctime = int(st.st_ctime)
90         self.mtime = int(st.st_mtime)
91         self.uid = st.st_uid
92         self.gid = st.st_gid
93         self.size = st.st_size
94         self.mode = st.st_mode
95         self.flags |= IX_EXISTS
96         if int(st.st_ctime) >= tstart or old != new:
97             self.flags &= ~IX_HASHVALID
98             self.set_dirty()
99
100     def validate(self, sha):
101         assert(sha)
102         self.sha = sha
103         self.flags |= IX_HASHVALID
104
105     def set_deleted(self):
106         self.flags &= ~(IX_EXISTS | IX_HASHVALID)
107         self.set_dirty()
108
109     def set_dirty(self):
110         pass # FIXME
111
112     def __cmp__(a, b):
113         return cmp(a.name, b.name)
114
115     def write(self, f):
116         f.write(self.basename + '\0' + self.packed())
117
118
119 class NewEntry(Entry):
120     def __init__(self, basename, name, dev, ctime, mtime, uid, gid,
121                  size, mode, gitmode, sha, flags, children_ofs, children_n):
122         Entry.__init__(self, basename, name)
123         (self.dev, self.ctime, self.mtime, self.uid, self.gid,
124          self.size, self.mode, self.gitmode, self.sha,
125          self.flags, self.children_ofs, self.children_n
126          ) = (dev, int(ctime), int(mtime), uid, gid,
127               size, mode, gitmode, sha, flags, children_ofs, children_n)
128
129
130 class BlankNewEntry(NewEntry):
131     def __init__(self, basename):
132         NewEntry.__init__(self, basename, basename,
133                           0, 0, 0, 0, 0, 0, 0,
134                           0, EMPTY_SHA, 0, 0, 0)
135
136
137 class ExistingEntry(Entry):
138     def __init__(self, basename, name, m, ofs):
139         Entry.__init__(self, basename, name)
140         self._m = m
141         self._ofs = ofs
142         (self.dev, self.ctime, self.mtime, self.uid, self.gid,
143          self.size, self.mode, self.gitmode, self.sha,
144          self.flags, self.children_ofs, self.children_n
145          ) = struct.unpack(INDEX_SIG, str(buffer(m, ofs, ENTLEN)))
146
147     def repack(self):
148         self._m[self._ofs:self._ofs+ENTLEN] = self.packed()
149
150     def iter(self, name=None):
151         dname = name
152         if dname and not dname.endswith('/'):
153             dname += '/'
154         ofs = self.children_ofs
155         assert(ofs <= len(self._m))
156         assert(self.children_n < 1000000)
157         for i in xrange(self.children_n):
158             eon = self._m.find('\0', ofs)
159             assert(eon >= 0)
160             assert(eon >= ofs)
161             assert(eon > ofs)
162             basename = str(buffer(self._m, ofs, eon-ofs))
163             child = ExistingEntry(basename, self.name + basename,
164                                   self._m, eon+1)
165             if (not dname
166                  or child.name.startswith(dname)
167                  or child.name.endswith('/') and dname.startswith(child.name)):
168                 for e in child.iter(name=name):
169                     yield e
170             if not name or child.name == name or child.name.startswith(dname):
171                 yield child
172             ofs = eon + 1 + ENTLEN
173
174     def __iter__(self):
175         return self.iter()
176             
177
178 class Reader:
179     def __init__(self, filename):
180         self.filename = filename
181         self.m = ''
182         self.writable = False
183         self.count = 0
184         f = None
185         try:
186             f = open(filename, 'r+')
187         except IOError, e:
188             if e.errno == errno.ENOENT:
189                 pass
190             else:
191                 raise
192         if f:
193             b = f.read(len(INDEX_HDR))
194             if b != INDEX_HDR:
195                 log('warning: %s: header: expected %r, got %r'
196                                  % (filename, INDEX_HDR, b))
197             else:
198                 st = os.fstat(f.fileno())
199                 if st.st_size:
200                     self.m = mmap_readwrite(f)
201                     self.writable = True
202                     self.count = struct.unpack(FOOTER_SIG,
203                           str(buffer(self.m, st.st_size-FOOTLEN, FOOTLEN)))[0]
204
205     def __del__(self):
206         self.close()
207
208     def __len__(self):
209         return int(self.count)
210
211     def forward_iter(self):
212         ofs = len(INDEX_HDR)
213         while ofs+ENTLEN <= len(self.m)-FOOTLEN:
214             eon = self.m.find('\0', ofs)
215             assert(eon >= 0)
216             assert(eon >= ofs)
217             assert(eon > ofs)
218             basename = str(buffer(self.m, ofs, eon-ofs))
219             yield ExistingEntry(basename, basename, self.m, eon+1)
220             ofs = eon + 1 + ENTLEN
221
222     def iter(self, name=None):
223         if len(self.m) > len(INDEX_HDR)+ENTLEN:
224             dname = name
225             if dname and not dname.endswith('/'):
226                 dname += '/'
227             root = ExistingEntry('/', '/', self.m, len(self.m)-FOOTLEN-ENTLEN)
228             for sub in root.iter(name=name):
229                 yield sub
230             if not dname or dname == root.name:
231                 yield root
232
233     def __iter__(self):
234         return self.iter()
235
236     def exists(self):
237         return self.m
238
239     def save(self):
240         if self.writable and self.m:
241             self.m.flush()
242
243     def close(self):
244         self.save()
245         if self.writable and self.m:
246             self.m = None
247             self.writable = False
248
249     def filter(self, prefixes):
250         for (rp, path) in reduce_paths(prefixes):
251             for e in self.iter(rp):
252                 assert(e.name.startswith(rp))
253                 name = path + e.name[len(rp):]
254                 yield (name, e)
255
256
257 class Writer:
258     def __init__(self, filename):
259         self.rootlevel = self.level = Level([], None)
260         self.f = None
261         self.count = 0
262         self.lastfile = None
263         self.filename = None
264         self.filename = filename = realpath(filename)
265         (dir,name) = os.path.split(filename)
266         (ffd,self.tmpname) = tempfile.mkstemp('.tmp', filename, dir)
267         self.f = os.fdopen(ffd, 'wb', 65536)
268         self.f.write(INDEX_HDR)
269
270     def __del__(self):
271         self.abort()
272
273     def abort(self):
274         f = self.f
275         self.f = None
276         if f:
277             f.close()
278             os.unlink(self.tmpname)
279
280     def flush(self):
281         if self.level:
282             self.level = _golevel(self.level, self.f, [], None)
283             self.count = self.rootlevel.count
284             if self.count:
285                 self.count += 1
286             self.f.write(struct.pack(FOOTER_SIG, self.count))
287             self.f.flush()
288         assert(self.level == None)
289
290     def close(self):
291         self.flush()
292         f = self.f
293         self.f = None
294         if f:
295             f.close()
296             os.rename(self.tmpname, self.filename)
297
298     def _add(self, ename, entry):
299         if self.lastfile and self.lastfile <= ename:
300             raise Error('%r must come before %r' 
301                              % (''.join(e.name), ''.join(self.lastfile)))
302             self.lastfile = e.name
303         self.level = _golevel(self.level, self.f, ename, entry)
304
305     def add(self, name, st, hashgen = None):
306         endswith = name.endswith('/')
307         ename = pathsplit(name)
308         basename = ename[-1]
309         #log('add: %r %r\n' % (basename, name))
310         flags = IX_EXISTS
311         sha = None
312         if hashgen:
313             (gitmode, sha) = hashgen(name)
314             flags |= IX_HASHVALID
315         else:
316             (gitmode, sha) = (0, EMPTY_SHA)
317         if st:
318             isdir = stat.S_ISDIR(st.st_mode)
319             assert(isdir == endswith)
320             e = NewEntry(basename, name, st.st_dev, int(st.st_ctime),
321                          int(st.st_mtime), st.st_uid, st.st_gid,
322                          st.st_size, st.st_mode, gitmode, sha, flags,
323                          0, 0)
324         else:
325             assert(endswith)
326             e = BlankNewEntry(basename)
327             e.gitmode = gitmode
328             e.sha = sha
329             e.flags = flags
330         self._add(ename, e)
331
332     def add_ixentry(self, e):
333         e.children_ofs = e.children_n = 0
334         self._add(pathsplit(e.name), e)
335
336     def new_reader(self):
337         self.flush()
338         return Reader(self.tmpname)
339
340
341 def reduce_paths(paths):
342     xpaths = []
343     for p in paths:
344         rp = realpath(p)
345         try:
346             st = os.lstat(rp)
347             if stat.S_ISDIR(st.st_mode):
348                 rp = slashappend(rp)
349                 p = slashappend(p)
350         except OSError, e:
351             if e.errno != errno.ENOENT:
352                 raise
353         xpaths.append((rp, p))
354     xpaths.sort()
355
356     paths = []
357     prev = None
358     for (rp, p) in xpaths:
359         if prev and (prev == rp 
360                      or (prev.endswith('/') and rp.startswith(prev))):
361             continue # already superceded by previous path
362         paths.append((rp, p))
363         prev = rp
364     paths.sort(reverse=True)
365     return paths
366