]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
DemuxConn.__init__: you can't assume the *last* 6 bytes are BUPMUX.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator
5 from bup import _version
6
7 # This function should really be in helpers, not in bup.options.  But we
8 # want options.py to be standalone so people can include it in other projects.
9 from bup.options import _tty_width
10 tty_width = _tty_width
11
12
13 def atoi(s):
14     """Convert the string 's' to an integer. Return 0 if s is not a number."""
15     try:
16         return int(s or '0')
17     except ValueError:
18         return 0
19
20
21 def atof(s):
22     """Convert the string 's' to a float. Return 0 if s is not a number."""
23     try:
24         return float(s or '0')
25     except ValueError:
26         return 0
27
28
29 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
30
31
32 # Write (blockingly) to sockets that may or may not be in blocking mode.
33 # We need this because our stderr is sometimes eaten by subprocesses
34 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
35 # leading to race conditions.  Ick.  We'll do it the hard way.
36 def _hard_write(fd, buf):
37     while buf:
38         (r,w,x) = select.select([], [fd], [], None)
39         if not w:
40             raise IOError('select(fd) returned without being writable')
41         try:
42             sz = os.write(fd, buf)
43         except OSError, e:
44             if e.errno != errno.EAGAIN:
45                 raise
46         assert(sz >= 0)
47         buf = buf[sz:]
48
49 def log(s):
50     """Print a log message to stderr."""
51     sys.stdout.flush()
52     _hard_write(sys.stderr.fileno(), s)
53
54
55 def debug1(s):
56     if buglvl >= 1:
57         log(s)
58
59
60 def debug2(s):
61     if buglvl >= 2:
62         log(s)
63
64
65 def mkdirp(d, mode=None):
66     """Recursively create directories on path 'd'.
67
68     Unlike os.makedirs(), it doesn't raise an exception if the last element of
69     the path already exists.
70     """
71     try:
72         if mode:
73             os.makedirs(d, mode)
74         else:
75             os.makedirs(d)
76     except OSError, e:
77         if e.errno == errno.EEXIST:
78             pass
79         else:
80             raise
81
82
83 def next(it):
84     """Get the next item from an iterator, None if we reached the end."""
85     try:
86         return it.next()
87     except StopIteration:
88         return None
89
90
91 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
92     if key:
93         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
94     else:
95         samekey = operator.eq
96     count = 0
97     total = sum(len(it) for it in iters)
98     iters = (iter(it) for it in iters)
99     heap = ((next(it),it) for it in iters)
100     heap = [(e,it) for e,it in heap if e]
101
102     heapq.heapify(heap)
103     pe = None
104     while heap:
105         if not count % pfreq:
106             pfunc(count, total)
107         e, it = heap[0]
108         if not samekey(e, pe):
109             pe = e
110             yield e
111         count += 1
112         try:
113             e = it.next() # Don't use next() function, it's too expensive
114         except StopIteration:
115             heapq.heappop(heap) # remove current
116         else:
117             heapq.heapreplace(heap, (e, it)) # shift current to new location
118     pfinal(count, total)
119
120
121 def unlink(f):
122     """Delete a file at path 'f' if it currently exists.
123
124     Unlike os.unlink(), does not throw an exception if the file didn't already
125     exist.
126     """
127     try:
128         os.unlink(f)
129     except OSError, e:
130         if e.errno == errno.ENOENT:
131             pass  # it doesn't exist, that's what you asked for
132
133
134 def readpipe(argv):
135     """Run a subprocess and return its output."""
136     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
137     r = p.stdout.read()
138     p.wait()
139     return r
140
141
142 def realpath(p):
143     """Get the absolute path of a file.
144
145     Behaves like os.path.realpath, but doesn't follow a symlink for the last
146     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
147     will follow symlinks in p's directory)
148     """
149     try:
150         st = os.lstat(p)
151     except OSError:
152         st = None
153     if st and stat.S_ISLNK(st.st_mode):
154         (dir, name) = os.path.split(p)
155         dir = os.path.realpath(dir)
156         out = os.path.join(dir, name)
157     else:
158         out = os.path.realpath(p)
159     #log('realpathing:%r,%r\n' % (p, out))
160     return out
161
162
163 _username = None
164 def username():
165     """Get the user's login name."""
166     global _username
167     if not _username:
168         uid = os.getuid()
169         try:
170             _username = pwd.getpwuid(uid)[0]
171         except KeyError:
172             _username = 'user%d' % uid
173     return _username
174
175
176 _userfullname = None
177 def userfullname():
178     """Get the user's full name."""
179     global _userfullname
180     if not _userfullname:
181         uid = os.getuid()
182         try:
183             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
184         except KeyError:
185             _userfullname = 'user%d' % uid
186     return _userfullname
187
188
189 _hostname = None
190 def hostname():
191     """Get the FQDN of this machine."""
192     global _hostname
193     if not _hostname:
194         _hostname = socket.getfqdn()
195     return _hostname
196
197
198 _resource_path = None
199 def resource_path(subdir=''):
200     global _resource_path
201     if not _resource_path:
202         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
203     return os.path.join(_resource_path, subdir)
204
205
206 class NotOk(Exception):
207     pass
208
209
210 class BaseConn:
211     def __init__(self, outp):
212         self.outp = outp
213
214     def close(self):
215         while self._read(65536): pass
216
217     def read(self, size):
218         """Read 'size' bytes from input stream."""
219         self.outp.flush()
220         return self._read(size)
221
222     def readline(self):
223         """Read from input stream until a newline is found."""
224         self.outp.flush()
225         return self._readline()
226
227     def write(self, data):
228         """Write 'data' to output stream."""
229         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
230         self.outp.write(data)
231
232     def has_input(self):
233         """Return true if input stream is readable."""
234         raise NotImplemented("Subclasses must implement has_input")
235
236     def ok(self):
237         """Indicate end of output from last sent command."""
238         self.write('\nok\n')
239
240     def error(self, s):
241         """Indicate server error to the client."""
242         s = re.sub(r'\s+', ' ', str(s))
243         self.write('\nerror %s\n' % s)
244
245     def _check_ok(self, onempty):
246         self.outp.flush()
247         rl = ''
248         for rl in linereader(self):
249             #log('%d got line: %r\n' % (os.getpid(), rl))
250             if not rl:  # empty line
251                 continue
252             elif rl == 'ok':
253                 return None
254             elif rl.startswith('error '):
255                 #log('client: error: %s\n' % rl[6:])
256                 return NotOk(rl[6:])
257             else:
258                 onempty(rl)
259         raise Exception('server exited unexpectedly; see errors above')
260
261     def drain_and_check_ok(self):
262         """Remove all data for the current command from input stream."""
263         def onempty(rl):
264             pass
265         return self._check_ok(onempty)
266
267     def check_ok(self):
268         """Verify that server action completed successfully."""
269         def onempty(rl):
270             raise Exception('expected "ok", got %r' % rl)
271         return self._check_ok(onempty)
272
273
274 class Conn(BaseConn):
275     def __init__(self, inp, outp):
276         BaseConn.__init__(self, outp)
277         self.inp = inp
278
279     def _read(self, size):
280         return self.inp.read(size)
281
282     def _readline(self):
283         return self.inp.readline()
284
285     def has_input(self):
286         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
287         if rl:
288             assert(rl[0] == self.inp.fileno())
289             return True
290         else:
291             return None
292
293
294 def checked_reader(fd, n):
295     while n > 0:
296         rl, _, _ = select.select([fd], [], [])
297         assert(rl[0] == fd)
298         buf = os.read(fd, n)
299         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
300         yield buf
301         n -= len(buf)
302
303
304 MAX_PACKET = 128 * 1024
305 def mux(p, outfd, outr, errr):
306     try:
307         fds = [outr, errr]
308         while p.poll() is None:
309             rl, _, _ = select.select(fds, [], [])
310             for fd in rl:
311                 if fd == outr:
312                     buf = os.read(outr, MAX_PACKET)
313                     if not buf: break
314                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
315                 elif fd == errr:
316                     buf = os.read(errr, 1024)
317                     if not buf: break
318                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
319     finally:
320         os.write(outfd, struct.pack('!IB', 0, 3))
321
322
323 class DemuxConn(BaseConn):
324     """A helper class for bup's client-server protocol."""
325     def __init__(self, infd, outp):
326         BaseConn.__init__(self, outp)
327         # Anything that comes through before the sync string was not
328         # multiplexed and can be assumed to be debug/log before mux init.
329         tail = ''
330         while tail != 'BUPMUX':
331             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
332             if not b:
333                 raise IOError('demux: unexpected EOF during initialization')
334             tail += b
335             sys.stderr.write(tail[:-6])  # pre-mux log messages
336             tail = tail[-6:]
337         self.infd = infd
338         self.reader = None
339         self.buf = None
340         self.closed = False
341
342     def write(self, data):
343         self._load_buf(0)
344         BaseConn.write(self, data)
345
346     def _next_packet(self, timeout):
347         if self.closed: return False
348         rl, wl, xl = select.select([self.infd], [], [], timeout)
349         if not rl: return False
350         assert(rl[0] == self.infd)
351         ns = ''.join(checked_reader(self.infd, 5))
352         n, fdw = struct.unpack('!IB', ns)
353         assert(n <= MAX_PACKET)
354         if fdw == 1:
355             self.reader = checked_reader(self.infd, n)
356         elif fdw == 2:
357             for buf in checked_reader(self.infd, n):
358                 sys.stderr.write(buf)
359         elif fdw == 3:
360             self.closed = True
361             debug2("DemuxConn: marked closed\n")
362         return True
363
364     def _load_buf(self, timeout):
365         if self.buf is not None:
366             return True
367         while not self.closed:
368             while not self.reader:
369                 if not self._next_packet(timeout):
370                     return False
371             try:
372                 self.buf = self.reader.next()
373                 return True
374             except StopIteration:
375                 self.reader = None
376         return False
377
378     def _read_parts(self, ix_fn):
379         while self._load_buf(None):
380             assert(self.buf is not None)
381             i = ix_fn(self.buf)
382             if i is None or i == len(self.buf):
383                 yv = self.buf
384                 self.buf = None
385             else:
386                 yv = self.buf[:i]
387                 self.buf = self.buf[i:]
388             yield yv
389             if i is not None:
390                 break
391
392     def _readline(self):
393         def find_eol(buf):
394             try:
395                 return buf.index('\n')+1
396             except ValueError:
397                 return None
398         return ''.join(self._read_parts(find_eol))
399
400     def _read(self, size):
401         csize = [size]
402         def until_size(buf): # Closes on csize
403             if len(buf) < csize[0]:
404                 csize[0] -= len(buf)
405                 return None
406             else:
407                 return csize[0]
408         return ''.join(self._read_parts(until_size))
409
410     def has_input(self):
411         return self._load_buf(0)
412
413
414 def linereader(f):
415     """Generate a list of input lines from 'f' without terminating newlines."""
416     while 1:
417         line = f.readline()
418         if not line:
419             break
420         yield line[:-1]
421
422
423 def chunkyreader(f, count = None):
424     """Generate a list of chunks of data read from 'f'.
425
426     If count is None, read until EOF is reached.
427
428     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
429     reached while reading, raise IOError.
430     """
431     if count != None:
432         while count > 0:
433             b = f.read(min(count, 65536))
434             if not b:
435                 raise IOError('EOF with %d bytes remaining' % count)
436             yield b
437             count -= len(b)
438     else:
439         while 1:
440             b = f.read(65536)
441             if not b: break
442             yield b
443
444
445 def slashappend(s):
446     """Append "/" to 's' if it doesn't aleady end in "/"."""
447     if s and not s.endswith('/'):
448         return s + '/'
449     else:
450         return s
451
452
453 def _mmap_do(f, sz, flags, prot):
454     if not sz:
455         st = os.fstat(f.fileno())
456         sz = st.st_size
457     if not sz:
458         # trying to open a zero-length map gives an error, but an empty
459         # string has all the same behaviour of a zero-length map, ie. it has
460         # no elements :)
461         return ''
462     map = mmap.mmap(f.fileno(), sz, flags, prot)
463     f.close()  # map will persist beyond file close
464     return map
465
466
467 def mmap_read(f, sz = 0):
468     """Create a read-only memory mapped region on file 'f'.
469
470     If sz is 0, the region will cover the entire file.
471     """
472     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
473
474
475 def mmap_readwrite(f, sz = 0):
476     """Create a read-write memory mapped region on file 'f'.
477
478     If sz is 0, the region will cover the entire file.
479     """
480     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
481
482
483 def parse_num(s):
484     """Parse data size information into a float number.
485
486     Here are some examples of conversions:
487         199.2k means 203981 bytes
488         1GB means 1073741824 bytes
489         2.1 tb means 2199023255552 bytes
490     """
491     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
492     if not g:
493         raise ValueError("can't parse %r as a number" % s)
494     (val, unit) = g.groups()
495     num = float(val)
496     unit = unit.lower()
497     if unit in ['t', 'tb']:
498         mult = 1024*1024*1024*1024
499     elif unit in ['g', 'gb']:
500         mult = 1024*1024*1024
501     elif unit in ['m', 'mb']:
502         mult = 1024*1024
503     elif unit in ['k', 'kb']:
504         mult = 1024
505     elif unit in ['', 'b']:
506         mult = 1
507     else:
508         raise ValueError("invalid unit %r in number %r" % (unit, s))
509     return int(num*mult)
510
511
512 def count(l):
513     """Count the number of elements in an iterator. (consumes the iterator)"""
514     return reduce(lambda x,y: x+1, l)
515
516
517 saved_errors = []
518 def add_error(e):
519     """Append an error message to the list of saved errors.
520
521     Once processing is able to stop and output the errors, the saved errors are
522     accessible in the module variable helpers.saved_errors.
523     """
524     saved_errors.append(e)
525     log('%-70s\n' % e)
526
527
528 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
529 def progress(s):
530     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
531     if istty:
532         log(s)
533
534
535 def handle_ctrl_c():
536     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
537
538     The new exception handler will make sure that bup will exit without an ugly
539     stacktrace when Ctrl-C is hit.
540     """
541     oldhook = sys.excepthook
542     def newhook(exctype, value, traceback):
543         if exctype == KeyboardInterrupt:
544             log('Interrupted.\n')
545         else:
546             return oldhook(exctype, value, traceback)
547     sys.excepthook = newhook
548
549
550 def columnate(l, prefix):
551     """Format elements of 'l' in columns with 'prefix' leading each line.
552
553     The number of columns is determined automatically based on the string
554     lengths.
555     """
556     if not l:
557         return ""
558     l = l[:]
559     clen = max(len(s) for s in l)
560     ncols = (tty_width() - len(prefix)) / (clen + 2)
561     if ncols <= 1:
562         ncols = 1
563         clen = 0
564     cols = []
565     while len(l) % ncols:
566         l.append('')
567     rows = len(l)/ncols
568     for s in range(0, len(l), rows):
569         cols.append(l[s:s+rows])
570     out = ''
571     for row in zip(*cols):
572         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
573     return out
574
575
576 def parse_date_or_fatal(str, fatal):
577     """Parses the given date or calls Option.fatal().
578     For now we expect a string that contains a float."""
579     try:
580         date = atof(str)
581     except ValueError, e:
582         raise fatal('invalid date format (should be a float): %r' % e)
583     else:
584         return date
585
586
587 def strip_path(prefix, path):
588     """Strips a given prefix from a path.
589
590     First both paths are normalized.
591
592     Raises an Exception if no prefix is given.
593     """
594     if prefix == None:
595         raise Exception('no path given')
596
597     normalized_prefix = os.path.realpath(prefix)
598     debug2("normalized_prefix: %s\n" % normalized_prefix)
599     normalized_path = os.path.realpath(path)
600     debug2("normalized_path: %s\n" % normalized_path)
601     if normalized_path.startswith(normalized_prefix):
602         return normalized_path[len(normalized_prefix):]
603     else:
604         return path
605
606
607 def strip_base_path(path, base_paths):
608     """Strips the base path from a given path.
609
610
611     Determines the base path for the given string and then strips it
612     using strip_path().
613     Iterates over all base_paths from long to short, to prevent that
614     a too short base_path is removed.
615     """
616     normalized_path = os.path.realpath(path)
617     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
618     for bp in sorted_base_paths:
619         if normalized_path.startswith(os.path.realpath(bp)):
620             return strip_path(bp, normalized_path)
621     return path
622
623
624 def graft_path(graft_points, path):
625     normalized_path = os.path.realpath(path)
626     for graft_point in graft_points:
627         old_prefix, new_prefix = graft_point
628         if normalized_path.startswith(old_prefix):
629             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
630     return normalized_path
631
632
633 # hashlib is only available in python 2.5 or higher, but the 'sha' module
634 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
635 # python 2.4 and above without any stupid warnings, so let's try using hashlib
636 # first, and downgrade if it fails.
637 try:
638     import hashlib
639 except ImportError:
640     import sha
641     Sha1 = sha.sha
642 else:
643     Sha1 = hashlib.sha1
644
645
646 def version_date():
647     """Format bup's version date string for output."""
648     return _version.DATE.split(' ')[0]
649
650
651 def version_commit():
652     """Get the commit hash of bup's current version."""
653     return _version.COMMIT
654
655
656 def version_tag():
657     """Format bup's version tag (the official version number).
658
659     When generated from a commit other than one pointed to with a tag, the
660     returned string will be "unknown-" followed by the first seven positions of
661     the commit hash.
662     """
663     names = _version.NAMES.strip()
664     assert(names[0] == '(')
665     assert(names[-1] == ')')
666     names = names[1:-1]
667     l = [n.strip() for n in names.split(',')]
668     for n in l:
669         if n.startswith('tag: bup-'):
670             return n[9:]
671     return 'unknown-%s' % _version.COMMIT[:7]