]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
19035d1d1c3ed74f5b6b78fb0a404e17ca8ed158
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import hashlib, heapq, operator, time, grp
5 from bup import _version, _helpers
6 import bup._helpers as _helpers
7 import math
8
9 # This function should really be in helpers, not in bup.options.  But we
10 # want options.py to be standalone so people can include it in other projects.
11 from bup.options import _tty_width
12 tty_width = _tty_width
13
14
15 def atoi(s):
16     """Convert the string 's' to an integer. Return 0 if s is not a number."""
17     try:
18         return int(s or '0')
19     except ValueError:
20         return 0
21
22
23 def atof(s):
24     """Convert the string 's' to a float. Return 0 if s is not a number."""
25     try:
26         return float(s or '0')
27     except ValueError:
28         return 0
29
30
31 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
32
33
34 # If the platform doesn't have fdatasync (OS X), fall back to fsync.
35 try:
36     fdatasync = os.fdatasync
37 except AttributeError:
38     fdatasync = os.fsync
39
40
41 # Write (blockingly) to sockets that may or may not be in blocking mode.
42 # We need this because our stderr is sometimes eaten by subprocesses
43 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
44 # leading to race conditions.  Ick.  We'll do it the hard way.
45 def _hard_write(fd, buf):
46     while buf:
47         (r,w,x) = select.select([], [fd], [], None)
48         if not w:
49             raise IOError('select(fd) returned without being writable')
50         try:
51             sz = os.write(fd, buf)
52         except OSError, e:
53             if e.errno != errno.EAGAIN:
54                 raise
55         assert(sz >= 0)
56         buf = buf[sz:]
57
58
59 _last_prog = 0
60 def log(s):
61     """Print a log message to stderr."""
62     global _last_prog
63     sys.stdout.flush()
64     _hard_write(sys.stderr.fileno(), s)
65     _last_prog = 0
66
67
68 def debug1(s):
69     if buglvl >= 1:
70         log(s)
71
72
73 def debug2(s):
74     if buglvl >= 2:
75         log(s)
76
77
78 istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
79 istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
80 _last_progress = ''
81 def progress(s):
82     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
83     global _last_progress
84     if istty2:
85         log(s)
86         _last_progress = s
87
88
89 def qprogress(s):
90     """Calls progress() only if we haven't printed progress in a while.
91     
92     This avoids overloading the stderr buffer with excess junk.
93     """
94     global _last_prog
95     now = time.time()
96     if now - _last_prog > 0.1:
97         progress(s)
98         _last_prog = now
99
100
101 def reprogress():
102     """Calls progress() to redisplay the most recent progress message.
103
104     Useful after you've printed some other message that wipes out the
105     progress line.
106     """
107     if _last_progress and _last_progress.endswith('\r'):
108         progress(_last_progress)
109
110
111 def mkdirp(d, mode=None):
112     """Recursively create directories on path 'd'.
113
114     Unlike os.makedirs(), it doesn't raise an exception if the last element of
115     the path already exists.
116     """
117     try:
118         if mode:
119             os.makedirs(d, mode)
120         else:
121             os.makedirs(d)
122     except OSError, e:
123         if e.errno == errno.EEXIST:
124             pass
125         else:
126             raise
127
128
129 def next(it):
130     """Get the next item from an iterator, None if we reached the end."""
131     try:
132         return it.next()
133     except StopIteration:
134         return None
135
136
137 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
138     if key:
139         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
140     else:
141         samekey = operator.eq
142     count = 0
143     total = sum(len(it) for it in iters)
144     iters = (iter(it) for it in iters)
145     heap = ((next(it),it) for it in iters)
146     heap = [(e,it) for e,it in heap if e]
147
148     heapq.heapify(heap)
149     pe = None
150     while heap:
151         if not count % pfreq:
152             pfunc(count, total)
153         e, it = heap[0]
154         if not samekey(e, pe):
155             pe = e
156             yield e
157         count += 1
158         try:
159             e = it.next() # Don't use next() function, it's too expensive
160         except StopIteration:
161             heapq.heappop(heap) # remove current
162         else:
163             heapq.heapreplace(heap, (e, it)) # shift current to new location
164     pfinal(count, total)
165
166
167 def unlink(f):
168     """Delete a file at path 'f' if it currently exists.
169
170     Unlike os.unlink(), does not throw an exception if the file didn't already
171     exist.
172     """
173     try:
174         os.unlink(f)
175     except OSError, e:
176         if e.errno == errno.ENOENT:
177             pass  # it doesn't exist, that's what you asked for
178
179
180 def readpipe(argv):
181     """Run a subprocess and return its output."""
182     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
183     r = p.stdout.read()
184     p.wait()
185     return r
186
187
188 def realpath(p):
189     """Get the absolute path of a file.
190
191     Behaves like os.path.realpath, but doesn't follow a symlink for the last
192     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
193     will follow symlinks in p's directory)
194     """
195     try:
196         st = os.lstat(p)
197     except OSError:
198         st = None
199     if st and stat.S_ISLNK(st.st_mode):
200         (dir, name) = os.path.split(p)
201         dir = os.path.realpath(dir)
202         out = os.path.join(dir, name)
203     else:
204         out = os.path.realpath(p)
205     #log('realpathing:%r,%r\n' % (p, out))
206     return out
207
208
209 def detect_fakeroot():
210     "Return True if we appear to be running under fakeroot."
211     return os.getenv("FAKEROOTKEY") != None
212
213
214 def is_superuser():
215     if sys.platform.startswith('cygwin'):
216         import ctypes
217         return ctypes.cdll.shell32.IsUserAnAdmin()
218     else:
219         return os.geteuid() == 0
220
221
222 def _cache_key_value(get_value, key, cache):
223     """Return (value, was_cached).  If there is a value in the cache
224     for key, use that, otherwise, call get_value(key) which should
225     throw a KeyError if there is no value -- in which case the cached
226     and returned value will be None.
227     """
228     try: # Do we already have it (or know there wasn't one)?
229         value = cache[key]
230         return value, True
231     except KeyError:
232         pass
233     value = None
234     try:
235         cache[key] = value = get_value(key)
236     except KeyError:
237         cache[key] = None
238     return value, False
239
240
241 _uid_to_pwd_cache = {}
242 _name_to_pwd_cache = {}
243
244 def pwd_from_uid(uid):
245     """Return password database entry for uid (may be a cached value).
246     Return None if no entry is found.
247     """
248     global _uid_to_pwd_cache, _name_to_pwd_cache
249     entry, cached = _cache_key_value(pwd.getpwuid, uid, _uid_to_pwd_cache)
250     if entry and not cached:
251         _name_to_pwd_cache[entry.pw_name] = entry
252     return entry
253
254
255 def pwd_from_name(name):
256     """Return password database entry for name (may be a cached value).
257     Return None if no entry is found.
258     """
259     global _uid_to_pwd_cache, _name_to_pwd_cache
260     entry, cached = _cache_key_value(pwd.getpwnam, name, _name_to_pwd_cache)
261     if entry and not cached:
262         _uid_to_pwd_cache[entry.pw_uid] = entry
263     return entry
264
265
266 _gid_to_grp_cache = {}
267 _name_to_grp_cache = {}
268
269 def grp_from_gid(gid):
270     """Return password database entry for gid (may be a cached value).
271     Return None if no entry is found.
272     """
273     global _gid_to_grp_cache, _name_to_grp_cache
274     entry, cached = _cache_key_value(grp.getgrgid, gid, _gid_to_grp_cache)
275     if entry and not cached:
276         _name_to_grp_cache[entry.gr_name] = entry
277     return entry
278
279
280 def grp_from_name(name):
281     """Return password database entry for name (may be a cached value).
282     Return None if no entry is found.
283     """
284     global _gid_to_grp_cache, _name_to_grp_cache
285     entry, cached = _cache_key_value(grp.getgrnam, name, _name_to_grp_cache)
286     if entry and not cached:
287         _gid_to_grp_cache[entry.gr_gid] = entry
288     return entry
289
290
291 _username = None
292 def username():
293     """Get the user's login name."""
294     global _username
295     if not _username:
296         uid = os.getuid()
297         _username = pwd_from_uid(uid)[0] or 'user%d' % uid
298     return _username
299
300
301 _userfullname = None
302 def userfullname():
303     """Get the user's full name."""
304     global _userfullname
305     if not _userfullname:
306         uid = os.getuid()
307         entry = pwd_from_uid(uid)
308         if entry:
309             _userfullname = entry[4].split(',')[0] or entry[0]
310         if not _userfullname:
311             _userfullname = 'user%d' % uid
312     return _userfullname
313
314
315 _hostname = None
316 def hostname():
317     """Get the FQDN of this machine."""
318     global _hostname
319     if not _hostname:
320         _hostname = socket.getfqdn()
321     return _hostname
322
323
324 _resource_path = None
325 def resource_path(subdir=''):
326     global _resource_path
327     if not _resource_path:
328         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
329     return os.path.join(_resource_path, subdir)
330
331 def format_filesize(size):
332     unit = 1024.0
333     size = float(size)
334     if size < unit:
335         return "%d" % (size)
336     exponent = int(math.log(size) / math.log(unit))
337     size_prefix = "KMGTPE"[exponent - 1]
338     return "%.1f%s" % (size / math.pow(unit, exponent), size_prefix)
339
340
341 class NotOk(Exception):
342     pass
343
344
345 class BaseConn:
346     def __init__(self, outp):
347         self.outp = outp
348
349     def close(self):
350         while self._read(65536): pass
351
352     def read(self, size):
353         """Read 'size' bytes from input stream."""
354         self.outp.flush()
355         return self._read(size)
356
357     def readline(self):
358         """Read from input stream until a newline is found."""
359         self.outp.flush()
360         return self._readline()
361
362     def write(self, data):
363         """Write 'data' to output stream."""
364         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
365         self.outp.write(data)
366
367     def has_input(self):
368         """Return true if input stream is readable."""
369         raise NotImplemented("Subclasses must implement has_input")
370
371     def ok(self):
372         """Indicate end of output from last sent command."""
373         self.write('\nok\n')
374
375     def error(self, s):
376         """Indicate server error to the client."""
377         s = re.sub(r'\s+', ' ', str(s))
378         self.write('\nerror %s\n' % s)
379
380     def _check_ok(self, onempty):
381         self.outp.flush()
382         rl = ''
383         for rl in linereader(self):
384             #log('%d got line: %r\n' % (os.getpid(), rl))
385             if not rl:  # empty line
386                 continue
387             elif rl == 'ok':
388                 return None
389             elif rl.startswith('error '):
390                 #log('client: error: %s\n' % rl[6:])
391                 return NotOk(rl[6:])
392             else:
393                 onempty(rl)
394         raise Exception('server exited unexpectedly; see errors above')
395
396     def drain_and_check_ok(self):
397         """Remove all data for the current command from input stream."""
398         def onempty(rl):
399             pass
400         return self._check_ok(onempty)
401
402     def check_ok(self):
403         """Verify that server action completed successfully."""
404         def onempty(rl):
405             raise Exception('expected "ok", got %r' % rl)
406         return self._check_ok(onempty)
407
408
409 class Conn(BaseConn):
410     def __init__(self, inp, outp):
411         BaseConn.__init__(self, outp)
412         self.inp = inp
413
414     def _read(self, size):
415         return self.inp.read(size)
416
417     def _readline(self):
418         return self.inp.readline()
419
420     def has_input(self):
421         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
422         if rl:
423             assert(rl[0] == self.inp.fileno())
424             return True
425         else:
426             return None
427
428
429 def checked_reader(fd, n):
430     while n > 0:
431         rl, _, _ = select.select([fd], [], [])
432         assert(rl[0] == fd)
433         buf = os.read(fd, n)
434         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
435         yield buf
436         n -= len(buf)
437
438
439 MAX_PACKET = 128 * 1024
440 def mux(p, outfd, outr, errr):
441     try:
442         fds = [outr, errr]
443         while p.poll() is None:
444             rl, _, _ = select.select(fds, [], [])
445             for fd in rl:
446                 if fd == outr:
447                     buf = os.read(outr, MAX_PACKET)
448                     if not buf: break
449                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
450                 elif fd == errr:
451                     buf = os.read(errr, 1024)
452                     if not buf: break
453                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
454     finally:
455         os.write(outfd, struct.pack('!IB', 0, 3))
456
457
458 class DemuxConn(BaseConn):
459     """A helper class for bup's client-server protocol."""
460     def __init__(self, infd, outp):
461         BaseConn.__init__(self, outp)
462         # Anything that comes through before the sync string was not
463         # multiplexed and can be assumed to be debug/log before mux init.
464         tail = ''
465         while tail != 'BUPMUX':
466             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
467             if not b:
468                 raise IOError('demux: unexpected EOF during initialization')
469             tail += b
470             sys.stderr.write(tail[:-6])  # pre-mux log messages
471             tail = tail[-6:]
472         self.infd = infd
473         self.reader = None
474         self.buf = None
475         self.closed = False
476
477     def write(self, data):
478         self._load_buf(0)
479         BaseConn.write(self, data)
480
481     def _next_packet(self, timeout):
482         if self.closed: return False
483         rl, wl, xl = select.select([self.infd], [], [], timeout)
484         if not rl: return False
485         assert(rl[0] == self.infd)
486         ns = ''.join(checked_reader(self.infd, 5))
487         n, fdw = struct.unpack('!IB', ns)
488         assert(n <= MAX_PACKET)
489         if fdw == 1:
490             self.reader = checked_reader(self.infd, n)
491         elif fdw == 2:
492             for buf in checked_reader(self.infd, n):
493                 sys.stderr.write(buf)
494         elif fdw == 3:
495             self.closed = True
496             debug2("DemuxConn: marked closed\n")
497         return True
498
499     def _load_buf(self, timeout):
500         if self.buf is not None:
501             return True
502         while not self.closed:
503             while not self.reader:
504                 if not self._next_packet(timeout):
505                     return False
506             try:
507                 self.buf = self.reader.next()
508                 return True
509             except StopIteration:
510                 self.reader = None
511         return False
512
513     def _read_parts(self, ix_fn):
514         while self._load_buf(None):
515             assert(self.buf is not None)
516             i = ix_fn(self.buf)
517             if i is None or i == len(self.buf):
518                 yv = self.buf
519                 self.buf = None
520             else:
521                 yv = self.buf[:i]
522                 self.buf = self.buf[i:]
523             yield yv
524             if i is not None:
525                 break
526
527     def _readline(self):
528         def find_eol(buf):
529             try:
530                 return buf.index('\n')+1
531             except ValueError:
532                 return None
533         return ''.join(self._read_parts(find_eol))
534
535     def _read(self, size):
536         csize = [size]
537         def until_size(buf): # Closes on csize
538             if len(buf) < csize[0]:
539                 csize[0] -= len(buf)
540                 return None
541             else:
542                 return csize[0]
543         return ''.join(self._read_parts(until_size))
544
545     def has_input(self):
546         return self._load_buf(0)
547
548
549 def linereader(f):
550     """Generate a list of input lines from 'f' without terminating newlines."""
551     while 1:
552         line = f.readline()
553         if not line:
554             break
555         yield line[:-1]
556
557
558 def chunkyreader(f, count = None):
559     """Generate a list of chunks of data read from 'f'.
560
561     If count is None, read until EOF is reached.
562
563     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
564     reached while reading, raise IOError.
565     """
566     if count != None:
567         while count > 0:
568             b = f.read(min(count, 65536))
569             if not b:
570                 raise IOError('EOF with %d bytes remaining' % count)
571             yield b
572             count -= len(b)
573     else:
574         while 1:
575             b = f.read(65536)
576             if not b: break
577             yield b
578
579
580 def slashappend(s):
581     """Append "/" to 's' if it doesn't aleady end in "/"."""
582     if s and not s.endswith('/'):
583         return s + '/'
584     else:
585         return s
586
587
588 def _mmap_do(f, sz, flags, prot, close):
589     if not sz:
590         st = os.fstat(f.fileno())
591         sz = st.st_size
592     if not sz:
593         # trying to open a zero-length map gives an error, but an empty
594         # string has all the same behaviour of a zero-length map, ie. it has
595         # no elements :)
596         return ''
597     map = mmap.mmap(f.fileno(), sz, flags, prot)
598     if close:
599         f.close()  # map will persist beyond file close
600     return map
601
602
603 def mmap_read(f, sz = 0, close=True):
604     """Create a read-only memory mapped region on file 'f'.
605     If sz is 0, the region will cover the entire file.
606     """
607     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
608
609
610 def mmap_readwrite(f, sz = 0, close=True):
611     """Create a read-write memory mapped region on file 'f'.
612     If sz is 0, the region will cover the entire file.
613     """
614     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
615                     close)
616
617
618 def mmap_readwrite_private(f, sz = 0, close=True):
619     """Create a read-write memory mapped region on file 'f'.
620     If sz is 0, the region will cover the entire file.
621     The map is private, which means the changes are never flushed back to the
622     file.
623     """
624     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
625                     close)
626
627
628 def parse_num(s):
629     """Parse data size information into a float number.
630
631     Here are some examples of conversions:
632         199.2k means 203981 bytes
633         1GB means 1073741824 bytes
634         2.1 tb means 2199023255552 bytes
635     """
636     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
637     if not g:
638         raise ValueError("can't parse %r as a number" % s)
639     (val, unit) = g.groups()
640     num = float(val)
641     unit = unit.lower()
642     if unit in ['t', 'tb']:
643         mult = 1024*1024*1024*1024
644     elif unit in ['g', 'gb']:
645         mult = 1024*1024*1024
646     elif unit in ['m', 'mb']:
647         mult = 1024*1024
648     elif unit in ['k', 'kb']:
649         mult = 1024
650     elif unit in ['', 'b']:
651         mult = 1
652     else:
653         raise ValueError("invalid unit %r in number %r" % (unit, s))
654     return int(num*mult)
655
656
657 def count(l):
658     """Count the number of elements in an iterator. (consumes the iterator)"""
659     return reduce(lambda x,y: x+1, l)
660
661
662 saved_errors = []
663 def add_error(e):
664     """Append an error message to the list of saved errors.
665
666     Once processing is able to stop and output the errors, the saved errors are
667     accessible in the module variable helpers.saved_errors.
668     """
669     saved_errors.append(e)
670     log('%-70s\n' % e)
671
672
673 def clear_errors():
674     global saved_errors
675     saved_errors = []
676
677
678 def handle_ctrl_c():
679     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
680
681     The new exception handler will make sure that bup will exit without an ugly
682     stacktrace when Ctrl-C is hit.
683     """
684     oldhook = sys.excepthook
685     def newhook(exctype, value, traceback):
686         if exctype == KeyboardInterrupt:
687             log('\nInterrupted.\n')
688         else:
689             return oldhook(exctype, value, traceback)
690     sys.excepthook = newhook
691
692
693 def columnate(l, prefix):
694     """Format elements of 'l' in columns with 'prefix' leading each line.
695
696     The number of columns is determined automatically based on the string
697     lengths.
698     """
699     if not l:
700         return ""
701     l = l[:]
702     clen = max(len(s) for s in l)
703     ncols = (tty_width() - len(prefix)) / (clen + 2)
704     if ncols <= 1:
705         ncols = 1
706         clen = 0
707     cols = []
708     while len(l) % ncols:
709         l.append('')
710     rows = len(l)/ncols
711     for s in range(0, len(l), rows):
712         cols.append(l[s:s+rows])
713     out = ''
714     for row in zip(*cols):
715         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
716     return out
717
718
719 def parse_date_or_fatal(str, fatal):
720     """Parses the given date or calls Option.fatal().
721     For now we expect a string that contains a float."""
722     try:
723         date = atof(str)
724     except ValueError, e:
725         raise fatal('invalid date format (should be a float): %r' % e)
726     else:
727         return date
728
729
730 def parse_excludes(options, fatal):
731     """Traverse the options and extract all excludes, or call Option.fatal()."""
732     excluded_paths = []
733
734     for flag in options:
735         (option, parameter) = flag
736         if option == '--exclude':
737             excluded_paths.append(realpath(parameter))
738         elif option == '--exclude-from':
739             try:
740                 f = open(realpath(parameter))
741             except IOError, e:
742                 raise fatal("couldn't read %s" % parameter)
743             for exclude_path in f.readlines():
744                 excluded_paths.append(realpath(exclude_path.strip()))
745     return excluded_paths
746
747
748 def parse_rx_excludes(options, fatal):
749     """Traverse the options and extract all rx excludes, or call
750     Option.fatal()."""
751     rxs = [v for f, v in options if f == '--exclude-rx']
752     for i in range(len(rxs)):
753         try:
754             rxs[i] = re.compile(rxs[i])
755         except re.error, ex:
756             o.fatal('invalid --exclude-rx pattern (%s):' % (ex, rxs[i]))
757     return rxs
758
759
760 def should_rx_exclude_path(path, exclude_rxs):
761     """Return True if path matches a regular expression in exclude_rxs."""
762     for rx in exclude_rxs:
763         if rx.search(path):
764             debug1('Skipping %r: excluded by rx pattern %r.\n'
765                    % (path, rx.pattern))
766             return True
767     return False
768
769
770 # FIXME: Carefully consider the use of functions (os.path.*, etc.)
771 # that resolve against the current filesystem in the strip/graft
772 # functions for example, but elsewhere as well.  I suspect bup's not
773 # always being careful about that.  For some cases, the contents of
774 # the current filesystem should be irrelevant, and consulting it might
775 # produce the wrong result, perhaps via unintended symlink resolution,
776 # for example.
777
778 def path_components(path):
779     """Break path into a list of pairs of the form (name,
780     full_path_to_name).  Path must start with '/'.
781     Example:
782       '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
783     if not path.startswith('/'):
784         raise Exception, 'path must start with "/": %s' % path
785     # Since we assume path startswith('/'), we can skip the first element.
786     result = [('', '/')]
787     norm_path = os.path.abspath(path)
788     if norm_path == '/':
789         return result
790     full_path = ''
791     for p in norm_path.split('/')[1:]:
792         full_path += '/' + p
793         result.append((p, full_path))
794     return result
795
796
797 def stripped_path_components(path, strip_prefixes):
798     """Strip any prefix in strip_prefixes from path and return a list
799     of path components where each component is (name,
800     none_or_full_fs_path_to_name).  Assume path startswith('/').
801     See thelpers.py for examples."""
802     normalized_path = os.path.abspath(path)
803     sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
804     for bp in sorted_strip_prefixes:
805         normalized_bp = os.path.abspath(bp)
806         if normalized_path.startswith(normalized_bp):
807             prefix = normalized_path[:len(normalized_bp)]
808             result = []
809             for p in normalized_path[len(normalized_bp):].split('/'):
810                 if p: # not root
811                     prefix += '/'
812                 prefix += p
813                 result.append((p, prefix))
814             return result
815     # Nothing to strip.
816     return path_components(path)
817
818
819 def grafted_path_components(graft_points, path):
820     # Create a result that consists of some number of faked graft
821     # directories before the graft point, followed by all of the real
822     # directories from path that are after the graft point.  Arrange
823     # for the directory at the graft point in the result to correspond
824     # to the "orig" directory in --graft orig=new.  See t/thelpers.py
825     # for some examples.
826
827     # Note that given --graft orig=new, orig and new have *nothing* to
828     # do with each other, even if some of their component names
829     # match. i.e. --graft /foo/bar/baz=/foo/bar/bax is semantically
830     # equivalent to --graft /foo/bar/baz=/x/y/z, or even
831     # /foo/bar/baz=/x.
832
833     # FIXME: This can't be the best solution...
834     clean_path = os.path.abspath(path)
835     for graft_point in graft_points:
836         old_prefix, new_prefix = graft_point
837         # Expand prefixes iff not absolute paths.
838         old_prefix = os.path.normpath(old_prefix)
839         new_prefix = os.path.normpath(new_prefix)
840         if clean_path.startswith(old_prefix):
841             escaped_prefix = re.escape(old_prefix)
842             grafted_path = re.sub(r'^' + escaped_prefix, new_prefix, clean_path)
843             # Handle /foo=/ (at least) -- which produces //whatever.
844             grafted_path = '/' + grafted_path.lstrip('/')
845             clean_path_components = path_components(clean_path)
846             # Count the components that were stripped.
847             strip_count = 0 if old_prefix == '/' else old_prefix.count('/')
848             new_prefix_parts = new_prefix.split('/')
849             result_prefix = grafted_path.split('/')[:new_prefix.count('/')]
850             result = [(p, None) for p in result_prefix] \
851                 + clean_path_components[strip_count:]
852             # Now set the graft point name to match the end of new_prefix.
853             graft_point = len(result_prefix)
854             result[graft_point] = \
855                 (new_prefix_parts[-1], clean_path_components[strip_count][1])
856             if new_prefix == '/': # --graft ...=/ is a special case.
857                 return result[1:]
858             return result
859     return path_components(clean_path)
860
861 Sha1 = hashlib.sha1
862
863 def version_date():
864     """Format bup's version date string for output."""
865     return _version.DATE.split(' ')[0]
866
867
868 def version_commit():
869     """Get the commit hash of bup's current version."""
870     return _version.COMMIT
871
872
873 def version_tag():
874     """Format bup's version tag (the official version number).
875
876     When generated from a commit other than one pointed to with a tag, the
877     returned string will be "unknown-" followed by the first seven positions of
878     the commit hash.
879     """
880     names = _version.NAMES.strip()
881     assert(names[0] == '(')
882     assert(names[-1] == ')')
883     names = names[1:-1]
884     l = [n.strip() for n in names.split(',')]
885     for n in l:
886         if n.startswith('tag: bup-'):
887             return n[9:]
888     return 'unknown-%s' % _version.COMMIT[:7]