]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Avoid partial writes of config/config.h.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator, time, platform
5 from bup import _version, _helpers
6 import bup._helpers as _helpers
7
8 # This function should really be in helpers, not in bup.options.  But we
9 # want options.py to be standalone so people can include it in other projects.
10 from bup.options import _tty_width
11 tty_width = _tty_width
12
13
14 def atoi(s):
15     """Convert the string 's' to an integer. Return 0 if s is not a number."""
16     try:
17         return int(s or '0')
18     except ValueError:
19         return 0
20
21
22 def atof(s):
23     """Convert the string 's' to a float. Return 0 if s is not a number."""
24     try:
25         return float(s or '0')
26     except ValueError:
27         return 0
28
29
30 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
31
32
33 # Write (blockingly) to sockets that may or may not be in blocking mode.
34 # We need this because our stderr is sometimes eaten by subprocesses
35 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
36 # leading to race conditions.  Ick.  We'll do it the hard way.
37 def _hard_write(fd, buf):
38     while buf:
39         (r,w,x) = select.select([], [fd], [], None)
40         if not w:
41             raise IOError('select(fd) returned without being writable')
42         try:
43             sz = os.write(fd, buf)
44         except OSError, e:
45             if e.errno != errno.EAGAIN:
46                 raise
47         assert(sz >= 0)
48         buf = buf[sz:]
49
50
51 _last_prog = 0
52 def log(s):
53     """Print a log message to stderr."""
54     global _last_prog
55     sys.stdout.flush()
56     _hard_write(sys.stderr.fileno(), s)
57     _last_prog = 0
58
59
60 def debug1(s):
61     if buglvl >= 1:
62         log(s)
63
64
65 def debug2(s):
66     if buglvl >= 2:
67         log(s)
68
69
70 istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
71 istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
72 _last_progress = ''
73 def progress(s):
74     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
75     global _last_progress
76     if istty2:
77         log(s)
78         _last_progress = s
79
80
81 def qprogress(s):
82     """Calls progress() only if we haven't printed progress in a while.
83     
84     This avoids overloading the stderr buffer with excess junk.
85     """
86     global _last_prog
87     now = time.time()
88     if now - _last_prog > 0.1:
89         progress(s)
90         _last_prog = now
91
92
93 def reprogress():
94     """Calls progress() to redisplay the most recent progress message.
95
96     Useful after you've printed some other message that wipes out the
97     progress line.
98     """
99     if _last_progress and _last_progress.endswith('\r'):
100         progress(_last_progress)
101
102
103 def mkdirp(d, mode=None):
104     """Recursively create directories on path 'd'.
105
106     Unlike os.makedirs(), it doesn't raise an exception if the last element of
107     the path already exists.
108     """
109     try:
110         if mode:
111             os.makedirs(d, mode)
112         else:
113             os.makedirs(d)
114     except OSError, e:
115         if e.errno == errno.EEXIST:
116             pass
117         else:
118             raise
119
120
121 def next(it):
122     """Get the next item from an iterator, None if we reached the end."""
123     try:
124         return it.next()
125     except StopIteration:
126         return None
127
128
129 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
130     if key:
131         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
132     else:
133         samekey = operator.eq
134     count = 0
135     total = sum(len(it) for it in iters)
136     iters = (iter(it) for it in iters)
137     heap = ((next(it),it) for it in iters)
138     heap = [(e,it) for e,it in heap if e]
139
140     heapq.heapify(heap)
141     pe = None
142     while heap:
143         if not count % pfreq:
144             pfunc(count, total)
145         e, it = heap[0]
146         if not samekey(e, pe):
147             pe = e
148             yield e
149         count += 1
150         try:
151             e = it.next() # Don't use next() function, it's too expensive
152         except StopIteration:
153             heapq.heappop(heap) # remove current
154         else:
155             heapq.heapreplace(heap, (e, it)) # shift current to new location
156     pfinal(count, total)
157
158
159 def unlink(f):
160     """Delete a file at path 'f' if it currently exists.
161
162     Unlike os.unlink(), does not throw an exception if the file didn't already
163     exist.
164     """
165     try:
166         os.unlink(f)
167     except OSError, e:
168         if e.errno == errno.ENOENT:
169             pass  # it doesn't exist, that's what you asked for
170
171
172 def readpipe(argv):
173     """Run a subprocess and return its output."""
174     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
175     r = p.stdout.read()
176     p.wait()
177     return r
178
179
180 def realpath(p):
181     """Get the absolute path of a file.
182
183     Behaves like os.path.realpath, but doesn't follow a symlink for the last
184     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
185     will follow symlinks in p's directory)
186     """
187     try:
188         st = os.lstat(p)
189     except OSError:
190         st = None
191     if st and stat.S_ISLNK(st.st_mode):
192         (dir, name) = os.path.split(p)
193         dir = os.path.realpath(dir)
194         out = os.path.join(dir, name)
195     else:
196         out = os.path.realpath(p)
197     #log('realpathing:%r,%r\n' % (p, out))
198     return out
199
200
201 def detect_fakeroot():
202     "Return True if we appear to be running under fakeroot."
203     return os.getenv("FAKEROOTKEY") != None
204
205
206 def is_superuser():
207     if platform.system().startswith('CYGWIN'):
208         import ctypes
209         return ctypes.cdll.shell32.IsUserAnAdmin()
210     else:
211         return os.geteuid() == 0
212
213
214 _username = None
215 def username():
216     """Get the user's login name."""
217     global _username
218     if not _username:
219         uid = os.getuid()
220         try:
221             _username = pwd.getpwuid(uid)[0]
222         except KeyError:
223             _username = 'user%d' % uid
224     return _username
225
226
227 _userfullname = None
228 def userfullname():
229     """Get the user's full name."""
230     global _userfullname
231     if not _userfullname:
232         uid = os.getuid()
233         try:
234             entry = pwd.getpwuid(uid)
235             _userfullname = entry[4].split(',')[0] or entry[0]
236         except KeyError:
237             pass
238         finally:
239             if not _userfullname:
240               _userfullname = 'user%d' % uid
241     return _userfullname
242
243
244 _hostname = None
245 def hostname():
246     """Get the FQDN of this machine."""
247     global _hostname
248     if not _hostname:
249         _hostname = socket.getfqdn()
250     return _hostname
251
252
253 _resource_path = None
254 def resource_path(subdir=''):
255     global _resource_path
256     if not _resource_path:
257         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
258     return os.path.join(_resource_path, subdir)
259
260
261 class NotOk(Exception):
262     pass
263
264
265 class BaseConn:
266     def __init__(self, outp):
267         self.outp = outp
268
269     def close(self):
270         while self._read(65536): pass
271
272     def read(self, size):
273         """Read 'size' bytes from input stream."""
274         self.outp.flush()
275         return self._read(size)
276
277     def readline(self):
278         """Read from input stream until a newline is found."""
279         self.outp.flush()
280         return self._readline()
281
282     def write(self, data):
283         """Write 'data' to output stream."""
284         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
285         self.outp.write(data)
286
287     def has_input(self):
288         """Return true if input stream is readable."""
289         raise NotImplemented("Subclasses must implement has_input")
290
291     def ok(self):
292         """Indicate end of output from last sent command."""
293         self.write('\nok\n')
294
295     def error(self, s):
296         """Indicate server error to the client."""
297         s = re.sub(r'\s+', ' ', str(s))
298         self.write('\nerror %s\n' % s)
299
300     def _check_ok(self, onempty):
301         self.outp.flush()
302         rl = ''
303         for rl in linereader(self):
304             #log('%d got line: %r\n' % (os.getpid(), rl))
305             if not rl:  # empty line
306                 continue
307             elif rl == 'ok':
308                 return None
309             elif rl.startswith('error '):
310                 #log('client: error: %s\n' % rl[6:])
311                 return NotOk(rl[6:])
312             else:
313                 onempty(rl)
314         raise Exception('server exited unexpectedly; see errors above')
315
316     def drain_and_check_ok(self):
317         """Remove all data for the current command from input stream."""
318         def onempty(rl):
319             pass
320         return self._check_ok(onempty)
321
322     def check_ok(self):
323         """Verify that server action completed successfully."""
324         def onempty(rl):
325             raise Exception('expected "ok", got %r' % rl)
326         return self._check_ok(onempty)
327
328
329 class Conn(BaseConn):
330     def __init__(self, inp, outp):
331         BaseConn.__init__(self, outp)
332         self.inp = inp
333
334     def _read(self, size):
335         return self.inp.read(size)
336
337     def _readline(self):
338         return self.inp.readline()
339
340     def has_input(self):
341         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
342         if rl:
343             assert(rl[0] == self.inp.fileno())
344             return True
345         else:
346             return None
347
348
349 def checked_reader(fd, n):
350     while n > 0:
351         rl, _, _ = select.select([fd], [], [])
352         assert(rl[0] == fd)
353         buf = os.read(fd, n)
354         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
355         yield buf
356         n -= len(buf)
357
358
359 MAX_PACKET = 128 * 1024
360 def mux(p, outfd, outr, errr):
361     try:
362         fds = [outr, errr]
363         while p.poll() is None:
364             rl, _, _ = select.select(fds, [], [])
365             for fd in rl:
366                 if fd == outr:
367                     buf = os.read(outr, MAX_PACKET)
368                     if not buf: break
369                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
370                 elif fd == errr:
371                     buf = os.read(errr, 1024)
372                     if not buf: break
373                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
374     finally:
375         os.write(outfd, struct.pack('!IB', 0, 3))
376
377
378 class DemuxConn(BaseConn):
379     """A helper class for bup's client-server protocol."""
380     def __init__(self, infd, outp):
381         BaseConn.__init__(self, outp)
382         # Anything that comes through before the sync string was not
383         # multiplexed and can be assumed to be debug/log before mux init.
384         tail = ''
385         while tail != 'BUPMUX':
386             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
387             if not b:
388                 raise IOError('demux: unexpected EOF during initialization')
389             tail += b
390             sys.stderr.write(tail[:-6])  # pre-mux log messages
391             tail = tail[-6:]
392         self.infd = infd
393         self.reader = None
394         self.buf = None
395         self.closed = False
396
397     def write(self, data):
398         self._load_buf(0)
399         BaseConn.write(self, data)
400
401     def _next_packet(self, timeout):
402         if self.closed: return False
403         rl, wl, xl = select.select([self.infd], [], [], timeout)
404         if not rl: return False
405         assert(rl[0] == self.infd)
406         ns = ''.join(checked_reader(self.infd, 5))
407         n, fdw = struct.unpack('!IB', ns)
408         assert(n <= MAX_PACKET)
409         if fdw == 1:
410             self.reader = checked_reader(self.infd, n)
411         elif fdw == 2:
412             for buf in checked_reader(self.infd, n):
413                 sys.stderr.write(buf)
414         elif fdw == 3:
415             self.closed = True
416             debug2("DemuxConn: marked closed\n")
417         return True
418
419     def _load_buf(self, timeout):
420         if self.buf is not None:
421             return True
422         while not self.closed:
423             while not self.reader:
424                 if not self._next_packet(timeout):
425                     return False
426             try:
427                 self.buf = self.reader.next()
428                 return True
429             except StopIteration:
430                 self.reader = None
431         return False
432
433     def _read_parts(self, ix_fn):
434         while self._load_buf(None):
435             assert(self.buf is not None)
436             i = ix_fn(self.buf)
437             if i is None or i == len(self.buf):
438                 yv = self.buf
439                 self.buf = None
440             else:
441                 yv = self.buf[:i]
442                 self.buf = self.buf[i:]
443             yield yv
444             if i is not None:
445                 break
446
447     def _readline(self):
448         def find_eol(buf):
449             try:
450                 return buf.index('\n')+1
451             except ValueError:
452                 return None
453         return ''.join(self._read_parts(find_eol))
454
455     def _read(self, size):
456         csize = [size]
457         def until_size(buf): # Closes on csize
458             if len(buf) < csize[0]:
459                 csize[0] -= len(buf)
460                 return None
461             else:
462                 return csize[0]
463         return ''.join(self._read_parts(until_size))
464
465     def has_input(self):
466         return self._load_buf(0)
467
468
469 def linereader(f):
470     """Generate a list of input lines from 'f' without terminating newlines."""
471     while 1:
472         line = f.readline()
473         if not line:
474             break
475         yield line[:-1]
476
477
478 def chunkyreader(f, count = None):
479     """Generate a list of chunks of data read from 'f'.
480
481     If count is None, read until EOF is reached.
482
483     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
484     reached while reading, raise IOError.
485     """
486     if count != None:
487         while count > 0:
488             b = f.read(min(count, 65536))
489             if not b:
490                 raise IOError('EOF with %d bytes remaining' % count)
491             yield b
492             count -= len(b)
493     else:
494         while 1:
495             b = f.read(65536)
496             if not b: break
497             yield b
498
499
500 def slashappend(s):
501     """Append "/" to 's' if it doesn't aleady end in "/"."""
502     if s and not s.endswith('/'):
503         return s + '/'
504     else:
505         return s
506
507
508 def _mmap_do(f, sz, flags, prot, close):
509     if not sz:
510         st = os.fstat(f.fileno())
511         sz = st.st_size
512     if not sz:
513         # trying to open a zero-length map gives an error, but an empty
514         # string has all the same behaviour of a zero-length map, ie. it has
515         # no elements :)
516         return ''
517     map = mmap.mmap(f.fileno(), sz, flags, prot)
518     if close:
519         f.close()  # map will persist beyond file close
520     return map
521
522
523 def mmap_read(f, sz = 0, close=True):
524     """Create a read-only memory mapped region on file 'f'.
525     If sz is 0, the region will cover the entire file.
526     """
527     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
528
529
530 def mmap_readwrite(f, sz = 0, close=True):
531     """Create a read-write memory mapped region on file 'f'.
532     If sz is 0, the region will cover the entire file.
533     """
534     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
535                     close)
536
537
538 def mmap_readwrite_private(f, sz = 0, close=True):
539     """Create a read-write memory mapped region on file 'f'.
540     If sz is 0, the region will cover the entire file.
541     The map is private, which means the changes are never flushed back to the
542     file.
543     """
544     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
545                     close)
546
547
548 def parse_num(s):
549     """Parse data size information into a float number.
550
551     Here are some examples of conversions:
552         199.2k means 203981 bytes
553         1GB means 1073741824 bytes
554         2.1 tb means 2199023255552 bytes
555     """
556     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
557     if not g:
558         raise ValueError("can't parse %r as a number" % s)
559     (val, unit) = g.groups()
560     num = float(val)
561     unit = unit.lower()
562     if unit in ['t', 'tb']:
563         mult = 1024*1024*1024*1024
564     elif unit in ['g', 'gb']:
565         mult = 1024*1024*1024
566     elif unit in ['m', 'mb']:
567         mult = 1024*1024
568     elif unit in ['k', 'kb']:
569         mult = 1024
570     elif unit in ['', 'b']:
571         mult = 1
572     else:
573         raise ValueError("invalid unit %r in number %r" % (unit, s))
574     return int(num*mult)
575
576
577 def count(l):
578     """Count the number of elements in an iterator. (consumes the iterator)"""
579     return reduce(lambda x,y: x+1, l)
580
581
582 saved_errors = []
583 def add_error(e):
584     """Append an error message to the list of saved errors.
585
586     Once processing is able to stop and output the errors, the saved errors are
587     accessible in the module variable helpers.saved_errors.
588     """
589     saved_errors.append(e)
590     log('%-70s\n' % e)
591
592
593 def clear_errors():
594     global saved_errors
595     saved_errors = []
596
597
598 def handle_ctrl_c():
599     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
600
601     The new exception handler will make sure that bup will exit without an ugly
602     stacktrace when Ctrl-C is hit.
603     """
604     oldhook = sys.excepthook
605     def newhook(exctype, value, traceback):
606         if exctype == KeyboardInterrupt:
607             log('Interrupted.\n')
608         else:
609             return oldhook(exctype, value, traceback)
610     sys.excepthook = newhook
611
612
613 def columnate(l, prefix):
614     """Format elements of 'l' in columns with 'prefix' leading each line.
615
616     The number of columns is determined automatically based on the string
617     lengths.
618     """
619     if not l:
620         return ""
621     l = l[:]
622     clen = max(len(s) for s in l)
623     ncols = (tty_width() - len(prefix)) / (clen + 2)
624     if ncols <= 1:
625         ncols = 1
626         clen = 0
627     cols = []
628     while len(l) % ncols:
629         l.append('')
630     rows = len(l)/ncols
631     for s in range(0, len(l), rows):
632         cols.append(l[s:s+rows])
633     out = ''
634     for row in zip(*cols):
635         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
636     return out
637
638
639 def parse_date_or_fatal(str, fatal):
640     """Parses the given date or calls Option.fatal().
641     For now we expect a string that contains a float."""
642     try:
643         date = atof(str)
644     except ValueError, e:
645         raise fatal('invalid date format (should be a float): %r' % e)
646     else:
647         return date
648
649
650 def path_components(path):
651     """Break path into a list of pairs of the form (name,
652     full_path_to_name).  Path must start with '/'.
653     Example:
654       '/home/foo' -> [('', '/'), ('home', '/home'), ('foo', '/home/foo')]"""
655     assert(path.startswith('/'))
656     # Since we assume path startswith('/'), we can skip the first element.
657     result = [('', '/')]
658     norm_path = os.path.abspath(path)
659     if norm_path == '/':
660         return result
661     full_path = ''
662     for p in norm_path.split('/')[1:]:
663         full_path += '/' + p
664         result.append((p, full_path))
665     return result
666
667
668 def stripped_path_components(path, strip_prefixes):
669     """Strip any prefix in strip_prefixes from path and return a list
670     of path components where each component is (name,
671     none_or_full_fs_path_to_name).  Assume path startswith('/').
672     See thelpers.py for examples."""
673     normalized_path = os.path.abspath(path)
674     sorted_strip_prefixes = sorted(strip_prefixes, key=len, reverse=True)
675     for bp in sorted_strip_prefixes:
676         normalized_bp = os.path.abspath(bp)
677         if normalized_path.startswith(normalized_bp):
678             prefix = normalized_path[:len(normalized_bp)]
679             result = []
680             for p in normalized_path[len(normalized_bp):].split('/'):
681                 if p: # not root
682                     prefix += '/'
683                 prefix += p
684                 result.append((p, prefix))
685             return result
686     # Nothing to strip.
687     return path_components(path)
688
689
690 def grafted_path_components(graft_points, path):
691     # Find the first '/' after the graft prefix, match that to the
692     # original source base dir, then move on.
693     clean_path = os.path.abspath(path)
694     for graft_point in graft_points:
695         old_prefix, new_prefix = graft_point
696         if clean_path.startswith(old_prefix):
697             grafted_path = re.sub(r'^' + old_prefix, new_prefix,
698                                   clean_path)
699             result = [(p, None) for p in grafted_path.split('/')]
700             result[-1] = (result[-1][0], clean_path)
701             return result
702     return path_components(clean_path)
703
704
705 # hashlib is only available in python 2.5 or higher, but the 'sha' module
706 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
707 # python 2.4 and above without any stupid warnings, so let's try using hashlib
708 # first, and downgrade if it fails.
709 try:
710     import hashlib
711 except ImportError:
712     import sha
713     Sha1 = sha.sha
714 else:
715     Sha1 = hashlib.sha1
716
717
718 def version_date():
719     """Format bup's version date string for output."""
720     return _version.DATE.split(' ')[0]
721
722
723 def version_commit():
724     """Get the commit hash of bup's current version."""
725     return _version.COMMIT
726
727
728 def version_tag():
729     """Format bup's version tag (the official version number).
730
731     When generated from a commit other than one pointed to with a tag, the
732     returned string will be "unknown-" followed by the first seven positions of
733     the commit hash.
734     """
735     names = _version.NAMES.strip()
736     assert(names[0] == '(')
737     assert(names[-1] == ')')
738     names = names[1:-1]
739     l = [n.strip() for n in names.split(',')]
740     for n in l:
741         if n.startswith('tag: bup-'):
742             return n[9:]
743     return 'unknown-%s' % _version.COMMIT[:7]