]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
bloom: avoid kernel disk flushes when we dirty a lot of pages.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator
5 from bup import _version
6
7 # This function should really be in helpers, not in bup.options.  But we
8 # want options.py to be standalone so people can include it in other projects.
9 from bup.options import _tty_width
10 tty_width = _tty_width
11
12
13 def atoi(s):
14     """Convert the string 's' to an integer. Return 0 if s is not a number."""
15     try:
16         return int(s or '0')
17     except ValueError:
18         return 0
19
20
21 def atof(s):
22     """Convert the string 's' to a float. Return 0 if s is not a number."""
23     try:
24         return float(s or '0')
25     except ValueError:
26         return 0
27
28
29 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
30
31
32 # Write (blockingly) to sockets that may or may not be in blocking mode.
33 # We need this because our stderr is sometimes eaten by subprocesses
34 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
35 # leading to race conditions.  Ick.  We'll do it the hard way.
36 def _hard_write(fd, buf):
37     while buf:
38         (r,w,x) = select.select([], [fd], [], None)
39         if not w:
40             raise IOError('select(fd) returned without being writable')
41         try:
42             sz = os.write(fd, buf)
43         except OSError, e:
44             if e.errno != errno.EAGAIN:
45                 raise
46         assert(sz >= 0)
47         buf = buf[sz:]
48
49 def log(s):
50     """Print a log message to stderr."""
51     sys.stdout.flush()
52     _hard_write(sys.stderr.fileno(), s)
53
54
55 def debug1(s):
56     if buglvl >= 1:
57         log(s)
58
59
60 def debug2(s):
61     if buglvl >= 2:
62         log(s)
63
64
65 def mkdirp(d, mode=None):
66     """Recursively create directories on path 'd'.
67
68     Unlike os.makedirs(), it doesn't raise an exception if the last element of
69     the path already exists.
70     """
71     try:
72         if mode:
73             os.makedirs(d, mode)
74         else:
75             os.makedirs(d)
76     except OSError, e:
77         if e.errno == errno.EEXIST:
78             pass
79         else:
80             raise
81
82
83 def next(it):
84     """Get the next item from an iterator, None if we reached the end."""
85     try:
86         return it.next()
87     except StopIteration:
88         return None
89
90
91 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
92     if key:
93         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
94     else:
95         samekey = operator.eq
96     count = 0
97     total = sum(len(it) for it in iters)
98     iters = (iter(it) for it in iters)
99     heap = ((next(it),it) for it in iters)
100     heap = [(e,it) for e,it in heap if e]
101
102     heapq.heapify(heap)
103     pe = None
104     while heap:
105         if not count % pfreq:
106             pfunc(count, total)
107         e, it = heap[0]
108         if not samekey(e, pe):
109             pe = e
110             yield e
111         count += 1
112         try:
113             e = it.next() # Don't use next() function, it's too expensive
114         except StopIteration:
115             heapq.heappop(heap) # remove current
116         else:
117             heapq.heapreplace(heap, (e, it)) # shift current to new location
118     pfinal(count, total)
119
120
121 def unlink(f):
122     """Delete a file at path 'f' if it currently exists.
123
124     Unlike os.unlink(), does not throw an exception if the file didn't already
125     exist.
126     """
127     try:
128         os.unlink(f)
129     except OSError, e:
130         if e.errno == errno.ENOENT:
131             pass  # it doesn't exist, that's what you asked for
132
133
134 def readpipe(argv):
135     """Run a subprocess and return its output."""
136     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
137     r = p.stdout.read()
138     p.wait()
139     return r
140
141
142 def realpath(p):
143     """Get the absolute path of a file.
144
145     Behaves like os.path.realpath, but doesn't follow a symlink for the last
146     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
147     will follow symlinks in p's directory)
148     """
149     try:
150         st = os.lstat(p)
151     except OSError:
152         st = None
153     if st and stat.S_ISLNK(st.st_mode):
154         (dir, name) = os.path.split(p)
155         dir = os.path.realpath(dir)
156         out = os.path.join(dir, name)
157     else:
158         out = os.path.realpath(p)
159     #log('realpathing:%r,%r\n' % (p, out))
160     return out
161
162
163 _username = None
164 def username():
165     """Get the user's login name."""
166     global _username
167     if not _username:
168         uid = os.getuid()
169         try:
170             _username = pwd.getpwuid(uid)[0]
171         except KeyError:
172             _username = 'user%d' % uid
173     return _username
174
175
176 _userfullname = None
177 def userfullname():
178     """Get the user's full name."""
179     global _userfullname
180     if not _userfullname:
181         uid = os.getuid()
182         try:
183             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
184         except KeyError:
185             _userfullname = 'user%d' % uid
186     return _userfullname
187
188
189 _hostname = None
190 def hostname():
191     """Get the FQDN of this machine."""
192     global _hostname
193     if not _hostname:
194         _hostname = socket.getfqdn()
195     return _hostname
196
197
198 _resource_path = None
199 def resource_path(subdir=''):
200     global _resource_path
201     if not _resource_path:
202         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
203     return os.path.join(_resource_path, subdir)
204
205
206 class NotOk(Exception):
207     pass
208
209
210 class BaseConn:
211     def __init__(self, outp):
212         self.outp = outp
213
214     def close(self):
215         while self._read(65536): pass
216
217     def read(self, size):
218         """Read 'size' bytes from input stream."""
219         self.outp.flush()
220         return self._read(size)
221
222     def readline(self):
223         """Read from input stream until a newline is found."""
224         self.outp.flush()
225         return self._readline()
226
227     def write(self, data):
228         """Write 'data' to output stream."""
229         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
230         self.outp.write(data)
231
232     def has_input(self):
233         """Return true if input stream is readable."""
234         raise NotImplemented("Subclasses must implement has_input")
235
236     def ok(self):
237         """Indicate end of output from last sent command."""
238         self.write('\nok\n')
239
240     def error(self, s):
241         """Indicate server error to the client."""
242         s = re.sub(r'\s+', ' ', str(s))
243         self.write('\nerror %s\n' % s)
244
245     def _check_ok(self, onempty):
246         self.outp.flush()
247         rl = ''
248         for rl in linereader(self):
249             #log('%d got line: %r\n' % (os.getpid(), rl))
250             if not rl:  # empty line
251                 continue
252             elif rl == 'ok':
253                 return None
254             elif rl.startswith('error '):
255                 #log('client: error: %s\n' % rl[6:])
256                 return NotOk(rl[6:])
257             else:
258                 onempty(rl)
259         raise Exception('server exited unexpectedly; see errors above')
260
261     def drain_and_check_ok(self):
262         """Remove all data for the current command from input stream."""
263         def onempty(rl):
264             pass
265         return self._check_ok(onempty)
266
267     def check_ok(self):
268         """Verify that server action completed successfully."""
269         def onempty(rl):
270             raise Exception('expected "ok", got %r' % rl)
271         return self._check_ok(onempty)
272
273
274 class Conn(BaseConn):
275     def __init__(self, inp, outp):
276         BaseConn.__init__(self, outp)
277         self.inp = inp
278
279     def _read(self, size):
280         return self.inp.read(size)
281
282     def _readline(self):
283         return self.inp.readline()
284
285     def has_input(self):
286         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
287         if rl:
288             assert(rl[0] == self.inp.fileno())
289             return True
290         else:
291             return None
292
293
294 def checked_reader(fd, n):
295     while n > 0:
296         rl, _, _ = select.select([fd], [], [])
297         assert(rl[0] == fd)
298         buf = os.read(fd, n)
299         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
300         yield buf
301         n -= len(buf)
302
303
304 MAX_PACKET = 128 * 1024
305 def mux(p, outfd, outr, errr):
306     try:
307         fds = [outr, errr]
308         while p.poll() is None:
309             rl, _, _ = select.select(fds, [], [])
310             for fd in rl:
311                 if fd == outr:
312                     buf = os.read(outr, MAX_PACKET)
313                     if not buf: break
314                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
315                 elif fd == errr:
316                     buf = os.read(errr, 1024)
317                     if not buf: break
318                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
319     finally:
320         os.write(outfd, struct.pack('!IB', 0, 3))
321
322
323 class DemuxConn(BaseConn):
324     """A helper class for bup's client-server protocol."""
325     def __init__(self, infd, outp):
326         BaseConn.__init__(self, outp)
327         # Anything that comes through before the sync string was not
328         # multiplexed and can be assumed to be debug/log before mux init.
329         tail = ''
330         while tail != 'BUPMUX':
331             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
332             if not b:
333                 raise IOError('demux: unexpected EOF during initialization')
334             tail += b
335             sys.stderr.write(tail[:-6])  # pre-mux log messages
336             tail = tail[-6:]
337         self.infd = infd
338         self.reader = None
339         self.buf = None
340         self.closed = False
341
342     def write(self, data):
343         self._load_buf(0)
344         BaseConn.write(self, data)
345
346     def _next_packet(self, timeout):
347         if self.closed: return False
348         rl, wl, xl = select.select([self.infd], [], [], timeout)
349         if not rl: return False
350         assert(rl[0] == self.infd)
351         ns = ''.join(checked_reader(self.infd, 5))
352         n, fdw = struct.unpack('!IB', ns)
353         assert(n <= MAX_PACKET)
354         if fdw == 1:
355             self.reader = checked_reader(self.infd, n)
356         elif fdw == 2:
357             for buf in checked_reader(self.infd, n):
358                 sys.stderr.write(buf)
359         elif fdw == 3:
360             self.closed = True
361             debug2("DemuxConn: marked closed\n")
362         return True
363
364     def _load_buf(self, timeout):
365         if self.buf is not None:
366             return True
367         while not self.closed:
368             while not self.reader:
369                 if not self._next_packet(timeout):
370                     return False
371             try:
372                 self.buf = self.reader.next()
373                 return True
374             except StopIteration:
375                 self.reader = None
376         return False
377
378     def _read_parts(self, ix_fn):
379         while self._load_buf(None):
380             assert(self.buf is not None)
381             i = ix_fn(self.buf)
382             if i is None or i == len(self.buf):
383                 yv = self.buf
384                 self.buf = None
385             else:
386                 yv = self.buf[:i]
387                 self.buf = self.buf[i:]
388             yield yv
389             if i is not None:
390                 break
391
392     def _readline(self):
393         def find_eol(buf):
394             try:
395                 return buf.index('\n')+1
396             except ValueError:
397                 return None
398         return ''.join(self._read_parts(find_eol))
399
400     def _read(self, size):
401         csize = [size]
402         def until_size(buf): # Closes on csize
403             if len(buf) < csize[0]:
404                 csize[0] -= len(buf)
405                 return None
406             else:
407                 return csize[0]
408         return ''.join(self._read_parts(until_size))
409
410     def has_input(self):
411         return self._load_buf(0)
412
413
414 def linereader(f):
415     """Generate a list of input lines from 'f' without terminating newlines."""
416     while 1:
417         line = f.readline()
418         if not line:
419             break
420         yield line[:-1]
421
422
423 def chunkyreader(f, count = None):
424     """Generate a list of chunks of data read from 'f'.
425
426     If count is None, read until EOF is reached.
427
428     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
429     reached while reading, raise IOError.
430     """
431     if count != None:
432         while count > 0:
433             b = f.read(min(count, 65536))
434             if not b:
435                 raise IOError('EOF with %d bytes remaining' % count)
436             yield b
437             count -= len(b)
438     else:
439         while 1:
440             b = f.read(65536)
441             if not b: break
442             yield b
443
444
445 def slashappend(s):
446     """Append "/" to 's' if it doesn't aleady end in "/"."""
447     if s and not s.endswith('/'):
448         return s + '/'
449     else:
450         return s
451
452
453 def _mmap_do(f, sz, flags, prot, close):
454     if not sz:
455         st = os.fstat(f.fileno())
456         sz = st.st_size
457     if not sz:
458         # trying to open a zero-length map gives an error, but an empty
459         # string has all the same behaviour of a zero-length map, ie. it has
460         # no elements :)
461         return ''
462     map = mmap.mmap(f.fileno(), sz, flags, prot)
463     if close:
464         f.close()  # map will persist beyond file close
465     return map
466
467
468 def mmap_read(f, sz = 0, close=True):
469     """Create a read-only memory mapped region on file 'f'.
470     If sz is 0, the region will cover the entire file.
471     """
472     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
473
474
475 def mmap_readwrite(f, sz = 0, close=True):
476     """Create a read-write memory mapped region on file 'f'.
477     If sz is 0, the region will cover the entire file.
478     """
479     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
480                     close)
481
482
483 def mmap_readwrite_private(f, sz = 0, close=True):
484     """Create a read-write memory mapped region on file 'f'.
485     If sz is 0, the region will cover the entire file.
486     The map is private, which means the changes are never flushed back to the
487     file.
488     """
489     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
490                     close)
491
492
493 def parse_num(s):
494     """Parse data size information into a float number.
495
496     Here are some examples of conversions:
497         199.2k means 203981 bytes
498         1GB means 1073741824 bytes
499         2.1 tb means 2199023255552 bytes
500     """
501     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
502     if not g:
503         raise ValueError("can't parse %r as a number" % s)
504     (val, unit) = g.groups()
505     num = float(val)
506     unit = unit.lower()
507     if unit in ['t', 'tb']:
508         mult = 1024*1024*1024*1024
509     elif unit in ['g', 'gb']:
510         mult = 1024*1024*1024
511     elif unit in ['m', 'mb']:
512         mult = 1024*1024
513     elif unit in ['k', 'kb']:
514         mult = 1024
515     elif unit in ['', 'b']:
516         mult = 1
517     else:
518         raise ValueError("invalid unit %r in number %r" % (unit, s))
519     return int(num*mult)
520
521
522 def count(l):
523     """Count the number of elements in an iterator. (consumes the iterator)"""
524     return reduce(lambda x,y: x+1, l)
525
526
527 saved_errors = []
528 def add_error(e):
529     """Append an error message to the list of saved errors.
530
531     Once processing is able to stop and output the errors, the saved errors are
532     accessible in the module variable helpers.saved_errors.
533     """
534     saved_errors.append(e)
535     log('%-70s\n' % e)
536
537
538 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
539 def progress(s):
540     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
541     if istty:
542         log(s)
543
544
545 def handle_ctrl_c():
546     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
547
548     The new exception handler will make sure that bup will exit without an ugly
549     stacktrace when Ctrl-C is hit.
550     """
551     oldhook = sys.excepthook
552     def newhook(exctype, value, traceback):
553         if exctype == KeyboardInterrupt:
554             log('Interrupted.\n')
555         else:
556             return oldhook(exctype, value, traceback)
557     sys.excepthook = newhook
558
559
560 def columnate(l, prefix):
561     """Format elements of 'l' in columns with 'prefix' leading each line.
562
563     The number of columns is determined automatically based on the string
564     lengths.
565     """
566     if not l:
567         return ""
568     l = l[:]
569     clen = max(len(s) for s in l)
570     ncols = (tty_width() - len(prefix)) / (clen + 2)
571     if ncols <= 1:
572         ncols = 1
573         clen = 0
574     cols = []
575     while len(l) % ncols:
576         l.append('')
577     rows = len(l)/ncols
578     for s in range(0, len(l), rows):
579         cols.append(l[s:s+rows])
580     out = ''
581     for row in zip(*cols):
582         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
583     return out
584
585
586 def parse_date_or_fatal(str, fatal):
587     """Parses the given date or calls Option.fatal().
588     For now we expect a string that contains a float."""
589     try:
590         date = atof(str)
591     except ValueError, e:
592         raise fatal('invalid date format (should be a float): %r' % e)
593     else:
594         return date
595
596
597 def strip_path(prefix, path):
598     """Strips a given prefix from a path.
599
600     First both paths are normalized.
601
602     Raises an Exception if no prefix is given.
603     """
604     if prefix == None:
605         raise Exception('no path given')
606
607     normalized_prefix = os.path.realpath(prefix)
608     debug2("normalized_prefix: %s\n" % normalized_prefix)
609     normalized_path = os.path.realpath(path)
610     debug2("normalized_path: %s\n" % normalized_path)
611     if normalized_path.startswith(normalized_prefix):
612         return normalized_path[len(normalized_prefix):]
613     else:
614         return path
615
616
617 def strip_base_path(path, base_paths):
618     """Strips the base path from a given path.
619
620
621     Determines the base path for the given string and then strips it
622     using strip_path().
623     Iterates over all base_paths from long to short, to prevent that
624     a too short base_path is removed.
625     """
626     normalized_path = os.path.realpath(path)
627     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
628     for bp in sorted_base_paths:
629         if normalized_path.startswith(os.path.realpath(bp)):
630             return strip_path(bp, normalized_path)
631     return path
632
633
634 def graft_path(graft_points, path):
635     normalized_path = os.path.realpath(path)
636     for graft_point in graft_points:
637         old_prefix, new_prefix = graft_point
638         if normalized_path.startswith(old_prefix):
639             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
640     return normalized_path
641
642
643 # hashlib is only available in python 2.5 or higher, but the 'sha' module
644 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
645 # python 2.4 and above without any stupid warnings, so let's try using hashlib
646 # first, and downgrade if it fails.
647 try:
648     import hashlib
649 except ImportError:
650     import sha
651     Sha1 = sha.sha
652 else:
653     Sha1 = hashlib.sha1
654
655
656 def version_date():
657     """Format bup's version date string for output."""
658     return _version.DATE.split(' ')[0]
659
660
661 def version_commit():
662     """Get the commit hash of bup's current version."""
663     return _version.COMMIT
664
665
666 def version_tag():
667     """Format bup's version tag (the official version number).
668
669     When generated from a commit other than one pointed to with a tag, the
670     returned string will be "unknown-" followed by the first seven positions of
671     the commit hash.
672     """
673     names = _version.NAMES.strip()
674     assert(names[0] == '(')
675     assert(names[-1] == ')')
676     names = names[1:-1]
677     l = [n.strip() for n in names.split(',')]
678     for n in l:
679         if n.startswith('tag: bup-'):
680             return n[9:]
681     return 'unknown-%s' % _version.COMMIT[:7]