]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Merge branch 'bl/bloomcheck' into ap/cleanups
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator, time
5 from bup import _version, _helpers
6
7 # This function should really be in helpers, not in bup.options.  But we
8 # want options.py to be standalone so people can include it in other projects.
9 from bup.options import _tty_width
10 tty_width = _tty_width
11
12
13 def atoi(s):
14     """Convert the string 's' to an integer. Return 0 if s is not a number."""
15     try:
16         return int(s or '0')
17     except ValueError:
18         return 0
19
20
21 def atof(s):
22     """Convert the string 's' to a float. Return 0 if s is not a number."""
23     try:
24         return float(s or '0')
25     except ValueError:
26         return 0
27
28
29 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
30
31
32 # Write (blockingly) to sockets that may or may not be in blocking mode.
33 # We need this because our stderr is sometimes eaten by subprocesses
34 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
35 # leading to race conditions.  Ick.  We'll do it the hard way.
36 def _hard_write(fd, buf):
37     while buf:
38         (r,w,x) = select.select([], [fd], [], None)
39         if not w:
40             raise IOError('select(fd) returned without being writable')
41         try:
42             sz = os.write(fd, buf)
43         except OSError, e:
44             if e.errno != errno.EAGAIN:
45                 raise
46         assert(sz >= 0)
47         buf = buf[sz:]
48
49
50 _last_prog = 0
51 def log(s):
52     """Print a log message to stderr."""
53     global _last_prog
54     sys.stdout.flush()
55     _hard_write(sys.stderr.fileno(), s)
56     _last_prog = 0
57
58
59 def debug1(s):
60     if buglvl >= 1:
61         log(s)
62
63
64 def debug2(s):
65     if buglvl >= 2:
66         log(s)
67
68
69 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
70 def progress(s):
71     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
72     if istty:
73         log(s)
74
75
76 def qprogress(s):
77     """Calls progress() only if we haven't printed progress in a while.
78     
79     This avoids overloading the stderr buffer with excess junk."""
80     global _last_prog
81     now = time.time()
82     if now - _last_prog > 0.1:
83         progress(s)
84         _last_prog = now
85
86
87 def mkdirp(d, mode=None):
88     """Recursively create directories on path 'd'.
89
90     Unlike os.makedirs(), it doesn't raise an exception if the last element of
91     the path already exists.
92     """
93     try:
94         if mode:
95             os.makedirs(d, mode)
96         else:
97             os.makedirs(d)
98     except OSError, e:
99         if e.errno == errno.EEXIST:
100             pass
101         else:
102             raise
103
104
105 def next(it):
106     """Get the next item from an iterator, None if we reached the end."""
107     try:
108         return it.next()
109     except StopIteration:
110         return None
111
112
113 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
114     if key:
115         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
116     else:
117         samekey = operator.eq
118     count = 0
119     total = sum(len(it) for it in iters)
120     iters = (iter(it) for it in iters)
121     heap = ((next(it),it) for it in iters)
122     heap = [(e,it) for e,it in heap if e]
123
124     heapq.heapify(heap)
125     pe = None
126     while heap:
127         if not count % pfreq:
128             pfunc(count, total)
129         e, it = heap[0]
130         if not samekey(e, pe):
131             pe = e
132             yield e
133         count += 1
134         try:
135             e = it.next() # Don't use next() function, it's too expensive
136         except StopIteration:
137             heapq.heappop(heap) # remove current
138         else:
139             heapq.heapreplace(heap, (e, it)) # shift current to new location
140     pfinal(count, total)
141
142
143 def unlink(f):
144     """Delete a file at path 'f' if it currently exists.
145
146     Unlike os.unlink(), does not throw an exception if the file didn't already
147     exist.
148     """
149     try:
150         os.unlink(f)
151     except OSError, e:
152         if e.errno == errno.ENOENT:
153             pass  # it doesn't exist, that's what you asked for
154
155
156 def readpipe(argv):
157     """Run a subprocess and return its output."""
158     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
159     r = p.stdout.read()
160     p.wait()
161     return r
162
163
164 def realpath(p):
165     """Get the absolute path of a file.
166
167     Behaves like os.path.realpath, but doesn't follow a symlink for the last
168     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
169     will follow symlinks in p's directory)
170     """
171     try:
172         st = os.lstat(p)
173     except OSError:
174         st = None
175     if st and stat.S_ISLNK(st.st_mode):
176         (dir, name) = os.path.split(p)
177         dir = os.path.realpath(dir)
178         out = os.path.join(dir, name)
179     else:
180         out = os.path.realpath(p)
181     #log('realpathing:%r,%r\n' % (p, out))
182     return out
183
184
185 _username = None
186 def username():
187     """Get the user's login name."""
188     global _username
189     if not _username:
190         uid = os.getuid()
191         try:
192             _username = pwd.getpwuid(uid)[0]
193         except KeyError:
194             _username = 'user%d' % uid
195     return _username
196
197
198 _userfullname = None
199 def userfullname():
200     """Get the user's full name."""
201     global _userfullname
202     if not _userfullname:
203         uid = os.getuid()
204         try:
205             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
206         except KeyError:
207             _userfullname = 'user%d' % uid
208     return _userfullname
209
210
211 _hostname = None
212 def hostname():
213     """Get the FQDN of this machine."""
214     global _hostname
215     if not _hostname:
216         _hostname = socket.getfqdn()
217     return _hostname
218
219
220 _resource_path = None
221 def resource_path(subdir=''):
222     global _resource_path
223     if not _resource_path:
224         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
225     return os.path.join(_resource_path, subdir)
226
227
228 class NotOk(Exception):
229     pass
230
231
232 class BaseConn:
233     def __init__(self, outp):
234         self.outp = outp
235
236     def close(self):
237         while self._read(65536): pass
238
239     def read(self, size):
240         """Read 'size' bytes from input stream."""
241         self.outp.flush()
242         return self._read(size)
243
244     def readline(self):
245         """Read from input stream until a newline is found."""
246         self.outp.flush()
247         return self._readline()
248
249     def write(self, data):
250         """Write 'data' to output stream."""
251         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
252         self.outp.write(data)
253
254     def has_input(self):
255         """Return true if input stream is readable."""
256         raise NotImplemented("Subclasses must implement has_input")
257
258     def ok(self):
259         """Indicate end of output from last sent command."""
260         self.write('\nok\n')
261
262     def error(self, s):
263         """Indicate server error to the client."""
264         s = re.sub(r'\s+', ' ', str(s))
265         self.write('\nerror %s\n' % s)
266
267     def _check_ok(self, onempty):
268         self.outp.flush()
269         rl = ''
270         for rl in linereader(self):
271             #log('%d got line: %r\n' % (os.getpid(), rl))
272             if not rl:  # empty line
273                 continue
274             elif rl == 'ok':
275                 return None
276             elif rl.startswith('error '):
277                 #log('client: error: %s\n' % rl[6:])
278                 return NotOk(rl[6:])
279             else:
280                 onempty(rl)
281         raise Exception('server exited unexpectedly; see errors above')
282
283     def drain_and_check_ok(self):
284         """Remove all data for the current command from input stream."""
285         def onempty(rl):
286             pass
287         return self._check_ok(onempty)
288
289     def check_ok(self):
290         """Verify that server action completed successfully."""
291         def onempty(rl):
292             raise Exception('expected "ok", got %r' % rl)
293         return self._check_ok(onempty)
294
295
296 class Conn(BaseConn):
297     def __init__(self, inp, outp):
298         BaseConn.__init__(self, outp)
299         self.inp = inp
300
301     def _read(self, size):
302         return self.inp.read(size)
303
304     def _readline(self):
305         return self.inp.readline()
306
307     def has_input(self):
308         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
309         if rl:
310             assert(rl[0] == self.inp.fileno())
311             return True
312         else:
313             return None
314
315
316 def checked_reader(fd, n):
317     while n > 0:
318         rl, _, _ = select.select([fd], [], [])
319         assert(rl[0] == fd)
320         buf = os.read(fd, n)
321         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
322         yield buf
323         n -= len(buf)
324
325
326 MAX_PACKET = 128 * 1024
327 def mux(p, outfd, outr, errr):
328     try:
329         fds = [outr, errr]
330         while p.poll() is None:
331             rl, _, _ = select.select(fds, [], [])
332             for fd in rl:
333                 if fd == outr:
334                     buf = os.read(outr, MAX_PACKET)
335                     if not buf: break
336                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
337                 elif fd == errr:
338                     buf = os.read(errr, 1024)
339                     if not buf: break
340                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
341     finally:
342         os.write(outfd, struct.pack('!IB', 0, 3))
343
344
345 class DemuxConn(BaseConn):
346     """A helper class for bup's client-server protocol."""
347     def __init__(self, infd, outp):
348         BaseConn.__init__(self, outp)
349         # Anything that comes through before the sync string was not
350         # multiplexed and can be assumed to be debug/log before mux init.
351         tail = ''
352         while tail != 'BUPMUX':
353             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
354             if not b:
355                 raise IOError('demux: unexpected EOF during initialization')
356             tail += b
357             sys.stderr.write(tail[:-6])  # pre-mux log messages
358             tail = tail[-6:]
359         self.infd = infd
360         self.reader = None
361         self.buf = None
362         self.closed = False
363
364     def write(self, data):
365         self._load_buf(0)
366         BaseConn.write(self, data)
367
368     def _next_packet(self, timeout):
369         if self.closed: return False
370         rl, wl, xl = select.select([self.infd], [], [], timeout)
371         if not rl: return False
372         assert(rl[0] == self.infd)
373         ns = ''.join(checked_reader(self.infd, 5))
374         n, fdw = struct.unpack('!IB', ns)
375         assert(n <= MAX_PACKET)
376         if fdw == 1:
377             self.reader = checked_reader(self.infd, n)
378         elif fdw == 2:
379             for buf in checked_reader(self.infd, n):
380                 sys.stderr.write(buf)
381         elif fdw == 3:
382             self.closed = True
383             debug2("DemuxConn: marked closed\n")
384         return True
385
386     def _load_buf(self, timeout):
387         if self.buf is not None:
388             return True
389         while not self.closed:
390             while not self.reader:
391                 if not self._next_packet(timeout):
392                     return False
393             try:
394                 self.buf = self.reader.next()
395                 return True
396             except StopIteration:
397                 self.reader = None
398         return False
399
400     def _read_parts(self, ix_fn):
401         while self._load_buf(None):
402             assert(self.buf is not None)
403             i = ix_fn(self.buf)
404             if i is None or i == len(self.buf):
405                 yv = self.buf
406                 self.buf = None
407             else:
408                 yv = self.buf[:i]
409                 self.buf = self.buf[i:]
410             yield yv
411             if i is not None:
412                 break
413
414     def _readline(self):
415         def find_eol(buf):
416             try:
417                 return buf.index('\n')+1
418             except ValueError:
419                 return None
420         return ''.join(self._read_parts(find_eol))
421
422     def _read(self, size):
423         csize = [size]
424         def until_size(buf): # Closes on csize
425             if len(buf) < csize[0]:
426                 csize[0] -= len(buf)
427                 return None
428             else:
429                 return csize[0]
430         return ''.join(self._read_parts(until_size))
431
432     def has_input(self):
433         return self._load_buf(0)
434
435
436 def linereader(f):
437     """Generate a list of input lines from 'f' without terminating newlines."""
438     while 1:
439         line = f.readline()
440         if not line:
441             break
442         yield line[:-1]
443
444
445 def chunkyreader(f, count = None):
446     """Generate a list of chunks of data read from 'f'.
447
448     If count is None, read until EOF is reached.
449
450     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
451     reached while reading, raise IOError.
452     """
453     if count != None:
454         while count > 0:
455             b = f.read(min(count, 65536))
456             if not b:
457                 raise IOError('EOF with %d bytes remaining' % count)
458             yield b
459             count -= len(b)
460     else:
461         while 1:
462             b = f.read(65536)
463             if not b: break
464             yield b
465
466
467 def slashappend(s):
468     """Append "/" to 's' if it doesn't aleady end in "/"."""
469     if s and not s.endswith('/'):
470         return s + '/'
471     else:
472         return s
473
474
475 def _mmap_do(f, sz, flags, prot, close):
476     if not sz:
477         st = os.fstat(f.fileno())
478         sz = st.st_size
479     if not sz:
480         # trying to open a zero-length map gives an error, but an empty
481         # string has all the same behaviour of a zero-length map, ie. it has
482         # no elements :)
483         return ''
484     map = mmap.mmap(f.fileno(), sz, flags, prot)
485     if close:
486         f.close()  # map will persist beyond file close
487     return map
488
489
490 def mmap_read(f, sz = 0, close=True):
491     """Create a read-only memory mapped region on file 'f'.
492     If sz is 0, the region will cover the entire file.
493     """
494     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
495
496
497 def mmap_readwrite(f, sz = 0, close=True):
498     """Create a read-write memory mapped region on file 'f'.
499     If sz is 0, the region will cover the entire file.
500     """
501     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
502                     close)
503
504
505 def mmap_readwrite_private(f, sz = 0, close=True):
506     """Create a read-write memory mapped region on file 'f'.
507     If sz is 0, the region will cover the entire file.
508     The map is private, which means the changes are never flushed back to the
509     file.
510     """
511     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
512                     close)
513
514
515 def parse_num(s):
516     """Parse data size information into a float number.
517
518     Here are some examples of conversions:
519         199.2k means 203981 bytes
520         1GB means 1073741824 bytes
521         2.1 tb means 2199023255552 bytes
522     """
523     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
524     if not g:
525         raise ValueError("can't parse %r as a number" % s)
526     (val, unit) = g.groups()
527     num = float(val)
528     unit = unit.lower()
529     if unit in ['t', 'tb']:
530         mult = 1024*1024*1024*1024
531     elif unit in ['g', 'gb']:
532         mult = 1024*1024*1024
533     elif unit in ['m', 'mb']:
534         mult = 1024*1024
535     elif unit in ['k', 'kb']:
536         mult = 1024
537     elif unit in ['', 'b']:
538         mult = 1
539     else:
540         raise ValueError("invalid unit %r in number %r" % (unit, s))
541     return int(num*mult)
542
543
544 def count(l):
545     """Count the number of elements in an iterator. (consumes the iterator)"""
546     return reduce(lambda x,y: x+1, l)
547
548
549 saved_errors = []
550 def add_error(e):
551     """Append an error message to the list of saved errors.
552
553     Once processing is able to stop and output the errors, the saved errors are
554     accessible in the module variable helpers.saved_errors.
555     """
556     saved_errors.append(e)
557     log('%-70s\n' % e)
558
559
560 def handle_ctrl_c():
561     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
562
563     The new exception handler will make sure that bup will exit without an ugly
564     stacktrace when Ctrl-C is hit.
565     """
566     oldhook = sys.excepthook
567     def newhook(exctype, value, traceback):
568         if exctype == KeyboardInterrupt:
569             log('Interrupted.\n')
570         else:
571             return oldhook(exctype, value, traceback)
572     sys.excepthook = newhook
573
574
575 def columnate(l, prefix):
576     """Format elements of 'l' in columns with 'prefix' leading each line.
577
578     The number of columns is determined automatically based on the string
579     lengths.
580     """
581     if not l:
582         return ""
583     l = l[:]
584     clen = max(len(s) for s in l)
585     ncols = (tty_width() - len(prefix)) / (clen + 2)
586     if ncols <= 1:
587         ncols = 1
588         clen = 0
589     cols = []
590     while len(l) % ncols:
591         l.append('')
592     rows = len(l)/ncols
593     for s in range(0, len(l), rows):
594         cols.append(l[s:s+rows])
595     out = ''
596     for row in zip(*cols):
597         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
598     return out
599
600
601 def parse_date_or_fatal(str, fatal):
602     """Parses the given date or calls Option.fatal().
603     For now we expect a string that contains a float."""
604     try:
605         date = atof(str)
606     except ValueError, e:
607         raise fatal('invalid date format (should be a float): %r' % e)
608     else:
609         return date
610
611
612 def strip_path(prefix, path):
613     """Strips a given prefix from a path.
614
615     First both paths are normalized.
616
617     Raises an Exception if no prefix is given.
618     """
619     if prefix == None:
620         raise Exception('no path given')
621
622     normalized_prefix = os.path.realpath(prefix)
623     debug2("normalized_prefix: %s\n" % normalized_prefix)
624     normalized_path = os.path.realpath(path)
625     debug2("normalized_path: %s\n" % normalized_path)
626     if normalized_path.startswith(normalized_prefix):
627         return normalized_path[len(normalized_prefix):]
628     else:
629         return path
630
631
632 def strip_base_path(path, base_paths):
633     """Strips the base path from a given path.
634
635
636     Determines the base path for the given string and then strips it
637     using strip_path().
638     Iterates over all base_paths from long to short, to prevent that
639     a too short base_path is removed.
640     """
641     normalized_path = os.path.realpath(path)
642     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
643     for bp in sorted_base_paths:
644         if normalized_path.startswith(os.path.realpath(bp)):
645             return strip_path(bp, normalized_path)
646     return path
647
648
649 def graft_path(graft_points, path):
650     normalized_path = os.path.realpath(path)
651     for graft_point in graft_points:
652         old_prefix, new_prefix = graft_point
653         if normalized_path.startswith(old_prefix):
654             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
655     return normalized_path
656
657
658 # hashlib is only available in python 2.5 or higher, but the 'sha' module
659 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
660 # python 2.4 and above without any stupid warnings, so let's try using hashlib
661 # first, and downgrade if it fails.
662 try:
663     import hashlib
664 except ImportError:
665     import sha
666     Sha1 = sha.sha
667 else:
668     Sha1 = hashlib.sha1
669
670
671 def version_date():
672     """Format bup's version date string for output."""
673     return _version.DATE.split(' ')[0]
674
675
676 def version_commit():
677     """Get the commit hash of bup's current version."""
678     return _version.COMMIT
679
680
681 def version_tag():
682     """Format bup's version tag (the official version number).
683
684     When generated from a commit other than one pointed to with a tag, the
685     returned string will be "unknown-" followed by the first seven positions of
686     the commit hash.
687     """
688     names = _version.NAMES.strip()
689     assert(names[0] == '(')
690     assert(names[-1] == ')')
691     names = names[1:-1]
692     l = [n.strip() for n in names.split(',')]
693     for n in l:
694         if n.startswith('tag: bup-'):
695             return n[9:]
696     return 'unknown-%s' % _version.COMMIT[:7]