]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
5da2b1a2844aff372fb136332b221416e6a3c27f
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 from collections import namedtuple
4 from ctypes import sizeof, c_void_p
5 from os import environ
6 from contextlib import contextmanager
7 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
8 import hashlib, heapq, math, operator, time, grp, tempfile
9
10 from bup import _helpers
11
12 sc_page_size = os.sysconf('SC_PAGE_SIZE')
13 assert(sc_page_size > 0)
14
15 sc_arg_max = os.sysconf('SC_ARG_MAX')
16 if sc_arg_max == -1:  # "no definite limit" - let's choose 2M
17     sc_arg_max = 2 * 1024 * 1024
18
19 # This function should really be in helpers, not in bup.options.  But we
20 # want options.py to be standalone so people can include it in other projects.
21 from bup.options import _tty_width
22 tty_width = _tty_width
23
24
25 def atoi(s):
26     """Convert the string 's' to an integer. Return 0 if s is not a number."""
27     try:
28         return int(s or '0')
29     except ValueError:
30         return 0
31
32
33 def atof(s):
34     """Convert the string 's' to a float. Return 0 if s is not a number."""
35     try:
36         return float(s or '0')
37     except ValueError:
38         return 0
39
40
41 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
42
43
44 try:
45     _fdatasync = os.fdatasync
46 except AttributeError:
47     _fdatasync = os.fsync
48
49 if sys.platform.startswith('darwin'):
50     # Apparently os.fsync on OS X doesn't guarantee to sync all the way down
51     import fcntl
52     def fdatasync(fd):
53         try:
54             return fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
55         except IOError as e:
56             # Fallback for file systems (SMB) that do not support F_FULLFSYNC
57             if e.errno == errno.ENOTSUP:
58                 return _fdatasync(fd)
59             else:
60                 raise
61 else:
62     fdatasync = _fdatasync
63
64
65 # Write (blockingly) to sockets that may or may not be in blocking mode.
66 # We need this because our stderr is sometimes eaten by subprocesses
67 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
68 # leading to race conditions.  Ick.  We'll do it the hard way.
69 def _hard_write(fd, buf):
70     while buf:
71         (r,w,x) = select.select([], [fd], [], None)
72         if not w:
73             raise IOError('select(fd) returned without being writable')
74         try:
75             sz = os.write(fd, buf)
76         except OSError as e:
77             if e.errno != errno.EAGAIN:
78                 raise
79         assert(sz >= 0)
80         buf = buf[sz:]
81
82
83 _last_prog = 0
84 def log(s):
85     """Print a log message to stderr."""
86     global _last_prog
87     sys.stdout.flush()
88     _hard_write(sys.stderr.fileno(), s)
89     _last_prog = 0
90
91
92 def debug1(s):
93     if buglvl >= 1:
94         log(s)
95
96
97 def debug2(s):
98     if buglvl >= 2:
99         log(s)
100
101
102 istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
103 istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
104 _last_progress = ''
105 def progress(s):
106     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
107     global _last_progress
108     if istty2:
109         log(s)
110         _last_progress = s
111
112
113 def qprogress(s):
114     """Calls progress() only if we haven't printed progress in a while.
115     
116     This avoids overloading the stderr buffer with excess junk.
117     """
118     global _last_prog
119     now = time.time()
120     if now - _last_prog > 0.1:
121         progress(s)
122         _last_prog = now
123
124
125 def reprogress():
126     """Calls progress() to redisplay the most recent progress message.
127
128     Useful after you've printed some other message that wipes out the
129     progress line.
130     """
131     if _last_progress and _last_progress.endswith('\r'):
132         progress(_last_progress)
133
134
135 def mkdirp(d, mode=None):
136     """Recursively create directories on path 'd'.
137
138     Unlike os.makedirs(), it doesn't raise an exception if the last element of
139     the path already exists.
140     """
141     try:
142         if mode:
143             os.makedirs(d, mode)
144         else:
145             os.makedirs(d)
146     except OSError as e:
147         if e.errno == errno.EEXIST:
148             pass
149         else:
150             raise
151
152
153 _unspecified_next_default = object()
154
155 def _fallback_next(it, default=_unspecified_next_default):
156     """Retrieve the next item from the iterator by calling its
157     next() method. If default is given, it is returned if the
158     iterator is exhausted, otherwise StopIteration is raised."""
159
160     if default is _unspecified_next_default:
161         return it.next()
162     else:
163         try:
164             return it.next()
165         except StopIteration:
166             return default
167
168 if sys.version_info < (2, 6):
169     next =  _fallback_next
170
171
172 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
173     if key:
174         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
175     else:
176         samekey = operator.eq
177     count = 0
178     total = sum(len(it) for it in iters)
179     iters = (iter(it) for it in iters)
180     heap = ((next(it, None),it) for it in iters)
181     heap = [(e,it) for e,it in heap if e]
182
183     heapq.heapify(heap)
184     pe = None
185     while heap:
186         if not count % pfreq:
187             pfunc(count, total)
188         e, it = heap[0]
189         if not samekey(e, pe):
190             pe = e
191             yield e
192         count += 1
193         try:
194             e = it.next() # Don't use next() function, it's too expensive
195         except StopIteration:
196             heapq.heappop(heap) # remove current
197         else:
198             heapq.heapreplace(heap, (e, it)) # shift current to new location
199     pfinal(count, total)
200
201
202 def unlink(f):
203     """Delete a file at path 'f' if it currently exists.
204
205     Unlike os.unlink(), does not throw an exception if the file didn't already
206     exist.
207     """
208     try:
209         os.unlink(f)
210     except OSError as e:
211         if e.errno != errno.ENOENT:
212             raise
213
214
215 def readpipe(argv, preexec_fn=None):
216     """Run a subprocess and return its output."""
217     p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn=preexec_fn)
218     out, err = p.communicate()
219     if p.returncode != 0:
220         raise Exception('subprocess %r failed with status %d'
221                         % (' '.join(argv), p.returncode))
222     return out
223
224
225 def _argmax_base(command):
226     base_size = 2048
227     for c in command:
228         base_size += len(command) + 1
229     for k, v in environ.iteritems():
230         base_size += len(k) + len(v) + 2 + sizeof(c_void_p)
231     return base_size
232
233
234 def _argmax_args_size(args):
235     return sum(len(x) + 1 + sizeof(c_void_p) for x in args)
236
237
238 def batchpipe(command, args, preexec_fn=None, arg_max=sc_arg_max):
239     """If args is not empty, yield the output produced by calling the
240 command list with args as a sequence of strings (It may be necessary
241 to return multiple strings in order to respect ARG_MAX)."""
242     # The optional arg_max arg is a workaround for an issue with the
243     # current wvtest behavior.
244     base_size = _argmax_base(command)
245     while args:
246         room = arg_max - base_size
247         i = 0
248         while i < len(args):
249             next_size = _argmax_args_size(args[i:i+1])
250             if room - next_size < 0:
251                 break
252             room -= next_size
253             i += 1
254         sub_args = args[:i]
255         args = args[i:]
256         assert(len(sub_args))
257         yield readpipe(command + sub_args, preexec_fn=preexec_fn)
258
259
260 def resolve_parent(p):
261     """Return the absolute path of a file without following any final symlink.
262
263     Behaves like os.path.realpath, but doesn't follow a symlink for the last
264     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
265     will follow symlinks in p's directory)
266     """
267     try:
268         st = os.lstat(p)
269     except OSError:
270         st = None
271     if st and stat.S_ISLNK(st.st_mode):
272         (dir, name) = os.path.split(p)
273         dir = os.path.realpath(dir)
274         out = os.path.join(dir, name)
275     else:
276         out = os.path.realpath(p)
277     #log('realpathing:%r,%r\n' % (p, out))
278     return out
279
280
281 def detect_fakeroot():
282     "Return True if we appear to be running under fakeroot."
283     return os.getenv("FAKEROOTKEY") != None
284
285
286 def is_superuser():
287     if sys.platform.startswith('cygwin'):
288         import ctypes
289         return ctypes.cdll.shell32.IsUserAnAdmin()
290     else:
291         return os.geteuid() == 0
292
293
294 def _cache_key_value(get_value, key, cache):
295     """Return (value, was_cached).  If there is a value in the cache
296     for key, use that, otherwise, call get_value(key) which should
297     throw a KeyError if there is no value -- in which case the cached
298     and returned value will be None.
299     """
300     try: # Do we already have it (or know there wasn't one)?
301         value = cache[key]
302         return value, True
303     except KeyError:
304         pass
305     value = None
306     try:
307         cache[key] = value = get_value(key)
308     except KeyError:
309         cache[key] = None
310     return value, False
311
312
313 _uid_to_pwd_cache = {}
314 _name_to_pwd_cache = {}
315
316 def pwd_from_uid(uid):
317     """Return password database entry for uid (may be a cached value).
318     Return None if no entry is found.
319     """
320     global _uid_to_pwd_cache, _name_to_pwd_cache
321     entry, cached = _cache_key_value(pwd.getpwuid, uid, _uid_to_pwd_cache)
322     if entry and not cached:
323         _name_to_pwd_cache[entry.pw_name] = entry
324     return entry
325
326
327 def pwd_from_name(name):
328     """Return password database entry for name (may be a cached value).
329     Return None if no entry is found.
330     """
331     global _uid_to_pwd_cache, _name_to_pwd_cache
332     entry, cached = _cache_key_value(pwd.getpwnam, name, _name_to_pwd_cache)
333     if entry and not cached:
334         _uid_to_pwd_cache[entry.pw_uid] = entry
335     return entry
336
337
338 _gid_to_grp_cache = {}
339 _name_to_grp_cache = {}
340
341 def grp_from_gid(gid):
342     """Return password database entry for gid (may be a cached value).
343     Return None if no entry is found.
344     """
345     global _gid_to_grp_cache, _name_to_grp_cache
346     entry, cached = _cache_key_value(grp.getgrgid, gid, _gid_to_grp_cache)
347     if entry and not cached:
348         _name_to_grp_cache[entry.gr_name] = entry
349     return entry
350
351
352 def grp_from_name(name):
353     """Return password database entry for name (may be a cached value).
354     Return None if no entry is found.
355     """
356     global _gid_to_grp_cache, _name_to_grp_cache
357     entry, cached = _cache_key_value(grp.getgrnam, name, _name_to_grp_cache)
358     if entry and not cached:
359         _gid_to_grp_cache[entry.gr_gid] = entry
360     return entry
361
362
363 _username = None
364 def username():
365     """Get the user's login name."""
366     global _username
367     if not _username:
368         uid = os.getuid()
369         _username = pwd_from_uid(uid)[0] or 'user%d' % uid
370     return _username
371
372
373 _userfullname = None
374 def userfullname():
375     """Get the user's full name."""
376     global _userfullname
377     if not _userfullname:
378         uid = os.getuid()
379         entry = pwd_from_uid(uid)
380         if entry:
381             _userfullname = entry[4].split(',')[0] or entry[0]
382         if not _userfullname:
383             _userfullname = 'user%d' % uid
384     return _userfullname
385
386
387 _hostname = None
388 def hostname():
389     """Get the FQDN of this machine."""
390     global _hostname
391     if not _hostname:
392         _hostname = socket.getfqdn()
393     return _hostname
394
395
396 _resource_path = None
397 def resource_path(subdir=''):
398     global _resource_path
399     if not _resource_path:
400         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
401     return os.path.join(_resource_path, subdir)
402
403 def format_filesize(size):
404     unit = 1024.0
405     size = float(size)
406     if size < unit:
407         return "%d" % (size)
408     exponent = int(math.log(size) / math.log(unit))
409     size_prefix = "KMGTPE"[exponent - 1]
410     return "%.1f%s" % (size / math.pow(unit, exponent), size_prefix)
411
412
413 class NotOk(Exception):
414     pass
415
416
417 class BaseConn:
418     def __init__(self, outp):
419         self.outp = outp
420
421     def close(self):
422         while self._read(65536): pass
423
424     def read(self, size):
425         """Read 'size' bytes from input stream."""
426         self.outp.flush()
427         return self._read(size)
428
429     def readline(self):
430         """Read from input stream until a newline is found."""
431         self.outp.flush()
432         return self._readline()
433
434     def write(self, data):
435         """Write 'data' to output stream."""
436         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
437         self.outp.write(data)
438
439     def has_input(self):
440         """Return true if input stream is readable."""
441         raise NotImplemented("Subclasses must implement has_input")
442
443     def ok(self):
444         """Indicate end of output from last sent command."""
445         self.write('\nok\n')
446
447     def error(self, s):
448         """Indicate server error to the client."""
449         s = re.sub(r'\s+', ' ', str(s))
450         self.write('\nerror %s\n' % s)
451
452     def _check_ok(self, onempty):
453         self.outp.flush()
454         rl = ''
455         for rl in linereader(self):
456             #log('%d got line: %r\n' % (os.getpid(), rl))
457             if not rl:  # empty line
458                 continue
459             elif rl == 'ok':
460                 return None
461             elif rl.startswith('error '):
462                 #log('client: error: %s\n' % rl[6:])
463                 return NotOk(rl[6:])
464             else:
465                 onempty(rl)
466         raise Exception('server exited unexpectedly; see errors above')
467
468     def drain_and_check_ok(self):
469         """Remove all data for the current command from input stream."""
470         def onempty(rl):
471             pass
472         return self._check_ok(onempty)
473
474     def check_ok(self):
475         """Verify that server action completed successfully."""
476         def onempty(rl):
477             raise Exception('expected "ok", got %r' % rl)
478         return self._check_ok(onempty)
479
480
481 class Conn(BaseConn):
482     def __init__(self, inp, outp):
483         BaseConn.__init__(self, outp)
484         self.inp = inp
485
486     def _read(self, size):
487         return self.inp.read(size)
488
489     def _readline(self):
490         return self.inp.readline()
491
492     def has_input(self):
493         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
494         if rl:
495             assert(rl[0] == self.inp.fileno())
496             return True
497         else:
498             return None
499
500
501 def checked_reader(fd, n):
502     while n > 0:
503         rl, _, _ = select.select([fd], [], [])
504         assert(rl[0] == fd)
505         buf = os.read(fd, n)
506         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
507         yield buf
508         n -= len(buf)
509
510
511 MAX_PACKET = 128 * 1024
512 def mux(p, outfd, outr, errr):
513     try:
514         fds = [outr, errr]
515         while p.poll() is None:
516             rl, _, _ = select.select(fds, [], [])
517             for fd in rl:
518                 if fd == outr:
519                     buf = os.read(outr, MAX_PACKET)
520                     if not buf: break
521                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
522                 elif fd == errr:
523                     buf = os.read(errr, 1024)
524                     if not buf: break
525                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
526     finally:
527         os.write(outfd, struct.pack('!IB', 0, 3))
528
529
530 class DemuxConn(BaseConn):
531     """A helper class for bup's client-server protocol."""
532     def __init__(self, infd, outp):
533         BaseConn.__init__(self, outp)
534         # Anything that comes through before the sync string was not
535         # multiplexed and can be assumed to be debug/log before mux init.
536         tail = ''
537         while tail != 'BUPMUX':
538             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
539             if not b:
540                 raise IOError('demux: unexpected EOF during initialization')
541             tail += b
542             sys.stderr.write(tail[:-6])  # pre-mux log messages
543             tail = tail[-6:]
544         self.infd = infd
545         self.reader = None
546         self.buf = None
547         self.closed = False
548
549     def write(self, data):
550         self._load_buf(0)
551         BaseConn.write(self, data)
552
553     def _next_packet(self, timeout):
554         if self.closed: return False
555         rl, wl, xl = select.select([self.infd], [], [], timeout)
556         if not rl: return False
557         assert(rl[0] == self.infd)
558         ns = ''.join(checked_reader(self.infd, 5))
559         n, fdw = struct.unpack('!IB', ns)
560         assert(n <= MAX_PACKET)
561         if fdw == 1:
562             self.reader = checked_reader(self.infd, n)
563         elif fdw == 2:
564             for buf in checked_reader(self.infd, n):
565                 sys.stderr.write(buf)
566         elif fdw == 3:
567             self.closed = True
568             debug2("DemuxConn: marked closed\n")
569         return True
570
571     def _load_buf(self, timeout):
572         if self.buf is not None:
573             return True
574         while not self.closed:
575             while not self.reader:
576                 if not self._next_packet(timeout):
577                     return False
578             try:
579                 self.buf = self.reader.next()
580                 return True
581             except StopIteration:
582                 self.reader = None
583         return False
584
585     def _read_parts(self, ix_fn):
586         while self._load_buf(None):
587             assert(self.buf is not None)
588             i = ix_fn(self.buf)
589             if i is None or i == len(self.buf):
590                 yv = self.buf
591                 self.buf = None
592             else:
593                 yv = self.buf[:i]
594                 self.buf = self.buf[i:]
595             yield yv
596             if i is not None:
597                 break
598
599     def _readline(self):
600         def find_eol(buf):
601             try:
602                 return buf.index('\n')+1
603             except ValueError:
604                 return None
605         return ''.join(self._read_parts(find_eol))
606
607     def _read(self, size):
608         csize = [size]
609         def until_size(buf): # Closes on csize
610             if len(buf) < csize[0]:
611                 csize[0] -= len(buf)
612                 return None
613             else:
614                 return csize[0]
615         return ''.join(self._read_parts(until_size))
616
617     def has_input(self):
618         return self._load_buf(0)
619
620
621 def linereader(f):
622     """Generate a list of input lines from 'f' without terminating newlines."""
623     while 1:
624         line = f.readline()
625         if not line:
626             break
627         yield line[:-1]
628
629
630 def chunkyreader(f, count = None):
631     """Generate a list of chunks of data read from 'f'.
632
633     If count is None, read until EOF is reached.
634
635     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
636     reached while reading, raise IOError.
637     """
638     if count != None:
639         while count > 0:
640             b = f.read(min(count, 65536))
641             if not b:
642                 raise IOError('EOF with %d bytes remaining' % count)
643             yield b
644             count -= len(b)
645     else:
646         while 1:
647             b = f.read(65536)
648             if not b: break
649             yield b
650
651
652 @contextmanager
653 def atomically_replaced_file(name, mode='w', buffering=-1):
654     """Yield a file that will be atomically renamed name when leaving the block.
655
656     This contextmanager yields an open file object that is backed by a
657     temporary file which will be renamed (atomically) to the target
658     name if everything succeeds.
659
660     The mode and buffering arguments are handled exactly as with open,
661     and the yielded file will have very restrictive permissions, as
662     per mkstemp.
663
664     E.g.::
665
666         with atomically_replaced_file('foo.txt', 'w') as f:
667             f.write('hello jack.')
668
669     """
670
671     (ffd, tempname) = tempfile.mkstemp(dir=os.path.dirname(name),
672                                        text=('b' not in mode))
673     try:
674         try:
675             f = os.fdopen(ffd, mode, buffering)
676         except:
677             os.close(ffd)
678             raise
679         try:
680             yield f
681         finally:
682             f.close()
683         os.rename(tempname, name)
684     finally:
685         unlink(tempname)  # nonexistant file is ignored
686
687
688 def slashappend(s):
689     """Append "/" to 's' if it doesn't aleady end in "/"."""
690     if s and not s.endswith('/'):
691         return s + '/'
692     else:
693         return s
694
695
696 def _mmap_do(f, sz, flags, prot, close):
697     if not sz:
698         st = os.fstat(f.fileno())
699         sz = st.st_size
700     if not sz:
701         # trying to open a zero-length map gives an error, but an empty
702         # string has all the same behaviour of a zero-length map, ie. it has
703         # no elements :)
704         return ''
705     map = mmap.mmap(f.fileno(), sz, flags, prot)
706     if close:
707         f.close()  # map will persist beyond file close
708     return map
709
710
711 def mmap_read(f, sz = 0, close=True):
712     """Create a read-only memory mapped region on file 'f'.
713     If sz is 0, the region will cover the entire file.
714     """
715     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
716
717
718 def mmap_readwrite(f, sz = 0, close=True):
719     """Create a read-write memory mapped region on file 'f'.
720     If sz is 0, the region will cover the entire file.
721     """
722     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
723                     close)
724
725
726 def mmap_readwrite_private(f, sz = 0, close=True):
727     """Create a read-write memory mapped region on file 'f'.
728     If sz is 0, the region will cover the entire file.
729     The map is private, which means the changes are never flushed back to the
730     file.
731     """
732     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
733                     close)
734
735
736 _mincore = getattr(_helpers, 'mincore', None)
737 if _mincore:
738     # ./configure ensures that we're on Linux if MINCORE_INCORE isn't defined.
739     MINCORE_INCORE = getattr(_helpers, 'MINCORE_INCORE', 1)
740
741     _fmincore_chunk_size = None
742     def _set_fmincore_chunk_size():
743         global _fmincore_chunk_size
744         pref_chunk_size = 64 * 1024 * 1024
745         chunk_size = sc_page_size
746         if (sc_page_size < pref_chunk_size):
747             chunk_size = sc_page_size * (pref_chunk_size / sc_page_size)
748         _fmincore_chunk_size = chunk_size
749
750     def fmincore(fd):
751         """Return the mincore() data for fd as a bytearray whose values can be
752         tested via MINCORE_INCORE, or None if fd does not fully
753         support the operation."""
754         st = os.fstat(fd)
755         if (st.st_size == 0):
756             return bytearray(0)
757         if not _fmincore_chunk_size:
758             _set_fmincore_chunk_size()
759         pages_per_chunk = _fmincore_chunk_size / sc_page_size;
760         page_count = (st.st_size + sc_page_size - 1) / sc_page_size;
761         chunk_count = page_count / _fmincore_chunk_size
762         if chunk_count < 1:
763             chunk_count = 1
764         result = bytearray(page_count)
765         for ci in xrange(chunk_count):
766             pos = _fmincore_chunk_size * ci;
767             msize = min(_fmincore_chunk_size, st.st_size - pos)
768             try:
769                 m = mmap.mmap(fd, msize, mmap.MAP_PRIVATE, 0, 0, pos)
770             except mmap.error as ex:
771                 if ex.errno == errno.EINVAL or ex.errno == errno.ENODEV:
772                     # Perhaps the file was a pipe, i.e. "... | bup split ..."
773                     return None
774                 raise ex
775             _mincore(m, msize, 0, result, ci * pages_per_chunk);
776         return result
777
778
779 def parse_timestamp(epoch_str):
780     """Return the number of nanoseconds since the epoch that are described
781 by epoch_str (100ms, 100ns, ...); when epoch_str cannot be parsed,
782 throw a ValueError that may contain additional information."""
783     ns_per = {'s' :  1000000000,
784               'ms' : 1000000,
785               'us' : 1000,
786               'ns' : 1}
787     match = re.match(r'^((?:[-+]?[0-9]+)?)(s|ms|us|ns)$', epoch_str)
788     if not match:
789         if re.match(r'^([-+]?[0-9]+)$', epoch_str):
790             raise ValueError('must include units, i.e. 100ns, 100ms, ...')
791         raise ValueError()
792     (n, units) = match.group(1, 2)
793     if not n:
794         n = 1
795     n = int(n)
796     return n * ns_per[units]
797
798
799 def parse_num(s):
800     """Parse data size information into a float number.
801
802     Here are some examples of conversions:
803         199.2k means 203981 bytes
804         1GB means 1073741824 bytes
805         2.1 tb means 2199023255552 bytes
806     """
807     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
808     if not g:
809         raise ValueError("can't parse %r as a number" % s)
810     (val, unit) = g.groups()
811     num = float(val)
812     unit = unit.lower()
813     if unit in ['t', 'tb']:
814         mult = 1024*1024*1024*1024
815     elif unit in ['g', 'gb']:
816         mult = 1024*1024*1024
817     elif unit in ['m', 'mb']:
818         mult = 1024*1024
819     elif unit in ['k', 'kb']:
820         mult = 1024
821     elif unit in ['', 'b']:
822         mult = 1
823     else:
824         raise ValueError("invalid unit %r in number %r" % (unit, s))
825     return int(num*mult)
826
827
828 def count(l):
829     """Count the number of elements in an iterator. (consumes the iterator)"""
830     return reduce(lambda x,y: x+1, l)
831
832
833 saved_errors = []
834 def add_error(e):
835     """Append an error message to the list of saved errors.
836
837     Once processing is able to stop and output the errors, the saved errors are
838     accessible in the module variable helpers.saved_errors.
839     """
840     saved_errors.append(e)
841     log('%-70s\n' % e)
842
843
844 def clear_errors():
845     global saved_errors
846     saved_errors = []
847
848
849 def handle_ctrl_c():
850     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
851
852     The new exception handler will make sure that bup will exit without an ugly
853     stacktrace when Ctrl-C is hit.
854     """
855     oldhook = sys.excepthook
856     def newhook(exctype, value, traceback):
857         if exctype == KeyboardInterrupt:
858             log('\nInterrupted.\n')
859         else:
860             return oldhook(exctype, value, traceback)
861     sys.excepthook = newhook
862
863
864 def columnate(l, prefix):
865     """Format elements of 'l' in columns with 'prefix' leading each line.
866
867     The number of columns is determined automatically based on the string
868     lengths.
869     """
870     if not l:
871         return ""
872     l = l[:]
873     clen = max(len(s) for s in l)
874     ncols = (tty_width() - len(prefix)) / (clen + 2)
875     if ncols <= 1:
876         ncols = 1
877         clen = 0
878     cols = []
879     while len(l) % ncols:
880         l.append('')
881     rows = len(l)/ncols
882     for s in range(0, len(l), rows):
883         cols.append(l[s:s+rows])
884     out = ''
885     for row in zip(*cols):
886         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
887     return out
888
889
890 def parse_date_or_fatal(str, fatal):
891     """Parses the given date or calls Option.fatal().
892     For now we expect a string that contains a float."""
893     try:
894         date = float(str)
895     except ValueError as e:
896         raise fatal('invalid date format (should be a float): %r' % e)
897     else:
898         return date
899
900
901 def parse_excludes(options, fatal):
902     """Traverse the options and extract all excludes, or call Option.fatal()."""
903     excluded_paths = []
904
905     for flag in options:
906         (option, parameter) = flag
907         if option == '--exclude':
908             excluded_paths.append(resolve_parent(parameter))
909         elif option == '--exclude-from':
910             try:
911                 f = open(resolve_parent(parameter))
912             except IOError as e:
913                 raise fatal("couldn't read %s" % parameter)
914             for exclude_path in f.readlines():
915                 # FIXME: perhaps this should be rstrip('\n')
916                 exclude_path = resolve_parent(exclude_path.strip())
917                 if exclude_path:
918                     excluded_paths.append(exclude_path)
919     return sorted(frozenset(excluded_paths))
920
921
922 def parse_rx_excludes(options, fatal):
923     """Traverse the options and extract all rx excludes, or call
924     Option.fatal()."""
925     excluded_patterns = []
926
927     for flag in options:
928         (option, parameter) = flag
929         if option == '--exclude-rx':
930             try:
931                 excluded_patterns.append(re.compile(parameter))
932             except re.error as ex:
933                 fatal('invalid --exclude-rx pattern (%s): %s' % (parameter, ex))
934         elif option == '--exclude-rx-from':
935             try:
936                 f = open(resolve_parent(parameter))
937             except IOError as e:
938                 raise fatal("couldn't read %s" % parameter)
939             for pattern in f.readlines():
940                 spattern = pattern.rstrip('\n')
941                 if not spattern:
942                     continue
943                 try:
944                     excluded_patterns.append(re.compile(spattern))
945                 except re.error as ex:
946                     fatal('invalid --exclude-rx pattern (%s): %s' % (spattern, ex))
947     return excluded_patterns
948
949
950 def should_rx_exclude_path(path, exclude_rxs):
951     """Return True if path matches a regular expression in exclude_rxs."""
952     for rx in exclude_rxs:
953         if rx.search(path):
954             debug1('Skipping %r: excluded by rx pattern %r.\n'
955                    % (path, rx.pattern))
956             return True
957     return False
958
959
960 # FIXME: Carefully consider the use of functions (os.path.*, etc.)
961 # that resolve against the current filesystem in the strip/graft
962 # functions for example, but elsewhere as well.  I suspect bup's not
963 # always being careful about that.  For some cases, the contents of
964 # the current filesystem should be irrelevant, and consulting it might
965 # produce the wrong result, perhaps via unintended symlink resolution,
966 # for example.
967
968 def path_components(path):
969     """Break path into a list of pairs of the form (name,
970     full_path_to_name).  Path must start with '/'.
971     Example:
972       '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
973     if not path.startswith('/'):
974         raise Exception, 'path must start with "/": %s' % path
975     # Since we assume path startswith('/'), we can skip the first element.
976     result = [('', '/')]
977     norm_path = os.path.abspath(path)
978     if norm_path == '/':
979         return result
980     full_path = ''
981     for p in norm_path.split('/')[1:]:
982         full_path += '/' + p
983         result.append((p, full_path))
984     return result
985
986
987 def stripped_path_components(path, strip_prefixes):
988     """Strip any prefix in strip_prefixes from path and return a list
989     of path components where each component is (name,
990     none_or_full_fs_path_to_name).  Assume path startswith('/').
991     See thelpers.py for examples."""
992     normalized_path = os.path.abspath(path)
993     sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
994     for bp in sorted_strip_prefixes:
995         normalized_bp = os.path.abspath(bp)
996         if normalized_bp == '/':
997             continue
998         if normalized_path.startswith(normalized_bp):
999             prefix = normalized_path[:len(normalized_bp)]
1000             result = []
1001             for p in normalized_path[len(normalized_bp):].split('/'):
1002                 if p: # not root
1003                     prefix += '/'
1004                 prefix += p
1005                 result.append((p, prefix))
1006             return result
1007     # Nothing to strip.
1008     return path_components(path)
1009
1010
1011 def grafted_path_components(graft_points, path):
1012     # Create a result that consists of some number of faked graft
1013     # directories before the graft point, followed by all of the real
1014     # directories from path that are after the graft point.  Arrange
1015     # for the directory at the graft point in the result to correspond
1016     # to the "orig" directory in --graft orig=new.  See t/thelpers.py
1017     # for some examples.
1018
1019     # Note that given --graft orig=new, orig and new have *nothing* to
1020     # do with each other, even if some of their component names
1021     # match. i.e. --graft /foo/bar/baz=/foo/bar/bax is semantically
1022     # equivalent to --graft /foo/bar/baz=/x/y/z, or even
1023     # /foo/bar/baz=/x.
1024
1025     # FIXME: This can't be the best solution...
1026     clean_path = os.path.abspath(path)
1027     for graft_point in graft_points:
1028         old_prefix, new_prefix = graft_point
1029         # Expand prefixes iff not absolute paths.
1030         old_prefix = os.path.normpath(old_prefix)
1031         new_prefix = os.path.normpath(new_prefix)
1032         if clean_path.startswith(old_prefix):
1033             escaped_prefix = re.escape(old_prefix)
1034             grafted_path = re.sub(r'^' + escaped_prefix, new_prefix, clean_path)
1035             # Handle /foo=/ (at least) -- which produces //whatever.
1036             grafted_path = '/' + grafted_path.lstrip('/')
1037             clean_path_components = path_components(clean_path)
1038             # Count the components that were stripped.
1039             strip_count = 0 if old_prefix == '/' else old_prefix.count('/')
1040             new_prefix_parts = new_prefix.split('/')
1041             result_prefix = grafted_path.split('/')[:new_prefix.count('/')]
1042             result = [(p, None) for p in result_prefix] \
1043                 + clean_path_components[strip_count:]
1044             # Now set the graft point name to match the end of new_prefix.
1045             graft_point = len(result_prefix)
1046             result[graft_point] = \
1047                 (new_prefix_parts[-1], clean_path_components[strip_count][1])
1048             if new_prefix == '/': # --graft ...=/ is a special case.
1049                 return result[1:]
1050             return result
1051     return path_components(clean_path)
1052
1053
1054 Sha1 = hashlib.sha1
1055
1056
1057 _localtime = getattr(_helpers, 'localtime', None)
1058
1059 if _localtime:
1060     bup_time = namedtuple('bup_time', ['tm_year', 'tm_mon', 'tm_mday',
1061                                        'tm_hour', 'tm_min', 'tm_sec',
1062                                        'tm_wday', 'tm_yday',
1063                                        'tm_isdst', 'tm_gmtoff', 'tm_zone'])
1064
1065 # Define a localtime() that returns bup_time when possible.  Note:
1066 # this means that any helpers.localtime() results may need to be
1067 # passed through to_py_time() before being passed to python's time
1068 # module, which doesn't appear willing to ignore the extra items.
1069 if _localtime:
1070     def localtime(time):
1071         return bup_time(*_helpers.localtime(time))
1072     def utc_offset_str(t):
1073         """Return the local offset from UTC as "+hhmm" or "-hhmm" for time t.
1074         If the current UTC offset does not represent an integer number
1075         of minutes, the fractional component will be truncated."""
1076         off = localtime(t).tm_gmtoff
1077         # Note: // doesn't truncate like C for negative values, it rounds down.
1078         offmin = abs(off) // 60
1079         m = offmin % 60
1080         h = (offmin - m) // 60
1081         return "%+03d%02d" % (-h if off < 0 else h, m)
1082     def to_py_time(x):
1083         if isinstance(x, time.struct_time):
1084             return x
1085         return time.struct_time(x[:9])
1086 else:
1087     localtime = time.localtime
1088     def utc_offset_str(t):
1089         return time.strftime('%z', localtime(t))
1090     def to_py_time(x):
1091         return x
1092
1093
1094 _some_invalid_save_parts_rx = re.compile(r'[[ ~^:?*\\]|\.\.|//|@{')
1095
1096 def valid_save_name(name):
1097     # Enforce a superset of the restrictions in git-check-ref-format(1)
1098     if name == '@' \
1099        or name.startswith('/') or name.endswith('/') \
1100        or name.endswith('.'):
1101         return False
1102     if _some_invalid_save_parts_rx.search(name):
1103         return False
1104     for c in name:
1105         if ord(c) < 0x20 or ord(c) == 0x7f:
1106             return False
1107     for part in name.split('/'):
1108         if part.startswith('.') or part.endswith('.lock'):
1109             return False
1110     return True