]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
atomically_replaced_file: respect umask/sgid/etc. via tmpdir
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 from __future__ import absolute_import, division
4 from collections import namedtuple
5 from contextlib import ExitStack
6 from ctypes import sizeof, c_void_p
7 from math import floor
8 from os import environ
9 from subprocess import PIPE, Popen
10 from tempfile import mkdtemp
11 from shutil import rmtree
12 import sys, os, subprocess, errno, select, mmap, stat, re, struct
13 import hashlib, heapq, math, operator, time
14
15 from bup import _helpers
16 from bup import io
17 from bup.compat import argv_bytes, byte_int, nullcontext, pending_raise
18 from bup.io import byte_stream, path_msg
19 # This function should really be in helpers, not in bup.options.  But we
20 # want options.py to be standalone so people can include it in other projects.
21 from bup.options import _tty_width as tty_width
22
23
24 buglvl = int(os.environ.get('BUP_DEBUG', 0))
25
26
27 class Nonlocal:
28     """Helper to deal with Python scoping issues"""
29     pass
30
31
32 def nullcontext_if_not(manager):
33     return manager if manager is not None else nullcontext()
34
35
36 class finalized:
37     def __init__(self, enter_result=None, finalize=None):
38         assert finalize
39         self.finalize = finalize
40         self.enter_result = enter_result
41     def __enter__(self):
42         return self.enter_result
43     def __exit__(self, exc_type, exc_value, traceback):
44         self.finalize(self.enter_result)
45
46 def temp_dir(*args, **kwargs):
47     # This is preferable to tempfile.TemporaryDirectory because the
48     # latter uses @contextmanager, and so will always eventually be
49     # deleted if it's handed to an ExitStack, whenever the stack is
50     # gc'ed, even if you pop_all() (the new stack will also trigger
51     # the deletion) because
52     # https://github.com/python/cpython/issues/88458
53     return finalized(mkdtemp(*args, **kwargs), lambda x: rmtree(x))
54
55 sc_page_size = os.sysconf('SC_PAGE_SIZE')
56 assert(sc_page_size > 0)
57
58 sc_arg_max = os.sysconf('SC_ARG_MAX')
59 if sc_arg_max == -1:  # "no definite limit" - let's choose 2M
60     sc_arg_max = 2 * 1024 * 1024
61
62 def last(iterable):
63     result = None
64     for result in iterable:
65         pass
66     return result
67
68 try:
69     _fdatasync = os.fdatasync
70 except AttributeError:
71     _fdatasync = os.fsync
72
73 if sys.platform.startswith('darwin'):
74     # Apparently os.fsync on OS X doesn't guarantee to sync all the way down
75     import fcntl
76     def fdatasync(fd):
77         try:
78             return fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
79         except IOError as e:
80             # Fallback for file systems (SMB) that do not support F_FULLFSYNC
81             if e.errno == errno.ENOTSUP:
82                 return _fdatasync(fd)
83             else:
84                 raise
85 else:
86     fdatasync = _fdatasync
87
88
89 def partition(predicate, stream):
90     """Returns (leading_matches_it, rest_it), where leading_matches_it
91     must be completely exhausted before traversing rest_it.
92
93     """
94     stream = iter(stream)
95     ns = Nonlocal()
96     ns.first_nonmatch = None
97     def leading_matches():
98         for x in stream:
99             if predicate(x):
100                 yield x
101             else:
102                 ns.first_nonmatch = (x,)
103                 break
104     def rest():
105         if ns.first_nonmatch:
106             yield ns.first_nonmatch[0]
107             for x in stream:
108                 yield x
109     return (leading_matches(), rest())
110
111
112 def merge_dict(*xs):
113     result = {}
114     for x in xs:
115         result.update(x)
116     return result
117
118
119 def lines_until_sentinel(f, sentinel, ex_type):
120     # sentinel must end with \n and must contain only one \n
121     while True:
122         line = f.readline()
123         if not (line and line.endswith(b'\n')):
124             raise ex_type('Hit EOF while reading line')
125         if line == sentinel:
126             return
127         yield line
128
129
130 def stat_if_exists(path):
131     try:
132         return os.stat(path)
133     except OSError as e:
134         if e.errno != errno.ENOENT:
135             raise
136     return None
137
138
139 # Write (blockingly) to sockets that may or may not be in blocking mode.
140 # We need this because our stderr is sometimes eaten by subprocesses
141 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
142 # leading to race conditions.  Ick.  We'll do it the hard way.
143 def _hard_write(fd, buf):
144     while buf:
145         (r,w,x) = select.select([], [fd], [], None)
146         if not w:
147             raise IOError('select(fd) returned without being writable')
148         try:
149             sz = os.write(fd, buf)
150         except OSError as e:
151             if e.errno != errno.EAGAIN:
152                 raise
153         assert(sz >= 0)
154         buf = buf[sz:]
155
156
157 _last_prog = 0
158 def log(s):
159     """Print a log message to stderr."""
160     global _last_prog
161     sys.stdout.flush()
162     _hard_write(sys.stderr.fileno(), s if isinstance(s, bytes) else s.encode())
163     _last_prog = 0
164
165
166 def debug1(s):
167     if buglvl >= 1:
168         log(s)
169
170
171 def debug2(s):
172     if buglvl >= 2:
173         log(s)
174
175
176 istty1 = os.isatty(1) or (int(os.environ.get('BUP_FORCE_TTY', 0)) & 1)
177 istty2 = os.isatty(2) or (int(os.environ.get('BUP_FORCE_TTY', 0)) & 2)
178 _last_progress = ''
179 def progress(s):
180     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
181     global _last_progress
182     if istty2:
183         log(s)
184         _last_progress = s
185
186
187 def qprogress(s):
188     """Calls progress() only if we haven't printed progress in a while.
189
190     This avoids overloading the stderr buffer with excess junk.
191     """
192     global _last_prog
193     now = time.time()
194     if now - _last_prog > 0.1:
195         progress(s)
196         _last_prog = now
197
198
199 def reprogress():
200     """Calls progress() to redisplay the most recent progress message.
201
202     Useful after you've printed some other message that wipes out the
203     progress line.
204     """
205     if _last_progress and _last_progress.endswith('\r'):
206         progress(_last_progress)
207
208
209 def mkdirp(d, mode=None):
210     """Recursively create directories on path 'd'.
211
212     Unlike os.makedirs(), it doesn't raise an exception if the last element of
213     the path already exists.
214     """
215     try:
216         if mode:
217             os.makedirs(d, mode)
218         else:
219             os.makedirs(d)
220     except OSError as e:
221         if e.errno == errno.EEXIST:
222             pass
223         else:
224             raise
225
226
227 class MergeIterItem:
228     def __init__(self, entry, read_it):
229         self.entry = entry
230         self.read_it = read_it
231     def __lt__(self, x):
232         return self.entry < x.entry
233
234 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
235     if key:
236         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
237     else:
238         samekey = operator.eq
239     count = 0
240     total = sum(len(it) for it in iters)
241     iters = (iter(it) for it in iters)
242     heap = ((next(it, None),it) for it in iters)
243     heap = [MergeIterItem(e, it) for e, it in heap if e]
244
245     heapq.heapify(heap)
246     pe = None
247     while heap:
248         if not count % pfreq:
249             pfunc(count, total)
250         e, it = heap[0].entry, heap[0].read_it
251         if not samekey(e, pe):
252             pe = e
253             yield e
254         count += 1
255         try:
256             e = next(it)
257         except StopIteration:
258             heapq.heappop(heap) # remove current
259         else:
260             # shift current to new location
261             heapq.heapreplace(heap, MergeIterItem(e, it))
262     pfinal(count, total)
263
264
265 def unlink(f):
266     """Delete a file at path 'f' if it currently exists.
267
268     Unlike os.unlink(), does not throw an exception if the file didn't already
269     exist.
270     """
271     try:
272         os.unlink(f)
273     except OSError as e:
274         if e.errno != errno.ENOENT:
275             raise
276
277
278 _bq_simple_id_rx = re.compile(br'^[-_./a-zA-Z0-9]+$')
279 _sq_simple_id_rx = re.compile(r'^[-_./a-zA-Z0-9]+$')
280
281 def bquote(x):
282     if x == b'':
283         return b"''"
284     if _bq_simple_id_rx.match(x):
285         return x
286     return b"'%s'" % x.replace(b"'", b"'\"'\"'")
287
288 def squote(x):
289     if x == '':
290         return "''"
291     if _sq_simple_id_rx.match(x):
292         return x
293     return "'%s'" % x.replace("'", "'\"'\"'")
294
295 def quote(x):
296     if isinstance(x, bytes):
297         return bquote(x)
298     if isinstance(x, str):
299         return squote(x)
300     assert False
301     # some versions of pylint get confused
302     return None
303
304 def shstr(cmd):
305     """Return a shell quoted string for cmd if it's a sequence, else cmd.
306
307     cmd must be a string, bytes, or a sequence of one or the other,
308     and the assumption is that if cmd is a string or bytes, then it's
309     already quoted (because it's what's actually being passed to
310     call() and friends.  e.g. log(shstr(cmd)); call(cmd)
311
312     """
313     if isinstance(cmd, (bytes, str)):
314         return cmd
315     elif all(isinstance(x, bytes) for x in cmd):
316         return b' '.join(map(bquote, cmd))
317     elif all(isinstance(x, str) for x in cmd):
318         return ' '.join(map(squote, cmd))
319     raise TypeError('unsupported shstr argument: ' + repr(cmd))
320
321
322 exc = subprocess.check_call
323
324 def exo(cmd,
325         input=None,
326         stdin=None,
327         stderr=None,
328         shell=False,
329         check=True,
330         preexec_fn=None,
331         close_fds=True):
332     if input:
333         assert stdin in (None, PIPE)
334         stdin = PIPE
335     p = Popen(cmd,
336               stdin=stdin, stdout=PIPE, stderr=stderr,
337               shell=shell,
338               preexec_fn=preexec_fn,
339               close_fds=close_fds)
340     out, err = p.communicate(input)
341     if check and p.returncode != 0:
342         raise Exception('subprocess %r failed with status %d%s'
343                         % (b' '.join(map(quote, cmd)), p.returncode,
344                            ', stderr: %r' % err if err else ''))
345     return out, err, p
346
347 def readpipe(argv, preexec_fn=None, shell=False):
348     """Run a subprocess and return its output."""
349     return exo(argv, preexec_fn=preexec_fn, shell=shell)[0]
350
351
352 def _argmax_base(command):
353     base_size = 2048
354     for c in command:
355         base_size += len(command) + 1
356     for k, v in environ.items():
357         base_size += len(k) + len(v) + 2 + sizeof(c_void_p)
358     return base_size
359
360
361 def _argmax_args_size(args):
362     return sum(len(x) + 1 + sizeof(c_void_p) for x in args)
363
364
365 def batchpipe(command, args, preexec_fn=None, arg_max=sc_arg_max):
366     """If args is not empty, yield the output produced by calling the
367 command list with args as a sequence of strings (It may be necessary
368 to return multiple strings in order to respect ARG_MAX)."""
369     # The optional arg_max arg is a workaround for an issue with the
370     # current wvtest behavior.
371     base_size = _argmax_base(command)
372     while args:
373         room = arg_max - base_size
374         i = 0
375         while i < len(args):
376             next_size = _argmax_args_size(args[i:i+1])
377             if room - next_size < 0:
378                 break
379             room -= next_size
380             i += 1
381         sub_args = args[:i]
382         args = args[i:]
383         assert(len(sub_args))
384         yield readpipe(command + sub_args, preexec_fn=preexec_fn)
385
386
387 def resolve_parent(p):
388     """Return the absolute path of a file without following any final symlink.
389
390     Behaves like os.path.realpath, but doesn't follow a symlink for the last
391     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
392     will follow symlinks in p's directory)
393     """
394     try:
395         st = os.lstat(p)
396     except OSError:
397         st = None
398     if st and stat.S_ISLNK(st.st_mode):
399         (dir, name) = os.path.split(p)
400         dir = os.path.realpath(dir)
401         out = os.path.join(dir, name)
402     else:
403         out = os.path.realpath(p)
404     #log('realpathing:%r,%r\n' % (p, out))
405     return out
406
407
408 def detect_fakeroot():
409     "Return True if we appear to be running under fakeroot."
410     return os.getenv("FAKEROOTKEY") != None
411
412
413 if sys.platform.startswith('cygwin'):
414     def is_superuser():
415         # https://cygwin.com/ml/cygwin/2015-02/msg00057.html
416         groups = os.getgroups()
417         return 544 in groups or 0 in groups
418 else:
419     def is_superuser():
420         return os.geteuid() == 0
421
422
423 def cache_key_value(get_value, key, cache):
424     """Return (value, was_cached).  If there is a value in the cache
425     for key, use that, otherwise, call get_value(key) which should
426     throw a KeyError if there is no value -- in which case the cached
427     and returned value will be None.
428     """
429     try: # Do we already have it (or know there wasn't one)?
430         value = cache[key]
431         return value, True
432     except KeyError:
433         pass
434     value = None
435     try:
436         cache[key] = value = get_value(key)
437     except KeyError:
438         cache[key] = None
439     return value, False
440
441
442 _hostname = None
443 def hostname():
444     """Get the FQDN of this machine."""
445     global _hostname
446     if not _hostname:
447         _hostname = _helpers.gethostname()
448     return _hostname
449
450
451 def format_filesize(size):
452     unit = 1024.0
453     size = float(size)
454     if size < unit:
455         return "%d" % (size)
456     exponent = int(math.log(size) // math.log(unit))
457     size_prefix = "KMGTPE"[exponent - 1]
458     return "%.1f%s" % (size / math.pow(unit, exponent), size_prefix)
459
460
461 class NotOk(Exception):
462     pass
463
464
465 class BaseConn:
466     def __init__(self, outp):
467         self._base_closed = False
468         self.outp = outp
469
470     def close(self):
471         self._base_closed = True
472
473     def __enter__(self):
474         return self
475
476     def __exit__(self, exc_type, exc_value, tb):
477         with pending_raise(exc_value, rethrow=False):
478             self.close()
479
480     def __del__(self):
481         assert self._base_closed
482
483     def _read(self, size):
484         raise NotImplementedError("Subclasses must implement _read")
485
486     def read(self, size):
487         """Read 'size' bytes from input stream."""
488         self.outp.flush()
489         return self._read(size)
490
491     def _readline(self, size):
492         raise NotImplementedError("Subclasses must implement _readline")
493
494     def readline(self):
495         """Read from input stream until a newline is found."""
496         self.outp.flush()
497         return self._readline()
498
499     def write(self, data):
500         """Write 'data' to output stream."""
501         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
502         self.outp.write(data)
503
504     def has_input(self):
505         """Return true if input stream is readable."""
506         raise NotImplementedError("Subclasses must implement has_input")
507
508     def ok(self):
509         """Indicate end of output from last sent command."""
510         self.write(b'\nok\n')
511
512     def error(self, s):
513         """Indicate server error to the client."""
514         s = re.sub(br'\s+', b' ', s)
515         self.write(b'\nerror %s\n' % s)
516
517     def _check_ok(self, onempty):
518         self.outp.flush()
519         rl = b''
520         for rl in linereader(self):
521             #log('%d got line: %r\n' % (os.getpid(), rl))
522             if not rl:  # empty line
523                 continue
524             elif rl == b'ok':
525                 return None
526             elif rl.startswith(b'error '):
527                 #log('client: error: %s\n' % rl[6:])
528                 return NotOk(rl[6:])
529             else:
530                 onempty(rl)
531         raise Exception('server exited unexpectedly; see errors above')
532
533     def drain_and_check_ok(self):
534         """Remove all data for the current command from input stream."""
535         def onempty(rl):
536             pass
537         return self._check_ok(onempty)
538
539     def check_ok(self):
540         """Verify that server action completed successfully."""
541         def onempty(rl):
542             raise Exception('expected "ok", got %r' % rl)
543         return self._check_ok(onempty)
544
545
546 class Conn(BaseConn):
547     def __init__(self, inp, outp):
548         BaseConn.__init__(self, outp)
549         self.inp = inp
550
551     def _read(self, size):
552         return self.inp.read(size)
553
554     def _readline(self):
555         return self.inp.readline()
556
557     def has_input(self):
558         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
559         if rl:
560             assert(rl[0] == self.inp.fileno())
561             return True
562         else:
563             return None
564
565
566 def checked_reader(fd, n):
567     while n > 0:
568         rl, _, _ = select.select([fd], [], [])
569         assert(rl[0] == fd)
570         buf = os.read(fd, n)
571         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
572         yield buf
573         n -= len(buf)
574
575
576 MAX_PACKET = 128 * 1024
577 def mux(p, outfd, outr, errr):
578     try:
579         fds = [outr, errr]
580         while p.poll() is None:
581             rl, _, _ = select.select(fds, [], [])
582             for fd in rl:
583                 if fd == outr:
584                     buf = os.read(outr, MAX_PACKET)
585                     if not buf: break
586                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
587                 elif fd == errr:
588                     buf = os.read(errr, 1024)
589                     if not buf: break
590                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
591     finally:
592         os.write(outfd, struct.pack('!IB', 0, 3))
593
594
595 class DemuxConn(BaseConn):
596     """A helper class for bup's client-server protocol."""
597     def __init__(self, infd, outp):
598         BaseConn.__init__(self, outp)
599         # Anything that comes through before the sync string was not
600         # multiplexed and can be assumed to be debug/log before mux init.
601         tail = b''
602         stderr = byte_stream(sys.stderr)
603         while tail != b'BUPMUX':
604             # Make sure to write all pre-BUPMUX output to stderr
605             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
606             if not b:
607                 ex = IOError('demux: unexpected EOF during initialization')
608                 with pending_raise(ex):
609                     stderr.write(tail)
610                     stderr.flush()
611             tail += b
612             stderr.write(tail[:-6])
613             tail = tail[-6:]
614         stderr.flush()
615         self.infd = infd
616         self.reader = None
617         self.buf = None
618         self.closed = False
619
620     def write(self, data):
621         self._load_buf(0)
622         BaseConn.write(self, data)
623
624     def _next_packet(self, timeout):
625         if self.closed: return False
626         rl, wl, xl = select.select([self.infd], [], [], timeout)
627         if not rl: return False
628         assert(rl[0] == self.infd)
629         ns = b''.join(checked_reader(self.infd, 5))
630         n, fdw = struct.unpack('!IB', ns)
631         if n > MAX_PACKET:
632             # assume that something went wrong and print stuff
633             ns += os.read(self.infd, 1024)
634             stderr = byte_stream(sys.stderr)
635             stderr.write(ns)
636             stderr.flush()
637             raise Exception("Connection broken")
638         if fdw == 1:
639             self.reader = checked_reader(self.infd, n)
640         elif fdw == 2:
641             for buf in checked_reader(self.infd, n):
642                 byte_stream(sys.stderr).write(buf)
643         elif fdw == 3:
644             self.closed = True
645             debug2("DemuxConn: marked closed\n")
646         return True
647
648     def _load_buf(self, timeout):
649         if self.buf is not None:
650             return True
651         while not self.closed:
652             while not self.reader:
653                 if not self._next_packet(timeout):
654                     return False
655             try:
656                 self.buf = next(self.reader)
657                 return True
658             except StopIteration:
659                 self.reader = None
660         return False
661
662     def _read_parts(self, ix_fn):
663         while self._load_buf(None):
664             assert(self.buf is not None)
665             i = ix_fn(self.buf)
666             if i is None or i == len(self.buf):
667                 yv = self.buf
668                 self.buf = None
669             else:
670                 yv = self.buf[:i]
671                 self.buf = self.buf[i:]
672             yield yv
673             if i is not None:
674                 break
675
676     def _readline(self):
677         def find_eol(buf):
678             try:
679                 return buf.index(b'\n')+1
680             except ValueError:
681                 return None
682         return b''.join(self._read_parts(find_eol))
683
684     def _read(self, size):
685         csize = [size]
686         def until_size(buf): # Closes on csize
687             if len(buf) < csize[0]:
688                 csize[0] -= len(buf)
689                 return None
690             else:
691                 return csize[0]
692         return b''.join(self._read_parts(until_size))
693
694     def has_input(self):
695         return self._load_buf(0)
696
697
698 def linereader(f):
699     """Generate a list of input lines from 'f' without terminating newlines."""
700     while 1:
701         line = f.readline()
702         if not line:
703             break
704         yield line[:-1]
705
706
707 def chunkyreader(f, count = None):
708     """Generate a list of chunks of data read from 'f'.
709
710     If count is None, read until EOF is reached.
711
712     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
713     reached while reading, raise IOError.
714     """
715     if count != None:
716         while count > 0:
717             b = f.read(min(count, 65536))
718             if not b:
719                 raise IOError('EOF with %d bytes remaining' % count)
720             yield b
721             count -= len(b)
722     else:
723         while 1:
724             b = f.read(65536)
725             if not b: break
726             yield b
727
728
729 class atomically_replaced_file:
730     def __init__(self, path, mode='w', buffering=-1):
731         """Return a context manager supporting the atomic replacement of a file.
732
733         The context manager yields an open file object that has been
734         created in a mkdtemp-style temporary directory in the same
735         directory as the path.  The temporary file will be renamed to
736         the target path (atomically if the platform allows it) if
737         there are no exceptions, and the temporary directory will
738         always be removed.  Calling cancel() will prevent the
739         replacement.
740
741         The file object will have a name attribute containing the
742         file's path, and the mode and buffering arguments will be
743         handled exactly as with open().  The resulting permissions
744         will also match those produced by open().
745
746         E.g.::
747
748           with atomically_replaced_file('foo.txt', 'w') as f:
749               f.write('hello jack.')
750
751         """
752         assert 'w' in mode
753         self.path = path
754         self.mode = mode
755         self.buffering = buffering
756         self.canceled = False
757         self.tmp_path = None
758         self.cleanup = ExitStack()
759     def __enter__(self):
760         with self.cleanup:
761             parent, name = os.path.split(self.path)
762             tmpdir = self.cleanup.enter_context(temp_dir(dir=parent,
763                                                          prefix=name + b'-'))
764             self.tmp_path = tmpdir + b'/pending'
765             f = open(self.tmp_path, mode=self.mode, buffering=self.buffering)
766             f = self.cleanup.enter_context(f)
767             self.cleanup = self.cleanup.pop_all()
768             return f
769     def __exit__(self, exc_type, exc_value, traceback):
770         with self.cleanup:
771             if not (self.canceled or exc_type):
772                 os.rename(self.tmp_path, self.path)
773     def cancel(self):
774         self.canceled = True
775
776
777 def slashappend(s):
778     """Append "/" to 's' if it doesn't aleady end in "/"."""
779     assert isinstance(s, bytes)
780     if s and not s.endswith(b'/'):
781         return s + b'/'
782     else:
783         return s
784
785
786 def _mmap_do(f, sz, flags, prot, close):
787     if not sz:
788         st = os.fstat(f.fileno())
789         sz = st.st_size
790     if not sz:
791         # trying to open a zero-length map gives an error, but an empty
792         # string has all the same behaviour of a zero-length map, ie. it has
793         # no elements :)
794         return ''
795     map = io.mmap(f.fileno(), sz, flags, prot)
796     if close:
797         f.close()  # map will persist beyond file close
798     return map
799
800
801 def mmap_read(f, sz = 0, close=True):
802     """Create a read-only memory mapped region on file 'f'.
803     If sz is 0, the region will cover the entire file.
804     """
805     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
806
807
808 def mmap_readwrite(f, sz = 0, close=True):
809     """Create a read-write memory mapped region on file 'f'.
810     If sz is 0, the region will cover the entire file.
811     """
812     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
813                     close)
814
815
816 def mmap_readwrite_private(f, sz = 0, close=True):
817     """Create a read-write memory mapped region on file 'f'.
818     If sz is 0, the region will cover the entire file.
819     The map is private, which means the changes are never flushed back to the
820     file.
821     """
822     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
823                     close)
824
825
826 _mincore = getattr(_helpers, 'mincore', None)
827 if _mincore:
828     # ./configure ensures that we're on Linux if MINCORE_INCORE isn't defined.
829     MINCORE_INCORE = getattr(_helpers, 'MINCORE_INCORE', 1)
830
831     _fmincore_chunk_size = None
832     def _set_fmincore_chunk_size():
833         global _fmincore_chunk_size
834         pref_chunk_size = 64 * 1024 * 1024
835         chunk_size = sc_page_size
836         if (sc_page_size < pref_chunk_size):
837             chunk_size = sc_page_size * (pref_chunk_size // sc_page_size)
838         _fmincore_chunk_size = chunk_size
839
840     def fmincore(fd):
841         """Return the mincore() data for fd as a bytearray whose values can be
842         tested via MINCORE_INCORE, or None if fd does not fully
843         support the operation."""
844         st = os.fstat(fd)
845         if (st.st_size == 0):
846             return bytearray(0)
847         if not _fmincore_chunk_size:
848             _set_fmincore_chunk_size()
849         pages_per_chunk = _fmincore_chunk_size // sc_page_size;
850         page_count = (st.st_size + sc_page_size - 1) // sc_page_size;
851         chunk_count = (st.st_size + _fmincore_chunk_size - 1) // _fmincore_chunk_size
852         result = bytearray(page_count)
853         for ci in range(chunk_count):
854             pos = _fmincore_chunk_size * ci;
855             msize = min(_fmincore_chunk_size, st.st_size - pos)
856             try:
857                 m = io.mmap(fd, msize, mmap.MAP_PRIVATE, 0, 0, pos)
858             except mmap.error as ex:
859                 if ex.errno == errno.EINVAL or ex.errno == errno.ENODEV:
860                     # Perhaps the file was a pipe, i.e. "... | bup split ..."
861                     return None
862                 raise ex
863             with m:
864                 try:
865                     _mincore(m, msize, 0, result, ci * pages_per_chunk)
866                 except OSError as ex:
867                     if ex.errno == errno.ENOSYS:
868                         return None
869                     raise
870         return result
871
872
873 def parse_timestamp(epoch_str):
874     """Return the number of nanoseconds since the epoch that are described
875 by epoch_str (100ms, 100ns, ...); when epoch_str cannot be parsed,
876 throw a ValueError that may contain additional information."""
877     ns_per = {'s' :  1000000000,
878               'ms' : 1000000,
879               'us' : 1000,
880               'ns' : 1}
881     match = re.match(r'^((?:[-+]?[0-9]+)?)(s|ms|us|ns)$', epoch_str)
882     if not match:
883         if re.match(r'^([-+]?[0-9]+)$', epoch_str):
884             raise ValueError('must include units, i.e. 100ns, 100ms, ...')
885         raise ValueError()
886     (n, units) = match.group(1, 2)
887     if not n:
888         n = 1
889     n = int(n)
890     return n * ns_per[units]
891
892
893 def parse_num(s):
894     """Parse string or bytes as a possibly unit suffixed number.
895
896     For example:
897         199.2k means 203981 bytes
898         1GB means 1073741824 bytes
899         2.1 tb means 2199023255552 bytes
900     """
901     if isinstance(s, bytes):
902         # FIXME: should this raise a ValueError for UnicodeDecodeError
903         # (perhaps with the latter as the context).
904         s = s.decode('ascii')
905     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
906     if not g:
907         raise ValueError("can't parse %r as a number" % s)
908     (val, unit) = g.groups()
909     num = float(val)
910     unit = unit.lower()
911     if unit in ['t', 'tb']:
912         mult = 1024*1024*1024*1024
913     elif unit in ['g', 'gb']:
914         mult = 1024*1024*1024
915     elif unit in ['m', 'mb']:
916         mult = 1024*1024
917     elif unit in ['k', 'kb']:
918         mult = 1024
919     elif unit in ['', 'b']:
920         mult = 1
921     else:
922         raise ValueError("invalid unit %r in number %r" % (unit, s))
923     return int(num*mult)
924
925
926 saved_errors = []
927 def add_error(e):
928     """Append an error message to the list of saved errors.
929
930     Once processing is able to stop and output the errors, the saved errors are
931     accessible in the module variable helpers.saved_errors.
932     """
933     saved_errors.append(e)
934     log('%-70s\n' % e)
935
936
937 def clear_errors():
938     global saved_errors
939     saved_errors = []
940
941
942 def die_if_errors(msg=None, status=1):
943     global saved_errors
944     if saved_errors:
945         if not msg:
946             msg = 'warning: %d errors encountered\n' % len(saved_errors)
947         log(msg)
948         sys.exit(status)
949
950
951 def handle_ctrl_c():
952     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
953
954     The new exception handler will make sure that bup will exit without an ugly
955     stacktrace when Ctrl-C is hit.
956     """
957     oldhook = sys.excepthook
958     def newhook(exctype, value, traceback):
959         if exctype == KeyboardInterrupt:
960             log('\nInterrupted.\n')
961         else:
962             oldhook(exctype, value, traceback)
963     sys.excepthook = newhook
964
965
966 def columnate(l, prefix):
967     """Format elements of 'l' in columns with 'prefix' leading each line.
968
969     The number of columns is determined automatically based on the string
970     lengths.
971     """
972     binary = isinstance(prefix, bytes)
973     nothing = b'' if binary else ''
974     nl = b'\n' if binary else '\n'
975     if not l:
976         return nothing
977     l = l[:]
978     clen = max(len(s) for s in l)
979     ncols = (tty_width() - len(prefix)) // (clen + 2)
980     if ncols <= 1:
981         ncols = 1
982         clen = 0
983     cols = []
984     while len(l) % ncols:
985         l.append(nothing)
986     rows = len(l) // ncols
987     for s in range(0, len(l), rows):
988         cols.append(l[s:s+rows])
989     out = nothing
990     fmt = b'%-*s' if binary else '%-*s'
991     for row in zip(*cols):
992         out += prefix + nothing.join((fmt % (clen+2, s)) for s in row) + nl
993     return out
994
995
996 def parse_date_or_fatal(str, fatal):
997     """Parses the given date or calls Option.fatal().
998     For now we expect a string that contains a float."""
999     try:
1000         date = float(str)
1001     except ValueError as e:
1002         raise fatal('invalid date format (should be a float): %r' % e)
1003     else:
1004         return date
1005
1006
1007 def parse_excludes(options, fatal):
1008     """Traverse the options and extract all excludes, or call Option.fatal()."""
1009     excluded_paths = []
1010
1011     for flag in options:
1012         (option, parameter) = flag
1013         if option == '--exclude':
1014             excluded_paths.append(resolve_parent(argv_bytes(parameter)))
1015         elif option == '--exclude-from':
1016             try:
1017                 f = open(resolve_parent(argv_bytes(parameter)), 'rb')
1018             except IOError as e:
1019                 raise fatal("couldn't read %r" % parameter)
1020             for exclude_path in f.readlines():
1021                 # FIXME: perhaps this should be rstrip('\n')
1022                 exclude_path = resolve_parent(exclude_path.strip())
1023                 if exclude_path:
1024                     excluded_paths.append(exclude_path)
1025     return sorted(frozenset(excluded_paths))
1026
1027
1028 def parse_rx_excludes(options, fatal):
1029     """Traverse the options and extract all rx excludes, or call
1030     Option.fatal()."""
1031     excluded_patterns = []
1032
1033     for flag in options:
1034         (option, parameter) = flag
1035         if option == '--exclude-rx':
1036             try:
1037                 excluded_patterns.append(re.compile(argv_bytes(parameter)))
1038             except re.error as ex:
1039                 fatal('invalid --exclude-rx pattern (%r): %s' % (parameter, ex))
1040         elif option == '--exclude-rx-from':
1041             try:
1042                 f = open(resolve_parent(parameter), 'rb')
1043             except IOError as e:
1044                 raise fatal("couldn't read %r" % parameter)
1045             for pattern in f.readlines():
1046                 spattern = pattern.rstrip(b'\n')
1047                 if not spattern:
1048                     continue
1049                 try:
1050                     excluded_patterns.append(re.compile(spattern))
1051                 except re.error as ex:
1052                     fatal('invalid --exclude-rx pattern (%r): %s' % (spattern, ex))
1053     return excluded_patterns
1054
1055
1056 def should_rx_exclude_path(path, exclude_rxs):
1057     """Return True if path matches a regular expression in exclude_rxs."""
1058     for rx in exclude_rxs:
1059         if rx.search(path):
1060             debug1('Skipping %r: excluded by rx pattern %r.\n'
1061                    % (path, rx.pattern))
1062             return True
1063     return False
1064
1065
1066 # FIXME: Carefully consider the use of functions (os.path.*, etc.)
1067 # that resolve against the current filesystem in the strip/graft
1068 # functions for example, but elsewhere as well.  I suspect bup's not
1069 # always being careful about that.  For some cases, the contents of
1070 # the current filesystem should be irrelevant, and consulting it might
1071 # produce the wrong result, perhaps via unintended symlink resolution,
1072 # for example.
1073
1074 def path_components(path):
1075     """Break path into a list of pairs of the form (name,
1076     full_path_to_name).  Path must start with '/'.
1077     Example:
1078       '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
1079     if not path.startswith(b'/'):
1080         raise Exception('path must start with "/": %s' % path_msg(path))
1081     # Since we assume path startswith('/'), we can skip the first element.
1082     result = [(b'', b'/')]
1083     norm_path = os.path.abspath(path)
1084     if norm_path == b'/':
1085         return result
1086     full_path = b''
1087     for p in norm_path.split(b'/')[1:]:
1088         full_path += b'/' + p
1089         result.append((p, full_path))
1090     return result
1091
1092
1093 def stripped_path_components(path, strip_prefixes):
1094     """Strip any prefix in strip_prefixes from path and return a list
1095     of path components where each component is (name,
1096     none_or_full_fs_path_to_name).  Assume path startswith('/').
1097     See thelpers.py for examples."""
1098     normalized_path = os.path.abspath(path)
1099     sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
1100     for bp in sorted_strip_prefixes:
1101         normalized_bp = os.path.abspath(bp)
1102         if normalized_bp == b'/':
1103             continue
1104         if normalized_path.startswith(normalized_bp):
1105             prefix = normalized_path[:len(normalized_bp)]
1106             result = []
1107             for p in normalized_path[len(normalized_bp):].split(b'/'):
1108                 if p: # not root
1109                     prefix += b'/'
1110                 prefix += p
1111                 result.append((p, prefix))
1112             return result
1113     # Nothing to strip.
1114     return path_components(path)
1115
1116
1117 def grafted_path_components(graft_points, path):
1118     # Create a result that consists of some number of faked graft
1119     # directories before the graft point, followed by all of the real
1120     # directories from path that are after the graft point.  Arrange
1121     # for the directory at the graft point in the result to correspond
1122     # to the "orig" directory in --graft orig=new.  See t/thelpers.py
1123     # for some examples.
1124
1125     # Note that given --graft orig=new, orig and new have *nothing* to
1126     # do with each other, even if some of their component names
1127     # match. i.e. --graft /foo/bar/baz=/foo/bar/bax is semantically
1128     # equivalent to --graft /foo/bar/baz=/x/y/z, or even
1129     # /foo/bar/baz=/x.
1130
1131     # FIXME: This can't be the best solution...
1132     clean_path = os.path.abspath(path)
1133     for graft_point in graft_points:
1134         old_prefix, new_prefix = graft_point
1135         # Expand prefixes iff not absolute paths.
1136         old_prefix = os.path.normpath(old_prefix)
1137         new_prefix = os.path.normpath(new_prefix)
1138         if clean_path.startswith(old_prefix):
1139             escaped_prefix = re.escape(old_prefix)
1140             grafted_path = re.sub(br'^' + escaped_prefix, new_prefix, clean_path)
1141             # Handle /foo=/ (at least) -- which produces //whatever.
1142             grafted_path = b'/' + grafted_path.lstrip(b'/')
1143             clean_path_components = path_components(clean_path)
1144             # Count the components that were stripped.
1145             strip_count = 0 if old_prefix == b'/' else old_prefix.count(b'/')
1146             new_prefix_parts = new_prefix.split(b'/')
1147             result_prefix = grafted_path.split(b'/')[:new_prefix.count(b'/')]
1148             result = [(p, None) for p in result_prefix] \
1149                 + clean_path_components[strip_count:]
1150             # Now set the graft point name to match the end of new_prefix.
1151             graft_point = len(result_prefix)
1152             result[graft_point] = \
1153                 (new_prefix_parts[-1], clean_path_components[strip_count][1])
1154             if new_prefix == b'/': # --graft ...=/ is a special case.
1155                 return result[1:]
1156             return result
1157     return path_components(clean_path)
1158
1159
1160 Sha1 = hashlib.sha1
1161
1162
1163 _localtime = getattr(_helpers, 'localtime', None)
1164
1165 if _localtime:
1166     bup_time = namedtuple('bup_time', ['tm_year', 'tm_mon', 'tm_mday',
1167                                        'tm_hour', 'tm_min', 'tm_sec',
1168                                        'tm_wday', 'tm_yday',
1169                                        'tm_isdst', 'tm_gmtoff', 'tm_zone'])
1170
1171 # Define a localtime() that returns bup_time when possible.  Note:
1172 # this means that any helpers.localtime() results may need to be
1173 # passed through to_py_time() before being passed to python's time
1174 # module, which doesn't appear willing to ignore the extra items.
1175 if _localtime:
1176     def localtime(time):
1177         return bup_time(*_helpers.localtime(int(floor(time))))
1178     def utc_offset_str(t):
1179         """Return the local offset from UTC as "+hhmm" or "-hhmm" for time t.
1180         If the current UTC offset does not represent an integer number
1181         of minutes, the fractional component will be truncated."""
1182         off = localtime(t).tm_gmtoff
1183         # Note: // doesn't truncate like C for negative values, it rounds down.
1184         offmin = abs(off) // 60
1185         m = offmin % 60
1186         h = (offmin - m) // 60
1187         return b'%+03d%02d' % (-h if off < 0 else h, m)
1188     def to_py_time(x):
1189         if isinstance(x, time.struct_time):
1190             return x
1191         return time.struct_time(x[:9])
1192 else:
1193     localtime = time.localtime
1194     def utc_offset_str(t):
1195         return time.strftime(b'%z', localtime(t))
1196     def to_py_time(x):
1197         return x
1198
1199
1200 _some_invalid_save_parts_rx = re.compile(br'[\[ ~^:?*\\]|\.\.|//|@{')
1201
1202 def valid_save_name(name):
1203     # Enforce a superset of the restrictions in git-check-ref-format(1)
1204     if name == b'@' \
1205        or name.startswith(b'/') or name.endswith(b'/') \
1206        or name.endswith(b'.'):
1207         return False
1208     if _some_invalid_save_parts_rx.search(name):
1209         return False
1210     for c in name:
1211         if byte_int(c) < 0x20 or byte_int(c) == 0x7f:
1212             return False
1213     for part in name.split(b'/'):
1214         if part.startswith(b'.') or part.endswith(b'.lock'):
1215             return False
1216     return True
1217
1218
1219 _period_rx = re.compile(br'^([0-9]+)(s|min|h|d|w|m|y)$')
1220
1221 def period_as_secs(s):
1222     if s == b'forever':
1223         return float('inf')
1224     match = _period_rx.match(s)
1225     if not match:
1226         return None
1227     mag = int(match.group(1))
1228     scale = match.group(2)
1229     return mag * {b's': 1,
1230                   b'min': 60,
1231                   b'h': 60 * 60,
1232                   b'd': 60 * 60 * 24,
1233                   b'w': 60 * 60 * 24 * 7,
1234                   b'm': 60 * 60 * 24 * 31,
1235                   b'y': 60 * 60 * 24 * 366}[scale]