]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
e7fbc5b8992a18433d3f54cc5241c8adac5de18c
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 from collections import namedtuple
4 from ctypes import sizeof, c_void_p
5 from os import environ
6 from contextlib import contextmanager
7 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
8 import hashlib, heapq, math, operator, time, grp, tempfile
9
10 from bup import _helpers
11
12 sc_page_size = os.sysconf('SC_PAGE_SIZE')
13 assert(sc_page_size > 0)
14
15 sc_arg_max = os.sysconf('SC_ARG_MAX')
16 if sc_arg_max == -1:  # "no definite limit" - let's choose 2M
17     sc_arg_max = 2 * 1024 * 1024
18
19 # This function should really be in helpers, not in bup.options.  But we
20 # want options.py to be standalone so people can include it in other projects.
21 from bup.options import _tty_width
22 tty_width = _tty_width
23
24
25 def atoi(s):
26     """Convert the string 's' to an integer. Return 0 if s is not a number."""
27     try:
28         return int(s or '0')
29     except ValueError:
30         return 0
31
32
33 def atof(s):
34     """Convert the string 's' to a float. Return 0 if s is not a number."""
35     try:
36         return float(s or '0')
37     except ValueError:
38         return 0
39
40
41 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
42
43
44 # If the platform doesn't have fdatasync (OS X), fall back to fsync.
45 try:
46     fdatasync = os.fdatasync
47 except AttributeError:
48     fdatasync = os.fsync
49
50
51 # Write (blockingly) to sockets that may or may not be in blocking mode.
52 # We need this because our stderr is sometimes eaten by subprocesses
53 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
54 # leading to race conditions.  Ick.  We'll do it the hard way.
55 def _hard_write(fd, buf):
56     while buf:
57         (r,w,x) = select.select([], [fd], [], None)
58         if not w:
59             raise IOError('select(fd) returned without being writable')
60         try:
61             sz = os.write(fd, buf)
62         except OSError, e:
63             if e.errno != errno.EAGAIN:
64                 raise
65         assert(sz >= 0)
66         buf = buf[sz:]
67
68
69 _last_prog = 0
70 def log(s):
71     """Print a log message to stderr."""
72     global _last_prog
73     sys.stdout.flush()
74     _hard_write(sys.stderr.fileno(), s)
75     _last_prog = 0
76
77
78 def debug1(s):
79     if buglvl >= 1:
80         log(s)
81
82
83 def debug2(s):
84     if buglvl >= 2:
85         log(s)
86
87
88 istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
89 istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
90 _last_progress = ''
91 def progress(s):
92     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
93     global _last_progress
94     if istty2:
95         log(s)
96         _last_progress = s
97
98
99 def qprogress(s):
100     """Calls progress() only if we haven't printed progress in a while.
101     
102     This avoids overloading the stderr buffer with excess junk.
103     """
104     global _last_prog
105     now = time.time()
106     if now - _last_prog > 0.1:
107         progress(s)
108         _last_prog = now
109
110
111 def reprogress():
112     """Calls progress() to redisplay the most recent progress message.
113
114     Useful after you've printed some other message that wipes out the
115     progress line.
116     """
117     if _last_progress and _last_progress.endswith('\r'):
118         progress(_last_progress)
119
120
121 def mkdirp(d, mode=None):
122     """Recursively create directories on path 'd'.
123
124     Unlike os.makedirs(), it doesn't raise an exception if the last element of
125     the path already exists.
126     """
127     try:
128         if mode:
129             os.makedirs(d, mode)
130         else:
131             os.makedirs(d)
132     except OSError, e:
133         if e.errno == errno.EEXIST:
134             pass
135         else:
136             raise
137
138
139 _unspecified_next_default = object()
140
141 def _fallback_next(it, default=_unspecified_next_default):
142     """Retrieve the next item from the iterator by calling its
143     next() method. If default is given, it is returned if the
144     iterator is exhausted, otherwise StopIteration is raised."""
145
146     if default is _unspecified_next_default:
147         return it.next()
148     else:
149         try:
150             return it.next()
151         except StopIteration:
152             return default
153
154 if sys.version_info < (2, 6):
155     next =  _fallback_next
156
157
158 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
159     if key:
160         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
161     else:
162         samekey = operator.eq
163     count = 0
164     total = sum(len(it) for it in iters)
165     iters = (iter(it) for it in iters)
166     heap = ((next(it, None),it) for it in iters)
167     heap = [(e,it) for e,it in heap if e]
168
169     heapq.heapify(heap)
170     pe = None
171     while heap:
172         if not count % pfreq:
173             pfunc(count, total)
174         e, it = heap[0]
175         if not samekey(e, pe):
176             pe = e
177             yield e
178         count += 1
179         try:
180             e = it.next() # Don't use next() function, it's too expensive
181         except StopIteration:
182             heapq.heappop(heap) # remove current
183         else:
184             heapq.heapreplace(heap, (e, it)) # shift current to new location
185     pfinal(count, total)
186
187
188 def unlink(f):
189     """Delete a file at path 'f' if it currently exists.
190
191     Unlike os.unlink(), does not throw an exception if the file didn't already
192     exist.
193     """
194     try:
195         os.unlink(f)
196     except OSError, e:
197         if e.errno != errno.ENOENT:
198             raise
199
200
201 def readpipe(argv, preexec_fn=None):
202     """Run a subprocess and return its output."""
203     p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn=preexec_fn)
204     out, err = p.communicate()
205     if p.returncode != 0:
206         raise Exception('subprocess %r failed with status %d'
207                         % (' '.join(argv), p.returncode))
208     return out
209
210
211 def _argmax_base(command):
212     base_size = 2048
213     for c in command:
214         base_size += len(command) + 1
215     for k, v in environ.iteritems():
216         base_size += len(k) + len(v) + 2 + sizeof(c_void_p)
217     return base_size
218
219
220 def _argmax_args_size(args):
221     return sum(len(x) + 1 + sizeof(c_void_p) for x in args)
222
223
224 def batchpipe(command, args, preexec_fn=None, arg_max=sc_arg_max):
225     """If args is not empty, yield the output produced by calling the
226 command list with args as a sequence of strings (It may be necessary
227 to return multiple strings in order to respect ARG_MAX)."""
228     # The optional arg_max arg is a workaround for an issue with the
229     # current wvtest behavior.
230     base_size = _argmax_base(command)
231     while args:
232         room = arg_max - base_size
233         i = 0
234         while i < len(args):
235             next_size = _argmax_args_size(args[i:i+1])
236             if room - next_size < 0:
237                 break
238             room -= next_size
239             i += 1
240         sub_args = args[:i]
241         args = args[i:]
242         assert(len(sub_args))
243         yield readpipe(command + sub_args, preexec_fn=preexec_fn)
244
245
246 def realpath(p):
247     """Get the absolute path of a file.
248
249     Behaves like os.path.realpath, but doesn't follow a symlink for the last
250     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
251     will follow symlinks in p's directory)
252     """
253     try:
254         st = os.lstat(p)
255     except OSError:
256         st = None
257     if st and stat.S_ISLNK(st.st_mode):
258         (dir, name) = os.path.split(p)
259         dir = os.path.realpath(dir)
260         out = os.path.join(dir, name)
261     else:
262         out = os.path.realpath(p)
263     #log('realpathing:%r,%r\n' % (p, out))
264     return out
265
266
267 def detect_fakeroot():
268     "Return True if we appear to be running under fakeroot."
269     return os.getenv("FAKEROOTKEY") != None
270
271
272 def is_superuser():
273     if sys.platform.startswith('cygwin'):
274         import ctypes
275         return ctypes.cdll.shell32.IsUserAnAdmin()
276     else:
277         return os.geteuid() == 0
278
279
280 def _cache_key_value(get_value, key, cache):
281     """Return (value, was_cached).  If there is a value in the cache
282     for key, use that, otherwise, call get_value(key) which should
283     throw a KeyError if there is no value -- in which case the cached
284     and returned value will be None.
285     """
286     try: # Do we already have it (or know there wasn't one)?
287         value = cache[key]
288         return value, True
289     except KeyError:
290         pass
291     value = None
292     try:
293         cache[key] = value = get_value(key)
294     except KeyError:
295         cache[key] = None
296     return value, False
297
298
299 _uid_to_pwd_cache = {}
300 _name_to_pwd_cache = {}
301
302 def pwd_from_uid(uid):
303     """Return password database entry for uid (may be a cached value).
304     Return None if no entry is found.
305     """
306     global _uid_to_pwd_cache, _name_to_pwd_cache
307     entry, cached = _cache_key_value(pwd.getpwuid, uid, _uid_to_pwd_cache)
308     if entry and not cached:
309         _name_to_pwd_cache[entry.pw_name] = entry
310     return entry
311
312
313 def pwd_from_name(name):
314     """Return password database entry for name (may be a cached value).
315     Return None if no entry is found.
316     """
317     global _uid_to_pwd_cache, _name_to_pwd_cache
318     entry, cached = _cache_key_value(pwd.getpwnam, name, _name_to_pwd_cache)
319     if entry and not cached:
320         _uid_to_pwd_cache[entry.pw_uid] = entry
321     return entry
322
323
324 _gid_to_grp_cache = {}
325 _name_to_grp_cache = {}
326
327 def grp_from_gid(gid):
328     """Return password database entry for gid (may be a cached value).
329     Return None if no entry is found.
330     """
331     global _gid_to_grp_cache, _name_to_grp_cache
332     entry, cached = _cache_key_value(grp.getgrgid, gid, _gid_to_grp_cache)
333     if entry and not cached:
334         _name_to_grp_cache[entry.gr_name] = entry
335     return entry
336
337
338 def grp_from_name(name):
339     """Return password database entry for name (may be a cached value).
340     Return None if no entry is found.
341     """
342     global _gid_to_grp_cache, _name_to_grp_cache
343     entry, cached = _cache_key_value(grp.getgrnam, name, _name_to_grp_cache)
344     if entry and not cached:
345         _gid_to_grp_cache[entry.gr_gid] = entry
346     return entry
347
348
349 _username = None
350 def username():
351     """Get the user's login name."""
352     global _username
353     if not _username:
354         uid = os.getuid()
355         _username = pwd_from_uid(uid)[0] or 'user%d' % uid
356     return _username
357
358
359 _userfullname = None
360 def userfullname():
361     """Get the user's full name."""
362     global _userfullname
363     if not _userfullname:
364         uid = os.getuid()
365         entry = pwd_from_uid(uid)
366         if entry:
367             _userfullname = entry[4].split(',')[0] or entry[0]
368         if not _userfullname:
369             _userfullname = 'user%d' % uid
370     return _userfullname
371
372
373 _hostname = None
374 def hostname():
375     """Get the FQDN of this machine."""
376     global _hostname
377     if not _hostname:
378         _hostname = socket.getfqdn()
379     return _hostname
380
381
382 _resource_path = None
383 def resource_path(subdir=''):
384     global _resource_path
385     if not _resource_path:
386         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
387     return os.path.join(_resource_path, subdir)
388
389 def format_filesize(size):
390     unit = 1024.0
391     size = float(size)
392     if size < unit:
393         return "%d" % (size)
394     exponent = int(math.log(size) / math.log(unit))
395     size_prefix = "KMGTPE"[exponent - 1]
396     return "%.1f%s" % (size / math.pow(unit, exponent), size_prefix)
397
398
399 class NotOk(Exception):
400     pass
401
402
403 class BaseConn:
404     def __init__(self, outp):
405         self.outp = outp
406
407     def close(self):
408         while self._read(65536): pass
409
410     def read(self, size):
411         """Read 'size' bytes from input stream."""
412         self.outp.flush()
413         return self._read(size)
414
415     def readline(self):
416         """Read from input stream until a newline is found."""
417         self.outp.flush()
418         return self._readline()
419
420     def write(self, data):
421         """Write 'data' to output stream."""
422         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
423         self.outp.write(data)
424
425     def has_input(self):
426         """Return true if input stream is readable."""
427         raise NotImplemented("Subclasses must implement has_input")
428
429     def ok(self):
430         """Indicate end of output from last sent command."""
431         self.write('\nok\n')
432
433     def error(self, s):
434         """Indicate server error to the client."""
435         s = re.sub(r'\s+', ' ', str(s))
436         self.write('\nerror %s\n' % s)
437
438     def _check_ok(self, onempty):
439         self.outp.flush()
440         rl = ''
441         for rl in linereader(self):
442             #log('%d got line: %r\n' % (os.getpid(), rl))
443             if not rl:  # empty line
444                 continue
445             elif rl == 'ok':
446                 return None
447             elif rl.startswith('error '):
448                 #log('client: error: %s\n' % rl[6:])
449                 return NotOk(rl[6:])
450             else:
451                 onempty(rl)
452         raise Exception('server exited unexpectedly; see errors above')
453
454     def drain_and_check_ok(self):
455         """Remove all data for the current command from input stream."""
456         def onempty(rl):
457             pass
458         return self._check_ok(onempty)
459
460     def check_ok(self):
461         """Verify that server action completed successfully."""
462         def onempty(rl):
463             raise Exception('expected "ok", got %r' % rl)
464         return self._check_ok(onempty)
465
466
467 class Conn(BaseConn):
468     def __init__(self, inp, outp):
469         BaseConn.__init__(self, outp)
470         self.inp = inp
471
472     def _read(self, size):
473         return self.inp.read(size)
474
475     def _readline(self):
476         return self.inp.readline()
477
478     def has_input(self):
479         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
480         if rl:
481             assert(rl[0] == self.inp.fileno())
482             return True
483         else:
484             return None
485
486
487 def checked_reader(fd, n):
488     while n > 0:
489         rl, _, _ = select.select([fd], [], [])
490         assert(rl[0] == fd)
491         buf = os.read(fd, n)
492         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
493         yield buf
494         n -= len(buf)
495
496
497 MAX_PACKET = 128 * 1024
498 def mux(p, outfd, outr, errr):
499     try:
500         fds = [outr, errr]
501         while p.poll() is None:
502             rl, _, _ = select.select(fds, [], [])
503             for fd in rl:
504                 if fd == outr:
505                     buf = os.read(outr, MAX_PACKET)
506                     if not buf: break
507                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
508                 elif fd == errr:
509                     buf = os.read(errr, 1024)
510                     if not buf: break
511                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
512     finally:
513         os.write(outfd, struct.pack('!IB', 0, 3))
514
515
516 class DemuxConn(BaseConn):
517     """A helper class for bup's client-server protocol."""
518     def __init__(self, infd, outp):
519         BaseConn.__init__(self, outp)
520         # Anything that comes through before the sync string was not
521         # multiplexed and can be assumed to be debug/log before mux init.
522         tail = ''
523         while tail != 'BUPMUX':
524             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
525             if not b:
526                 raise IOError('demux: unexpected EOF during initialization')
527             tail += b
528             sys.stderr.write(tail[:-6])  # pre-mux log messages
529             tail = tail[-6:]
530         self.infd = infd
531         self.reader = None
532         self.buf = None
533         self.closed = False
534
535     def write(self, data):
536         self._load_buf(0)
537         BaseConn.write(self, data)
538
539     def _next_packet(self, timeout):
540         if self.closed: return False
541         rl, wl, xl = select.select([self.infd], [], [], timeout)
542         if not rl: return False
543         assert(rl[0] == self.infd)
544         ns = ''.join(checked_reader(self.infd, 5))
545         n, fdw = struct.unpack('!IB', ns)
546         assert(n <= MAX_PACKET)
547         if fdw == 1:
548             self.reader = checked_reader(self.infd, n)
549         elif fdw == 2:
550             for buf in checked_reader(self.infd, n):
551                 sys.stderr.write(buf)
552         elif fdw == 3:
553             self.closed = True
554             debug2("DemuxConn: marked closed\n")
555         return True
556
557     def _load_buf(self, timeout):
558         if self.buf is not None:
559             return True
560         while not self.closed:
561             while not self.reader:
562                 if not self._next_packet(timeout):
563                     return False
564             try:
565                 self.buf = self.reader.next()
566                 return True
567             except StopIteration:
568                 self.reader = None
569         return False
570
571     def _read_parts(self, ix_fn):
572         while self._load_buf(None):
573             assert(self.buf is not None)
574             i = ix_fn(self.buf)
575             if i is None or i == len(self.buf):
576                 yv = self.buf
577                 self.buf = None
578             else:
579                 yv = self.buf[:i]
580                 self.buf = self.buf[i:]
581             yield yv
582             if i is not None:
583                 break
584
585     def _readline(self):
586         def find_eol(buf):
587             try:
588                 return buf.index('\n')+1
589             except ValueError:
590                 return None
591         return ''.join(self._read_parts(find_eol))
592
593     def _read(self, size):
594         csize = [size]
595         def until_size(buf): # Closes on csize
596             if len(buf) < csize[0]:
597                 csize[0] -= len(buf)
598                 return None
599             else:
600                 return csize[0]
601         return ''.join(self._read_parts(until_size))
602
603     def has_input(self):
604         return self._load_buf(0)
605
606
607 def linereader(f):
608     """Generate a list of input lines from 'f' without terminating newlines."""
609     while 1:
610         line = f.readline()
611         if not line:
612             break
613         yield line[:-1]
614
615
616 def chunkyreader(f, count = None):
617     """Generate a list of chunks of data read from 'f'.
618
619     If count is None, read until EOF is reached.
620
621     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
622     reached while reading, raise IOError.
623     """
624     if count != None:
625         while count > 0:
626             b = f.read(min(count, 65536))
627             if not b:
628                 raise IOError('EOF with %d bytes remaining' % count)
629             yield b
630             count -= len(b)
631     else:
632         while 1:
633             b = f.read(65536)
634             if not b: break
635             yield b
636
637
638 @contextmanager
639 def atomically_replaced_file(name, mode='w', buffering=-1):
640     """Yield a file that will be atomically renamed name when leaving the block.
641
642     This contextmanager yields an open file object that is backed by a
643     temporary file which will be renamed (atomically) to the target
644     name if everything succeeds.
645
646     The mode and buffering arguments are handled exactly as with open,
647     and the yielded file will have very restrictive permissions, as
648     per mkstemp.
649
650     E.g.::
651
652         with atomically_replaced_file('foo.txt', 'w') as f:
653             f.write('hello jack.')
654
655     """
656
657     (ffd, tempname) = tempfile.mkstemp(dir=os.path.dirname(name),
658                                        text=('b' not in mode))
659     try:
660         try:
661             f = os.fdopen(ffd, mode, buffering)
662         except:
663             os.close(ffd)
664             raise
665         try:
666             yield f
667         finally:
668             f.close()
669         os.rename(tempname, name)
670     finally:
671         unlink(tempname)  # nonexistant file is ignored
672
673
674 def slashappend(s):
675     """Append "/" to 's' if it doesn't aleady end in "/"."""
676     if s and not s.endswith('/'):
677         return s + '/'
678     else:
679         return s
680
681
682 def _mmap_do(f, sz, flags, prot, close):
683     if not sz:
684         st = os.fstat(f.fileno())
685         sz = st.st_size
686     if not sz:
687         # trying to open a zero-length map gives an error, but an empty
688         # string has all the same behaviour of a zero-length map, ie. it has
689         # no elements :)
690         return ''
691     map = mmap.mmap(f.fileno(), sz, flags, prot)
692     if close:
693         f.close()  # map will persist beyond file close
694     return map
695
696
697 def mmap_read(f, sz = 0, close=True):
698     """Create a read-only memory mapped region on file 'f'.
699     If sz is 0, the region will cover the entire file.
700     """
701     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
702
703
704 def mmap_readwrite(f, sz = 0, close=True):
705     """Create a read-write memory mapped region on file 'f'.
706     If sz is 0, the region will cover the entire file.
707     """
708     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
709                     close)
710
711
712 def mmap_readwrite_private(f, sz = 0, close=True):
713     """Create a read-write memory mapped region on file 'f'.
714     If sz is 0, the region will cover the entire file.
715     The map is private, which means the changes are never flushed back to the
716     file.
717     """
718     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
719                     close)
720
721
722 _mincore = getattr(_helpers, 'mincore', None)
723 if _mincore:
724     # ./configure ensures that we're on Linux if MINCORE_INCORE isn't defined.
725     MINCORE_INCORE = getattr(_helpers, 'MINCORE_INCORE', 1)
726
727     _fmincore_chunk_size = None
728     def _set_fmincore_chunk_size():
729         global _fmincore_chunk_size
730         pref_chunk_size = 64 * 1024 * 1024
731         chunk_size = sc_page_size
732         if (sc_page_size < pref_chunk_size):
733             chunk_size = sc_page_size * (pref_chunk_size / sc_page_size)
734         _fmincore_chunk_size = chunk_size
735
736     def fmincore(fd):
737         """Return the mincore() data for fd as a bytearray whose values can be
738         tested via MINCORE_INCORE"""
739         st = os.fstat(fd)
740         if (st.st_size == 0):
741             return bytearray(0)
742         if not _fmincore_chunk_size:
743             _set_fmincore_chunk_size()
744         pages_per_chunk = _fmincore_chunk_size / sc_page_size;
745         page_count = (st.st_size + sc_page_size - 1) / sc_page_size;
746         chunk_count = page_count / _fmincore_chunk_size
747         if chunk_count < 1:
748             chunk_count = 1
749         result = bytearray(page_count)
750         for ci in xrange(chunk_count):
751             pos = _fmincore_chunk_size * ci;
752             msize = min(_fmincore_chunk_size, st.st_size - pos)
753             m = mmap.mmap(fd, msize, mmap.MAP_PRIVATE, 0, 0, pos)
754             _mincore(m, msize, 0, result, ci * pages_per_chunk);
755         return result
756
757
758 def parse_timestamp(epoch_str):
759     """Return the number of nanoseconds since the epoch that are described
760 by epoch_str (100ms, 100ns, ...); when epoch_str cannot be parsed,
761 throw a ValueError that may contain additional information."""
762     ns_per = {'s' :  1000000000,
763               'ms' : 1000000,
764               'us' : 1000,
765               'ns' : 1}
766     match = re.match(r'^((?:[-+]?[0-9]+)?)(s|ms|us|ns)$', epoch_str)
767     if not match:
768         if re.match(r'^([-+]?[0-9]+)$', epoch_str):
769             raise ValueError('must include units, i.e. 100ns, 100ms, ...')
770         raise ValueError()
771     (n, units) = match.group(1, 2)
772     if not n:
773         n = 1
774     n = int(n)
775     return n * ns_per[units]
776
777
778 def parse_num(s):
779     """Parse data size information into a float number.
780
781     Here are some examples of conversions:
782         199.2k means 203981 bytes
783         1GB means 1073741824 bytes
784         2.1 tb means 2199023255552 bytes
785     """
786     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
787     if not g:
788         raise ValueError("can't parse %r as a number" % s)
789     (val, unit) = g.groups()
790     num = float(val)
791     unit = unit.lower()
792     if unit in ['t', 'tb']:
793         mult = 1024*1024*1024*1024
794     elif unit in ['g', 'gb']:
795         mult = 1024*1024*1024
796     elif unit in ['m', 'mb']:
797         mult = 1024*1024
798     elif unit in ['k', 'kb']:
799         mult = 1024
800     elif unit in ['', 'b']:
801         mult = 1
802     else:
803         raise ValueError("invalid unit %r in number %r" % (unit, s))
804     return int(num*mult)
805
806
807 def count(l):
808     """Count the number of elements in an iterator. (consumes the iterator)"""
809     return reduce(lambda x,y: x+1, l)
810
811
812 saved_errors = []
813 def add_error(e):
814     """Append an error message to the list of saved errors.
815
816     Once processing is able to stop and output the errors, the saved errors are
817     accessible in the module variable helpers.saved_errors.
818     """
819     saved_errors.append(e)
820     log('%-70s\n' % e)
821
822
823 def clear_errors():
824     global saved_errors
825     saved_errors = []
826
827
828 def handle_ctrl_c():
829     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
830
831     The new exception handler will make sure that bup will exit without an ugly
832     stacktrace when Ctrl-C is hit.
833     """
834     oldhook = sys.excepthook
835     def newhook(exctype, value, traceback):
836         if exctype == KeyboardInterrupt:
837             log('\nInterrupted.\n')
838         else:
839             return oldhook(exctype, value, traceback)
840     sys.excepthook = newhook
841
842
843 def columnate(l, prefix):
844     """Format elements of 'l' in columns with 'prefix' leading each line.
845
846     The number of columns is determined automatically based on the string
847     lengths.
848     """
849     if not l:
850         return ""
851     l = l[:]
852     clen = max(len(s) for s in l)
853     ncols = (tty_width() - len(prefix)) / (clen + 2)
854     if ncols <= 1:
855         ncols = 1
856         clen = 0
857     cols = []
858     while len(l) % ncols:
859         l.append('')
860     rows = len(l)/ncols
861     for s in range(0, len(l), rows):
862         cols.append(l[s:s+rows])
863     out = ''
864     for row in zip(*cols):
865         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
866     return out
867
868
869 def parse_date_or_fatal(str, fatal):
870     """Parses the given date or calls Option.fatal().
871     For now we expect a string that contains a float."""
872     try:
873         date = float(str)
874     except ValueError, e:
875         raise fatal('invalid date format (should be a float): %r' % e)
876     else:
877         return date
878
879
880 def parse_excludes(options, fatal):
881     """Traverse the options and extract all excludes, or call Option.fatal()."""
882     excluded_paths = []
883
884     for flag in options:
885         (option, parameter) = flag
886         if option == '--exclude':
887             excluded_paths.append(realpath(parameter))
888         elif option == '--exclude-from':
889             try:
890                 f = open(realpath(parameter))
891             except IOError, e:
892                 raise fatal("couldn't read %s" % parameter)
893             for exclude_path in f.readlines():
894                 # FIXME: perhaps this should be rstrip('\n')
895                 exclude_path = realpath(exclude_path.strip())
896                 if exclude_path:
897                     excluded_paths.append(exclude_path)
898     return sorted(frozenset(excluded_paths))
899
900
901 def parse_rx_excludes(options, fatal):
902     """Traverse the options and extract all rx excludes, or call
903     Option.fatal()."""
904     excluded_patterns = []
905
906     for flag in options:
907         (option, parameter) = flag
908         if option == '--exclude-rx':
909             try:
910                 excluded_patterns.append(re.compile(parameter))
911             except re.error, ex:
912                 fatal('invalid --exclude-rx pattern (%s): %s' % (parameter, ex))
913         elif option == '--exclude-rx-from':
914             try:
915                 f = open(realpath(parameter))
916             except IOError, e:
917                 raise fatal("couldn't read %s" % parameter)
918             for pattern in f.readlines():
919                 spattern = pattern.rstrip('\n')
920                 if not spattern:
921                     continue
922                 try:
923                     excluded_patterns.append(re.compile(spattern))
924                 except re.error, ex:
925                     fatal('invalid --exclude-rx pattern (%s): %s' % (spattern, ex))
926     return excluded_patterns
927
928
929 def should_rx_exclude_path(path, exclude_rxs):
930     """Return True if path matches a regular expression in exclude_rxs."""
931     for rx in exclude_rxs:
932         if rx.search(path):
933             debug1('Skipping %r: excluded by rx pattern %r.\n'
934                    % (path, rx.pattern))
935             return True
936     return False
937
938
939 # FIXME: Carefully consider the use of functions (os.path.*, etc.)
940 # that resolve against the current filesystem in the strip/graft
941 # functions for example, but elsewhere as well.  I suspect bup's not
942 # always being careful about that.  For some cases, the contents of
943 # the current filesystem should be irrelevant, and consulting it might
944 # produce the wrong result, perhaps via unintended symlink resolution,
945 # for example.
946
947 def path_components(path):
948     """Break path into a list of pairs of the form (name,
949     full_path_to_name).  Path must start with '/'.
950     Example:
951       '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
952     if not path.startswith('/'):
953         raise Exception, 'path must start with "/": %s' % path
954     # Since we assume path startswith('/'), we can skip the first element.
955     result = [('', '/')]
956     norm_path = os.path.abspath(path)
957     if norm_path == '/':
958         return result
959     full_path = ''
960     for p in norm_path.split('/')[1:]:
961         full_path += '/' + p
962         result.append((p, full_path))
963     return result
964
965
966 def stripped_path_components(path, strip_prefixes):
967     """Strip any prefix in strip_prefixes from path and return a list
968     of path components where each component is (name,
969     none_or_full_fs_path_to_name).  Assume path startswith('/').
970     See thelpers.py for examples."""
971     normalized_path = os.path.abspath(path)
972     sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
973     for bp in sorted_strip_prefixes:
974         normalized_bp = os.path.abspath(bp)
975         if normalized_bp == '/':
976             continue
977         if normalized_path.startswith(normalized_bp):
978             prefix = normalized_path[:len(normalized_bp)]
979             result = []
980             for p in normalized_path[len(normalized_bp):].split('/'):
981                 if p: # not root
982                     prefix += '/'
983                 prefix += p
984                 result.append((p, prefix))
985             return result
986     # Nothing to strip.
987     return path_components(path)
988
989
990 def grafted_path_components(graft_points, path):
991     # Create a result that consists of some number of faked graft
992     # directories before the graft point, followed by all of the real
993     # directories from path that are after the graft point.  Arrange
994     # for the directory at the graft point in the result to correspond
995     # to the "orig" directory in --graft orig=new.  See t/thelpers.py
996     # for some examples.
997
998     # Note that given --graft orig=new, orig and new have *nothing* to
999     # do with each other, even if some of their component names
1000     # match. i.e. --graft /foo/bar/baz=/foo/bar/bax is semantically
1001     # equivalent to --graft /foo/bar/baz=/x/y/z, or even
1002     # /foo/bar/baz=/x.
1003
1004     # FIXME: This can't be the best solution...
1005     clean_path = os.path.abspath(path)
1006     for graft_point in graft_points:
1007         old_prefix, new_prefix = graft_point
1008         # Expand prefixes iff not absolute paths.
1009         old_prefix = os.path.normpath(old_prefix)
1010         new_prefix = os.path.normpath(new_prefix)
1011         if clean_path.startswith(old_prefix):
1012             escaped_prefix = re.escape(old_prefix)
1013             grafted_path = re.sub(r'^' + escaped_prefix, new_prefix, clean_path)
1014             # Handle /foo=/ (at least) -- which produces //whatever.
1015             grafted_path = '/' + grafted_path.lstrip('/')
1016             clean_path_components = path_components(clean_path)
1017             # Count the components that were stripped.
1018             strip_count = 0 if old_prefix == '/' else old_prefix.count('/')
1019             new_prefix_parts = new_prefix.split('/')
1020             result_prefix = grafted_path.split('/')[:new_prefix.count('/')]
1021             result = [(p, None) for p in result_prefix] \
1022                 + clean_path_components[strip_count:]
1023             # Now set the graft point name to match the end of new_prefix.
1024             graft_point = len(result_prefix)
1025             result[graft_point] = \
1026                 (new_prefix_parts[-1], clean_path_components[strip_count][1])
1027             if new_prefix == '/': # --graft ...=/ is a special case.
1028                 return result[1:]
1029             return result
1030     return path_components(clean_path)
1031
1032
1033 Sha1 = hashlib.sha1
1034
1035
1036 _localtime = getattr(_helpers, 'localtime', None)
1037
1038 if _localtime:
1039     bup_time = namedtuple('bup_time', ['tm_year', 'tm_mon', 'tm_mday',
1040                                        'tm_hour', 'tm_min', 'tm_sec',
1041                                        'tm_wday', 'tm_yday',
1042                                        'tm_isdst', 'tm_gmtoff', 'tm_zone'])
1043
1044 # Define a localtime() that returns bup_time when possible.  Note:
1045 # this means that any helpers.localtime() results may need to be
1046 # passed through to_py_time() before being passed to python's time
1047 # module, which doesn't appear willing to ignore the extra items.
1048 if _localtime:
1049     def localtime(time):
1050         return bup_time(*_helpers.localtime(time))
1051     def utc_offset_str(t):
1052         'Return the local offset from UTC as "+hhmm" or "-hhmm" for time t.'
1053         off = localtime(t).tm_gmtoff
1054         hrs = off / 60 / 60
1055         return "%+03d%02d" % (hrs, abs(off - (hrs * 60 * 60)))
1056     def to_py_time(x):
1057         if isinstance(x, time.struct_time):
1058             return x
1059         return time.struct_time(x[:9])
1060 else:
1061     localtime = time.localtime
1062     def utc_offset_str(t):
1063         return time.strftime('%z', localtime(t))
1064     def to_py_time(x):
1065         return x