]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
DemuxConn.__init__: abort the loop if read() returns EOF.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator
5 from bup import _version
6
7 # This function should really be in helpers, not in bup.options.  But we
8 # want options.py to be standalone so people can include it in other projects.
9 from bup.options import _tty_width
10 tty_width = _tty_width
11
12
13 def atoi(s):
14     """Convert the string 's' to an integer. Return 0 if s is not a number."""
15     try:
16         return int(s or '0')
17     except ValueError:
18         return 0
19
20
21 def atof(s):
22     """Convert the string 's' to a float. Return 0 if s is not a number."""
23     try:
24         return float(s or '0')
25     except ValueError:
26         return 0
27
28
29 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
30
31
32 # Write (blockingly) to sockets that may or may not be in blocking mode.
33 # We need this because our stderr is sometimes eaten by subprocesses
34 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
35 # leading to race conditions.  Ick.  We'll do it the hard way.
36 def _hard_write(fd, buf):
37     while buf:
38         (r,w,x) = select.select([], [fd], [], None)
39         if not w:
40             raise IOError('select(fd) returned without being writable')
41         try:
42             sz = os.write(fd, buf)
43         except OSError, e:
44             if e.errno != errno.EAGAIN:
45                 raise
46         assert(sz >= 0)
47         buf = buf[sz:]
48
49 def log(s):
50     """Print a log message to stderr."""
51     sys.stdout.flush()
52     _hard_write(sys.stderr.fileno(), s)
53
54
55 def debug1(s):
56     if buglvl >= 1:
57         log(s)
58
59
60 def debug2(s):
61     if buglvl >= 2:
62         log(s)
63
64
65 def mkdirp(d, mode=None):
66     """Recursively create directories on path 'd'.
67
68     Unlike os.makedirs(), it doesn't raise an exception if the last element of
69     the path already exists.
70     """
71     try:
72         if mode:
73             os.makedirs(d, mode)
74         else:
75             os.makedirs(d)
76     except OSError, e:
77         if e.errno == errno.EEXIST:
78             pass
79         else:
80             raise
81
82
83 def next(it):
84     """Get the next item from an iterator, None if we reached the end."""
85     try:
86         return it.next()
87     except StopIteration:
88         return None
89
90
91 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
92     if key:
93         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
94     else:
95         samekey = operator.eq
96     count = 0
97     total = sum(len(it) for it in iters)
98     iters = (iter(it) for it in iters)
99     heap = ((next(it),it) for it in iters)
100     heap = [(e,it) for e,it in heap if e]
101
102     heapq.heapify(heap)
103     pe = None
104     while heap:
105         if not count % pfreq:
106             pfunc(count, total)
107         e, it = heap[0]
108         if not samekey(e, pe):
109             pe = e
110             yield e
111         count += 1
112         try:
113             e = it.next() # Don't use next() function, it's too expensive
114         except StopIteration:
115             heapq.heappop(heap) # remove current
116         else:
117             heapq.heapreplace(heap, (e, it)) # shift current to new location
118     pfinal(count, total)
119
120
121 def unlink(f):
122     """Delete a file at path 'f' if it currently exists.
123
124     Unlike os.unlink(), does not throw an exception if the file didn't already
125     exist.
126     """
127     try:
128         os.unlink(f)
129     except OSError, e:
130         if e.errno == errno.ENOENT:
131             pass  # it doesn't exist, that's what you asked for
132
133
134 def readpipe(argv):
135     """Run a subprocess and return its output."""
136     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
137     r = p.stdout.read()
138     p.wait()
139     return r
140
141
142 def realpath(p):
143     """Get the absolute path of a file.
144
145     Behaves like os.path.realpath, but doesn't follow a symlink for the last
146     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
147     will follow symlinks in p's directory)
148     """
149     try:
150         st = os.lstat(p)
151     except OSError:
152         st = None
153     if st and stat.S_ISLNK(st.st_mode):
154         (dir, name) = os.path.split(p)
155         dir = os.path.realpath(dir)
156         out = os.path.join(dir, name)
157     else:
158         out = os.path.realpath(p)
159     #log('realpathing:%r,%r\n' % (p, out))
160     return out
161
162
163 _username = None
164 def username():
165     """Get the user's login name."""
166     global _username
167     if not _username:
168         uid = os.getuid()
169         try:
170             _username = pwd.getpwuid(uid)[0]
171         except KeyError:
172             _username = 'user%d' % uid
173     return _username
174
175
176 _userfullname = None
177 def userfullname():
178     """Get the user's full name."""
179     global _userfullname
180     if not _userfullname:
181         uid = os.getuid()
182         try:
183             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
184         except KeyError:
185             _userfullname = 'user%d' % uid
186     return _userfullname
187
188
189 _hostname = None
190 def hostname():
191     """Get the FQDN of this machine."""
192     global _hostname
193     if not _hostname:
194         _hostname = socket.getfqdn()
195     return _hostname
196
197
198 _resource_path = None
199 def resource_path(subdir=''):
200     global _resource_path
201     if not _resource_path:
202         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
203     return os.path.join(_resource_path, subdir)
204
205
206 class NotOk(Exception):
207     pass
208
209
210 class BaseConn:
211     def __init__(self, outp):
212         self.outp = outp
213
214     def close(self):
215         while self._read(65536): pass
216
217     def read(self, size):
218         """Read 'size' bytes from input stream."""
219         self.outp.flush()
220         return self._read(size)
221
222     def readline(self):
223         """Read from input stream until a newline is found."""
224         self.outp.flush()
225         return self._readline()
226
227     def write(self, data):
228         """Write 'data' to output stream."""
229         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
230         self.outp.write(data)
231
232     def has_input(self):
233         """Return true if input stream is readable."""
234         raise NotImplemented("Subclasses must implement has_input")
235
236     def ok(self):
237         """Indicate end of output from last sent command."""
238         self.write('\nok\n')
239
240     def error(self, s):
241         """Indicate server error to the client."""
242         s = re.sub(r'\s+', ' ', str(s))
243         self.write('\nerror %s\n' % s)
244
245     def _check_ok(self, onempty):
246         self.outp.flush()
247         rl = ''
248         for rl in linereader(self):
249             #log('%d got line: %r\n' % (os.getpid(), rl))
250             if not rl:  # empty line
251                 continue
252             elif rl == 'ok':
253                 return None
254             elif rl.startswith('error '):
255                 #log('client: error: %s\n' % rl[6:])
256                 return NotOk(rl[6:])
257             else:
258                 onempty(rl)
259         raise Exception('server exited unexpectedly; see errors above')
260
261     def drain_and_check_ok(self):
262         """Remove all data for the current command from input stream."""
263         def onempty(rl):
264             pass
265         return self._check_ok(onempty)
266
267     def check_ok(self):
268         """Verify that server action completed successfully."""
269         def onempty(rl):
270             raise Exception('expected "ok", got %r' % rl)
271         return self._check_ok(onempty)
272
273
274 class Conn(BaseConn):
275     def __init__(self, inp, outp):
276         BaseConn.__init__(self, outp)
277         self.inp = inp
278
279     def _read(self, size):
280         return self.inp.read(size)
281
282     def _readline(self):
283         return self.inp.readline()
284
285     def has_input(self):
286         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
287         if rl:
288             assert(rl[0] == self.inp.fileno())
289             return True
290         else:
291             return None
292
293
294 def checked_reader(fd, n):
295     while n > 0:
296         rl, _, _ = select.select([fd], [], [])
297         assert(rl[0] == fd)
298         buf = os.read(fd, n)
299         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
300         yield buf
301         n -= len(buf)
302
303
304 MAX_PACKET = 128 * 1024
305 def mux(p, outfd, outr, errr):
306     try:
307         fds = [outr, errr]
308         while p.poll() is None:
309             rl, _, _ = select.select(fds, [], [])
310             for fd in rl:
311                 if fd == outr:
312                     buf = os.read(outr, MAX_PACKET)
313                     if not buf: break
314                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
315                 elif fd == errr:
316                     buf = os.read(errr, 1024)
317                     if not buf: break
318                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
319     finally:
320         os.write(outfd, struct.pack('!IB', 0, 3))
321
322
323 class DemuxConn(BaseConn):
324     """A helper class for bup's client-server protocol."""
325     def __init__(self, infd, outp):
326         BaseConn.__init__(self, outp)
327         # Anything that comes through before the sync string was not
328         # multiplexed and can be assumed to be debug/log before mux init.
329         tail = ''
330         while tail != 'BUPMUX':
331             b = os.read(infd, 1024)
332             if not b:
333                 raise IOError('demux: unexpected EOF during initialization')
334             tail += b
335             buf = tail[:-6]
336             tail = tail[-6:]
337             sys.stderr.write(buf)
338         self.infd = infd
339         self.reader = None
340         self.buf = None
341         self.closed = False
342
343     def write(self, data):
344         self._load_buf(0)
345         BaseConn.write(self, data)
346
347     def _next_packet(self, timeout):
348         if self.closed: return False
349         rl, wl, xl = select.select([self.infd], [], [], timeout)
350         if not rl: return False
351         assert(rl[0] == self.infd)
352         ns = ''.join(checked_reader(self.infd, 5))
353         n, fdw = struct.unpack('!IB', ns)
354         assert(n<=MAX_PACKET)
355         if fdw == 1:
356             self.reader = checked_reader(self.infd, n)
357         elif fdw == 2:
358             for buf in checked_reader(self.infd, n):
359                 sys.stderr.write(buf)
360         elif fdw == 3:
361             self.closed = True
362             debug2("DemuxConn: marked closed\n")
363         return True
364
365     def _load_buf(self, timeout):
366         if self.buf is not None:
367             return True
368         while not self.closed:
369             while not self.reader:
370                 if not self._next_packet(timeout):
371                     return False
372             try:
373                 self.buf = self.reader.next()
374                 return True
375             except StopIteration:
376                 self.reader = None
377         return False
378
379     def _read_parts(self, ix_fn):
380         while self._load_buf(None):
381             assert(self.buf is not None)
382             i = ix_fn(self.buf)
383             if i is None or i == len(self.buf):
384                 yv = self.buf
385                 self.buf = None
386             else:
387                 yv = self.buf[:i]
388                 self.buf = self.buf[i:]
389             yield yv
390             if i is not None:
391                 break
392
393     def _readline(self):
394         def find_eol(buf):
395             try:
396                 return buf.index('\n')+1
397             except ValueError:
398                 return None
399         return ''.join(self._read_parts(find_eol))
400
401     def _read(self, size):
402         csize = [size]
403         def until_size(buf): # Closes on csize
404             if len(buf) < csize[0]:
405                 csize[0] -= len(buf)
406                 return None
407             else:
408                 return csize[0]
409         return ''.join(self._read_parts(until_size))
410
411     def has_input(self):
412         return self._load_buf(0)
413
414
415 def linereader(f):
416     """Generate a list of input lines from 'f' without terminating newlines."""
417     while 1:
418         line = f.readline()
419         if not line:
420             break
421         yield line[:-1]
422
423
424 def chunkyreader(f, count = None):
425     """Generate a list of chunks of data read from 'f'.
426
427     If count is None, read until EOF is reached.
428
429     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
430     reached while reading, raise IOError.
431     """
432     if count != None:
433         while count > 0:
434             b = f.read(min(count, 65536))
435             if not b:
436                 raise IOError('EOF with %d bytes remaining' % count)
437             yield b
438             count -= len(b)
439     else:
440         while 1:
441             b = f.read(65536)
442             if not b: break
443             yield b
444
445
446 def slashappend(s):
447     """Append "/" to 's' if it doesn't aleady end in "/"."""
448     if s and not s.endswith('/'):
449         return s + '/'
450     else:
451         return s
452
453
454 def _mmap_do(f, sz, flags, prot):
455     if not sz:
456         st = os.fstat(f.fileno())
457         sz = st.st_size
458     if not sz:
459         # trying to open a zero-length map gives an error, but an empty
460         # string has all the same behaviour of a zero-length map, ie. it has
461         # no elements :)
462         return ''
463     map = mmap.mmap(f.fileno(), sz, flags, prot)
464     f.close()  # map will persist beyond file close
465     return map
466
467
468 def mmap_read(f, sz = 0):
469     """Create a read-only memory mapped region on file 'f'.
470
471     If sz is 0, the region will cover the entire file.
472     """
473     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
474
475
476 def mmap_readwrite(f, sz = 0):
477     """Create a read-write memory mapped region on file 'f'.
478
479     If sz is 0, the region will cover the entire file.
480     """
481     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
482
483
484 def parse_num(s):
485     """Parse data size information into a float number.
486
487     Here are some examples of conversions:
488         199.2k means 203981 bytes
489         1GB means 1073741824 bytes
490         2.1 tb means 2199023255552 bytes
491     """
492     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
493     if not g:
494         raise ValueError("can't parse %r as a number" % s)
495     (val, unit) = g.groups()
496     num = float(val)
497     unit = unit.lower()
498     if unit in ['t', 'tb']:
499         mult = 1024*1024*1024*1024
500     elif unit in ['g', 'gb']:
501         mult = 1024*1024*1024
502     elif unit in ['m', 'mb']:
503         mult = 1024*1024
504     elif unit in ['k', 'kb']:
505         mult = 1024
506     elif unit in ['', 'b']:
507         mult = 1
508     else:
509         raise ValueError("invalid unit %r in number %r" % (unit, s))
510     return int(num*mult)
511
512
513 def count(l):
514     """Count the number of elements in an iterator. (consumes the iterator)"""
515     return reduce(lambda x,y: x+1, l)
516
517
518 saved_errors = []
519 def add_error(e):
520     """Append an error message to the list of saved errors.
521
522     Once processing is able to stop and output the errors, the saved errors are
523     accessible in the module variable helpers.saved_errors.
524     """
525     saved_errors.append(e)
526     log('%-70s\n' % e)
527
528
529 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
530 def progress(s):
531     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
532     if istty:
533         log(s)
534
535
536 def handle_ctrl_c():
537     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
538
539     The new exception handler will make sure that bup will exit without an ugly
540     stacktrace when Ctrl-C is hit.
541     """
542     oldhook = sys.excepthook
543     def newhook(exctype, value, traceback):
544         if exctype == KeyboardInterrupt:
545             log('Interrupted.\n')
546         else:
547             return oldhook(exctype, value, traceback)
548     sys.excepthook = newhook
549
550
551 def columnate(l, prefix):
552     """Format elements of 'l' in columns with 'prefix' leading each line.
553
554     The number of columns is determined automatically based on the string
555     lengths.
556     """
557     if not l:
558         return ""
559     l = l[:]
560     clen = max(len(s) for s in l)
561     ncols = (tty_width() - len(prefix)) / (clen + 2)
562     if ncols <= 1:
563         ncols = 1
564         clen = 0
565     cols = []
566     while len(l) % ncols:
567         l.append('')
568     rows = len(l)/ncols
569     for s in range(0, len(l), rows):
570         cols.append(l[s:s+rows])
571     out = ''
572     for row in zip(*cols):
573         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
574     return out
575
576
577 def parse_date_or_fatal(str, fatal):
578     """Parses the given date or calls Option.fatal().
579     For now we expect a string that contains a float."""
580     try:
581         date = atof(str)
582     except ValueError, e:
583         raise fatal('invalid date format (should be a float): %r' % e)
584     else:
585         return date
586
587
588 def strip_path(prefix, path):
589     """Strips a given prefix from a path.
590
591     First both paths are normalized.
592
593     Raises an Exception if no prefix is given.
594     """
595     if prefix == None:
596         raise Exception('no path given')
597
598     normalized_prefix = os.path.realpath(prefix)
599     debug2("normalized_prefix: %s\n" % normalized_prefix)
600     normalized_path = os.path.realpath(path)
601     debug2("normalized_path: %s\n" % normalized_path)
602     if normalized_path.startswith(normalized_prefix):
603         return normalized_path[len(normalized_prefix):]
604     else:
605         return path
606
607
608 def strip_base_path(path, base_paths):
609     """Strips the base path from a given path.
610
611
612     Determines the base path for the given string and then strips it
613     using strip_path().
614     Iterates over all base_paths from long to short, to prevent that
615     a too short base_path is removed.
616     """
617     normalized_path = os.path.realpath(path)
618     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
619     for bp in sorted_base_paths:
620         if normalized_path.startswith(os.path.realpath(bp)):
621             return strip_path(bp, normalized_path)
622     return path
623
624
625 def graft_path(graft_points, path):
626     normalized_path = os.path.realpath(path)
627     for graft_point in graft_points:
628         old_prefix, new_prefix = graft_point
629         if normalized_path.startswith(old_prefix):
630             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
631     return normalized_path
632
633
634 # hashlib is only available in python 2.5 or higher, but the 'sha' module
635 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
636 # python 2.4 and above without any stupid warnings, so let's try using hashlib
637 # first, and downgrade if it fails.
638 try:
639     import hashlib
640 except ImportError:
641     import sha
642     Sha1 = sha.sha
643 else:
644     Sha1 = hashlib.sha1
645
646
647 def version_date():
648     """Format bup's version date string for output."""
649     return _version.DATE.split(' ')[0]
650
651
652 def version_commit():
653     """Get the commit hash of bup's current version."""
654     return _version.COMMIT
655
656
657 def version_tag():
658     """Format bup's version tag (the official version number).
659
660     When generated from a commit other than one pointed to with a tag, the
661     returned string will be "unknown-" followed by the first seven positions of
662     the commit hash.
663     """
664     names = _version.NAMES.strip()
665     assert(names[0] == '(')
666     assert(names[-1] == ')')
667     names = names[1:-1]
668     l = [n.strip() for n in names.split(',')]
669     for n in l:
670         if n.startswith('tag: bup-'):
671             return n[9:]
672     return 'unknown-%s' % _version.COMMIT[:7]