]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Merge branch 'master' into meta
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator, time
5 from bup import _version, _helpers
6 import bup._helpers as _helpers
7
8 # This function should really be in helpers, not in bup.options.  But we
9 # want options.py to be standalone so people can include it in other projects.
10 from bup.options import _tty_width
11 tty_width = _tty_width
12
13
14 def atoi(s):
15     """Convert the string 's' to an integer. Return 0 if s is not a number."""
16     try:
17         return int(s or '0')
18     except ValueError:
19         return 0
20
21
22 def atof(s):
23     """Convert the string 's' to a float. Return 0 if s is not a number."""
24     try:
25         return float(s or '0')
26     except ValueError:
27         return 0
28
29
30 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
31
32
33 # Write (blockingly) to sockets that may or may not be in blocking mode.
34 # We need this because our stderr is sometimes eaten by subprocesses
35 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
36 # leading to race conditions.  Ick.  We'll do it the hard way.
37 def _hard_write(fd, buf):
38     while buf:
39         (r,w,x) = select.select([], [fd], [], None)
40         if not w:
41             raise IOError('select(fd) returned without being writable')
42         try:
43             sz = os.write(fd, buf)
44         except OSError, e:
45             if e.errno != errno.EAGAIN:
46                 raise
47         assert(sz >= 0)
48         buf = buf[sz:]
49
50
51 _last_prog = 0
52 def log(s):
53     """Print a log message to stderr."""
54     global _last_prog
55     sys.stdout.flush()
56     _hard_write(sys.stderr.fileno(), s)
57     _last_prog = 0
58
59
60 def debug1(s):
61     if buglvl >= 1:
62         log(s)
63
64
65 def debug2(s):
66     if buglvl >= 2:
67         log(s)
68
69
70 istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
71 istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
72 _last_progress = ''
73 def progress(s):
74     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
75     global _last_progress
76     if istty2:
77         log(s)
78         _last_progress = s
79
80
81 def qprogress(s):
82     """Calls progress() only if we haven't printed progress in a while.
83     
84     This avoids overloading the stderr buffer with excess junk.
85     """
86     global _last_prog
87     now = time.time()
88     if now - _last_prog > 0.1:
89         progress(s)
90         _last_prog = now
91
92
93 def reprogress():
94     """Calls progress() to redisplay the most recent progress message.
95
96     Useful after you've printed some other message that wipes out the
97     progress line.
98     """
99     if _last_progress and _last_progress.endswith('\r'):
100         progress(_last_progress)
101
102
103 def mkdirp(d, mode=None):
104     """Recursively create directories on path 'd'.
105
106     Unlike os.makedirs(), it doesn't raise an exception if the last element of
107     the path already exists.
108     """
109     try:
110         if mode:
111             os.makedirs(d, mode)
112         else:
113             os.makedirs(d)
114     except OSError, e:
115         if e.errno == errno.EEXIST:
116             pass
117         else:
118             raise
119
120
121 def next(it):
122     """Get the next item from an iterator, None if we reached the end."""
123     try:
124         return it.next()
125     except StopIteration:
126         return None
127
128
129 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
130     if key:
131         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
132     else:
133         samekey = operator.eq
134     count = 0
135     total = sum(len(it) for it in iters)
136     iters = (iter(it) for it in iters)
137     heap = ((next(it),it) for it in iters)
138     heap = [(e,it) for e,it in heap if e]
139
140     heapq.heapify(heap)
141     pe = None
142     while heap:
143         if not count % pfreq:
144             pfunc(count, total)
145         e, it = heap[0]
146         if not samekey(e, pe):
147             pe = e
148             yield e
149         count += 1
150         try:
151             e = it.next() # Don't use next() function, it's too expensive
152         except StopIteration:
153             heapq.heappop(heap) # remove current
154         else:
155             heapq.heapreplace(heap, (e, it)) # shift current to new location
156     pfinal(count, total)
157
158
159 def unlink(f):
160     """Delete a file at path 'f' if it currently exists.
161
162     Unlike os.unlink(), does not throw an exception if the file didn't already
163     exist.
164     """
165     try:
166         os.unlink(f)
167     except OSError, e:
168         if e.errno == errno.ENOENT:
169             pass  # it doesn't exist, that's what you asked for
170
171
172 def readpipe(argv):
173     """Run a subprocess and return its output."""
174     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
175     r = p.stdout.read()
176     p.wait()
177     return r
178
179
180 def realpath(p):
181     """Get the absolute path of a file.
182
183     Behaves like os.path.realpath, but doesn't follow a symlink for the last
184     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
185     will follow symlinks in p's directory)
186     """
187     try:
188         st = os.lstat(p)
189     except OSError:
190         st = None
191     if st and stat.S_ISLNK(st.st_mode):
192         (dir, name) = os.path.split(p)
193         dir = os.path.realpath(dir)
194         out = os.path.join(dir, name)
195     else:
196         out = os.path.realpath(p)
197     #log('realpathing:%r,%r\n' % (p, out))
198     return out
199
200
201 def detect_fakeroot():
202     "Return True if we appear to be running under fakeroot."
203     return os.getenv("FAKEROOTKEY") != None
204
205
206 _username = None
207 def username():
208     """Get the user's login name."""
209     global _username
210     if not _username:
211         uid = os.getuid()
212         try:
213             _username = pwd.getpwuid(uid)[0]
214         except KeyError:
215             _username = 'user%d' % uid
216     return _username
217
218
219 _userfullname = None
220 def userfullname():
221     """Get the user's full name."""
222     global _userfullname
223     if not _userfullname:
224         uid = os.getuid()
225         try:
226             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
227         except KeyError:
228             _userfullname = 'user%d' % uid
229     return _userfullname
230
231
232 _hostname = None
233 def hostname():
234     """Get the FQDN of this machine."""
235     global _hostname
236     if not _hostname:
237         _hostname = socket.getfqdn()
238     return _hostname
239
240
241 _resource_path = None
242 def resource_path(subdir=''):
243     global _resource_path
244     if not _resource_path:
245         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
246     return os.path.join(_resource_path, subdir)
247
248
249 class NotOk(Exception):
250     pass
251
252
253 class BaseConn:
254     def __init__(self, outp):
255         self.outp = outp
256
257     def close(self):
258         while self._read(65536): pass
259
260     def read(self, size):
261         """Read 'size' bytes from input stream."""
262         self.outp.flush()
263         return self._read(size)
264
265     def readline(self):
266         """Read from input stream until a newline is found."""
267         self.outp.flush()
268         return self._readline()
269
270     def write(self, data):
271         """Write 'data' to output stream."""
272         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
273         self.outp.write(data)
274
275     def has_input(self):
276         """Return true if input stream is readable."""
277         raise NotImplemented("Subclasses must implement has_input")
278
279     def ok(self):
280         """Indicate end of output from last sent command."""
281         self.write('\nok\n')
282
283     def error(self, s):
284         """Indicate server error to the client."""
285         s = re.sub(r'\s+', ' ', str(s))
286         self.write('\nerror %s\n' % s)
287
288     def _check_ok(self, onempty):
289         self.outp.flush()
290         rl = ''
291         for rl in linereader(self):
292             #log('%d got line: %r\n' % (os.getpid(), rl))
293             if not rl:  # empty line
294                 continue
295             elif rl == 'ok':
296                 return None
297             elif rl.startswith('error '):
298                 #log('client: error: %s\n' % rl[6:])
299                 return NotOk(rl[6:])
300             else:
301                 onempty(rl)
302         raise Exception('server exited unexpectedly; see errors above')
303
304     def drain_and_check_ok(self):
305         """Remove all data for the current command from input stream."""
306         def onempty(rl):
307             pass
308         return self._check_ok(onempty)
309
310     def check_ok(self):
311         """Verify that server action completed successfully."""
312         def onempty(rl):
313             raise Exception('expected "ok", got %r' % rl)
314         return self._check_ok(onempty)
315
316
317 class Conn(BaseConn):
318     def __init__(self, inp, outp):
319         BaseConn.__init__(self, outp)
320         self.inp = inp
321
322     def _read(self, size):
323         return self.inp.read(size)
324
325     def _readline(self):
326         return self.inp.readline()
327
328     def has_input(self):
329         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
330         if rl:
331             assert(rl[0] == self.inp.fileno())
332             return True
333         else:
334             return None
335
336
337 def checked_reader(fd, n):
338     while n > 0:
339         rl, _, _ = select.select([fd], [], [])
340         assert(rl[0] == fd)
341         buf = os.read(fd, n)
342         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
343         yield buf
344         n -= len(buf)
345
346
347 MAX_PACKET = 128 * 1024
348 def mux(p, outfd, outr, errr):
349     try:
350         fds = [outr, errr]
351         while p.poll() is None:
352             rl, _, _ = select.select(fds, [], [])
353             for fd in rl:
354                 if fd == outr:
355                     buf = os.read(outr, MAX_PACKET)
356                     if not buf: break
357                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
358                 elif fd == errr:
359                     buf = os.read(errr, 1024)
360                     if not buf: break
361                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
362     finally:
363         os.write(outfd, struct.pack('!IB', 0, 3))
364
365
366 class DemuxConn(BaseConn):
367     """A helper class for bup's client-server protocol."""
368     def __init__(self, infd, outp):
369         BaseConn.__init__(self, outp)
370         # Anything that comes through before the sync string was not
371         # multiplexed and can be assumed to be debug/log before mux init.
372         tail = ''
373         while tail != 'BUPMUX':
374             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
375             if not b:
376                 raise IOError('demux: unexpected EOF during initialization')
377             tail += b
378             sys.stderr.write(tail[:-6])  # pre-mux log messages
379             tail = tail[-6:]
380         self.infd = infd
381         self.reader = None
382         self.buf = None
383         self.closed = False
384
385     def write(self, data):
386         self._load_buf(0)
387         BaseConn.write(self, data)
388
389     def _next_packet(self, timeout):
390         if self.closed: return False
391         rl, wl, xl = select.select([self.infd], [], [], timeout)
392         if not rl: return False
393         assert(rl[0] == self.infd)
394         ns = ''.join(checked_reader(self.infd, 5))
395         n, fdw = struct.unpack('!IB', ns)
396         assert(n <= MAX_PACKET)
397         if fdw == 1:
398             self.reader = checked_reader(self.infd, n)
399         elif fdw == 2:
400             for buf in checked_reader(self.infd, n):
401                 sys.stderr.write(buf)
402         elif fdw == 3:
403             self.closed = True
404             debug2("DemuxConn: marked closed\n")
405         return True
406
407     def _load_buf(self, timeout):
408         if self.buf is not None:
409             return True
410         while not self.closed:
411             while not self.reader:
412                 if not self._next_packet(timeout):
413                     return False
414             try:
415                 self.buf = self.reader.next()
416                 return True
417             except StopIteration:
418                 self.reader = None
419         return False
420
421     def _read_parts(self, ix_fn):
422         while self._load_buf(None):
423             assert(self.buf is not None)
424             i = ix_fn(self.buf)
425             if i is None or i == len(self.buf):
426                 yv = self.buf
427                 self.buf = None
428             else:
429                 yv = self.buf[:i]
430                 self.buf = self.buf[i:]
431             yield yv
432             if i is not None:
433                 break
434
435     def _readline(self):
436         def find_eol(buf):
437             try:
438                 return buf.index('\n')+1
439             except ValueError:
440                 return None
441         return ''.join(self._read_parts(find_eol))
442
443     def _read(self, size):
444         csize = [size]
445         def until_size(buf): # Closes on csize
446             if len(buf) < csize[0]:
447                 csize[0] -= len(buf)
448                 return None
449             else:
450                 return csize[0]
451         return ''.join(self._read_parts(until_size))
452
453     def has_input(self):
454         return self._load_buf(0)
455
456
457 def linereader(f):
458     """Generate a list of input lines from 'f' without terminating newlines."""
459     while 1:
460         line = f.readline()
461         if not line:
462             break
463         yield line[:-1]
464
465
466 def chunkyreader(f, count = None):
467     """Generate a list of chunks of data read from 'f'.
468
469     If count is None, read until EOF is reached.
470
471     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
472     reached while reading, raise IOError.
473     """
474     if count != None:
475         while count > 0:
476             b = f.read(min(count, 65536))
477             if not b:
478                 raise IOError('EOF with %d bytes remaining' % count)
479             yield b
480             count -= len(b)
481     else:
482         while 1:
483             b = f.read(65536)
484             if not b: break
485             yield b
486
487
488 def slashappend(s):
489     """Append "/" to 's' if it doesn't aleady end in "/"."""
490     if s and not s.endswith('/'):
491         return s + '/'
492     else:
493         return s
494
495
496 def _mmap_do(f, sz, flags, prot, close):
497     if not sz:
498         st = os.fstat(f.fileno())
499         sz = st.st_size
500     if not sz:
501         # trying to open a zero-length map gives an error, but an empty
502         # string has all the same behaviour of a zero-length map, ie. it has
503         # no elements :)
504         return ''
505     map = mmap.mmap(f.fileno(), sz, flags, prot)
506     if close:
507         f.close()  # map will persist beyond file close
508     return map
509
510
511 def mmap_read(f, sz = 0, close=True):
512     """Create a read-only memory mapped region on file 'f'.
513     If sz is 0, the region will cover the entire file.
514     """
515     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
516
517
518 def mmap_readwrite(f, sz = 0, close=True):
519     """Create a read-write memory mapped region on file 'f'.
520     If sz is 0, the region will cover the entire file.
521     """
522     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
523                     close)
524
525
526 def mmap_readwrite_private(f, sz = 0, close=True):
527     """Create a read-write memory mapped region on file 'f'.
528     If sz is 0, the region will cover the entire file.
529     The map is private, which means the changes are never flushed back to the
530     file.
531     """
532     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
533                     close)
534
535
536 def parse_num(s):
537     """Parse data size information into a float number.
538
539     Here are some examples of conversions:
540         199.2k means 203981 bytes
541         1GB means 1073741824 bytes
542         2.1 tb means 2199023255552 bytes
543     """
544     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
545     if not g:
546         raise ValueError("can't parse %r as a number" % s)
547     (val, unit) = g.groups()
548     num = float(val)
549     unit = unit.lower()
550     if unit in ['t', 'tb']:
551         mult = 1024*1024*1024*1024
552     elif unit in ['g', 'gb']:
553         mult = 1024*1024*1024
554     elif unit in ['m', 'mb']:
555         mult = 1024*1024
556     elif unit in ['k', 'kb']:
557         mult = 1024
558     elif unit in ['', 'b']:
559         mult = 1
560     else:
561         raise ValueError("invalid unit %r in number %r" % (unit, s))
562     return int(num*mult)
563
564
565 def count(l):
566     """Count the number of elements in an iterator. (consumes the iterator)"""
567     return reduce(lambda x,y: x+1, l)
568
569
570 saved_errors = []
571 def add_error(e):
572     """Append an error message to the list of saved errors.
573
574     Once processing is able to stop and output the errors, the saved errors are
575     accessible in the module variable helpers.saved_errors.
576     """
577     saved_errors.append(e)
578     log('%-70s\n' % e)
579
580
581 def clear_errors():
582     global saved_errors
583     saved_errors = []
584
585
586 def handle_ctrl_c():
587     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
588
589     The new exception handler will make sure that bup will exit without an ugly
590     stacktrace when Ctrl-C is hit.
591     """
592     oldhook = sys.excepthook
593     def newhook(exctype, value, traceback):
594         if exctype == KeyboardInterrupt:
595             log('Interrupted.\n')
596         else:
597             return oldhook(exctype, value, traceback)
598     sys.excepthook = newhook
599
600
601 def columnate(l, prefix):
602     """Format elements of 'l' in columns with 'prefix' leading each line.
603
604     The number of columns is determined automatically based on the string
605     lengths.
606     """
607     if not l:
608         return ""
609     l = l[:]
610     clen = max(len(s) for s in l)
611     ncols = (tty_width() - len(prefix)) / (clen + 2)
612     if ncols <= 1:
613         ncols = 1
614         clen = 0
615     cols = []
616     while len(l) % ncols:
617         l.append('')
618     rows = len(l)/ncols
619     for s in range(0, len(l), rows):
620         cols.append(l[s:s+rows])
621     out = ''
622     for row in zip(*cols):
623         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
624     return out
625
626
627 def parse_date_or_fatal(str, fatal):
628     """Parses the given date or calls Option.fatal().
629     For now we expect a string that contains a float."""
630     try:
631         date = atof(str)
632     except ValueError, e:
633         raise fatal('invalid date format (should be a float): %r' % e)
634     else:
635         return date
636
637
638 def strip_path(prefix, path):
639     """Strips a given prefix from a path.
640
641     First both paths are normalized.
642
643     Raises an Exception if no prefix is given.
644     """
645     if prefix == None:
646         raise Exception('no path given')
647
648     normalized_prefix = os.path.realpath(prefix)
649     debug2("normalized_prefix: %s\n" % normalized_prefix)
650     normalized_path = os.path.realpath(path)
651     debug2("normalized_path: %s\n" % normalized_path)
652     if normalized_path.startswith(normalized_prefix):
653         return normalized_path[len(normalized_prefix):]
654     else:
655         return path
656
657
658 def strip_base_path(path, base_paths):
659     """Strips the base path from a given path.
660
661
662     Determines the base path for the given string and then strips it
663     using strip_path().
664     Iterates over all base_paths from long to short, to prevent that
665     a too short base_path is removed.
666     """
667     normalized_path = os.path.realpath(path)
668     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
669     for bp in sorted_base_paths:
670         if normalized_path.startswith(os.path.realpath(bp)):
671             return strip_path(bp, normalized_path)
672     return path
673
674
675 def graft_path(graft_points, path):
676     normalized_path = os.path.realpath(path)
677     for graft_point in graft_points:
678         old_prefix, new_prefix = graft_point
679         if normalized_path.startswith(old_prefix):
680             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
681     return normalized_path
682
683
684 # hashlib is only available in python 2.5 or higher, but the 'sha' module
685 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
686 # python 2.4 and above without any stupid warnings, so let's try using hashlib
687 # first, and downgrade if it fails.
688 try:
689     import hashlib
690 except ImportError:
691     import sha
692     Sha1 = sha.sha
693 else:
694     Sha1 = hashlib.sha1
695
696
697 def version_date():
698     """Format bup's version date string for output."""
699     return _version.DATE.split(' ')[0]
700
701
702 def version_commit():
703     """Get the commit hash of bup's current version."""
704     return _version.COMMIT
705
706
707 def version_tag():
708     """Format bup's version tag (the official version number).
709
710     When generated from a commit other than one pointed to with a tag, the
711     returned string will be "unknown-" followed by the first seven positions of
712     the commit hash.
713     """
714     names = _version.NAMES.strip()
715     assert(names[0] == '(')
716     assert(names[-1] == ')')
717     names = names[1:-1]
718     l = [n.strip() for n in names.split(',')]
719     for n in l:
720         if n.startswith('tag: bup-'):
721             return n[9:]
722     return 'unknown-%s' % _version.COMMIT[:7]