]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
helpers: separately determine if stdout and stderr are ttys.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator, time
5 from bup import _version, _helpers
6
7 # This function should really be in helpers, not in bup.options.  But we
8 # want options.py to be standalone so people can include it in other projects.
9 from bup.options import _tty_width
10 tty_width = _tty_width
11
12
13 def atoi(s):
14     """Convert the string 's' to an integer. Return 0 if s is not a number."""
15     try:
16         return int(s or '0')
17     except ValueError:
18         return 0
19
20
21 def atof(s):
22     """Convert the string 's' to a float. Return 0 if s is not a number."""
23     try:
24         return float(s or '0')
25     except ValueError:
26         return 0
27
28
29 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
30
31
32 # Write (blockingly) to sockets that may or may not be in blocking mode.
33 # We need this because our stderr is sometimes eaten by subprocesses
34 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
35 # leading to race conditions.  Ick.  We'll do it the hard way.
36 def _hard_write(fd, buf):
37     while buf:
38         (r,w,x) = select.select([], [fd], [], None)
39         if not w:
40             raise IOError('select(fd) returned without being writable')
41         try:
42             sz = os.write(fd, buf)
43         except OSError, e:
44             if e.errno != errno.EAGAIN:
45                 raise
46         assert(sz >= 0)
47         buf = buf[sz:]
48
49
50 _last_prog = 0
51 def log(s):
52     """Print a log message to stderr."""
53     global _last_prog
54     sys.stdout.flush()
55     _hard_write(sys.stderr.fileno(), s)
56     _last_prog = 0
57
58
59 def debug1(s):
60     if buglvl >= 1:
61         log(s)
62
63
64 def debug2(s):
65     if buglvl >= 2:
66         log(s)
67
68
69 istty1 = os.isatty(1) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 1)
70 istty2 = os.isatty(2) or (atoi(os.environ.get('BUP_FORCE_TTY')) & 2)
71 _last_progress = ''
72 def progress(s):
73     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
74     global _last_progress
75     if istty2:
76         log(s)
77         _last_progress = s
78
79
80 def qprogress(s):
81     """Calls progress() only if we haven't printed progress in a while.
82     
83     This avoids overloading the stderr buffer with excess junk.
84     """
85     global _last_prog
86     now = time.time()
87     if now - _last_prog > 0.1:
88         progress(s)
89         _last_prog = now
90
91
92 def reprogress():
93     """Calls progress() to redisplay the most recent progress message.
94
95     Useful after you've printed some other message that wipes out the
96     progress line.
97     """
98     if _last_progress and _last_progress.endswith('\r'):
99         progress(_last_progress)
100
101
102 def mkdirp(d, mode=None):
103     """Recursively create directories on path 'd'.
104
105     Unlike os.makedirs(), it doesn't raise an exception if the last element of
106     the path already exists.
107     """
108     try:
109         if mode:
110             os.makedirs(d, mode)
111         else:
112             os.makedirs(d)
113     except OSError, e:
114         if e.errno == errno.EEXIST:
115             pass
116         else:
117             raise
118
119
120 def next(it):
121     """Get the next item from an iterator, None if we reached the end."""
122     try:
123         return it.next()
124     except StopIteration:
125         return None
126
127
128 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
129     if key:
130         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
131     else:
132         samekey = operator.eq
133     count = 0
134     total = sum(len(it) for it in iters)
135     iters = (iter(it) for it in iters)
136     heap = ((next(it),it) for it in iters)
137     heap = [(e,it) for e,it in heap if e]
138
139     heapq.heapify(heap)
140     pe = None
141     while heap:
142         if not count % pfreq:
143             pfunc(count, total)
144         e, it = heap[0]
145         if not samekey(e, pe):
146             pe = e
147             yield e
148         count += 1
149         try:
150             e = it.next() # Don't use next() function, it's too expensive
151         except StopIteration:
152             heapq.heappop(heap) # remove current
153         else:
154             heapq.heapreplace(heap, (e, it)) # shift current to new location
155     pfinal(count, total)
156
157
158 def unlink(f):
159     """Delete a file at path 'f' if it currently exists.
160
161     Unlike os.unlink(), does not throw an exception if the file didn't already
162     exist.
163     """
164     try:
165         os.unlink(f)
166     except OSError, e:
167         if e.errno == errno.ENOENT:
168             pass  # it doesn't exist, that's what you asked for
169
170
171 def readpipe(argv):
172     """Run a subprocess and return its output."""
173     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
174     r = p.stdout.read()
175     p.wait()
176     return r
177
178
179 def realpath(p):
180     """Get the absolute path of a file.
181
182     Behaves like os.path.realpath, but doesn't follow a symlink for the last
183     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
184     will follow symlinks in p's directory)
185     """
186     try:
187         st = os.lstat(p)
188     except OSError:
189         st = None
190     if st and stat.S_ISLNK(st.st_mode):
191         (dir, name) = os.path.split(p)
192         dir = os.path.realpath(dir)
193         out = os.path.join(dir, name)
194     else:
195         out = os.path.realpath(p)
196     #log('realpathing:%r,%r\n' % (p, out))
197     return out
198
199
200 _username = None
201 def username():
202     """Get the user's login name."""
203     global _username
204     if not _username:
205         uid = os.getuid()
206         try:
207             _username = pwd.getpwuid(uid)[0]
208         except KeyError:
209             _username = 'user%d' % uid
210     return _username
211
212
213 _userfullname = None
214 def userfullname():
215     """Get the user's full name."""
216     global _userfullname
217     if not _userfullname:
218         uid = os.getuid()
219         try:
220             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
221         except KeyError:
222             _userfullname = 'user%d' % uid
223     return _userfullname
224
225
226 _hostname = None
227 def hostname():
228     """Get the FQDN of this machine."""
229     global _hostname
230     if not _hostname:
231         _hostname = socket.getfqdn()
232     return _hostname
233
234
235 _resource_path = None
236 def resource_path(subdir=''):
237     global _resource_path
238     if not _resource_path:
239         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
240     return os.path.join(_resource_path, subdir)
241
242
243 class NotOk(Exception):
244     pass
245
246
247 class BaseConn:
248     def __init__(self, outp):
249         self.outp = outp
250
251     def close(self):
252         while self._read(65536): pass
253
254     def read(self, size):
255         """Read 'size' bytes from input stream."""
256         self.outp.flush()
257         return self._read(size)
258
259     def readline(self):
260         """Read from input stream until a newline is found."""
261         self.outp.flush()
262         return self._readline()
263
264     def write(self, data):
265         """Write 'data' to output stream."""
266         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
267         self.outp.write(data)
268
269     def has_input(self):
270         """Return true if input stream is readable."""
271         raise NotImplemented("Subclasses must implement has_input")
272
273     def ok(self):
274         """Indicate end of output from last sent command."""
275         self.write('\nok\n')
276
277     def error(self, s):
278         """Indicate server error to the client."""
279         s = re.sub(r'\s+', ' ', str(s))
280         self.write('\nerror %s\n' % s)
281
282     def _check_ok(self, onempty):
283         self.outp.flush()
284         rl = ''
285         for rl in linereader(self):
286             #log('%d got line: %r\n' % (os.getpid(), rl))
287             if not rl:  # empty line
288                 continue
289             elif rl == 'ok':
290                 return None
291             elif rl.startswith('error '):
292                 #log('client: error: %s\n' % rl[6:])
293                 return NotOk(rl[6:])
294             else:
295                 onempty(rl)
296         raise Exception('server exited unexpectedly; see errors above')
297
298     def drain_and_check_ok(self):
299         """Remove all data for the current command from input stream."""
300         def onempty(rl):
301             pass
302         return self._check_ok(onempty)
303
304     def check_ok(self):
305         """Verify that server action completed successfully."""
306         def onempty(rl):
307             raise Exception('expected "ok", got %r' % rl)
308         return self._check_ok(onempty)
309
310
311 class Conn(BaseConn):
312     def __init__(self, inp, outp):
313         BaseConn.__init__(self, outp)
314         self.inp = inp
315
316     def _read(self, size):
317         return self.inp.read(size)
318
319     def _readline(self):
320         return self.inp.readline()
321
322     def has_input(self):
323         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
324         if rl:
325             assert(rl[0] == self.inp.fileno())
326             return True
327         else:
328             return None
329
330
331 def checked_reader(fd, n):
332     while n > 0:
333         rl, _, _ = select.select([fd], [], [])
334         assert(rl[0] == fd)
335         buf = os.read(fd, n)
336         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
337         yield buf
338         n -= len(buf)
339
340
341 MAX_PACKET = 128 * 1024
342 def mux(p, outfd, outr, errr):
343     try:
344         fds = [outr, errr]
345         while p.poll() is None:
346             rl, _, _ = select.select(fds, [], [])
347             for fd in rl:
348                 if fd == outr:
349                     buf = os.read(outr, MAX_PACKET)
350                     if not buf: break
351                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
352                 elif fd == errr:
353                     buf = os.read(errr, 1024)
354                     if not buf: break
355                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
356     finally:
357         os.write(outfd, struct.pack('!IB', 0, 3))
358
359
360 class DemuxConn(BaseConn):
361     """A helper class for bup's client-server protocol."""
362     def __init__(self, infd, outp):
363         BaseConn.__init__(self, outp)
364         # Anything that comes through before the sync string was not
365         # multiplexed and can be assumed to be debug/log before mux init.
366         tail = ''
367         while tail != 'BUPMUX':
368             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
369             if not b:
370                 raise IOError('demux: unexpected EOF during initialization')
371             tail += b
372             sys.stderr.write(tail[:-6])  # pre-mux log messages
373             tail = tail[-6:]
374         self.infd = infd
375         self.reader = None
376         self.buf = None
377         self.closed = False
378
379     def write(self, data):
380         self._load_buf(0)
381         BaseConn.write(self, data)
382
383     def _next_packet(self, timeout):
384         if self.closed: return False
385         rl, wl, xl = select.select([self.infd], [], [], timeout)
386         if not rl: return False
387         assert(rl[0] == self.infd)
388         ns = ''.join(checked_reader(self.infd, 5))
389         n, fdw = struct.unpack('!IB', ns)
390         assert(n <= MAX_PACKET)
391         if fdw == 1:
392             self.reader = checked_reader(self.infd, n)
393         elif fdw == 2:
394             for buf in checked_reader(self.infd, n):
395                 sys.stderr.write(buf)
396         elif fdw == 3:
397             self.closed = True
398             debug2("DemuxConn: marked closed\n")
399         return True
400
401     def _load_buf(self, timeout):
402         if self.buf is not None:
403             return True
404         while not self.closed:
405             while not self.reader:
406                 if not self._next_packet(timeout):
407                     return False
408             try:
409                 self.buf = self.reader.next()
410                 return True
411             except StopIteration:
412                 self.reader = None
413         return False
414
415     def _read_parts(self, ix_fn):
416         while self._load_buf(None):
417             assert(self.buf is not None)
418             i = ix_fn(self.buf)
419             if i is None or i == len(self.buf):
420                 yv = self.buf
421                 self.buf = None
422             else:
423                 yv = self.buf[:i]
424                 self.buf = self.buf[i:]
425             yield yv
426             if i is not None:
427                 break
428
429     def _readline(self):
430         def find_eol(buf):
431             try:
432                 return buf.index('\n')+1
433             except ValueError:
434                 return None
435         return ''.join(self._read_parts(find_eol))
436
437     def _read(self, size):
438         csize = [size]
439         def until_size(buf): # Closes on csize
440             if len(buf) < csize[0]:
441                 csize[0] -= len(buf)
442                 return None
443             else:
444                 return csize[0]
445         return ''.join(self._read_parts(until_size))
446
447     def has_input(self):
448         return self._load_buf(0)
449
450
451 def linereader(f):
452     """Generate a list of input lines from 'f' without terminating newlines."""
453     while 1:
454         line = f.readline()
455         if not line:
456             break
457         yield line[:-1]
458
459
460 def chunkyreader(f, count = None):
461     """Generate a list of chunks of data read from 'f'.
462
463     If count is None, read until EOF is reached.
464
465     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
466     reached while reading, raise IOError.
467     """
468     if count != None:
469         while count > 0:
470             b = f.read(min(count, 65536))
471             if not b:
472                 raise IOError('EOF with %d bytes remaining' % count)
473             yield b
474             count -= len(b)
475     else:
476         while 1:
477             b = f.read(65536)
478             if not b: break
479             yield b
480
481
482 def slashappend(s):
483     """Append "/" to 's' if it doesn't aleady end in "/"."""
484     if s and not s.endswith('/'):
485         return s + '/'
486     else:
487         return s
488
489
490 def _mmap_do(f, sz, flags, prot, close):
491     if not sz:
492         st = os.fstat(f.fileno())
493         sz = st.st_size
494     if not sz:
495         # trying to open a zero-length map gives an error, but an empty
496         # string has all the same behaviour of a zero-length map, ie. it has
497         # no elements :)
498         return ''
499     map = mmap.mmap(f.fileno(), sz, flags, prot)
500     if close:
501         f.close()  # map will persist beyond file close
502     return map
503
504
505 def mmap_read(f, sz = 0, close=True):
506     """Create a read-only memory mapped region on file 'f'.
507     If sz is 0, the region will cover the entire file.
508     """
509     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
510
511
512 def mmap_readwrite(f, sz = 0, close=True):
513     """Create a read-write memory mapped region on file 'f'.
514     If sz is 0, the region will cover the entire file.
515     """
516     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
517                     close)
518
519
520 def mmap_readwrite_private(f, sz = 0, close=True):
521     """Create a read-write memory mapped region on file 'f'.
522     If sz is 0, the region will cover the entire file.
523     The map is private, which means the changes are never flushed back to the
524     file.
525     """
526     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
527                     close)
528
529
530 def parse_num(s):
531     """Parse data size information into a float number.
532
533     Here are some examples of conversions:
534         199.2k means 203981 bytes
535         1GB means 1073741824 bytes
536         2.1 tb means 2199023255552 bytes
537     """
538     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
539     if not g:
540         raise ValueError("can't parse %r as a number" % s)
541     (val, unit) = g.groups()
542     num = float(val)
543     unit = unit.lower()
544     if unit in ['t', 'tb']:
545         mult = 1024*1024*1024*1024
546     elif unit in ['g', 'gb']:
547         mult = 1024*1024*1024
548     elif unit in ['m', 'mb']:
549         mult = 1024*1024
550     elif unit in ['k', 'kb']:
551         mult = 1024
552     elif unit in ['', 'b']:
553         mult = 1
554     else:
555         raise ValueError("invalid unit %r in number %r" % (unit, s))
556     return int(num*mult)
557
558
559 def count(l):
560     """Count the number of elements in an iterator. (consumes the iterator)"""
561     return reduce(lambda x,y: x+1, l)
562
563
564 saved_errors = []
565 def add_error(e):
566     """Append an error message to the list of saved errors.
567
568     Once processing is able to stop and output the errors, the saved errors are
569     accessible in the module variable helpers.saved_errors.
570     """
571     saved_errors.append(e)
572     log('%-70s\n' % e)
573
574
575 def handle_ctrl_c():
576     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
577
578     The new exception handler will make sure that bup will exit without an ugly
579     stacktrace when Ctrl-C is hit.
580     """
581     oldhook = sys.excepthook
582     def newhook(exctype, value, traceback):
583         if exctype == KeyboardInterrupt:
584             log('Interrupted.\n')
585         else:
586             return oldhook(exctype, value, traceback)
587     sys.excepthook = newhook
588
589
590 def columnate(l, prefix):
591     """Format elements of 'l' in columns with 'prefix' leading each line.
592
593     The number of columns is determined automatically based on the string
594     lengths.
595     """
596     if not l:
597         return ""
598     l = l[:]
599     clen = max(len(s) for s in l)
600     ncols = (tty_width() - len(prefix)) / (clen + 2)
601     if ncols <= 1:
602         ncols = 1
603         clen = 0
604     cols = []
605     while len(l) % ncols:
606         l.append('')
607     rows = len(l)/ncols
608     for s in range(0, len(l), rows):
609         cols.append(l[s:s+rows])
610     out = ''
611     for row in zip(*cols):
612         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
613     return out
614
615
616 def parse_date_or_fatal(str, fatal):
617     """Parses the given date or calls Option.fatal().
618     For now we expect a string that contains a float."""
619     try:
620         date = atof(str)
621     except ValueError, e:
622         raise fatal('invalid date format (should be a float): %r' % e)
623     else:
624         return date
625
626
627 def strip_path(prefix, path):
628     """Strips a given prefix from a path.
629
630     First both paths are normalized.
631
632     Raises an Exception if no prefix is given.
633     """
634     if prefix == None:
635         raise Exception('no path given')
636
637     normalized_prefix = os.path.realpath(prefix)
638     debug2("normalized_prefix: %s\n" % normalized_prefix)
639     normalized_path = os.path.realpath(path)
640     debug2("normalized_path: %s\n" % normalized_path)
641     if normalized_path.startswith(normalized_prefix):
642         return normalized_path[len(normalized_prefix):]
643     else:
644         return path
645
646
647 def strip_base_path(path, base_paths):
648     """Strips the base path from a given path.
649
650
651     Determines the base path for the given string and then strips it
652     using strip_path().
653     Iterates over all base_paths from long to short, to prevent that
654     a too short base_path is removed.
655     """
656     normalized_path = os.path.realpath(path)
657     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
658     for bp in sorted_base_paths:
659         if normalized_path.startswith(os.path.realpath(bp)):
660             return strip_path(bp, normalized_path)
661     return path
662
663
664 def graft_path(graft_points, path):
665     normalized_path = os.path.realpath(path)
666     for graft_point in graft_points:
667         old_prefix, new_prefix = graft_point
668         if normalized_path.startswith(old_prefix):
669             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
670     return normalized_path
671
672
673 # hashlib is only available in python 2.5 or higher, but the 'sha' module
674 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
675 # python 2.4 and above without any stupid warnings, so let's try using hashlib
676 # first, and downgrade if it fails.
677 try:
678     import hashlib
679 except ImportError:
680     import sha
681     Sha1 = sha.sha
682 else:
683     Sha1 = hashlib.sha1
684
685
686 def version_date():
687     """Format bup's version date string for output."""
688     return _version.DATE.split(' ')[0]
689
690
691 def version_commit():
692     """Get the commit hash of bup's current version."""
693     return _version.COMMIT
694
695
696 def version_tag():
697     """Format bup's version tag (the official version number).
698
699     When generated from a commit other than one pointed to with a tag, the
700     returned string will be "unknown-" followed by the first seven positions of
701     the commit hash.
702     """
703     names = _version.NAMES.strip()
704     assert(names[0] == '(')
705     assert(names[-1] == ')')
706     names = names[1:-1]
707     l = [n.strip() for n in names.split(',')]
708     for n in l:
709         if n.startswith('tag: bup-'):
710             return n[9:]
711     return 'unknown-%s' % _version.COMMIT[:7]