]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Only define helpers.next() if Python's isn't new enough.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import hashlib, heapq, operator, time, grp
5 from bup import _version, _helpers
6 import bup._helpers as _helpers
7 import math
8
9 # This function should really be in helpers, not in bup.options.  But we
10 # want options.py to be standalone so people can include it in other projects.
11 from bup.options import _tty_width
12 tty_width = _tty_width
13
14
15 def atoi(s):
16     """Convert the string 's' to an integer. Return 0 if s is not a number."""
17     try:
18         return int(s or '0')
19     except ValueError:
20         return 0
21
22
23 def atof(s):
24     """Convert the string 's' to a float. Return 0 if s is not a number."""
25     try:
26         return float(s or '0')
27     except ValueError:
28         return 0
29
30
31 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
32
33
34 # If the platform doesn't have fdatasync (OS X), fall back to fsync.
35 try:
36     fdatasync = os.fdatasync
37 except AttributeError:
38     fdatasync = os.fsync
39
40
41 # Write (blockingly) to sockets that may or may not be in blocking mode.
42 # We need this because our stderr is sometimes eaten by subprocesses
43 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
44 # leading to race conditions.  Ick.  We'll do it the hard way.
45 def _hard_write(fd, buf):
46     while buf:
47         (r,w,x) = select.select([], [fd], [], None)
48         if not w:
49             raise IOError('select(fd) returned without being writable')
50         try:
51             sz = os.write(fd, buf)
52         except OSError, e:
53             if e.errno != errno.EAGAIN:
54                 raise
55         assert(sz >= 0)
56         buf = buf[sz:]
57
58
59 _last_prog = 0
60 def log(s):
61     """Print a log message to stderr."""
62     global _last_prog
63     sys.stdout.flush()
64     _hard_write(sys.stderr.fileno(), s)
65     _last_prog = 0
66
67
68 def debug1(s):
69     if buglvl >= 1:
70         log(s)
71
72
73 def debug2(s):
74     if buglvl >= 2:
75         log(s)
76
77
78 istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
79 istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
80 _last_progress = ''
81 def progress(s):
82     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
83     global _last_progress
84     if istty2:
85         log(s)
86         _last_progress = s
87
88
89 def qprogress(s):
90     """Calls progress() only if we haven't printed progress in a while.
91     
92     This avoids overloading the stderr buffer with excess junk.
93     """
94     global _last_prog
95     now = time.time()
96     if now - _last_prog > 0.1:
97         progress(s)
98         _last_prog = now
99
100
101 def reprogress():
102     """Calls progress() to redisplay the most recent progress message.
103
104     Useful after you've printed some other message that wipes out the
105     progress line.
106     """
107     if _last_progress and _last_progress.endswith('\r'):
108         progress(_last_progress)
109
110
111 def mkdirp(d, mode=None):
112     """Recursively create directories on path 'd'.
113
114     Unlike os.makedirs(), it doesn't raise an exception if the last element of
115     the path already exists.
116     """
117     try:
118         if mode:
119             os.makedirs(d, mode)
120         else:
121             os.makedirs(d)
122     except OSError, e:
123         if e.errno == errno.EEXIST:
124             pass
125         else:
126             raise
127
128
129 _unspecified_next_default = object()
130
131 def _fallback_next(it, default=_unspecified_next_default):
132     """Retrieve the next item from the iterator by calling its
133     next() method. If default is given, it is returned if the
134     iterator is exhausted, otherwise StopIteration is raised."""
135
136     if default is _unspecified_next_default:
137         return it.next()
138     else:
139         try:
140             return it.next()
141         except StopIteration:
142             return default
143
144 if sys.version_info < (2, 6):
145     next =  _fallback_next
146
147
148 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
149     if key:
150         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
151     else:
152         samekey = operator.eq
153     count = 0
154     total = sum(len(it) for it in iters)
155     iters = (iter(it) for it in iters)
156     heap = ((next(it, None),it) for it in iters)
157     heap = [(e,it) for e,it in heap if e]
158
159     heapq.heapify(heap)
160     pe = None
161     while heap:
162         if not count % pfreq:
163             pfunc(count, total)
164         e, it = heap[0]
165         if not samekey(e, pe):
166             pe = e
167             yield e
168         count += 1
169         try:
170             e = it.next() # Don't use next() function, it's too expensive
171         except StopIteration:
172             heapq.heappop(heap) # remove current
173         else:
174             heapq.heapreplace(heap, (e, it)) # shift current to new location
175     pfinal(count, total)
176
177
178 def unlink(f):
179     """Delete a file at path 'f' if it currently exists.
180
181     Unlike os.unlink(), does not throw an exception if the file didn't already
182     exist.
183     """
184     try:
185         os.unlink(f)
186     except OSError, e:
187         if e.errno == errno.ENOENT:
188             pass  # it doesn't exist, that's what you asked for
189
190
191 def readpipe(argv):
192     """Run a subprocess and return its output."""
193     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
194     out, err = p.communicate()
195     if p.returncode != 0:
196         raise Exception('subprocess %r failed with status %d'
197                         % (' '.join(argv), p.returncode))
198     return out
199
200
201 def realpath(p):
202     """Get the absolute path of a file.
203
204     Behaves like os.path.realpath, but doesn't follow a symlink for the last
205     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
206     will follow symlinks in p's directory)
207     """
208     try:
209         st = os.lstat(p)
210     except OSError:
211         st = None
212     if st and stat.S_ISLNK(st.st_mode):
213         (dir, name) = os.path.split(p)
214         dir = os.path.realpath(dir)
215         out = os.path.join(dir, name)
216     else:
217         out = os.path.realpath(p)
218     #log('realpathing:%r,%r\n' % (p, out))
219     return out
220
221
222 def detect_fakeroot():
223     "Return True if we appear to be running under fakeroot."
224     return os.getenv("FAKEROOTKEY") != None
225
226
227 def is_superuser():
228     if sys.platform.startswith('cygwin'):
229         import ctypes
230         return ctypes.cdll.shell32.IsUserAnAdmin()
231     else:
232         return os.geteuid() == 0
233
234
235 def _cache_key_value(get_value, key, cache):
236     """Return (value, was_cached).  If there is a value in the cache
237     for key, use that, otherwise, call get_value(key) which should
238     throw a KeyError if there is no value -- in which case the cached
239     and returned value will be None.
240     """
241     try: # Do we already have it (or know there wasn't one)?
242         value = cache[key]
243         return value, True
244     except KeyError:
245         pass
246     value = None
247     try:
248         cache[key] = value = get_value(key)
249     except KeyError:
250         cache[key] = None
251     return value, False
252
253
254 _uid_to_pwd_cache = {}
255 _name_to_pwd_cache = {}
256
257 def pwd_from_uid(uid):
258     """Return password database entry for uid (may be a cached value).
259     Return None if no entry is found.
260     """
261     global _uid_to_pwd_cache, _name_to_pwd_cache
262     entry, cached = _cache_key_value(pwd.getpwuid, uid, _uid_to_pwd_cache)
263     if entry and not cached:
264         _name_to_pwd_cache[entry.pw_name] = entry
265     return entry
266
267
268 def pwd_from_name(name):
269     """Return password database entry for name (may be a cached value).
270     Return None if no entry is found.
271     """
272     global _uid_to_pwd_cache, _name_to_pwd_cache
273     entry, cached = _cache_key_value(pwd.getpwnam, name, _name_to_pwd_cache)
274     if entry and not cached:
275         _uid_to_pwd_cache[entry.pw_uid] = entry
276     return entry
277
278
279 _gid_to_grp_cache = {}
280 _name_to_grp_cache = {}
281
282 def grp_from_gid(gid):
283     """Return password database entry for gid (may be a cached value).
284     Return None if no entry is found.
285     """
286     global _gid_to_grp_cache, _name_to_grp_cache
287     entry, cached = _cache_key_value(grp.getgrgid, gid, _gid_to_grp_cache)
288     if entry and not cached:
289         _name_to_grp_cache[entry.gr_name] = entry
290     return entry
291
292
293 def grp_from_name(name):
294     """Return password database entry for name (may be a cached value).
295     Return None if no entry is found.
296     """
297     global _gid_to_grp_cache, _name_to_grp_cache
298     entry, cached = _cache_key_value(grp.getgrnam, name, _name_to_grp_cache)
299     if entry and not cached:
300         _gid_to_grp_cache[entry.gr_gid] = entry
301     return entry
302
303
304 _username = None
305 def username():
306     """Get the user's login name."""
307     global _username
308     if not _username:
309         uid = os.getuid()
310         _username = pwd_from_uid(uid)[0] or 'user%d' % uid
311     return _username
312
313
314 _userfullname = None
315 def userfullname():
316     """Get the user's full name."""
317     global _userfullname
318     if not _userfullname:
319         uid = os.getuid()
320         entry = pwd_from_uid(uid)
321         if entry:
322             _userfullname = entry[4].split(',')[0] or entry[0]
323         if not _userfullname:
324             _userfullname = 'user%d' % uid
325     return _userfullname
326
327
328 _hostname = None
329 def hostname():
330     """Get the FQDN of this machine."""
331     global _hostname
332     if not _hostname:
333         _hostname = socket.getfqdn()
334     return _hostname
335
336
337 _resource_path = None
338 def resource_path(subdir=''):
339     global _resource_path
340     if not _resource_path:
341         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
342     return os.path.join(_resource_path, subdir)
343
344 def format_filesize(size):
345     unit = 1024.0
346     size = float(size)
347     if size < unit:
348         return "%d" % (size)
349     exponent = int(math.log(size) / math.log(unit))
350     size_prefix = "KMGTPE"[exponent - 1]
351     return "%.1f%s" % (size / math.pow(unit, exponent), size_prefix)
352
353
354 class NotOk(Exception):
355     pass
356
357
358 class BaseConn:
359     def __init__(self, outp):
360         self.outp = outp
361
362     def close(self):
363         while self._read(65536): pass
364
365     def read(self, size):
366         """Read 'size' bytes from input stream."""
367         self.outp.flush()
368         return self._read(size)
369
370     def readline(self):
371         """Read from input stream until a newline is found."""
372         self.outp.flush()
373         return self._readline()
374
375     def write(self, data):
376         """Write 'data' to output stream."""
377         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
378         self.outp.write(data)
379
380     def has_input(self):
381         """Return true if input stream is readable."""
382         raise NotImplemented("Subclasses must implement has_input")
383
384     def ok(self):
385         """Indicate end of output from last sent command."""
386         self.write('\nok\n')
387
388     def error(self, s):
389         """Indicate server error to the client."""
390         s = re.sub(r'\s+', ' ', str(s))
391         self.write('\nerror %s\n' % s)
392
393     def _check_ok(self, onempty):
394         self.outp.flush()
395         rl = ''
396         for rl in linereader(self):
397             #log('%d got line: %r\n' % (os.getpid(), rl))
398             if not rl:  # empty line
399                 continue
400             elif rl == 'ok':
401                 return None
402             elif rl.startswith('error '):
403                 #log('client: error: %s\n' % rl[6:])
404                 return NotOk(rl[6:])
405             else:
406                 onempty(rl)
407         raise Exception('server exited unexpectedly; see errors above')
408
409     def drain_and_check_ok(self):
410         """Remove all data for the current command from input stream."""
411         def onempty(rl):
412             pass
413         return self._check_ok(onempty)
414
415     def check_ok(self):
416         """Verify that server action completed successfully."""
417         def onempty(rl):
418             raise Exception('expected "ok", got %r' % rl)
419         return self._check_ok(onempty)
420
421
422 class Conn(BaseConn):
423     def __init__(self, inp, outp):
424         BaseConn.__init__(self, outp)
425         self.inp = inp
426
427     def _read(self, size):
428         return self.inp.read(size)
429
430     def _readline(self):
431         return self.inp.readline()
432
433     def has_input(self):
434         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
435         if rl:
436             assert(rl[0] == self.inp.fileno())
437             return True
438         else:
439             return None
440
441
442 def checked_reader(fd, n):
443     while n > 0:
444         rl, _, _ = select.select([fd], [], [])
445         assert(rl[0] == fd)
446         buf = os.read(fd, n)
447         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
448         yield buf
449         n -= len(buf)
450
451
452 MAX_PACKET = 128 * 1024
453 def mux(p, outfd, outr, errr):
454     try:
455         fds = [outr, errr]
456         while p.poll() is None:
457             rl, _, _ = select.select(fds, [], [])
458             for fd in rl:
459                 if fd == outr:
460                     buf = os.read(outr, MAX_PACKET)
461                     if not buf: break
462                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
463                 elif fd == errr:
464                     buf = os.read(errr, 1024)
465                     if not buf: break
466                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
467     finally:
468         os.write(outfd, struct.pack('!IB', 0, 3))
469
470
471 class DemuxConn(BaseConn):
472     """A helper class for bup's client-server protocol."""
473     def __init__(self, infd, outp):
474         BaseConn.__init__(self, outp)
475         # Anything that comes through before the sync string was not
476         # multiplexed and can be assumed to be debug/log before mux init.
477         tail = ''
478         while tail != 'BUPMUX':
479             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
480             if not b:
481                 raise IOError('demux: unexpected EOF during initialization')
482             tail += b
483             sys.stderr.write(tail[:-6])  # pre-mux log messages
484             tail = tail[-6:]
485         self.infd = infd
486         self.reader = None
487         self.buf = None
488         self.closed = False
489
490     def write(self, data):
491         self._load_buf(0)
492         BaseConn.write(self, data)
493
494     def _next_packet(self, timeout):
495         if self.closed: return False
496         rl, wl, xl = select.select([self.infd], [], [], timeout)
497         if not rl: return False
498         assert(rl[0] == self.infd)
499         ns = ''.join(checked_reader(self.infd, 5))
500         n, fdw = struct.unpack('!IB', ns)
501         assert(n <= MAX_PACKET)
502         if fdw == 1:
503             self.reader = checked_reader(self.infd, n)
504         elif fdw == 2:
505             for buf in checked_reader(self.infd, n):
506                 sys.stderr.write(buf)
507         elif fdw == 3:
508             self.closed = True
509             debug2("DemuxConn: marked closed\n")
510         return True
511
512     def _load_buf(self, timeout):
513         if self.buf is not None:
514             return True
515         while not self.closed:
516             while not self.reader:
517                 if not self._next_packet(timeout):
518                     return False
519             try:
520                 self.buf = self.reader.next()
521                 return True
522             except StopIteration:
523                 self.reader = None
524         return False
525
526     def _read_parts(self, ix_fn):
527         while self._load_buf(None):
528             assert(self.buf is not None)
529             i = ix_fn(self.buf)
530             if i is None or i == len(self.buf):
531                 yv = self.buf
532                 self.buf = None
533             else:
534                 yv = self.buf[:i]
535                 self.buf = self.buf[i:]
536             yield yv
537             if i is not None:
538                 break
539
540     def _readline(self):
541         def find_eol(buf):
542             try:
543                 return buf.index('\n')+1
544             except ValueError:
545                 return None
546         return ''.join(self._read_parts(find_eol))
547
548     def _read(self, size):
549         csize = [size]
550         def until_size(buf): # Closes on csize
551             if len(buf) < csize[0]:
552                 csize[0] -= len(buf)
553                 return None
554             else:
555                 return csize[0]
556         return ''.join(self._read_parts(until_size))
557
558     def has_input(self):
559         return self._load_buf(0)
560
561
562 def linereader(f):
563     """Generate a list of input lines from 'f' without terminating newlines."""
564     while 1:
565         line = f.readline()
566         if not line:
567             break
568         yield line[:-1]
569
570
571 def chunkyreader(f, count = None):
572     """Generate a list of chunks of data read from 'f'.
573
574     If count is None, read until EOF is reached.
575
576     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
577     reached while reading, raise IOError.
578     """
579     if count != None:
580         while count > 0:
581             b = f.read(min(count, 65536))
582             if not b:
583                 raise IOError('EOF with %d bytes remaining' % count)
584             yield b
585             count -= len(b)
586     else:
587         while 1:
588             b = f.read(65536)
589             if not b: break
590             yield b
591
592
593 def slashappend(s):
594     """Append "/" to 's' if it doesn't aleady end in "/"."""
595     if s and not s.endswith('/'):
596         return s + '/'
597     else:
598         return s
599
600
601 def _mmap_do(f, sz, flags, prot, close):
602     if not sz:
603         st = os.fstat(f.fileno())
604         sz = st.st_size
605     if not sz:
606         # trying to open a zero-length map gives an error, but an empty
607         # string has all the same behaviour of a zero-length map, ie. it has
608         # no elements :)
609         return ''
610     map = mmap.mmap(f.fileno(), sz, flags, prot)
611     if close:
612         f.close()  # map will persist beyond file close
613     return map
614
615
616 def mmap_read(f, sz = 0, close=True):
617     """Create a read-only memory mapped region on file 'f'.
618     If sz is 0, the region will cover the entire file.
619     """
620     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
621
622
623 def mmap_readwrite(f, sz = 0, close=True):
624     """Create a read-write memory mapped region on file 'f'.
625     If sz is 0, the region will cover the entire file.
626     """
627     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
628                     close)
629
630
631 def mmap_readwrite_private(f, sz = 0, close=True):
632     """Create a read-write memory mapped region on file 'f'.
633     If sz is 0, the region will cover the entire file.
634     The map is private, which means the changes are never flushed back to the
635     file.
636     """
637     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
638                     close)
639
640
641 def parse_timestamp(epoch_str):
642     """Return the number of nanoseconds since the epoch that are described
643 by epoch_str (100ms, 100ns, ...); when epoch_str cannot be parsed,
644 throw a ValueError that may contain additional information."""
645     ns_per = {'s' :  1000000000,
646               'ms' : 1000000,
647               'us' : 1000,
648               'ns' : 1}
649     match = re.match(r'^((?:[-+]?[0-9]+)?)(s|ms|us|ns)$', epoch_str)
650     if not match:
651         if re.match(r'^([-+]?[0-9]+)$', epoch_str):
652             raise ValueError('must include units, i.e. 100ns, 100ms, ...')
653         raise ValueError()
654     (n, units) = match.group(1, 2)
655     if not n:
656         n = 1
657     n = int(n)
658     return n * ns_per[units]
659
660
661 def parse_num(s):
662     """Parse data size information into a float number.
663
664     Here are some examples of conversions:
665         199.2k means 203981 bytes
666         1GB means 1073741824 bytes
667         2.1 tb means 2199023255552 bytes
668     """
669     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
670     if not g:
671         raise ValueError("can't parse %r as a number" % s)
672     (val, unit) = g.groups()
673     num = float(val)
674     unit = unit.lower()
675     if unit in ['t', 'tb']:
676         mult = 1024*1024*1024*1024
677     elif unit in ['g', 'gb']:
678         mult = 1024*1024*1024
679     elif unit in ['m', 'mb']:
680         mult = 1024*1024
681     elif unit in ['k', 'kb']:
682         mult = 1024
683     elif unit in ['', 'b']:
684         mult = 1
685     else:
686         raise ValueError("invalid unit %r in number %r" % (unit, s))
687     return int(num*mult)
688
689
690 def count(l):
691     """Count the number of elements in an iterator. (consumes the iterator)"""
692     return reduce(lambda x,y: x+1, l)
693
694
695 saved_errors = []
696 def add_error(e):
697     """Append an error message to the list of saved errors.
698
699     Once processing is able to stop and output the errors, the saved errors are
700     accessible in the module variable helpers.saved_errors.
701     """
702     saved_errors.append(e)
703     log('%-70s\n' % e)
704
705
706 def clear_errors():
707     global saved_errors
708     saved_errors = []
709
710
711 def handle_ctrl_c():
712     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
713
714     The new exception handler will make sure that bup will exit without an ugly
715     stacktrace when Ctrl-C is hit.
716     """
717     oldhook = sys.excepthook
718     def newhook(exctype, value, traceback):
719         if exctype == KeyboardInterrupt:
720             log('\nInterrupted.\n')
721         else:
722             return oldhook(exctype, value, traceback)
723     sys.excepthook = newhook
724
725
726 def columnate(l, prefix):
727     """Format elements of 'l' in columns with 'prefix' leading each line.
728
729     The number of columns is determined automatically based on the string
730     lengths.
731     """
732     if not l:
733         return ""
734     l = l[:]
735     clen = max(len(s) for s in l)
736     ncols = (tty_width() - len(prefix)) / (clen + 2)
737     if ncols <= 1:
738         ncols = 1
739         clen = 0
740     cols = []
741     while len(l) % ncols:
742         l.append('')
743     rows = len(l)/ncols
744     for s in range(0, len(l), rows):
745         cols.append(l[s:s+rows])
746     out = ''
747     for row in zip(*cols):
748         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
749     return out
750
751
752 def parse_date_or_fatal(str, fatal):
753     """Parses the given date or calls Option.fatal().
754     For now we expect a string that contains a float."""
755     try:
756         date = atof(str)
757     except ValueError, e:
758         raise fatal('invalid date format (should be a float): %r' % e)
759     else:
760         return date
761
762
763 def parse_excludes(options, fatal):
764     """Traverse the options and extract all excludes, or call Option.fatal()."""
765     excluded_paths = []
766
767     for flag in options:
768         (option, parameter) = flag
769         if option == '--exclude':
770             excluded_paths.append(realpath(parameter))
771         elif option == '--exclude-from':
772             try:
773                 f = open(realpath(parameter))
774             except IOError, e:
775                 raise fatal("couldn't read %s" % parameter)
776             for exclude_path in f.readlines():
777                 excluded_paths.append(realpath(exclude_path.strip()))
778     return sorted(frozenset(excluded_paths))
779
780
781 def parse_rx_excludes(options, fatal):
782     """Traverse the options and extract all rx excludes, or call
783     Option.fatal()."""
784     excluded_patterns = []
785
786     for flag in options:
787         (option, parameter) = flag
788         if option == '--exclude-rx':
789             try:
790                 excluded_patterns.append(re.compile(parameter))
791             except re.error, ex:
792                 fatal('invalid --exclude-rx pattern (%s): %s' % (parameter, ex))
793         elif option == '--exclude-rx-from':
794             try:
795                 f = open(realpath(parameter))
796             except IOError, e:
797                 raise fatal("couldn't read %s" % parameter)
798             for pattern in f.readlines():
799                 spattern = pattern.rstrip('\n')
800                 try:
801                     excluded_patterns.append(re.compile(spattern))
802                 except re.error, ex:
803                     fatal('invalid --exclude-rx pattern (%s): %s' % (spattern, ex))
804     return excluded_patterns
805
806
807 def should_rx_exclude_path(path, exclude_rxs):
808     """Return True if path matches a regular expression in exclude_rxs."""
809     for rx in exclude_rxs:
810         if rx.search(path):
811             debug1('Skipping %r: excluded by rx pattern %r.\n'
812                    % (path, rx.pattern))
813             return True
814     return False
815
816
817 # FIXME: Carefully consider the use of functions (os.path.*, etc.)
818 # that resolve against the current filesystem in the strip/graft
819 # functions for example, but elsewhere as well.  I suspect bup's not
820 # always being careful about that.  For some cases, the contents of
821 # the current filesystem should be irrelevant, and consulting it might
822 # produce the wrong result, perhaps via unintended symlink resolution,
823 # for example.
824
825 def path_components(path):
826     """Break path into a list of pairs of the form (name,
827     full_path_to_name).  Path must start with '/'.
828     Example:
829       '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
830     if not path.startswith('/'):
831         raise Exception, 'path must start with "/": %s' % path
832     # Since we assume path startswith('/'), we can skip the first element.
833     result = [('', '/')]
834     norm_path = os.path.abspath(path)
835     if norm_path == '/':
836         return result
837     full_path = ''
838     for p in norm_path.split('/')[1:]:
839         full_path += '/' + p
840         result.append((p, full_path))
841     return result
842
843
844 def stripped_path_components(path, strip_prefixes):
845     """Strip any prefix in strip_prefixes from path and return a list
846     of path components where each component is (name,
847     none_or_full_fs_path_to_name).  Assume path startswith('/').
848     See thelpers.py for examples."""
849     normalized_path = os.path.abspath(path)
850     sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
851     for bp in sorted_strip_prefixes:
852         normalized_bp = os.path.abspath(bp)
853         if normalized_path.startswith(normalized_bp):
854             prefix = normalized_path[:len(normalized_bp)]
855             result = []
856             for p in normalized_path[len(normalized_bp):].split('/'):
857                 if p: # not root
858                     prefix += '/'
859                 prefix += p
860                 result.append((p, prefix))
861             return result
862     # Nothing to strip.
863     return path_components(path)
864
865
866 def grafted_path_components(graft_points, path):
867     # Create a result that consists of some number of faked graft
868     # directories before the graft point, followed by all of the real
869     # directories from path that are after the graft point.  Arrange
870     # for the directory at the graft point in the result to correspond
871     # to the "orig" directory in --graft orig=new.  See t/thelpers.py
872     # for some examples.
873
874     # Note that given --graft orig=new, orig and new have *nothing* to
875     # do with each other, even if some of their component names
876     # match. i.e. --graft /foo/bar/baz=/foo/bar/bax is semantically
877     # equivalent to --graft /foo/bar/baz=/x/y/z, or even
878     # /foo/bar/baz=/x.
879
880     # FIXME: This can't be the best solution...
881     clean_path = os.path.abspath(path)
882     for graft_point in graft_points:
883         old_prefix, new_prefix = graft_point
884         # Expand prefixes iff not absolute paths.
885         old_prefix = os.path.normpath(old_prefix)
886         new_prefix = os.path.normpath(new_prefix)
887         if clean_path.startswith(old_prefix):
888             escaped_prefix = re.escape(old_prefix)
889             grafted_path = re.sub(r'^' + escaped_prefix, new_prefix, clean_path)
890             # Handle /foo=/ (at least) -- which produces //whatever.
891             grafted_path = '/' + grafted_path.lstrip('/')
892             clean_path_components = path_components(clean_path)
893             # Count the components that were stripped.
894             strip_count = 0 if old_prefix == '/' else old_prefix.count('/')
895             new_prefix_parts = new_prefix.split('/')
896             result_prefix = grafted_path.split('/')[:new_prefix.count('/')]
897             result = [(p, None) for p in result_prefix] \
898                 + clean_path_components[strip_count:]
899             # Now set the graft point name to match the end of new_prefix.
900             graft_point = len(result_prefix)
901             result[graft_point] = \
902                 (new_prefix_parts[-1], clean_path_components[strip_count][1])
903             if new_prefix == '/': # --graft ...=/ is a special case.
904                 return result[1:]
905             return result
906     return path_components(clean_path)
907
908 Sha1 = hashlib.sha1
909
910 def version_date():
911     """Format bup's version date string for output."""
912     return _version.DATE.split(' ')[0]
913
914
915 def version_commit():
916     """Get the commit hash of bup's current version."""
917     return _version.COMMIT
918
919
920 def version_tag():
921     """Format bup's version tag (the official version number).
922
923     When generated from a commit other than one pointed to with a tag, the
924     returned string will be "unknown-" followed by the first seven positions of
925     the commit hash.
926     """
927     names = _version.NAMES.strip()
928     assert(names[0] == '(')
929     assert(names[-1] == ')')
930     names = names[1:-1]
931     l = [n.strip() for n in names.split(',')]
932     for n in l:
933         if n.startswith('tag: bup-'):
934             return n[9:]
935     return 'unknown-%s' % _version.COMMIT[:7]