]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
helpers.py: always use two blank lines between functions/classes.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator
5 from bup import _version
6
7 # This function should really be in helpers, not in bup.options.  But we
8 # want options.py to be standalone so people can include it in other projects.
9 from bup.options import _tty_width
10 tty_width = _tty_width
11
12
13 def atoi(s):
14     """Convert the string 's' to an integer. Return 0 if s is not a number."""
15     try:
16         return int(s or '0')
17     except ValueError:
18         return 0
19
20
21 def atof(s):
22     """Convert the string 's' to a float. Return 0 if s is not a number."""
23     try:
24         return float(s or '0')
25     except ValueError:
26         return 0
27
28
29 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
30
31
32 # Write (blockingly) to sockets that may or may not be in blocking mode.
33 # We need this because our stderr is sometimes eaten by subprocesses
34 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
35 # leading to race conditions.  Ick.  We'll do it the hard way.
36 def _hard_write(fd, buf):
37     while buf:
38         (r,w,x) = select.select([], [fd], [], None)
39         if not w:
40             raise IOError('select(fd) returned without being writable')
41         try:
42             sz = os.write(fd, buf)
43         except OSError, e:
44             if e.errno != errno.EAGAIN:
45                 raise
46         assert(sz >= 0)
47         buf = buf[sz:]
48
49 def log(s):
50     """Print a log message to stderr."""
51     sys.stdout.flush()
52     _hard_write(sys.stderr.fileno(), s)
53
54
55 def debug1(s):
56     if buglvl >= 1:
57         log(s)
58
59
60 def debug2(s):
61     if buglvl >= 2:
62         log(s)
63
64
65 def mkdirp(d, mode=None):
66     """Recursively create directories on path 'd'.
67
68     Unlike os.makedirs(), it doesn't raise an exception if the last element of
69     the path already exists.
70     """
71     try:
72         if mode:
73             os.makedirs(d, mode)
74         else:
75             os.makedirs(d)
76     except OSError, e:
77         if e.errno == errno.EEXIST:
78             pass
79         else:
80             raise
81
82
83 def next(it):
84     """Get the next item from an iterator, None if we reached the end."""
85     try:
86         return it.next()
87     except StopIteration:
88         return None
89
90
91 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
92     if key:
93         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
94     else:
95         samekey = operator.eq
96     count = 0
97     total = sum(len(it) for it in iters)
98     iters = (iter(it) for it in iters)
99     heap = ((next(it),it) for it in iters)
100     heap = [(e,it) for e,it in heap if e]
101
102     heapq.heapify(heap)
103     pe = None
104     while heap:
105         if not count % pfreq:
106             pfunc(count, total)
107         e, it = heap[0]
108         if not samekey(e, pe):
109             pe = e
110             yield e
111         count += 1
112         try:
113             e = it.next() # Don't use next() function, it's too expensive
114         except StopIteration:
115             heapq.heappop(heap) # remove current
116         else:
117             heapq.heapreplace(heap, (e, it)) # shift current to new location
118     pfinal(count, total)
119
120
121 def unlink(f):
122     """Delete a file at path 'f' if it currently exists.
123
124     Unlike os.unlink(), does not throw an exception if the file didn't already
125     exist.
126     """
127     try:
128         os.unlink(f)
129     except OSError, e:
130         if e.errno == errno.ENOENT:
131             pass  # it doesn't exist, that's what you asked for
132
133
134 def readpipe(argv):
135     """Run a subprocess and return its output."""
136     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
137     r = p.stdout.read()
138     p.wait()
139     return r
140
141
142 def realpath(p):
143     """Get the absolute path of a file.
144
145     Behaves like os.path.realpath, but doesn't follow a symlink for the last
146     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
147     will follow symlinks in p's directory)
148     """
149     try:
150         st = os.lstat(p)
151     except OSError:
152         st = None
153     if st and stat.S_ISLNK(st.st_mode):
154         (dir, name) = os.path.split(p)
155         dir = os.path.realpath(dir)
156         out = os.path.join(dir, name)
157     else:
158         out = os.path.realpath(p)
159     #log('realpathing:%r,%r\n' % (p, out))
160     return out
161
162
163 _username = None
164 def username():
165     """Get the user's login name."""
166     global _username
167     if not _username:
168         uid = os.getuid()
169         try:
170             _username = pwd.getpwuid(uid)[0]
171         except KeyError:
172             _username = 'user%d' % uid
173     return _username
174
175
176 _userfullname = None
177 def userfullname():
178     """Get the user's full name."""
179     global _userfullname
180     if not _userfullname:
181         uid = os.getuid()
182         try:
183             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
184         except KeyError:
185             _userfullname = 'user%d' % uid
186     return _userfullname
187
188
189 _hostname = None
190 def hostname():
191     """Get the FQDN of this machine."""
192     global _hostname
193     if not _hostname:
194         _hostname = socket.getfqdn()
195     return _hostname
196
197
198 _resource_path = None
199 def resource_path(subdir=''):
200     global _resource_path
201     if not _resource_path:
202         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
203     return os.path.join(_resource_path, subdir)
204
205
206 class NotOk(Exception):
207     pass
208
209
210 class BaseConn:
211     def __init__(self, outp):
212         self.outp = outp
213
214     def close(self):
215         while self._read(65536): pass
216
217     def read(self, size):
218         """Read 'size' bytes from input stream."""
219         self.outp.flush()
220         return self._read(size)
221
222     def readline(self):
223         """Read from input stream until a newline is found."""
224         self.outp.flush()
225         return self._readline()
226
227     def write(self, data):
228         """Write 'data' to output stream."""
229         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
230         self.outp.write(data)
231
232     def has_input(self):
233         """Return true if input stream is readable."""
234         raise NotImplemented("Subclasses must implement has_input")
235
236     def ok(self):
237         """Indicate end of output from last sent command."""
238         self.write('\nok\n')
239
240     def error(self, s):
241         """Indicate server error to the client."""
242         s = re.sub(r'\s+', ' ', str(s))
243         self.write('\nerror %s\n' % s)
244
245     def _check_ok(self, onempty):
246         self.outp.flush()
247         rl = ''
248         for rl in linereader(self):
249             #log('%d got line: %r\n' % (os.getpid(), rl))
250             if not rl:  # empty line
251                 continue
252             elif rl == 'ok':
253                 return None
254             elif rl.startswith('error '):
255                 #log('client: error: %s\n' % rl[6:])
256                 return NotOk(rl[6:])
257             else:
258                 onempty(rl)
259         raise Exception('server exited unexpectedly; see errors above')
260
261     def drain_and_check_ok(self):
262         """Remove all data for the current command from input stream."""
263         def onempty(rl):
264             pass
265         return self._check_ok(onempty)
266
267     def check_ok(self):
268         """Verify that server action completed successfully."""
269         def onempty(rl):
270             raise Exception('expected "ok", got %r' % rl)
271         return self._check_ok(onempty)
272
273
274 class Conn(BaseConn):
275     def __init__(self, inp, outp):
276         BaseConn.__init__(self, outp)
277         self.inp = inp
278
279     def _read(self, size):
280         return self.inp.read(size)
281
282     def _readline(self):
283         return self.inp.readline()
284
285     def has_input(self):
286         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
287         if rl:
288             assert(rl[0] == self.inp.fileno())
289             return True
290         else:
291             return None
292
293
294 def checked_reader(fd, n):
295     while n > 0:
296         rl, _, _ = select.select([fd], [], [])
297         assert(rl[0] == fd)
298         buf = os.read(fd, n)
299         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
300         yield buf
301         n -= len(buf)
302
303
304 MAX_PACKET = 128 * 1024
305 def mux(p, outfd, outr, errr):
306     try:
307         fds = [outr, errr]
308         while p.poll() is None:
309             rl, _, _ = select.select(fds, [], [])
310             for fd in rl:
311                 if fd == outr:
312                     buf = os.read(outr, MAX_PACKET)
313                     if not buf: break
314                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
315                 elif fd == errr:
316                     buf = os.read(errr, 1024)
317                     if not buf: break
318                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
319     finally:
320         os.write(outfd, struct.pack('!IB', 0, 3))
321
322
323 class DemuxConn(BaseConn):
324     """A helper class for bup's client-server protocol."""
325     def __init__(self, infd, outp):
326         BaseConn.__init__(self, outp)
327         # Anything that comes through before the sync string was not
328         # multiplexed and can be assumed to be debug/log before mux init.
329         tail = ''
330         while tail != 'BUPMUX':
331             tail += os.read(infd, 1024)
332             buf = tail[:-6]
333             tail = tail[-6:]
334             sys.stderr.write(buf)
335         self.infd = infd
336         self.reader = None
337         self.buf = None
338         self.closed = False
339
340     def write(self, data):
341         self._load_buf(0)
342         BaseConn.write(self, data)
343
344     def _next_packet(self, timeout):
345         if self.closed: return False
346         rl, wl, xl = select.select([self.infd], [], [], timeout)
347         if not rl: return False
348         assert(rl[0] == self.infd)
349         ns = ''.join(checked_reader(self.infd, 5))
350         n, fdw = struct.unpack('!IB', ns)
351         assert(n<=MAX_PACKET)
352         if fdw == 1:
353             self.reader = checked_reader(self.infd, n)
354         elif fdw == 2:
355             for buf in checked_reader(self.infd, n):
356                 sys.stderr.write(buf)
357         elif fdw == 3:
358             self.closed = True
359             debug2("DemuxConn: marked closed\n")
360         return True
361
362     def _load_buf(self, timeout):
363         if self.buf is not None:
364             return True
365         while not self.closed:
366             while not self.reader:
367                 if not self._next_packet(timeout):
368                     return False
369             try:
370                 self.buf = self.reader.next()
371                 return True
372             except StopIteration:
373                 self.reader = None
374         return False
375
376     def _read_parts(self, ix_fn):
377         while self._load_buf(None):
378             assert(self.buf is not None)
379             i = ix_fn(self.buf)
380             if i is None or i == len(self.buf):
381                 yv = self.buf
382                 self.buf = None
383             else:
384                 yv = self.buf[:i]
385                 self.buf = self.buf[i:]
386             yield yv
387             if i is not None:
388                 break
389
390     def _readline(self):
391         def find_eol(buf):
392             try:
393                 return buf.index('\n')+1
394             except ValueError:
395                 return None
396         return ''.join(self._read_parts(find_eol))
397
398     def _read(self, size):
399         csize = [size]
400         def until_size(buf): # Closes on csize
401             if len(buf) < csize[0]:
402                 csize[0] -= len(buf)
403                 return None
404             else:
405                 return csize[0]
406         return ''.join(self._read_parts(until_size))
407
408     def has_input(self):
409         return self._load_buf(0)
410
411
412 def linereader(f):
413     """Generate a list of input lines from 'f' without terminating newlines."""
414     while 1:
415         line = f.readline()
416         if not line:
417             break
418         yield line[:-1]
419
420
421 def chunkyreader(f, count = None):
422     """Generate a list of chunks of data read from 'f'.
423
424     If count is None, read until EOF is reached.
425
426     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
427     reached while reading, raise IOError.
428     """
429     if count != None:
430         while count > 0:
431             b = f.read(min(count, 65536))
432             if not b:
433                 raise IOError('EOF with %d bytes remaining' % count)
434             yield b
435             count -= len(b)
436     else:
437         while 1:
438             b = f.read(65536)
439             if not b: break
440             yield b
441
442
443 def slashappend(s):
444     """Append "/" to 's' if it doesn't aleady end in "/"."""
445     if s and not s.endswith('/'):
446         return s + '/'
447     else:
448         return s
449
450
451 def _mmap_do(f, sz, flags, prot):
452     if not sz:
453         st = os.fstat(f.fileno())
454         sz = st.st_size
455     if not sz:
456         # trying to open a zero-length map gives an error, but an empty
457         # string has all the same behaviour of a zero-length map, ie. it has
458         # no elements :)
459         return ''
460     map = mmap.mmap(f.fileno(), sz, flags, prot)
461     f.close()  # map will persist beyond file close
462     return map
463
464
465 def mmap_read(f, sz = 0):
466     """Create a read-only memory mapped region on file 'f'.
467
468     If sz is 0, the region will cover the entire file.
469     """
470     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
471
472
473 def mmap_readwrite(f, sz = 0):
474     """Create a read-write memory mapped region on file 'f'.
475
476     If sz is 0, the region will cover the entire file.
477     """
478     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
479
480
481 def parse_num(s):
482     """Parse data size information into a float number.
483
484     Here are some examples of conversions:
485         199.2k means 203981 bytes
486         1GB means 1073741824 bytes
487         2.1 tb means 2199023255552 bytes
488     """
489     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
490     if not g:
491         raise ValueError("can't parse %r as a number" % s)
492     (val, unit) = g.groups()
493     num = float(val)
494     unit = unit.lower()
495     if unit in ['t', 'tb']:
496         mult = 1024*1024*1024*1024
497     elif unit in ['g', 'gb']:
498         mult = 1024*1024*1024
499     elif unit in ['m', 'mb']:
500         mult = 1024*1024
501     elif unit in ['k', 'kb']:
502         mult = 1024
503     elif unit in ['', 'b']:
504         mult = 1
505     else:
506         raise ValueError("invalid unit %r in number %r" % (unit, s))
507     return int(num*mult)
508
509
510 def count(l):
511     """Count the number of elements in an iterator. (consumes the iterator)"""
512     return reduce(lambda x,y: x+1, l)
513
514
515 saved_errors = []
516 def add_error(e):
517     """Append an error message to the list of saved errors.
518
519     Once processing is able to stop and output the errors, the saved errors are
520     accessible in the module variable helpers.saved_errors.
521     """
522     saved_errors.append(e)
523     log('%-70s\n' % e)
524
525
526 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
527 def progress(s):
528     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
529     if istty:
530         log(s)
531
532
533 def handle_ctrl_c():
534     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
535
536     The new exception handler will make sure that bup will exit without an ugly
537     stacktrace when Ctrl-C is hit.
538     """
539     oldhook = sys.excepthook
540     def newhook(exctype, value, traceback):
541         if exctype == KeyboardInterrupt:
542             log('Interrupted.\n')
543         else:
544             return oldhook(exctype, value, traceback)
545     sys.excepthook = newhook
546
547
548 def columnate(l, prefix):
549     """Format elements of 'l' in columns with 'prefix' leading each line.
550
551     The number of columns is determined automatically based on the string
552     lengths.
553     """
554     if not l:
555         return ""
556     l = l[:]
557     clen = max(len(s) for s in l)
558     ncols = (tty_width() - len(prefix)) / (clen + 2)
559     if ncols <= 1:
560         ncols = 1
561         clen = 0
562     cols = []
563     while len(l) % ncols:
564         l.append('')
565     rows = len(l)/ncols
566     for s in range(0, len(l), rows):
567         cols.append(l[s:s+rows])
568     out = ''
569     for row in zip(*cols):
570         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
571     return out
572
573
574 def parse_date_or_fatal(str, fatal):
575     """Parses the given date or calls Option.fatal().
576     For now we expect a string that contains a float."""
577     try:
578         date = atof(str)
579     except ValueError, e:
580         raise fatal('invalid date format (should be a float): %r' % e)
581     else:
582         return date
583
584
585 def strip_path(prefix, path):
586     """Strips a given prefix from a path.
587
588     First both paths are normalized.
589
590     Raises an Exception if no prefix is given.
591     """
592     if prefix == None:
593         raise Exception('no path given')
594
595     normalized_prefix = os.path.realpath(prefix)
596     debug2("normalized_prefix: %s\n" % normalized_prefix)
597     normalized_path = os.path.realpath(path)
598     debug2("normalized_path: %s\n" % normalized_path)
599     if normalized_path.startswith(normalized_prefix):
600         return normalized_path[len(normalized_prefix):]
601     else:
602         return path
603
604
605 def strip_base_path(path, base_paths):
606     """Strips the base path from a given path.
607
608
609     Determines the base path for the given string and then strips it
610     using strip_path().
611     Iterates over all base_paths from long to short, to prevent that
612     a too short base_path is removed.
613     """
614     normalized_path = os.path.realpath(path)
615     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
616     for bp in sorted_base_paths:
617         if normalized_path.startswith(os.path.realpath(bp)):
618             return strip_path(bp, normalized_path)
619     return path
620
621
622 def graft_path(graft_points, path):
623     normalized_path = os.path.realpath(path)
624     for graft_point in graft_points:
625         old_prefix, new_prefix = graft_point
626         if normalized_path.startswith(old_prefix):
627             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
628     return normalized_path
629
630
631 # hashlib is only available in python 2.5 or higher, but the 'sha' module
632 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
633 # python 2.4 and above without any stupid warnings, so let's try using hashlib
634 # first, and downgrade if it fails.
635 try:
636     import hashlib
637 except ImportError:
638     import sha
639     Sha1 = sha.sha
640 else:
641     Sha1 = hashlib.sha1
642
643
644 def version_date():
645     """Format bup's version date string for output."""
646     return _version.DATE.split(' ')[0]
647
648
649 def version_commit():
650     """Get the commit hash of bup's current version."""
651     return _version.COMMIT
652
653
654 def version_tag():
655     """Format bup's version tag (the official version number).
656
657     When generated from a commit other than one pointed to with a tag, the
658     returned string will be "unknown-" followed by the first seven positions of
659     the commit hash.
660     """
661     names = _version.NAMES.strip()
662     assert(names[0] == '(')
663     assert(names[-1] == ')')
664     names = names[1:-1]
665     l = [n.strip() for n in names.split(',')]
666     for n in l:
667         if n.startswith('tag: bup-'):
668             return n[9:]
669     return 'unknown-%s' % _version.COMMIT[:7]