]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
cmd/split: fixup progress message, and print -b output incrementally.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator, time
5 from bup import _version, _helpers
6
7 # This function should really be in helpers, not in bup.options.  But we
8 # want options.py to be standalone so people can include it in other projects.
9 from bup.options import _tty_width
10 tty_width = _tty_width
11
12
13 def atoi(s):
14     """Convert the string 's' to an integer. Return 0 if s is not a number."""
15     try:
16         return int(s or '0')
17     except ValueError:
18         return 0
19
20
21 def atof(s):
22     """Convert the string 's' to a float. Return 0 if s is not a number."""
23     try:
24         return float(s or '0')
25     except ValueError:
26         return 0
27
28
29 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
30
31
32 # Write (blockingly) to sockets that may or may not be in blocking mode.
33 # We need this because our stderr is sometimes eaten by subprocesses
34 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
35 # leading to race conditions.  Ick.  We'll do it the hard way.
36 def _hard_write(fd, buf):
37     while buf:
38         (r,w,x) = select.select([], [fd], [], None)
39         if not w:
40             raise IOError('select(fd) returned without being writable')
41         try:
42             sz = os.write(fd, buf)
43         except OSError, e:
44             if e.errno != errno.EAGAIN:
45                 raise
46         assert(sz >= 0)
47         buf = buf[sz:]
48
49
50 _last_prog = 0
51 def log(s):
52     """Print a log message to stderr."""
53     global _last_prog
54     sys.stdout.flush()
55     _hard_write(sys.stderr.fileno(), s)
56     _last_prog = 0
57
58
59 def debug1(s):
60     if buglvl >= 1:
61         log(s)
62
63
64 def debug2(s):
65     if buglvl >= 2:
66         log(s)
67
68
69 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
70 _last_progress = ''
71 def progress(s):
72     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
73     global _last_progress
74     if istty:
75         log(s)
76         _last_progress = s
77
78
79 def qprogress(s):
80     """Calls progress() only if we haven't printed progress in a while.
81     
82     This avoids overloading the stderr buffer with excess junk.
83     """
84     global _last_prog
85     now = time.time()
86     if now - _last_prog > 0.1:
87         progress(s)
88         _last_prog = now
89
90
91 def reprogress():
92     """Calls progress() to redisplay the most recent progress message.
93
94     Useful after you've printed some other message that wipes out the
95     progress line.
96     """
97     if _last_progress and _last_progress.endswith('\r'):
98         progress(_last_progress)
99
100
101 def mkdirp(d, mode=None):
102     """Recursively create directories on path 'd'.
103
104     Unlike os.makedirs(), it doesn't raise an exception if the last element of
105     the path already exists.
106     """
107     try:
108         if mode:
109             os.makedirs(d, mode)
110         else:
111             os.makedirs(d)
112     except OSError, e:
113         if e.errno == errno.EEXIST:
114             pass
115         else:
116             raise
117
118
119 def next(it):
120     """Get the next item from an iterator, None if we reached the end."""
121     try:
122         return it.next()
123     except StopIteration:
124         return None
125
126
127 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
128     if key:
129         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
130     else:
131         samekey = operator.eq
132     count = 0
133     total = sum(len(it) for it in iters)
134     iters = (iter(it) for it in iters)
135     heap = ((next(it),it) for it in iters)
136     heap = [(e,it) for e,it in heap if e]
137
138     heapq.heapify(heap)
139     pe = None
140     while heap:
141         if not count % pfreq:
142             pfunc(count, total)
143         e, it = heap[0]
144         if not samekey(e, pe):
145             pe = e
146             yield e
147         count += 1
148         try:
149             e = it.next() # Don't use next() function, it's too expensive
150         except StopIteration:
151             heapq.heappop(heap) # remove current
152         else:
153             heapq.heapreplace(heap, (e, it)) # shift current to new location
154     pfinal(count, total)
155
156
157 def unlink(f):
158     """Delete a file at path 'f' if it currently exists.
159
160     Unlike os.unlink(), does not throw an exception if the file didn't already
161     exist.
162     """
163     try:
164         os.unlink(f)
165     except OSError, e:
166         if e.errno == errno.ENOENT:
167             pass  # it doesn't exist, that's what you asked for
168
169
170 def readpipe(argv):
171     """Run a subprocess and return its output."""
172     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
173     r = p.stdout.read()
174     p.wait()
175     return r
176
177
178 def realpath(p):
179     """Get the absolute path of a file.
180
181     Behaves like os.path.realpath, but doesn't follow a symlink for the last
182     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
183     will follow symlinks in p's directory)
184     """
185     try:
186         st = os.lstat(p)
187     except OSError:
188         st = None
189     if st and stat.S_ISLNK(st.st_mode):
190         (dir, name) = os.path.split(p)
191         dir = os.path.realpath(dir)
192         out = os.path.join(dir, name)
193     else:
194         out = os.path.realpath(p)
195     #log('realpathing:%r,%r\n' % (p, out))
196     return out
197
198
199 _username = None
200 def username():
201     """Get the user's login name."""
202     global _username
203     if not _username:
204         uid = os.getuid()
205         try:
206             _username = pwd.getpwuid(uid)[0]
207         except KeyError:
208             _username = 'user%d' % uid
209     return _username
210
211
212 _userfullname = None
213 def userfullname():
214     """Get the user's full name."""
215     global _userfullname
216     if not _userfullname:
217         uid = os.getuid()
218         try:
219             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
220         except KeyError:
221             _userfullname = 'user%d' % uid
222     return _userfullname
223
224
225 _hostname = None
226 def hostname():
227     """Get the FQDN of this machine."""
228     global _hostname
229     if not _hostname:
230         _hostname = socket.getfqdn()
231     return _hostname
232
233
234 _resource_path = None
235 def resource_path(subdir=''):
236     global _resource_path
237     if not _resource_path:
238         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
239     return os.path.join(_resource_path, subdir)
240
241
242 class NotOk(Exception):
243     pass
244
245
246 class BaseConn:
247     def __init__(self, outp):
248         self.outp = outp
249
250     def close(self):
251         while self._read(65536): pass
252
253     def read(self, size):
254         """Read 'size' bytes from input stream."""
255         self.outp.flush()
256         return self._read(size)
257
258     def readline(self):
259         """Read from input stream until a newline is found."""
260         self.outp.flush()
261         return self._readline()
262
263     def write(self, data):
264         """Write 'data' to output stream."""
265         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
266         self.outp.write(data)
267
268     def has_input(self):
269         """Return true if input stream is readable."""
270         raise NotImplemented("Subclasses must implement has_input")
271
272     def ok(self):
273         """Indicate end of output from last sent command."""
274         self.write('\nok\n')
275
276     def error(self, s):
277         """Indicate server error to the client."""
278         s = re.sub(r'\s+', ' ', str(s))
279         self.write('\nerror %s\n' % s)
280
281     def _check_ok(self, onempty):
282         self.outp.flush()
283         rl = ''
284         for rl in linereader(self):
285             #log('%d got line: %r\n' % (os.getpid(), rl))
286             if not rl:  # empty line
287                 continue
288             elif rl == 'ok':
289                 return None
290             elif rl.startswith('error '):
291                 #log('client: error: %s\n' % rl[6:])
292                 return NotOk(rl[6:])
293             else:
294                 onempty(rl)
295         raise Exception('server exited unexpectedly; see errors above')
296
297     def drain_and_check_ok(self):
298         """Remove all data for the current command from input stream."""
299         def onempty(rl):
300             pass
301         return self._check_ok(onempty)
302
303     def check_ok(self):
304         """Verify that server action completed successfully."""
305         def onempty(rl):
306             raise Exception('expected "ok", got %r' % rl)
307         return self._check_ok(onempty)
308
309
310 class Conn(BaseConn):
311     def __init__(self, inp, outp):
312         BaseConn.__init__(self, outp)
313         self.inp = inp
314
315     def _read(self, size):
316         return self.inp.read(size)
317
318     def _readline(self):
319         return self.inp.readline()
320
321     def has_input(self):
322         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
323         if rl:
324             assert(rl[0] == self.inp.fileno())
325             return True
326         else:
327             return None
328
329
330 def checked_reader(fd, n):
331     while n > 0:
332         rl, _, _ = select.select([fd], [], [])
333         assert(rl[0] == fd)
334         buf = os.read(fd, n)
335         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
336         yield buf
337         n -= len(buf)
338
339
340 MAX_PACKET = 128 * 1024
341 def mux(p, outfd, outr, errr):
342     try:
343         fds = [outr, errr]
344         while p.poll() is None:
345             rl, _, _ = select.select(fds, [], [])
346             for fd in rl:
347                 if fd == outr:
348                     buf = os.read(outr, MAX_PACKET)
349                     if not buf: break
350                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
351                 elif fd == errr:
352                     buf = os.read(errr, 1024)
353                     if not buf: break
354                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
355     finally:
356         os.write(outfd, struct.pack('!IB', 0, 3))
357
358
359 class DemuxConn(BaseConn):
360     """A helper class for bup's client-server protocol."""
361     def __init__(self, infd, outp):
362         BaseConn.__init__(self, outp)
363         # Anything that comes through before the sync string was not
364         # multiplexed and can be assumed to be debug/log before mux init.
365         tail = ''
366         while tail != 'BUPMUX':
367             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
368             if not b:
369                 raise IOError('demux: unexpected EOF during initialization')
370             tail += b
371             sys.stderr.write(tail[:-6])  # pre-mux log messages
372             tail = tail[-6:]
373         self.infd = infd
374         self.reader = None
375         self.buf = None
376         self.closed = False
377
378     def write(self, data):
379         self._load_buf(0)
380         BaseConn.write(self, data)
381
382     def _next_packet(self, timeout):
383         if self.closed: return False
384         rl, wl, xl = select.select([self.infd], [], [], timeout)
385         if not rl: return False
386         assert(rl[0] == self.infd)
387         ns = ''.join(checked_reader(self.infd, 5))
388         n, fdw = struct.unpack('!IB', ns)
389         assert(n <= MAX_PACKET)
390         if fdw == 1:
391             self.reader = checked_reader(self.infd, n)
392         elif fdw == 2:
393             for buf in checked_reader(self.infd, n):
394                 sys.stderr.write(buf)
395         elif fdw == 3:
396             self.closed = True
397             debug2("DemuxConn: marked closed\n")
398         return True
399
400     def _load_buf(self, timeout):
401         if self.buf is not None:
402             return True
403         while not self.closed:
404             while not self.reader:
405                 if not self._next_packet(timeout):
406                     return False
407             try:
408                 self.buf = self.reader.next()
409                 return True
410             except StopIteration:
411                 self.reader = None
412         return False
413
414     def _read_parts(self, ix_fn):
415         while self._load_buf(None):
416             assert(self.buf is not None)
417             i = ix_fn(self.buf)
418             if i is None or i == len(self.buf):
419                 yv = self.buf
420                 self.buf = None
421             else:
422                 yv = self.buf[:i]
423                 self.buf = self.buf[i:]
424             yield yv
425             if i is not None:
426                 break
427
428     def _readline(self):
429         def find_eol(buf):
430             try:
431                 return buf.index('\n')+1
432             except ValueError:
433                 return None
434         return ''.join(self._read_parts(find_eol))
435
436     def _read(self, size):
437         csize = [size]
438         def until_size(buf): # Closes on csize
439             if len(buf) < csize[0]:
440                 csize[0] -= len(buf)
441                 return None
442             else:
443                 return csize[0]
444         return ''.join(self._read_parts(until_size))
445
446     def has_input(self):
447         return self._load_buf(0)
448
449
450 def linereader(f):
451     """Generate a list of input lines from 'f' without terminating newlines."""
452     while 1:
453         line = f.readline()
454         if not line:
455             break
456         yield line[:-1]
457
458
459 def chunkyreader(f, count = None):
460     """Generate a list of chunks of data read from 'f'.
461
462     If count is None, read until EOF is reached.
463
464     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
465     reached while reading, raise IOError.
466     """
467     if count != None:
468         while count > 0:
469             b = f.read(min(count, 65536))
470             if not b:
471                 raise IOError('EOF with %d bytes remaining' % count)
472             yield b
473             count -= len(b)
474     else:
475         while 1:
476             b = f.read(65536)
477             if not b: break
478             yield b
479
480
481 def slashappend(s):
482     """Append "/" to 's' if it doesn't aleady end in "/"."""
483     if s and not s.endswith('/'):
484         return s + '/'
485     else:
486         return s
487
488
489 def _mmap_do(f, sz, flags, prot, close):
490     if not sz:
491         st = os.fstat(f.fileno())
492         sz = st.st_size
493     if not sz:
494         # trying to open a zero-length map gives an error, but an empty
495         # string has all the same behaviour of a zero-length map, ie. it has
496         # no elements :)
497         return ''
498     map = mmap.mmap(f.fileno(), sz, flags, prot)
499     if close:
500         f.close()  # map will persist beyond file close
501     return map
502
503
504 def mmap_read(f, sz = 0, close=True):
505     """Create a read-only memory mapped region on file 'f'.
506     If sz is 0, the region will cover the entire file.
507     """
508     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
509
510
511 def mmap_readwrite(f, sz = 0, close=True):
512     """Create a read-write memory mapped region on file 'f'.
513     If sz is 0, the region will cover the entire file.
514     """
515     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
516                     close)
517
518
519 def mmap_readwrite_private(f, sz = 0, close=True):
520     """Create a read-write memory mapped region on file 'f'.
521     If sz is 0, the region will cover the entire file.
522     The map is private, which means the changes are never flushed back to the
523     file.
524     """
525     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
526                     close)
527
528
529 def parse_num(s):
530     """Parse data size information into a float number.
531
532     Here are some examples of conversions:
533         199.2k means 203981 bytes
534         1GB means 1073741824 bytes
535         2.1 tb means 2199023255552 bytes
536     """
537     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
538     if not g:
539         raise ValueError("can't parse %r as a number" % s)
540     (val, unit) = g.groups()
541     num = float(val)
542     unit = unit.lower()
543     if unit in ['t', 'tb']:
544         mult = 1024*1024*1024*1024
545     elif unit in ['g', 'gb']:
546         mult = 1024*1024*1024
547     elif unit in ['m', 'mb']:
548         mult = 1024*1024
549     elif unit in ['k', 'kb']:
550         mult = 1024
551     elif unit in ['', 'b']:
552         mult = 1
553     else:
554         raise ValueError("invalid unit %r in number %r" % (unit, s))
555     return int(num*mult)
556
557
558 def count(l):
559     """Count the number of elements in an iterator. (consumes the iterator)"""
560     return reduce(lambda x,y: x+1, l)
561
562
563 saved_errors = []
564 def add_error(e):
565     """Append an error message to the list of saved errors.
566
567     Once processing is able to stop and output the errors, the saved errors are
568     accessible in the module variable helpers.saved_errors.
569     """
570     saved_errors.append(e)
571     log('%-70s\n' % e)
572
573
574 def handle_ctrl_c():
575     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
576
577     The new exception handler will make sure that bup will exit without an ugly
578     stacktrace when Ctrl-C is hit.
579     """
580     oldhook = sys.excepthook
581     def newhook(exctype, value, traceback):
582         if exctype == KeyboardInterrupt:
583             log('Interrupted.\n')
584         else:
585             return oldhook(exctype, value, traceback)
586     sys.excepthook = newhook
587
588
589 def columnate(l, prefix):
590     """Format elements of 'l' in columns with 'prefix' leading each line.
591
592     The number of columns is determined automatically based on the string
593     lengths.
594     """
595     if not l:
596         return ""
597     l = l[:]
598     clen = max(len(s) for s in l)
599     ncols = (tty_width() - len(prefix)) / (clen + 2)
600     if ncols <= 1:
601         ncols = 1
602         clen = 0
603     cols = []
604     while len(l) % ncols:
605         l.append('')
606     rows = len(l)/ncols
607     for s in range(0, len(l), rows):
608         cols.append(l[s:s+rows])
609     out = ''
610     for row in zip(*cols):
611         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
612     return out
613
614
615 def parse_date_or_fatal(str, fatal):
616     """Parses the given date or calls Option.fatal().
617     For now we expect a string that contains a float."""
618     try:
619         date = atof(str)
620     except ValueError, e:
621         raise fatal('invalid date format (should be a float): %r' % e)
622     else:
623         return date
624
625
626 def strip_path(prefix, path):
627     """Strips a given prefix from a path.
628
629     First both paths are normalized.
630
631     Raises an Exception if no prefix is given.
632     """
633     if prefix == None:
634         raise Exception('no path given')
635
636     normalized_prefix = os.path.realpath(prefix)
637     debug2("normalized_prefix: %s\n" % normalized_prefix)
638     normalized_path = os.path.realpath(path)
639     debug2("normalized_path: %s\n" % normalized_path)
640     if normalized_path.startswith(normalized_prefix):
641         return normalized_path[len(normalized_prefix):]
642     else:
643         return path
644
645
646 def strip_base_path(path, base_paths):
647     """Strips the base path from a given path.
648
649
650     Determines the base path for the given string and then strips it
651     using strip_path().
652     Iterates over all base_paths from long to short, to prevent that
653     a too short base_path is removed.
654     """
655     normalized_path = os.path.realpath(path)
656     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
657     for bp in sorted_base_paths:
658         if normalized_path.startswith(os.path.realpath(bp)):
659             return strip_path(bp, normalized_path)
660     return path
661
662
663 def graft_path(graft_points, path):
664     normalized_path = os.path.realpath(path)
665     for graft_point in graft_points:
666         old_prefix, new_prefix = graft_point
667         if normalized_path.startswith(old_prefix):
668             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
669     return normalized_path
670
671
672 # hashlib is only available in python 2.5 or higher, but the 'sha' module
673 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
674 # python 2.4 and above without any stupid warnings, so let's try using hashlib
675 # first, and downgrade if it fails.
676 try:
677     import hashlib
678 except ImportError:
679     import sha
680     Sha1 = sha.sha
681 else:
682     Sha1 = hashlib.sha1
683
684
685 def version_date():
686     """Format bup's version date string for output."""
687     return _version.DATE.split(' ')[0]
688
689
690 def version_commit():
691     """Get the commit hash of bup's current version."""
692     return _version.COMMIT
693
694
695 def version_tag():
696     """Format bup's version tag (the official version number).
697
698     When generated from a commit other than one pointed to with a tag, the
699     returned string will be "unknown-" followed by the first seven positions of
700     the commit hash.
701     """
702     names = _version.NAMES.strip()
703     assert(names[0] == '(')
704     assert(names[-1] == ')')
705     names = names[1:-1]
706     l = [n.strip() for n in names.split(',')]
707     for n in l:
708         if n.startswith('tag: bup-'):
709             return n[9:]
710     return 'unknown-%s' % _version.COMMIT[:7]