]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Merge remote branch 'origin/master' into meta
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator, time
5 from bup import _version, _helpers
6 import bup._helpers as _helpers
7
8 # This function should really be in helpers, not in bup.options.  But we
9 # want options.py to be standalone so people can include it in other projects.
10 from bup.options import _tty_width
11 tty_width = _tty_width
12
13
14 def atoi(s):
15     """Convert the string 's' to an integer. Return 0 if s is not a number."""
16     try:
17         return int(s or '0')
18     except ValueError:
19         return 0
20
21
22 def atof(s):
23     """Convert the string 's' to a float. Return 0 if s is not a number."""
24     try:
25         return float(s or '0')
26     except ValueError:
27         return 0
28
29
30 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
31
32
33 # Write (blockingly) to sockets that may or may not be in blocking mode.
34 # We need this because our stderr is sometimes eaten by subprocesses
35 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
36 # leading to race conditions.  Ick.  We'll do it the hard way.
37 def _hard_write(fd, buf):
38     while buf:
39         (r,w,x) = select.select([], [fd], [], None)
40         if not w:
41             raise IOError('select(fd) returned without being writable')
42         try:
43             sz = os.write(fd, buf)
44         except OSError, e:
45             if e.errno != errno.EAGAIN:
46                 raise
47         assert(sz >= 0)
48         buf = buf[sz:]
49
50
51 _last_prog = 0
52 def log(s):
53     """Print a log message to stderr."""
54     global _last_prog
55     sys.stdout.flush()
56     _hard_write(sys.stderr.fileno(), s)
57     _last_prog = 0
58
59
60 def debug1(s):
61     if buglvl >= 1:
62         log(s)
63
64
65 def debug2(s):
66     if buglvl >= 2:
67         log(s)
68
69
70 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
71 def progress(s):
72     """Calls log() if stderr is a TTY.  Does nothing otherwise."""
73     if istty:
74         log(s)
75
76
77 def qprogress(s):
78     """Calls progress() only if we haven't printed progress in a while.
79     
80     This avoids overloading the stderr buffer with excess junk."""
81     global _last_prog
82     now = time.time()
83     if now - _last_prog > 0.1:
84         progress(s)
85         _last_prog = now
86
87
88 def mkdirp(d, mode=None):
89     """Recursively create directories on path 'd'.
90
91     Unlike os.makedirs(), it doesn't raise an exception if the last element of
92     the path already exists.
93     """
94     try:
95         if mode:
96             os.makedirs(d, mode)
97         else:
98             os.makedirs(d)
99     except OSError, e:
100         if e.errno == errno.EEXIST:
101             pass
102         else:
103             raise
104
105
106 def next(it):
107     """Get the next item from an iterator, None if we reached the end."""
108     try:
109         return it.next()
110     except StopIteration:
111         return None
112
113
114 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
115     if key:
116         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
117     else:
118         samekey = operator.eq
119     count = 0
120     total = sum(len(it) for it in iters)
121     iters = (iter(it) for it in iters)
122     heap = ((next(it),it) for it in iters)
123     heap = [(e,it) for e,it in heap if e]
124
125     heapq.heapify(heap)
126     pe = None
127     while heap:
128         if not count % pfreq:
129             pfunc(count, total)
130         e, it = heap[0]
131         if not samekey(e, pe):
132             pe = e
133             yield e
134         count += 1
135         try:
136             e = it.next() # Don't use next() function, it's too expensive
137         except StopIteration:
138             heapq.heappop(heap) # remove current
139         else:
140             heapq.heapreplace(heap, (e, it)) # shift current to new location
141     pfinal(count, total)
142
143
144 def unlink(f):
145     """Delete a file at path 'f' if it currently exists.
146
147     Unlike os.unlink(), does not throw an exception if the file didn't already
148     exist.
149     """
150     try:
151         os.unlink(f)
152     except OSError, e:
153         if e.errno == errno.ENOENT:
154             pass  # it doesn't exist, that's what you asked for
155
156
157 def readpipe(argv):
158     """Run a subprocess and return its output."""
159     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
160     r = p.stdout.read()
161     p.wait()
162     return r
163
164
165 def realpath(p):
166     """Get the absolute path of a file.
167
168     Behaves like os.path.realpath, but doesn't follow a symlink for the last
169     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
170     will follow symlinks in p's directory)
171     """
172     try:
173         st = os.lstat(p)
174     except OSError:
175         st = None
176     if st and stat.S_ISLNK(st.st_mode):
177         (dir, name) = os.path.split(p)
178         dir = os.path.realpath(dir)
179         out = os.path.join(dir, name)
180     else:
181         out = os.path.realpath(p)
182     #log('realpathing:%r,%r\n' % (p, out))
183     return out
184
185
186 def detect_fakeroot():
187     "Return True if we appear to be running under fakeroot."
188     return os.getenv("FAKEROOTKEY") != None
189
190
191 _username = None
192 def username():
193     """Get the user's login name."""
194     global _username
195     if not _username:
196         uid = os.getuid()
197         try:
198             _username = pwd.getpwuid(uid)[0]
199         except KeyError:
200             _username = 'user%d' % uid
201     return _username
202
203
204 _userfullname = None
205 def userfullname():
206     """Get the user's full name."""
207     global _userfullname
208     if not _userfullname:
209         uid = os.getuid()
210         try:
211             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
212         except KeyError:
213             _userfullname = 'user%d' % uid
214     return _userfullname
215
216
217 _hostname = None
218 def hostname():
219     """Get the FQDN of this machine."""
220     global _hostname
221     if not _hostname:
222         _hostname = socket.getfqdn()
223     return _hostname
224
225
226 _resource_path = None
227 def resource_path(subdir=''):
228     global _resource_path
229     if not _resource_path:
230         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
231     return os.path.join(_resource_path, subdir)
232
233
234 class NotOk(Exception):
235     pass
236
237
238 class BaseConn:
239     def __init__(self, outp):
240         self.outp = outp
241
242     def close(self):
243         while self._read(65536): pass
244
245     def read(self, size):
246         """Read 'size' bytes from input stream."""
247         self.outp.flush()
248         return self._read(size)
249
250     def readline(self):
251         """Read from input stream until a newline is found."""
252         self.outp.flush()
253         return self._readline()
254
255     def write(self, data):
256         """Write 'data' to output stream."""
257         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
258         self.outp.write(data)
259
260     def has_input(self):
261         """Return true if input stream is readable."""
262         raise NotImplemented("Subclasses must implement has_input")
263
264     def ok(self):
265         """Indicate end of output from last sent command."""
266         self.write('\nok\n')
267
268     def error(self, s):
269         """Indicate server error to the client."""
270         s = re.sub(r'\s+', ' ', str(s))
271         self.write('\nerror %s\n' % s)
272
273     def _check_ok(self, onempty):
274         self.outp.flush()
275         rl = ''
276         for rl in linereader(self):
277             #log('%d got line: %r\n' % (os.getpid(), rl))
278             if not rl:  # empty line
279                 continue
280             elif rl == 'ok':
281                 return None
282             elif rl.startswith('error '):
283                 #log('client: error: %s\n' % rl[6:])
284                 return NotOk(rl[6:])
285             else:
286                 onempty(rl)
287         raise Exception('server exited unexpectedly; see errors above')
288
289     def drain_and_check_ok(self):
290         """Remove all data for the current command from input stream."""
291         def onempty(rl):
292             pass
293         return self._check_ok(onempty)
294
295     def check_ok(self):
296         """Verify that server action completed successfully."""
297         def onempty(rl):
298             raise Exception('expected "ok", got %r' % rl)
299         return self._check_ok(onempty)
300
301
302 class Conn(BaseConn):
303     def __init__(self, inp, outp):
304         BaseConn.__init__(self, outp)
305         self.inp = inp
306
307     def _read(self, size):
308         return self.inp.read(size)
309
310     def _readline(self):
311         return self.inp.readline()
312
313     def has_input(self):
314         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
315         if rl:
316             assert(rl[0] == self.inp.fileno())
317             return True
318         else:
319             return None
320
321
322 def checked_reader(fd, n):
323     while n > 0:
324         rl, _, _ = select.select([fd], [], [])
325         assert(rl[0] == fd)
326         buf = os.read(fd, n)
327         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
328         yield buf
329         n -= len(buf)
330
331
332 MAX_PACKET = 128 * 1024
333 def mux(p, outfd, outr, errr):
334     try:
335         fds = [outr, errr]
336         while p.poll() is None:
337             rl, _, _ = select.select(fds, [], [])
338             for fd in rl:
339                 if fd == outr:
340                     buf = os.read(outr, MAX_PACKET)
341                     if not buf: break
342                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
343                 elif fd == errr:
344                     buf = os.read(errr, 1024)
345                     if not buf: break
346                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
347     finally:
348         os.write(outfd, struct.pack('!IB', 0, 3))
349
350
351 class DemuxConn(BaseConn):
352     """A helper class for bup's client-server protocol."""
353     def __init__(self, infd, outp):
354         BaseConn.__init__(self, outp)
355         # Anything that comes through before the sync string was not
356         # multiplexed and can be assumed to be debug/log before mux init.
357         tail = ''
358         while tail != 'BUPMUX':
359             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
360             if not b:
361                 raise IOError('demux: unexpected EOF during initialization')
362             tail += b
363             sys.stderr.write(tail[:-6])  # pre-mux log messages
364             tail = tail[-6:]
365         self.infd = infd
366         self.reader = None
367         self.buf = None
368         self.closed = False
369
370     def write(self, data):
371         self._load_buf(0)
372         BaseConn.write(self, data)
373
374     def _next_packet(self, timeout):
375         if self.closed: return False
376         rl, wl, xl = select.select([self.infd], [], [], timeout)
377         if not rl: return False
378         assert(rl[0] == self.infd)
379         ns = ''.join(checked_reader(self.infd, 5))
380         n, fdw = struct.unpack('!IB', ns)
381         assert(n <= MAX_PACKET)
382         if fdw == 1:
383             self.reader = checked_reader(self.infd, n)
384         elif fdw == 2:
385             for buf in checked_reader(self.infd, n):
386                 sys.stderr.write(buf)
387         elif fdw == 3:
388             self.closed = True
389             debug2("DemuxConn: marked closed\n")
390         return True
391
392     def _load_buf(self, timeout):
393         if self.buf is not None:
394             return True
395         while not self.closed:
396             while not self.reader:
397                 if not self._next_packet(timeout):
398                     return False
399             try:
400                 self.buf = self.reader.next()
401                 return True
402             except StopIteration:
403                 self.reader = None
404         return False
405
406     def _read_parts(self, ix_fn):
407         while self._load_buf(None):
408             assert(self.buf is not None)
409             i = ix_fn(self.buf)
410             if i is None or i == len(self.buf):
411                 yv = self.buf
412                 self.buf = None
413             else:
414                 yv = self.buf[:i]
415                 self.buf = self.buf[i:]
416             yield yv
417             if i is not None:
418                 break
419
420     def _readline(self):
421         def find_eol(buf):
422             try:
423                 return buf.index('\n')+1
424             except ValueError:
425                 return None
426         return ''.join(self._read_parts(find_eol))
427
428     def _read(self, size):
429         csize = [size]
430         def until_size(buf): # Closes on csize
431             if len(buf) < csize[0]:
432                 csize[0] -= len(buf)
433                 return None
434             else:
435                 return csize[0]
436         return ''.join(self._read_parts(until_size))
437
438     def has_input(self):
439         return self._load_buf(0)
440
441
442 def linereader(f):
443     """Generate a list of input lines from 'f' without terminating newlines."""
444     while 1:
445         line = f.readline()
446         if not line:
447             break
448         yield line[:-1]
449
450
451 def chunkyreader(f, count = None):
452     """Generate a list of chunks of data read from 'f'.
453
454     If count is None, read until EOF is reached.
455
456     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
457     reached while reading, raise IOError.
458     """
459     if count != None:
460         while count > 0:
461             b = f.read(min(count, 65536))
462             if not b:
463                 raise IOError('EOF with %d bytes remaining' % count)
464             yield b
465             count -= len(b)
466     else:
467         while 1:
468             b = f.read(65536)
469             if not b: break
470             yield b
471
472
473 def slashappend(s):
474     """Append "/" to 's' if it doesn't aleady end in "/"."""
475     if s and not s.endswith('/'):
476         return s + '/'
477     else:
478         return s
479
480
481 def _mmap_do(f, sz, flags, prot, close):
482     if not sz:
483         st = os.fstat(f.fileno())
484         sz = st.st_size
485     if not sz:
486         # trying to open a zero-length map gives an error, but an empty
487         # string has all the same behaviour of a zero-length map, ie. it has
488         # no elements :)
489         return ''
490     map = mmap.mmap(f.fileno(), sz, flags, prot)
491     if close:
492         f.close()  # map will persist beyond file close
493     return map
494
495
496 def mmap_read(f, sz = 0, close=True):
497     """Create a read-only memory mapped region on file 'f'.
498     If sz is 0, the region will cover the entire file.
499     """
500     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
501
502
503 def mmap_readwrite(f, sz = 0, close=True):
504     """Create a read-write memory mapped region on file 'f'.
505     If sz is 0, the region will cover the entire file.
506     """
507     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
508                     close)
509
510
511 def mmap_readwrite_private(f, sz = 0, close=True):
512     """Create a read-write memory mapped region on file 'f'.
513     If sz is 0, the region will cover the entire file.
514     The map is private, which means the changes are never flushed back to the
515     file.
516     """
517     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
518                     close)
519
520
521 def parse_num(s):
522     """Parse data size information into a float number.
523
524     Here are some examples of conversions:
525         199.2k means 203981 bytes
526         1GB means 1073741824 bytes
527         2.1 tb means 2199023255552 bytes
528     """
529     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
530     if not g:
531         raise ValueError("can't parse %r as a number" % s)
532     (val, unit) = g.groups()
533     num = float(val)
534     unit = unit.lower()
535     if unit in ['t', 'tb']:
536         mult = 1024*1024*1024*1024
537     elif unit in ['g', 'gb']:
538         mult = 1024*1024*1024
539     elif unit in ['m', 'mb']:
540         mult = 1024*1024
541     elif unit in ['k', 'kb']:
542         mult = 1024
543     elif unit in ['', 'b']:
544         mult = 1
545     else:
546         raise ValueError("invalid unit %r in number %r" % (unit, s))
547     return int(num*mult)
548
549
550 def count(l):
551     """Count the number of elements in an iterator. (consumes the iterator)"""
552     return reduce(lambda x,y: x+1, l)
553
554
555 saved_errors = []
556 def add_error(e):
557     """Append an error message to the list of saved errors.
558
559     Once processing is able to stop and output the errors, the saved errors are
560     accessible in the module variable helpers.saved_errors.
561     """
562     saved_errors.append(e)
563     log('%-70s\n' % e)
564
565
566 def clear_errors():
567     global saved_errors
568     saved_errors = []
569
570
571 def handle_ctrl_c():
572     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
573
574     The new exception handler will make sure that bup will exit without an ugly
575     stacktrace when Ctrl-C is hit.
576     """
577     oldhook = sys.excepthook
578     def newhook(exctype, value, traceback):
579         if exctype == KeyboardInterrupt:
580             log('Interrupted.\n')
581         else:
582             return oldhook(exctype, value, traceback)
583     sys.excepthook = newhook
584
585
586 def columnate(l, prefix):
587     """Format elements of 'l' in columns with 'prefix' leading each line.
588
589     The number of columns is determined automatically based on the string
590     lengths.
591     """
592     if not l:
593         return ""
594     l = l[:]
595     clen = max(len(s) for s in l)
596     ncols = (tty_width() - len(prefix)) / (clen + 2)
597     if ncols <= 1:
598         ncols = 1
599         clen = 0
600     cols = []
601     while len(l) % ncols:
602         l.append('')
603     rows = len(l)/ncols
604     for s in range(0, len(l), rows):
605         cols.append(l[s:s+rows])
606     out = ''
607     for row in zip(*cols):
608         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
609     return out
610
611
612 def parse_date_or_fatal(str, fatal):
613     """Parses the given date or calls Option.fatal().
614     For now we expect a string that contains a float."""
615     try:
616         date = atof(str)
617     except ValueError, e:
618         raise fatal('invalid date format (should be a float): %r' % e)
619     else:
620         return date
621
622
623 def strip_path(prefix, path):
624     """Strips a given prefix from a path.
625
626     First both paths are normalized.
627
628     Raises an Exception if no prefix is given.
629     """
630     if prefix == None:
631         raise Exception('no path given')
632
633     normalized_prefix = os.path.realpath(prefix)
634     debug2("normalized_prefix: %s\n" % normalized_prefix)
635     normalized_path = os.path.realpath(path)
636     debug2("normalized_path: %s\n" % normalized_path)
637     if normalized_path.startswith(normalized_prefix):
638         return normalized_path[len(normalized_prefix):]
639     else:
640         return path
641
642
643 def strip_base_path(path, base_paths):
644     """Strips the base path from a given path.
645
646
647     Determines the base path for the given string and then strips it
648     using strip_path().
649     Iterates over all base_paths from long to short, to prevent that
650     a too short base_path is removed.
651     """
652     normalized_path = os.path.realpath(path)
653     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
654     for bp in sorted_base_paths:
655         if normalized_path.startswith(os.path.realpath(bp)):
656             return strip_path(bp, normalized_path)
657     return path
658
659
660 def graft_path(graft_points, path):
661     normalized_path = os.path.realpath(path)
662     for graft_point in graft_points:
663         old_prefix, new_prefix = graft_point
664         if normalized_path.startswith(old_prefix):
665             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
666     return normalized_path
667
668
669 # hashlib is only available in python 2.5 or higher, but the 'sha' module
670 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
671 # python 2.4 and above without any stupid warnings, so let's try using hashlib
672 # first, and downgrade if it fails.
673 try:
674     import hashlib
675 except ImportError:
676     import sha
677     Sha1 = sha.sha
678 else:
679     Sha1 = hashlib.sha1
680
681
682 def version_date():
683     """Format bup's version date string for output."""
684     return _version.DATE.split(' ')[0]
685
686
687 def version_commit():
688     """Get the commit hash of bup's current version."""
689     return _version.COMMIT
690
691
692 def version_tag():
693     """Format bup's version tag (the official version number).
694
695     When generated from a commit other than one pointed to with a tag, the
696     returned string will be "unknown-" followed by the first seven positions of
697     the commit hash.
698     """
699     names = _version.NAMES.strip()
700     assert(names[0] == '(')
701     assert(names[-1] == ')')
702     names = names[1:-1]
703     l = [n.strip() for n in names.split(',')]
704     for n in l:
705         if n.startswith('tag: bup-'):
706             return n[9:]
707     return 'unknown-%s' % _version.COMMIT[:7]