]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
84b1b978682f375c6bbeaeaa4b7fbde53380f51e
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 from __future__ import absolute_import, division
4 from collections import namedtuple
5 from contextlib import contextmanager
6 from ctypes import sizeof, c_void_p
7 from math import floor
8 from os import environ
9 from pipes import quote
10 from subprocess import PIPE, Popen
11 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
12 import hashlib, heapq, math, operator, time, grp, tempfile
13
14 from bup import _helpers
15 from bup import compat
16 from bup.compat import byte_int
17 from bup.io import path_msg
18 # This function should really be in helpers, not in bup.options.  But we
19 # want options.py to be standalone so people can include it in other projects.
20 from bup.options import _tty_width as tty_width
21
22
23 class Nonlocal:
24     """Helper to deal with Python scoping issues"""
25     pass
26
27
28 sc_page_size = os.sysconf('SC_PAGE_SIZE')
29 assert(sc_page_size > 0)
30
31 sc_arg_max = os.sysconf('SC_ARG_MAX')
32 if sc_arg_max == -1:  # "no definite limit" - let's choose 2M
33     sc_arg_max = 2 * 1024 * 1024
34
35 def last(iterable):
36     result = None
37     for result in iterable:
38         pass
39     return result
40
41
42 def atoi(s):
43     """Convert s (ascii bytes) to an integer. Return 0 if s is not a number."""
44     try:
45         return int(s or b'0')
46     except ValueError:
47         return 0
48
49
50 def atof(s):
51     """Convert s (ascii bytes) to a float. Return 0 if s is not a number."""
52     try:
53         return float(s or b'0')
54     except ValueError:
55         return 0
56
57
58 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
59
60
61 try:
62     _fdatasync = os.fdatasync
63 except AttributeError:
64     _fdatasync = os.fsync
65
66 if sys.platform.startswith('darwin'):
67     # Apparently os.fsync on OS X doesn't guarantee to sync all the way down
68     import fcntl
69     def fdatasync(fd):
70         try:
71             return fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
72         except IOError as e:
73             # Fallback for file systems (SMB) that do not support F_FULLFSYNC
74             if e.errno == errno.ENOTSUP:
75                 return _fdatasync(fd)
76             else:
77                 raise
78 else:
79     fdatasync = _fdatasync
80
81
82 def partition(predicate, stream):
83     """Returns (leading_matches_it, rest_it), where leading_matches_it
84     must be completely exhausted before traversing rest_it.
85
86     """
87     stream = iter(stream)
88     ns = Nonlocal()
89     ns.first_nonmatch = None
90     def leading_matches():
91         for x in stream:
92             if predicate(x):
93                 yield x
94             else:
95                 ns.first_nonmatch = (x,)
96                 break
97     def rest():
98         if ns.first_nonmatch:
99             yield ns.first_nonmatch[0]
100             for x in stream:
101                 yield x
102     return (leading_matches(), rest())
103
104
105 def merge_dict(*xs):
106     result = {}
107     for x in xs:
108         result.update(x)
109     return result
110
111
112 def lines_until_sentinel(f, sentinel, ex_type):
113     # sentinel must end with \n and must contain only one \n
114     while True:
115         line = f.readline()
116         if not (line and line.endswith('\n')):
117             raise ex_type('Hit EOF while reading line')
118         if line == sentinel:
119             return
120         yield line
121
122
123 def stat_if_exists(path):
124     try:
125         return os.stat(path)
126     except OSError as e:
127         if e.errno != errno.ENOENT:
128             raise
129     return None
130
131
132 # Write (blockingly) to sockets that may or may not be in blocking mode.
133 # We need this because our stderr is sometimes eaten by subprocesses
134 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
135 # leading to race conditions.  Ick.  We'll do it the hard way.
136 def _hard_write(fd, buf):
137     while buf:
138         (r,w,x) = select.select([], [fd], [], None)
139         if not w:
140             raise IOError('select(fd) returned without being writable')
141         try:
142             sz = os.write(fd, buf)
143         except OSError as e:
144             if e.errno != errno.EAGAIN:
145                 raise
146         assert(sz >= 0)
147         buf = buf[sz:]
148
149
150 _last_prog = 0
151 def log(s):
152     """Print a log message to stderr."""
153     global _last_prog
154     sys.stdout.flush()
155     _hard_write(sys.stderr.fileno(), s if isinstance(s, bytes) else s.encode())
156     _last_prog = 0
157
158
159 def debug1(s):
160     if buglvl >= 1:
161         log(s)
162
163
164 def debug2(s):
165     if buglvl >= 2:
166         log(s)
167
168
169 istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
170 istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
171 _last_progress = ''
172 def progress(s):
173     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
174     global _last_progress
175     if istty2:
176         log(s)
177         _last_progress = s
178
179
180 def qprogress(s):
181     """Calls progress() only if we haven't printed progress in a while.
182     
183     This avoids overloading the stderr buffer with excess junk.
184     """
185     global _last_prog
186     now = time.time()
187     if now - _last_prog > 0.1:
188         progress(s)
189         _last_prog = now
190
191
192 def reprogress():
193     """Calls progress() to redisplay the most recent progress message.
194
195     Useful after you've printed some other message that wipes out the
196     progress line.
197     """
198     if _last_progress and _last_progress.endswith('\r'):
199         progress(_last_progress)
200
201
202 def mkdirp(d, mode=None):
203     """Recursively create directories on path 'd'.
204
205     Unlike os.makedirs(), it doesn't raise an exception if the last element of
206     the path already exists.
207     """
208     try:
209         if mode:
210             os.makedirs(d, mode)
211         else:
212             os.makedirs(d)
213     except OSError as e:
214         if e.errno == errno.EEXIST:
215             pass
216         else:
217             raise
218
219
220 class MergeIterItem:
221     def __init__(self, entry, read_it):
222         self.entry = entry
223         self.read_it = read_it
224     def __lt__(self, x):
225         return self.entry < x.entry
226
227 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
228     if key:
229         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
230     else:
231         samekey = operator.eq
232     count = 0
233     total = sum(len(it) for it in iters)
234     iters = (iter(it) for it in iters)
235     heap = ((next(it, None),it) for it in iters)
236     heap = [MergeIterItem(e, it) for e, it in heap if e]
237
238     heapq.heapify(heap)
239     pe = None
240     while heap:
241         if not count % pfreq:
242             pfunc(count, total)
243         e, it = heap[0].entry, heap[0].read_it
244         if not samekey(e, pe):
245             pe = e
246             yield e
247         count += 1
248         try:
249             e = next(it)
250         except StopIteration:
251             heapq.heappop(heap) # remove current
252         else:
253             # shift current to new location
254             heapq.heapreplace(heap, MergeIterItem(e, it))
255     pfinal(count, total)
256
257
258 def unlink(f):
259     """Delete a file at path 'f' if it currently exists.
260
261     Unlike os.unlink(), does not throw an exception if the file didn't already
262     exist.
263     """
264     try:
265         os.unlink(f)
266     except OSError as e:
267         if e.errno != errno.ENOENT:
268             raise
269
270
271 def shstr(cmd):
272     if isinstance(cmd, compat.str_type):
273         return cmd
274     else:
275         return ' '.join(map(quote, cmd))
276
277 exc = subprocess.check_call
278
279 def exo(cmd,
280         input=None,
281         stdin=None,
282         stderr=None,
283         shell=False,
284         check=True,
285         preexec_fn=None):
286     if input:
287         assert stdin in (None, PIPE)
288         stdin = PIPE
289     p = Popen(cmd,
290               stdin=stdin, stdout=PIPE, stderr=stderr,
291               shell=shell,
292               preexec_fn=preexec_fn)
293     out, err = p.communicate(input)
294     if check and p.returncode != 0:
295         raise Exception('subprocess %r failed with status %d, stderr: %r'
296                         % (b' '.join(map(quote, cmd)), p.returncode, err))
297     return out, err, p
298
299 def readpipe(argv, preexec_fn=None, shell=False):
300     """Run a subprocess and return its output."""
301     p = subprocess.Popen(argv, stdout=subprocess.PIPE, preexec_fn=preexec_fn,
302                          shell=shell)
303     out, err = p.communicate()
304     if p.returncode != 0:
305         raise Exception('subprocess %r failed with status %d'
306                         % (b' '.join(argv), p.returncode))
307     return out
308
309
310 def _argmax_base(command):
311     base_size = 2048
312     for c in command:
313         base_size += len(command) + 1
314     for k, v in compat.items(environ):
315         base_size += len(k) + len(v) + 2 + sizeof(c_void_p)
316     return base_size
317
318
319 def _argmax_args_size(args):
320     return sum(len(x) + 1 + sizeof(c_void_p) for x in args)
321
322
323 def batchpipe(command, args, preexec_fn=None, arg_max=sc_arg_max):
324     """If args is not empty, yield the output produced by calling the
325 command list with args as a sequence of strings (It may be necessary
326 to return multiple strings in order to respect ARG_MAX)."""
327     # The optional arg_max arg is a workaround for an issue with the
328     # current wvtest behavior.
329     base_size = _argmax_base(command)
330     while args:
331         room = arg_max - base_size
332         i = 0
333         while i < len(args):
334             next_size = _argmax_args_size(args[i:i+1])
335             if room - next_size < 0:
336                 break
337             room -= next_size
338             i += 1
339         sub_args = args[:i]
340         args = args[i:]
341         assert(len(sub_args))
342         yield readpipe(command + sub_args, preexec_fn=preexec_fn)
343
344
345 def resolve_parent(p):
346     """Return the absolute path of a file without following any final symlink.
347
348     Behaves like os.path.realpath, but doesn't follow a symlink for the last
349     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
350     will follow symlinks in p's directory)
351     """
352     try:
353         st = os.lstat(p)
354     except OSError:
355         st = None
356     if st and stat.S_ISLNK(st.st_mode):
357         (dir, name) = os.path.split(p)
358         dir = os.path.realpath(dir)
359         out = os.path.join(dir, name)
360     else:
361         out = os.path.realpath(p)
362     #log('realpathing:%r,%r\n' % (p, out))
363     return out
364
365
366 def detect_fakeroot():
367     "Return True if we appear to be running under fakeroot."
368     return os.getenv("FAKEROOTKEY") != None
369
370
371 if sys.platform.startswith('cygwin'):
372     def is_superuser():
373         # https://cygwin.com/ml/cygwin/2015-02/msg00057.html
374         groups = os.getgroups()
375         return 544 in groups or 0 in groups
376 else:
377     def is_superuser():
378         return os.geteuid() == 0
379
380
381 def cache_key_value(get_value, key, cache):
382     """Return (value, was_cached).  If there is a value in the cache
383     for key, use that, otherwise, call get_value(key) which should
384     throw a KeyError if there is no value -- in which case the cached
385     and returned value will be None.
386     """
387     try: # Do we already have it (or know there wasn't one)?
388         value = cache[key]
389         return value, True
390     except KeyError:
391         pass
392     value = None
393     try:
394         cache[key] = value = get_value(key)
395     except KeyError:
396         cache[key] = None
397     return value, False
398
399
400 _hostname = None
401 def hostname():
402     """Get the FQDN of this machine."""
403     global _hostname
404     if not _hostname:
405         _hostname = socket.getfqdn().encode('iso-8859-1')
406     return _hostname
407
408
409 def format_filesize(size):
410     unit = 1024.0
411     size = float(size)
412     if size < unit:
413         return "%d" % (size)
414     exponent = int(math.log(size) // math.log(unit))
415     size_prefix = "KMGTPE"[exponent - 1]
416     return "%.1f%s" % (size // math.pow(unit, exponent), size_prefix)
417
418
419 class NotOk(Exception):
420     pass
421
422
423 class BaseConn:
424     def __init__(self, outp):
425         self.outp = outp
426
427     def close(self):
428         while self._read(65536): pass
429
430     def read(self, size):
431         """Read 'size' bytes from input stream."""
432         self.outp.flush()
433         return self._read(size)
434
435     def readline(self):
436         """Read from input stream until a newline is found."""
437         self.outp.flush()
438         return self._readline()
439
440     def write(self, data):
441         """Write 'data' to output stream."""
442         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
443         self.outp.write(data)
444
445     def has_input(self):
446         """Return true if input stream is readable."""
447         raise NotImplemented("Subclasses must implement has_input")
448
449     def ok(self):
450         """Indicate end of output from last sent command."""
451         self.write('\nok\n')
452
453     def error(self, s):
454         """Indicate server error to the client."""
455         s = re.sub(r'\s+', ' ', str(s))
456         self.write('\nerror %s\n' % s)
457
458     def _check_ok(self, onempty):
459         self.outp.flush()
460         rl = ''
461         for rl in linereader(self):
462             #log('%d got line: %r\n' % (os.getpid(), rl))
463             if not rl:  # empty line
464                 continue
465             elif rl == 'ok':
466                 return None
467             elif rl.startswith('error '):
468                 #log('client: error: %s\n' % rl[6:])
469                 return NotOk(rl[6:])
470             else:
471                 onempty(rl)
472         raise Exception('server exited unexpectedly; see errors above')
473
474     def drain_and_check_ok(self):
475         """Remove all data for the current command from input stream."""
476         def onempty(rl):
477             pass
478         return self._check_ok(onempty)
479
480     def check_ok(self):
481         """Verify that server action completed successfully."""
482         def onempty(rl):
483             raise Exception('expected "ok", got %r' % rl)
484         return self._check_ok(onempty)
485
486
487 class Conn(BaseConn):
488     def __init__(self, inp, outp):
489         BaseConn.__init__(self, outp)
490         self.inp = inp
491
492     def _read(self, size):
493         return self.inp.read(size)
494
495     def _readline(self):
496         return self.inp.readline()
497
498     def has_input(self):
499         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
500         if rl:
501             assert(rl[0] == self.inp.fileno())
502             return True
503         else:
504             return None
505
506
507 def checked_reader(fd, n):
508     while n > 0:
509         rl, _, _ = select.select([fd], [], [])
510         assert(rl[0] == fd)
511         buf = os.read(fd, n)
512         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
513         yield buf
514         n -= len(buf)
515
516
517 MAX_PACKET = 128 * 1024
518 def mux(p, outfd, outr, errr):
519     try:
520         fds = [outr, errr]
521         while p.poll() is None:
522             rl, _, _ = select.select(fds, [], [])
523             for fd in rl:
524                 if fd == outr:
525                     buf = os.read(outr, MAX_PACKET)
526                     if not buf: break
527                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
528                 elif fd == errr:
529                     buf = os.read(errr, 1024)
530                     if not buf: break
531                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
532     finally:
533         os.write(outfd, struct.pack('!IB', 0, 3))
534
535
536 class DemuxConn(BaseConn):
537     """A helper class for bup's client-server protocol."""
538     def __init__(self, infd, outp):
539         BaseConn.__init__(self, outp)
540         # Anything that comes through before the sync string was not
541         # multiplexed and can be assumed to be debug/log before mux init.
542         tail = ''
543         while tail != 'BUPMUX':
544             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
545             if not b:
546                 raise IOError('demux: unexpected EOF during initialization')
547             tail += b
548             sys.stderr.write(tail[:-6])  # pre-mux log messages
549             tail = tail[-6:]
550         self.infd = infd
551         self.reader = None
552         self.buf = None
553         self.closed = False
554
555     def write(self, data):
556         self._load_buf(0)
557         BaseConn.write(self, data)
558
559     def _next_packet(self, timeout):
560         if self.closed: return False
561         rl, wl, xl = select.select([self.infd], [], [], timeout)
562         if not rl: return False
563         assert(rl[0] == self.infd)
564         ns = ''.join(checked_reader(self.infd, 5))
565         n, fdw = struct.unpack('!IB', ns)
566         assert(n <= MAX_PACKET)
567         if fdw == 1:
568             self.reader = checked_reader(self.infd, n)
569         elif fdw == 2:
570             for buf in checked_reader(self.infd, n):
571                 sys.stderr.write(buf)
572         elif fdw == 3:
573             self.closed = True
574             debug2("DemuxConn: marked closed\n")
575         return True
576
577     def _load_buf(self, timeout):
578         if self.buf is not None:
579             return True
580         while not self.closed:
581             while not self.reader:
582                 if not self._next_packet(timeout):
583                     return False
584             try:
585                 self.buf = next(self.reader)
586                 return True
587             except StopIteration:
588                 self.reader = None
589         return False
590
591     def _read_parts(self, ix_fn):
592         while self._load_buf(None):
593             assert(self.buf is not None)
594             i = ix_fn(self.buf)
595             if i is None or i == len(self.buf):
596                 yv = self.buf
597                 self.buf = None
598             else:
599                 yv = self.buf[:i]
600                 self.buf = self.buf[i:]
601             yield yv
602             if i is not None:
603                 break
604
605     def _readline(self):
606         def find_eol(buf):
607             try:
608                 return buf.index('\n')+1
609             except ValueError:
610                 return None
611         return ''.join(self._read_parts(find_eol))
612
613     def _read(self, size):
614         csize = [size]
615         def until_size(buf): # Closes on csize
616             if len(buf) < csize[0]:
617                 csize[0] -= len(buf)
618                 return None
619             else:
620                 return csize[0]
621         return ''.join(self._read_parts(until_size))
622
623     def has_input(self):
624         return self._load_buf(0)
625
626
627 def linereader(f):
628     """Generate a list of input lines from 'f' without terminating newlines."""
629     while 1:
630         line = f.readline()
631         if not line:
632             break
633         yield line[:-1]
634
635
636 def chunkyreader(f, count = None):
637     """Generate a list of chunks of data read from 'f'.
638
639     If count is None, read until EOF is reached.
640
641     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
642     reached while reading, raise IOError.
643     """
644     if count != None:
645         while count > 0:
646             b = f.read(min(count, 65536))
647             if not b:
648                 raise IOError('EOF with %d bytes remaining' % count)
649             yield b
650             count -= len(b)
651     else:
652         while 1:
653             b = f.read(65536)
654             if not b: break
655             yield b
656
657
658 @contextmanager
659 def atomically_replaced_file(name, mode='w', buffering=-1):
660     """Yield a file that will be atomically renamed name when leaving the block.
661
662     This contextmanager yields an open file object that is backed by a
663     temporary file which will be renamed (atomically) to the target
664     name if everything succeeds.
665
666     The mode and buffering arguments are handled exactly as with open,
667     and the yielded file will have very restrictive permissions, as
668     per mkstemp.
669
670     E.g.::
671
672         with atomically_replaced_file('foo.txt', 'w') as f:
673             f.write('hello jack.')
674
675     """
676
677     (ffd, tempname) = tempfile.mkstemp(dir=os.path.dirname(name),
678                                        text=('b' not in mode))
679     try:
680         try:
681             f = os.fdopen(ffd, mode, buffering)
682         except:
683             os.close(ffd)
684             raise
685         try:
686             yield f
687         finally:
688             f.close()
689         os.rename(tempname, name)
690     finally:
691         unlink(tempname)  # nonexistant file is ignored
692
693
694 def slashappend(s):
695     """Append "/" to 's' if it doesn't aleady end in "/"."""
696     assert isinstance(s, bytes)
697     if s and not s.endswith(b'/'):
698         return s + b'/'
699     else:
700         return s
701
702
703 def _mmap_do(f, sz, flags, prot, close):
704     if not sz:
705         st = os.fstat(f.fileno())
706         sz = st.st_size
707     if not sz:
708         # trying to open a zero-length map gives an error, but an empty
709         # string has all the same behaviour of a zero-length map, ie. it has
710         # no elements :)
711         return ''
712     map = mmap.mmap(f.fileno(), sz, flags, prot)
713     if close:
714         f.close()  # map will persist beyond file close
715     return map
716
717
718 def mmap_read(f, sz = 0, close=True):
719     """Create a read-only memory mapped region on file 'f'.
720     If sz is 0, the region will cover the entire file.
721     """
722     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
723
724
725 def mmap_readwrite(f, sz = 0, close=True):
726     """Create a read-write memory mapped region on file 'f'.
727     If sz is 0, the region will cover the entire file.
728     """
729     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
730                     close)
731
732
733 def mmap_readwrite_private(f, sz = 0, close=True):
734     """Create a read-write memory mapped region on file 'f'.
735     If sz is 0, the region will cover the entire file.
736     The map is private, which means the changes are never flushed back to the
737     file.
738     """
739     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
740                     close)
741
742
743 _mincore = getattr(_helpers, 'mincore', None)
744 if _mincore:
745     # ./configure ensures that we're on Linux if MINCORE_INCORE isn't defined.
746     MINCORE_INCORE = getattr(_helpers, 'MINCORE_INCORE', 1)
747
748     _fmincore_chunk_size = None
749     def _set_fmincore_chunk_size():
750         global _fmincore_chunk_size
751         pref_chunk_size = 64 * 1024 * 1024
752         chunk_size = sc_page_size
753         if (sc_page_size < pref_chunk_size):
754             chunk_size = sc_page_size * (pref_chunk_size // sc_page_size)
755         _fmincore_chunk_size = chunk_size
756
757     def fmincore(fd):
758         """Return the mincore() data for fd as a bytearray whose values can be
759         tested via MINCORE_INCORE, or None if fd does not fully
760         support the operation."""
761         st = os.fstat(fd)
762         if (st.st_size == 0):
763             return bytearray(0)
764         if not _fmincore_chunk_size:
765             _set_fmincore_chunk_size()
766         pages_per_chunk = _fmincore_chunk_size // sc_page_size;
767         page_count = (st.st_size + sc_page_size - 1) // sc_page_size;
768         chunk_count = page_count // _fmincore_chunk_size
769         if chunk_count < 1:
770             chunk_count = 1
771         result = bytearray(page_count)
772         for ci in compat.range(chunk_count):
773             pos = _fmincore_chunk_size * ci;
774             msize = min(_fmincore_chunk_size, st.st_size - pos)
775             try:
776                 m = mmap.mmap(fd, msize, mmap.MAP_PRIVATE, 0, 0, pos)
777             except mmap.error as ex:
778                 if ex.errno == errno.EINVAL or ex.errno == errno.ENODEV:
779                     # Perhaps the file was a pipe, i.e. "... | bup split ..."
780                     return None
781                 raise ex
782             try:
783                 _mincore(m, msize, 0, result, ci * pages_per_chunk)
784             except OSError as ex:
785                 if ex.errno == errno.ENOSYS:
786                     return None
787                 raise
788         return result
789
790
791 def parse_timestamp(epoch_str):
792     """Return the number of nanoseconds since the epoch that are described
793 by epoch_str (100ms, 100ns, ...); when epoch_str cannot be parsed,
794 throw a ValueError that may contain additional information."""
795     ns_per = {'s' :  1000000000,
796               'ms' : 1000000,
797               'us' : 1000,
798               'ns' : 1}
799     match = re.match(r'^((?:[-+]?[0-9]+)?)(s|ms|us|ns)$', epoch_str)
800     if not match:
801         if re.match(r'^([-+]?[0-9]+)$', epoch_str):
802             raise ValueError('must include units, i.e. 100ns, 100ms, ...')
803         raise ValueError()
804     (n, units) = match.group(1, 2)
805     if not n:
806         n = 1
807     n = int(n)
808     return n * ns_per[units]
809
810
811 def parse_num(s):
812     """Parse string or bytes as a possibly unit suffixed number.
813
814     For example:
815         199.2k means 203981 bytes
816         1GB means 1073741824 bytes
817         2.1 tb means 2199023255552 bytes
818     """
819     if isinstance(s, bytes):
820         # FIXME: should this raise a ValueError for UnicodeDecodeError
821         # (perhaps with the latter as the context).
822         s = s.decode('ascii')
823     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
824     if not g:
825         raise ValueError("can't parse %r as a number" % s)
826     (val, unit) = g.groups()
827     num = float(val)
828     unit = unit.lower()
829     if unit in ['t', 'tb']:
830         mult = 1024*1024*1024*1024
831     elif unit in ['g', 'gb']:
832         mult = 1024*1024*1024
833     elif unit in ['m', 'mb']:
834         mult = 1024*1024
835     elif unit in ['k', 'kb']:
836         mult = 1024
837     elif unit in ['', 'b']:
838         mult = 1
839     else:
840         raise ValueError("invalid unit %r in number %r" % (unit, s))
841     return int(num*mult)
842
843
844 def count(l):
845     """Count the number of elements in an iterator. (consumes the iterator)"""
846     return reduce(lambda x,y: x+1, l)
847
848
849 saved_errors = []
850 def add_error(e):
851     """Append an error message to the list of saved errors.
852
853     Once processing is able to stop and output the errors, the saved errors are
854     accessible in the module variable helpers.saved_errors.
855     """
856     saved_errors.append(e)
857     log('%-70s\n' % e)
858
859
860 def clear_errors():
861     global saved_errors
862     saved_errors = []
863
864
865 def die_if_errors(msg=None, status=1):
866     global saved_errors
867     if saved_errors:
868         if not msg:
869             msg = 'warning: %d errors encountered\n' % len(saved_errors)
870         log(msg)
871         sys.exit(status)
872
873
874 def handle_ctrl_c():
875     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
876
877     The new exception handler will make sure that bup will exit without an ugly
878     stacktrace when Ctrl-C is hit.
879     """
880     oldhook = sys.excepthook
881     def newhook(exctype, value, traceback):
882         if exctype == KeyboardInterrupt:
883             log('\nInterrupted.\n')
884         else:
885             return oldhook(exctype, value, traceback)
886     sys.excepthook = newhook
887
888
889 def columnate(l, prefix):
890     """Format elements of 'l' in columns with 'prefix' leading each line.
891
892     The number of columns is determined automatically based on the string
893     lengths.
894     """
895     if not l:
896         return ""
897     l = l[:]
898     clen = max(len(s) for s in l)
899     ncols = (tty_width() - len(prefix)) // (clen + 2)
900     if ncols <= 1:
901         ncols = 1
902         clen = 0
903     cols = []
904     while len(l) % ncols:
905         l.append('')
906     rows = len(l) // ncols
907     for s in compat.range(0, len(l), rows):
908         cols.append(l[s:s+rows])
909     out = ''
910     for row in zip(*cols):
911         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
912     return out
913
914
915 def parse_date_or_fatal(str, fatal):
916     """Parses the given date or calls Option.fatal().
917     For now we expect a string that contains a float."""
918     try:
919         date = float(str)
920     except ValueError as e:
921         raise fatal('invalid date format (should be a float): %r' % e)
922     else:
923         return date
924
925
926 def parse_excludes(options, fatal):
927     """Traverse the options and extract all excludes, or call Option.fatal()."""
928     excluded_paths = []
929
930     for flag in options:
931         (option, parameter) = flag
932         if option == '--exclude':
933             excluded_paths.append(resolve_parent(parameter))
934         elif option == '--exclude-from':
935             try:
936                 f = open(resolve_parent(parameter))
937             except IOError as e:
938                 raise fatal("couldn't read %s" % parameter)
939             for exclude_path in f.readlines():
940                 # FIXME: perhaps this should be rstrip('\n')
941                 exclude_path = resolve_parent(exclude_path.strip())
942                 if exclude_path:
943                     excluded_paths.append(exclude_path)
944     return sorted(frozenset(excluded_paths))
945
946
947 def parse_rx_excludes(options, fatal):
948     """Traverse the options and extract all rx excludes, or call
949     Option.fatal()."""
950     excluded_patterns = []
951
952     for flag in options:
953         (option, parameter) = flag
954         if option == '--exclude-rx':
955             try:
956                 excluded_patterns.append(re.compile(parameter))
957             except re.error as ex:
958                 fatal('invalid --exclude-rx pattern (%s): %s' % (parameter, ex))
959         elif option == '--exclude-rx-from':
960             try:
961                 f = open(resolve_parent(parameter))
962             except IOError as e:
963                 raise fatal("couldn't read %s" % parameter)
964             for pattern in f.readlines():
965                 spattern = pattern.rstrip('\n')
966                 if not spattern:
967                     continue
968                 try:
969                     excluded_patterns.append(re.compile(spattern))
970                 except re.error as ex:
971                     fatal('invalid --exclude-rx pattern (%s): %s' % (spattern, ex))
972     return excluded_patterns
973
974
975 def should_rx_exclude_path(path, exclude_rxs):
976     """Return True if path matches a regular expression in exclude_rxs."""
977     for rx in exclude_rxs:
978         if rx.search(path):
979             debug1('Skipping %r: excluded by rx pattern %r.\n'
980                    % (path, rx.pattern))
981             return True
982     return False
983
984
985 # FIXME: Carefully consider the use of functions (os.path.*, etc.)
986 # that resolve against the current filesystem in the strip/graft
987 # functions for example, but elsewhere as well.  I suspect bup's not
988 # always being careful about that.  For some cases, the contents of
989 # the current filesystem should be irrelevant, and consulting it might
990 # produce the wrong result, perhaps via unintended symlink resolution,
991 # for example.
992
993 def path_components(path):
994     """Break path into a list of pairs of the form (name,
995     full_path_to_name).  Path must start with '/'.
996     Example:
997       '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
998     if not path.startswith(b'/'):
999         raise Exception('path must start with "/": %s' % path_msg(path))
1000     # Since we assume path startswith('/'), we can skip the first element.
1001     result = [(b'', b'/')]
1002     norm_path = os.path.abspath(path)
1003     if norm_path == b'/':
1004         return result
1005     full_path = b''
1006     for p in norm_path.split(b'/')[1:]:
1007         full_path += b'/' + p
1008         result.append((p, full_path))
1009     return result
1010
1011
1012 def stripped_path_components(path, strip_prefixes):
1013     """Strip any prefix in strip_prefixes from path and return a list
1014     of path components where each component is (name,
1015     none_or_full_fs_path_to_name).  Assume path startswith('/').
1016     See thelpers.py for examples."""
1017     normalized_path = os.path.abspath(path)
1018     sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
1019     for bp in sorted_strip_prefixes:
1020         normalized_bp = os.path.abspath(bp)
1021         if normalized_bp == b'/':
1022             continue
1023         if normalized_path.startswith(normalized_bp):
1024             prefix = normalized_path[:len(normalized_bp)]
1025             result = []
1026             for p in normalized_path[len(normalized_bp):].split(b'/'):
1027                 if p: # not root
1028                     prefix += b'/'
1029                 prefix += p
1030                 result.append((p, prefix))
1031             return result
1032     # Nothing to strip.
1033     return path_components(path)
1034
1035
1036 def grafted_path_components(graft_points, path):
1037     # Create a result that consists of some number of faked graft
1038     # directories before the graft point, followed by all of the real
1039     # directories from path that are after the graft point.  Arrange
1040     # for the directory at the graft point in the result to correspond
1041     # to the "orig" directory in --graft orig=new.  See t/thelpers.py
1042     # for some examples.
1043
1044     # Note that given --graft orig=new, orig and new have *nothing* to
1045     # do with each other, even if some of their component names
1046     # match. i.e. --graft /foo/bar/baz=/foo/bar/bax is semantically
1047     # equivalent to --graft /foo/bar/baz=/x/y/z, or even
1048     # /foo/bar/baz=/x.
1049
1050     # FIXME: This can't be the best solution...
1051     clean_path = os.path.abspath(path)
1052     for graft_point in graft_points:
1053         old_prefix, new_prefix = graft_point
1054         # Expand prefixes iff not absolute paths.
1055         old_prefix = os.path.normpath(old_prefix)
1056         new_prefix = os.path.normpath(new_prefix)
1057         if clean_path.startswith(old_prefix):
1058             escaped_prefix = re.escape(old_prefix)
1059             grafted_path = re.sub(br'^' + escaped_prefix, new_prefix, clean_path)
1060             # Handle /foo=/ (at least) -- which produces //whatever.
1061             grafted_path = b'/' + grafted_path.lstrip(b'/')
1062             clean_path_components = path_components(clean_path)
1063             # Count the components that were stripped.
1064             strip_count = 0 if old_prefix == b'/' else old_prefix.count(b'/')
1065             new_prefix_parts = new_prefix.split(b'/')
1066             result_prefix = grafted_path.split(b'/')[:new_prefix.count(b'/')]
1067             result = [(p, None) for p in result_prefix] \
1068                 + clean_path_components[strip_count:]
1069             # Now set the graft point name to match the end of new_prefix.
1070             graft_point = len(result_prefix)
1071             result[graft_point] = \
1072                 (new_prefix_parts[-1], clean_path_components[strip_count][1])
1073             if new_prefix == b'/': # --graft ...=/ is a special case.
1074                 return result[1:]
1075             return result
1076     return path_components(clean_path)
1077
1078
1079 Sha1 = hashlib.sha1
1080
1081
1082 _localtime = getattr(_helpers, 'localtime', None)
1083
1084 if _localtime:
1085     bup_time = namedtuple('bup_time', ['tm_year', 'tm_mon', 'tm_mday',
1086                                        'tm_hour', 'tm_min', 'tm_sec',
1087                                        'tm_wday', 'tm_yday',
1088                                        'tm_isdst', 'tm_gmtoff', 'tm_zone'])
1089
1090 # Define a localtime() that returns bup_time when possible.  Note:
1091 # this means that any helpers.localtime() results may need to be
1092 # passed through to_py_time() before being passed to python's time
1093 # module, which doesn't appear willing to ignore the extra items.
1094 if _localtime:
1095     def localtime(time):
1096         return bup_time(*_helpers.localtime(floor(time)))
1097     def utc_offset_str(t):
1098         """Return the local offset from UTC as "+hhmm" or "-hhmm" for time t.
1099         If the current UTC offset does not represent an integer number
1100         of minutes, the fractional component will be truncated."""
1101         off = localtime(t).tm_gmtoff
1102         # Note: // doesn't truncate like C for negative values, it rounds down.
1103         offmin = abs(off) // 60
1104         m = offmin % 60
1105         h = (offmin - m) // 60
1106         return b'%+03d%02d' % (-h if off < 0 else h, m)
1107     def to_py_time(x):
1108         if isinstance(x, time.struct_time):
1109             return x
1110         return time.struct_time(x[:9])
1111 else:
1112     localtime = time.localtime
1113     def utc_offset_str(t):
1114         return time.strftime(b'%z', localtime(t))
1115     def to_py_time(x):
1116         return x
1117
1118
1119 _some_invalid_save_parts_rx = re.compile(br'[\[ ~^:?*\\]|\.\.|//|@{')
1120
1121 def valid_save_name(name):
1122     # Enforce a superset of the restrictions in git-check-ref-format(1)
1123     if name == b'@' \
1124        or name.startswith(b'/') or name.endswith(b'/') \
1125        or name.endswith(b'.'):
1126         return False
1127     if _some_invalid_save_parts_rx.search(name):
1128         return False
1129     for c in name:
1130         if byte_int(c) < 0x20 or byte_int(c) == 0x7f:
1131             return False
1132     for part in name.split(b'/'):
1133         if part.startswith(b'.') or part.endswith(b'.lock'):
1134             return False
1135     return True
1136
1137
1138 _period_rx = re.compile(r'^([0-9]+)(s|min|h|d|w|m|y)$')
1139
1140 def period_as_secs(s):
1141     if s == 'forever':
1142         return float('inf')
1143     match = _period_rx.match(s)
1144     if not match:
1145         return None
1146     mag = int(match.group(1))
1147     scale = match.group(2)
1148     return mag * {'s': 1,
1149                   'min': 60,
1150                   'h': 60 * 60,
1151                   'd': 60 * 60 * 24,
1152                   'w': 60 * 60 * 24 * 7,
1153                   'm': 60 * 60 * 24 * 31,
1154                   'y': 60 * 60 * 24 * 366}[scale]