]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Merge branch 'master' into config
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator, time, platform
5 from bup import _version, _helpers
6 import bup._helpers as _helpers
7
8 # This function should really be in helpers, not in bup.options.  But we
9 # want options.py to be standalone so people can include it in other projects.
10 from bup.options import _tty_width
11 tty_width = _tty_width
12
13
14 def atoi(s):
15     """Convert the string 's' to an integer. Return 0 if s is not a number."""
16     try:
17         return int(s or '0')
18     except ValueError:
19         return 0
20
21
22 def atof(s):
23     """Convert the string 's' to a float. Return 0 if s is not a number."""
24     try:
25         return float(s or '0')
26     except ValueError:
27         return 0
28
29
30 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
31
32
33 # Write (blockingly) to sockets that may or may not be in blocking mode.
34 # We need this because our stderr is sometimes eaten by subprocesses
35 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
36 # leading to race conditions.  Ick.  We'll do it the hard way.
37 def _hard_write(fd, buf):
38     while buf:
39         (r,w,x) = select.select([], [fd], [], None)
40         if not w:
41             raise IOError('select(fd) returned without being writable')
42         try:
43             sz = os.write(fd, buf)
44         except OSError, e:
45             if e.errno != errno.EAGAIN:
46                 raise
47         assert(sz >= 0)
48         buf = buf[sz:]
49
50
51 _last_prog = 0
52 def log(s):
53     """Print a log message to stderr."""
54     global _last_prog
55     sys.stdout.flush()
56     _hard_write(sys.stderr.fileno(), s)
57     _last_prog = 0
58
59
60 def debug1(s):
61     if buglvl >= 1:
62         log(s)
63
64
65 def debug2(s):
66     if buglvl >= 2:
67         log(s)
68
69
70 istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
71 istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
72 _last_progress = ''
73 def progress(s):
74     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
75     global _last_progress
76     if istty2:
77         log(s)
78         _last_progress = s
79
80
81 def qprogress(s):
82     """Calls progress() only if we haven't printed progress in a while.
83     
84     This avoids overloading the stderr buffer with excess junk.
85     """
86     global _last_prog
87     now = time.time()
88     if now - _last_prog > 0.1:
89         progress(s)
90         _last_prog = now
91
92
93 def reprogress():
94     """Calls progress() to redisplay the most recent progress message.
95
96     Useful after you've printed some other message that wipes out the
97     progress line.
98     """
99     if _last_progress and _last_progress.endswith('\r'):
100         progress(_last_progress)
101
102
103 def mkdirp(d, mode=None):
104     """Recursively create directories on path 'd'.
105
106     Unlike os.makedirs(), it doesn't raise an exception if the last element of
107     the path already exists.
108     """
109     try:
110         if mode:
111             os.makedirs(d, mode)
112         else:
113             os.makedirs(d)
114     except OSError, e:
115         if e.errno == errno.EEXIST:
116             pass
117         else:
118             raise
119
120
121 def next(it):
122     """Get the next item from an iterator, None if we reached the end."""
123     try:
124         return it.next()
125     except StopIteration:
126         return None
127
128
129 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
130     if key:
131         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
132     else:
133         samekey = operator.eq
134     count = 0
135     total = sum(len(it) for it in iters)
136     iters = (iter(it) for it in iters)
137     heap = ((next(it),it) for it in iters)
138     heap = [(e,it) for e,it in heap if e]
139
140     heapq.heapify(heap)
141     pe = None
142     while heap:
143         if not count % pfreq:
144             pfunc(count, total)
145         e, it = heap[0]
146         if not samekey(e, pe):
147             pe = e
148             yield e
149         count += 1
150         try:
151             e = it.next() # Don't use next() function, it's too expensive
152         except StopIteration:
153             heapq.heappop(heap) # remove current
154         else:
155             heapq.heapreplace(heap, (e, it)) # shift current to new location
156     pfinal(count, total)
157
158
159 def unlink(f):
160     """Delete a file at path 'f' if it currently exists.
161
162     Unlike os.unlink(), does not throw an exception if the file didn't already
163     exist.
164     """
165     try:
166         os.unlink(f)
167     except OSError, e:
168         if e.errno == errno.ENOENT:
169             pass  # it doesn't exist, that's what you asked for
170
171
172 def readpipe(argv):
173     """Run a subprocess and return its output."""
174     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
175     r = p.stdout.read()
176     p.wait()
177     return r
178
179
180 def realpath(p):
181     """Get the absolute path of a file.
182
183     Behaves like os.path.realpath, but doesn't follow a symlink for the last
184     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
185     will follow symlinks in p's directory)
186     """
187     try:
188         st = os.lstat(p)
189     except OSError:
190         st = None
191     if st and stat.S_ISLNK(st.st_mode):
192         (dir, name) = os.path.split(p)
193         dir = os.path.realpath(dir)
194         out = os.path.join(dir, name)
195     else:
196         out = os.path.realpath(p)
197     #log('realpathing:%r,%r\n' % (p, out))
198     return out
199
200
201 def detect_fakeroot():
202     "Return True if we appear to be running under fakeroot."
203     return os.getenv("FAKEROOTKEY") != None
204
205
206 def is_superuser():
207     if platform.system().startswith('CYGWIN'):
208         import ctypes
209         return ctypes.cdll.shell32.IsUserAnAdmin()
210     else:
211         return os.geteuid() == 0
212
213
214 _username = None
215 def username():
216     """Get the user's login name."""
217     global _username
218     if not _username:
219         uid = os.getuid()
220         try:
221             _username = pwd.getpwuid(uid)[0]
222         except KeyError:
223             _username = 'user%d' % uid
224     return _username
225
226
227 _userfullname = None
228 def userfullname():
229     """Get the user's full name."""
230     global _userfullname
231     if not _userfullname:
232         uid = os.getuid()
233         try:
234             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
235         except KeyError:
236             _userfullname = 'user%d' % uid
237     return _userfullname
238
239
240 _hostname = None
241 def hostname():
242     """Get the FQDN of this machine."""
243     global _hostname
244     if not _hostname:
245         _hostname = socket.getfqdn()
246     return _hostname
247
248
249 _resource_path = None
250 def resource_path(subdir=''):
251     global _resource_path
252     if not _resource_path:
253         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
254     return os.path.join(_resource_path, subdir)
255
256
257 class NotOk(Exception):
258     pass
259
260
261 class BaseConn:
262     def __init__(self, outp):
263         self.outp = outp
264
265     def close(self):
266         while self._read(65536): pass
267
268     def read(self, size):
269         """Read 'size' bytes from input stream."""
270         self.outp.flush()
271         return self._read(size)
272
273     def readline(self):
274         """Read from input stream until a newline is found."""
275         self.outp.flush()
276         return self._readline()
277
278     def write(self, data):
279         """Write 'data' to output stream."""
280         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
281         self.outp.write(data)
282
283     def has_input(self):
284         """Return true if input stream is readable."""
285         raise NotImplemented("Subclasses must implement has_input")
286
287     def ok(self):
288         """Indicate end of output from last sent command."""
289         self.write('\nok\n')
290
291     def error(self, s):
292         """Indicate server error to the client."""
293         s = re.sub(r'\s+', ' ', str(s))
294         self.write('\nerror %s\n' % s)
295
296     def _check_ok(self, onempty):
297         self.outp.flush()
298         rl = ''
299         for rl in linereader(self):
300             #log('%d got line: %r\n' % (os.getpid(), rl))
301             if not rl:  # empty line
302                 continue
303             elif rl == 'ok':
304                 return None
305             elif rl.startswith('error '):
306                 #log('client: error: %s\n' % rl[6:])
307                 return NotOk(rl[6:])
308             else:
309                 onempty(rl)
310         raise Exception('server exited unexpectedly; see errors above')
311
312     def drain_and_check_ok(self):
313         """Remove all data for the current command from input stream."""
314         def onempty(rl):
315             pass
316         return self._check_ok(onempty)
317
318     def check_ok(self):
319         """Verify that server action completed successfully."""
320         def onempty(rl):
321             raise Exception('expected "ok", got %r' % rl)
322         return self._check_ok(onempty)
323
324
325 class Conn(BaseConn):
326     def __init__(self, inp, outp):
327         BaseConn.__init__(self, outp)
328         self.inp = inp
329
330     def _read(self, size):
331         return self.inp.read(size)
332
333     def _readline(self):
334         return self.inp.readline()
335
336     def has_input(self):
337         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
338         if rl:
339             assert(rl[0] == self.inp.fileno())
340             return True
341         else:
342             return None
343
344
345 def checked_reader(fd, n):
346     while n > 0:
347         rl, _, _ = select.select([fd], [], [])
348         assert(rl[0] == fd)
349         buf = os.read(fd, n)
350         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
351         yield buf
352         n -= len(buf)
353
354
355 MAX_PACKET = 128 * 1024
356 def mux(p, outfd, outr, errr):
357     try:
358         fds = [outr, errr]
359         while p.poll() is None:
360             rl, _, _ = select.select(fds, [], [])
361             for fd in rl:
362                 if fd == outr:
363                     buf = os.read(outr, MAX_PACKET)
364                     if not buf: break
365                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
366                 elif fd == errr:
367                     buf = os.read(errr, 1024)
368                     if not buf: break
369                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
370     finally:
371         os.write(outfd, struct.pack('!IB', 0, 3))
372
373
374 class DemuxConn(BaseConn):
375     """A helper class for bup's client-server protocol."""
376     def __init__(self, infd, outp):
377         BaseConn.__init__(self, outp)
378         # Anything that comes through before the sync string was not
379         # multiplexed and can be assumed to be debug/log before mux init.
380         tail = ''
381         while tail != 'BUPMUX':
382             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
383             if not b:
384                 raise IOError('demux: unexpected EOF during initialization')
385             tail += b
386             sys.stderr.write(tail[:-6])  # pre-mux log messages
387             tail = tail[-6:]
388         self.infd = infd
389         self.reader = None
390         self.buf = None
391         self.closed = False
392
393     def write(self, data):
394         self._load_buf(0)
395         BaseConn.write(self, data)
396
397     def _next_packet(self, timeout):
398         if self.closed: return False
399         rl, wl, xl = select.select([self.infd], [], [], timeout)
400         if not rl: return False
401         assert(rl[0] == self.infd)
402         ns = ''.join(checked_reader(self.infd, 5))
403         n, fdw = struct.unpack('!IB', ns)
404         assert(n <= MAX_PACKET)
405         if fdw == 1:
406             self.reader = checked_reader(self.infd, n)
407         elif fdw == 2:
408             for buf in checked_reader(self.infd, n):
409                 sys.stderr.write(buf)
410         elif fdw == 3:
411             self.closed = True
412             debug2("DemuxConn: marked closed\n")
413         return True
414
415     def _load_buf(self, timeout):
416         if self.buf is not None:
417             return True
418         while not self.closed:
419             while not self.reader:
420                 if not self._next_packet(timeout):
421                     return False
422             try:
423                 self.buf = self.reader.next()
424                 return True
425             except StopIteration:
426                 self.reader = None
427         return False
428
429     def _read_parts(self, ix_fn):
430         while self._load_buf(None):
431             assert(self.buf is not None)
432             i = ix_fn(self.buf)
433             if i is None or i == len(self.buf):
434                 yv = self.buf
435                 self.buf = None
436             else:
437                 yv = self.buf[:i]
438                 self.buf = self.buf[i:]
439             yield yv
440             if i is not None:
441                 break
442
443     def _readline(self):
444         def find_eol(buf):
445             try:
446                 return buf.index('\n')+1
447             except ValueError:
448                 return None
449         return ''.join(self._read_parts(find_eol))
450
451     def _read(self, size):
452         csize = [size]
453         def until_size(buf): # Closes on csize
454             if len(buf) < csize[0]:
455                 csize[0] -= len(buf)
456                 return None
457             else:
458                 return csize[0]
459         return ''.join(self._read_parts(until_size))
460
461     def has_input(self):
462         return self._load_buf(0)
463
464
465 def linereader(f):
466     """Generate a list of input lines from 'f' without terminating newlines."""
467     while 1:
468         line = f.readline()
469         if not line:
470             break
471         yield line[:-1]
472
473
474 def chunkyreader(f, count = None):
475     """Generate a list of chunks of data read from 'f'.
476
477     If count is None, read until EOF is reached.
478
479     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
480     reached while reading, raise IOError.
481     """
482     if count != None:
483         while count > 0:
484             b = f.read(min(count, 65536))
485             if not b:
486                 raise IOError('EOF with %d bytes remaining' % count)
487             yield b
488             count -= len(b)
489     else:
490         while 1:
491             b = f.read(65536)
492             if not b: break
493             yield b
494
495
496 def slashappend(s):
497     """Append "/" to 's' if it doesn't aleady end in "/"."""
498     if s and not s.endswith('/'):
499         return s + '/'
500     else:
501         return s
502
503
504 def _mmap_do(f, sz, flags, prot, close):
505     if not sz:
506         st = os.fstat(f.fileno())
507         sz = st.st_size
508     if not sz:
509         # trying to open a zero-length map gives an error, but an empty
510         # string has all the same behaviour of a zero-length map, ie. it has
511         # no elements :)
512         return ''
513     map = mmap.mmap(f.fileno(), sz, flags, prot)
514     if close:
515         f.close()  # map will persist beyond file close
516     return map
517
518
519 def mmap_read(f, sz = 0, close=True):
520     """Create a read-only memory mapped region on file 'f'.
521     If sz is 0, the region will cover the entire file.
522     """
523     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
524
525
526 def mmap_readwrite(f, sz = 0, close=True):
527     """Create a read-write memory mapped region on file 'f'.
528     If sz is 0, the region will cover the entire file.
529     """
530     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
531                     close)
532
533
534 def mmap_readwrite_private(f, sz = 0, close=True):
535     """Create a read-write memory mapped region on file 'f'.
536     If sz is 0, the region will cover the entire file.
537     The map is private, which means the changes are never flushed back to the
538     file.
539     """
540     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
541                     close)
542
543
544 def parse_num(s):
545     """Parse data size information into a float number.
546
547     Here are some examples of conversions:
548         199.2k means 203981 bytes
549         1GB means 1073741824 bytes
550         2.1 tb means 2199023255552 bytes
551     """
552     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
553     if not g:
554         raise ValueError("can't parse %r as a number" % s)
555     (val, unit) = g.groups()
556     num = float(val)
557     unit = unit.lower()
558     if unit in ['t', 'tb']:
559         mult = 1024*1024*1024*1024
560     elif unit in ['g', 'gb']:
561         mult = 1024*1024*1024
562     elif unit in ['m', 'mb']:
563         mult = 1024*1024
564     elif unit in ['k', 'kb']:
565         mult = 1024
566     elif unit in ['', 'b']:
567         mult = 1
568     else:
569         raise ValueError("invalid unit %r in number %r" % (unit, s))
570     return int(num*mult)
571
572
573 def count(l):
574     """Count the number of elements in an iterator. (consumes the iterator)"""
575     return reduce(lambda x,y: x+1, l)
576
577
578 saved_errors = []
579 def add_error(e):
580     """Append an error message to the list of saved errors.
581
582     Once processing is able to stop and output the errors, the saved errors are
583     accessible in the module variable helpers.saved_errors.
584     """
585     saved_errors.append(e)
586     log('%-70s\n' % e)
587
588
589 def clear_errors():
590     global saved_errors
591     saved_errors = []
592
593
594 def handle_ctrl_c():
595     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
596
597     The new exception handler will make sure that bup will exit without an ugly
598     stacktrace when Ctrl-C is hit.
599     """
600     oldhook = sys.excepthook
601     def newhook(exctype, value, traceback):
602         if exctype == KeyboardInterrupt:
603             log('Interrupted.\n')
604         else:
605             return oldhook(exctype, value, traceback)
606     sys.excepthook = newhook
607
608
609 def columnate(l, prefix):
610     """Format elements of 'l' in columns with 'prefix' leading each line.
611
612     The number of columns is determined automatically based on the string
613     lengths.
614     """
615     if not l:
616         return ""
617     l = l[:]
618     clen = max(len(s) for s in l)
619     ncols = (tty_width() - len(prefix)) / (clen + 2)
620     if ncols <= 1:
621         ncols = 1
622         clen = 0
623     cols = []
624     while len(l) % ncols:
625         l.append('')
626     rows = len(l)/ncols
627     for s in range(0, len(l), rows):
628         cols.append(l[s:s+rows])
629     out = ''
630     for row in zip(*cols):
631         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
632     return out
633
634
635 def parse_date_or_fatal(str, fatal):
636     """Parses the given date or calls Option.fatal().
637     For now we expect a string that contains a float."""
638     try:
639         date = atof(str)
640     except ValueError, e:
641         raise fatal('invalid date format (should be a float): %r' % e)
642     else:
643         return date
644
645
646 def strip_path(prefix, path):
647     """Strips a given prefix from a path.
648
649     First both paths are normalized.
650
651     Raises an Exception if no prefix is given.
652     """
653     if prefix == None:
654         raise Exception('no path given')
655
656     normalized_prefix = os.path.realpath(prefix)
657     debug2("normalized_prefix: %s\n" % normalized_prefix)
658     normalized_path = os.path.realpath(path)
659     debug2("normalized_path: %s\n" % normalized_path)
660     if normalized_path.startswith(normalized_prefix):
661         return normalized_path[len(normalized_prefix):]
662     else:
663         return path
664
665
666 def strip_base_path(path, base_paths):
667     """Strips the base path from a given path.
668
669
670     Determines the base path for the given string and then strips it
671     using strip_path().
672     Iterates over all base_paths from long to short, to prevent that
673     a too short base_path is removed.
674     """
675     normalized_path = os.path.realpath(path)
676     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
677     for bp in sorted_base_paths:
678         if normalized_path.startswith(os.path.realpath(bp)):
679             return strip_path(bp, normalized_path)
680     return path
681
682
683 def graft_path(graft_points, path):
684     normalized_path = os.path.realpath(path)
685     for graft_point in graft_points:
686         old_prefix, new_prefix = graft_point
687         if normalized_path.startswith(old_prefix):
688             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
689     return normalized_path
690
691
692 # hashlib is only available in python 2.5 or higher, but the 'sha' module
693 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
694 # python 2.4 and above without any stupid warnings, so let's try using hashlib
695 # first, and downgrade if it fails.
696 try:
697     import hashlib
698 except ImportError:
699     import sha
700     Sha1 = sha.sha
701 else:
702     Sha1 = hashlib.sha1
703
704
705 def version_date():
706     """Format bup's version date string for output."""
707     return _version.DATE.split(' ')[0]
708
709
710 def version_commit():
711     """Get the commit hash of bup's current version."""
712     return _version.COMMIT
713
714
715 def version_tag():
716     """Format bup's version tag (the official version number).
717
718     When generated from a commit other than one pointed to with a tag, the
719     returned string will be "unknown-" followed by the first seven positions of
720     the commit hash.
721     """
722     names = _version.NAMES.strip()
723     assert(names[0] == '(')
724     assert(names[-1] == ')')
725     names = names[1:-1]
726     l = [n.strip() for n in names.split(',')]
727     for n in l:
728         if n.startswith('tag: bup-'):
729             return n[9:]
730     return 'unknown-%s' % _version.COMMIT[:7]