]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Merge remote branch 'origin/master' into meta
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator
5 from bup import _version
6 import bup._helpers as _helpers
7
8 # This function should really be in helpers, not in bup.options.  But we
9 # want options.py to be standalone so people can include it in other projects.
10 from bup.options import _tty_width
11 tty_width = _tty_width
12
13
14 def atoi(s):
15     """Convert the string 's' to an integer. Return 0 if s is not a number."""
16     try:
17         return int(s or '0')
18     except ValueError:
19         return 0
20
21
22 def atof(s):
23     """Convert the string 's' to a float. Return 0 if s is not a number."""
24     try:
25         return float(s or '0')
26     except ValueError:
27         return 0
28
29
30 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
31
32
33 # Write (blockingly) to sockets that may or may not be in blocking mode.
34 # We need this because our stderr is sometimes eaten by subprocesses
35 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
36 # leading to race conditions.  Ick.  We'll do it the hard way.
37 def _hard_write(fd, buf):
38     while buf:
39         (r,w,x) = select.select([], [fd], [], None)
40         if not w:
41             raise IOError('select(fd) returned without being writable')
42         try:
43             sz = os.write(fd, buf)
44         except OSError, e:
45             if e.errno != errno.EAGAIN:
46                 raise
47         assert(sz >= 0)
48         buf = buf[sz:]
49
50 def log(s):
51     """Print a log message to stderr."""
52     sys.stdout.flush()
53     _hard_write(sys.stderr.fileno(), s)
54
55
56 def debug1(s):
57     if buglvl >= 1:
58         log(s)
59
60
61 def debug2(s):
62     if buglvl >= 2:
63         log(s)
64
65
66 def mkdirp(d, mode=None):
67     """Recursively create directories on path 'd'.
68
69     Unlike os.makedirs(), it doesn't raise an exception if the last element of
70     the path already exists.
71     """
72     try:
73         if mode:
74             os.makedirs(d, mode)
75         else:
76             os.makedirs(d)
77     except OSError, e:
78         if e.errno == errno.EEXIST:
79             pass
80         else:
81             raise
82
83
84 def next(it):
85     """Get the next item from an iterator, None if we reached the end."""
86     try:
87         return it.next()
88     except StopIteration:
89         return None
90
91
92 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
93     if key:
94         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
95     else:
96         samekey = operator.eq
97     count = 0
98     total = sum(len(it) for it in iters)
99     iters = (iter(it) for it in iters)
100     heap = ((next(it),it) for it in iters)
101     heap = [(e,it) for e,it in heap if e]
102
103     heapq.heapify(heap)
104     pe = None
105     while heap:
106         if not count % pfreq:
107             pfunc(count, total)
108         e, it = heap[0]
109         if not samekey(e, pe):
110             pe = e
111             yield e
112         count += 1
113         try:
114             e = it.next() # Don't use next() function, it's too expensive
115         except StopIteration:
116             heapq.heappop(heap) # remove current
117         else:
118             heapq.heapreplace(heap, (e, it)) # shift current to new location
119     pfinal(count, total)
120
121
122 def unlink(f):
123     """Delete a file at path 'f' if it currently exists.
124
125     Unlike os.unlink(), does not throw an exception if the file didn't already
126     exist.
127     """
128     try:
129         os.unlink(f)
130     except OSError, e:
131         if e.errno == errno.ENOENT:
132             pass  # it doesn't exist, that's what you asked for
133
134
135 def readpipe(argv):
136     """Run a subprocess and return its output."""
137     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
138     r = p.stdout.read()
139     p.wait()
140     return r
141
142
143 def realpath(p):
144     """Get the absolute path of a file.
145
146     Behaves like os.path.realpath, but doesn't follow a symlink for the last
147     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
148     will follow symlinks in p's directory)
149     """
150     try:
151         st = os.lstat(p)
152     except OSError:
153         st = None
154     if st and stat.S_ISLNK(st.st_mode):
155         (dir, name) = os.path.split(p)
156         dir = os.path.realpath(dir)
157         out = os.path.join(dir, name)
158     else:
159         out = os.path.realpath(p)
160     #log('realpathing:%r,%r\n' % (p, out))
161     return out
162
163
164 def detect_fakeroot():
165     "Return True if we appear to be running under fakeroot."
166     return os.getenv("FAKEROOTKEY") != None
167
168
169 _username = None
170 def username():
171     """Get the user's login name."""
172     global _username
173     if not _username:
174         uid = os.getuid()
175         try:
176             _username = pwd.getpwuid(uid)[0]
177         except KeyError:
178             _username = 'user%d' % uid
179     return _username
180
181
182 _userfullname = None
183 def userfullname():
184     """Get the user's full name."""
185     global _userfullname
186     if not _userfullname:
187         uid = os.getuid()
188         try:
189             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
190         except KeyError:
191             _userfullname = 'user%d' % uid
192     return _userfullname
193
194
195 _hostname = None
196 def hostname():
197     """Get the FQDN of this machine."""
198     global _hostname
199     if not _hostname:
200         _hostname = socket.getfqdn()
201     return _hostname
202
203
204 _resource_path = None
205 def resource_path(subdir=''):
206     global _resource_path
207     if not _resource_path:
208         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
209     return os.path.join(_resource_path, subdir)
210
211
212 class NotOk(Exception):
213     pass
214
215
216 class BaseConn:
217     def __init__(self, outp):
218         self.outp = outp
219
220     def close(self):
221         while self._read(65536): pass
222
223     def read(self, size):
224         """Read 'size' bytes from input stream."""
225         self.outp.flush()
226         return self._read(size)
227
228     def readline(self):
229         """Read from input stream until a newline is found."""
230         self.outp.flush()
231         return self._readline()
232
233     def write(self, data):
234         """Write 'data' to output stream."""
235         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
236         self.outp.write(data)
237
238     def has_input(self):
239         """Return true if input stream is readable."""
240         raise NotImplemented("Subclasses must implement has_input")
241
242     def ok(self):
243         """Indicate end of output from last sent command."""
244         self.write('\nok\n')
245
246     def error(self, s):
247         """Indicate server error to the client."""
248         s = re.sub(r'\s+', ' ', str(s))
249         self.write('\nerror %s\n' % s)
250
251     def _check_ok(self, onempty):
252         self.outp.flush()
253         rl = ''
254         for rl in linereader(self):
255             #log('%d got line: %r\n' % (os.getpid(), rl))
256             if not rl:  # empty line
257                 continue
258             elif rl == 'ok':
259                 return None
260             elif rl.startswith('error '):
261                 #log('client: error: %s\n' % rl[6:])
262                 return NotOk(rl[6:])
263             else:
264                 onempty(rl)
265         raise Exception('server exited unexpectedly; see errors above')
266
267     def drain_and_check_ok(self):
268         """Remove all data for the current command from input stream."""
269         def onempty(rl):
270             pass
271         return self._check_ok(onempty)
272
273     def check_ok(self):
274         """Verify that server action completed successfully."""
275         def onempty(rl):
276             raise Exception('expected "ok", got %r' % rl)
277         return self._check_ok(onempty)
278
279
280 class Conn(BaseConn):
281     def __init__(self, inp, outp):
282         BaseConn.__init__(self, outp)
283         self.inp = inp
284
285     def _read(self, size):
286         return self.inp.read(size)
287
288     def _readline(self):
289         return self.inp.readline()
290
291     def has_input(self):
292         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
293         if rl:
294             assert(rl[0] == self.inp.fileno())
295             return True
296         else:
297             return None
298
299
300 def checked_reader(fd, n):
301     while n > 0:
302         rl, _, _ = select.select([fd], [], [])
303         assert(rl[0] == fd)
304         buf = os.read(fd, n)
305         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
306         yield buf
307         n -= len(buf)
308
309
310 MAX_PACKET = 128 * 1024
311 def mux(p, outfd, outr, errr):
312     try:
313         fds = [outr, errr]
314         while p.poll() is None:
315             rl, _, _ = select.select(fds, [], [])
316             for fd in rl:
317                 if fd == outr:
318                     buf = os.read(outr, MAX_PACKET)
319                     if not buf: break
320                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
321                 elif fd == errr:
322                     buf = os.read(errr, 1024)
323                     if not buf: break
324                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
325     finally:
326         os.write(outfd, struct.pack('!IB', 0, 3))
327
328
329 class DemuxConn(BaseConn):
330     """A helper class for bup's client-server protocol."""
331     def __init__(self, infd, outp):
332         BaseConn.__init__(self, outp)
333         # Anything that comes through before the sync string was not
334         # multiplexed and can be assumed to be debug/log before mux init.
335         tail = ''
336         while tail != 'BUPMUX':
337             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
338             if not b:
339                 raise IOError('demux: unexpected EOF during initialization')
340             tail += b
341             sys.stderr.write(tail[:-6])  # pre-mux log messages
342             tail = tail[-6:]
343         self.infd = infd
344         self.reader = None
345         self.buf = None
346         self.closed = False
347
348     def write(self, data):
349         self._load_buf(0)
350         BaseConn.write(self, data)
351
352     def _next_packet(self, timeout):
353         if self.closed: return False
354         rl, wl, xl = select.select([self.infd], [], [], timeout)
355         if not rl: return False
356         assert(rl[0] == self.infd)
357         ns = ''.join(checked_reader(self.infd, 5))
358         n, fdw = struct.unpack('!IB', ns)
359         assert(n <= MAX_PACKET)
360         if fdw == 1:
361             self.reader = checked_reader(self.infd, n)
362         elif fdw == 2:
363             for buf in checked_reader(self.infd, n):
364                 sys.stderr.write(buf)
365         elif fdw == 3:
366             self.closed = True
367             debug2("DemuxConn: marked closed\n")
368         return True
369
370     def _load_buf(self, timeout):
371         if self.buf is not None:
372             return True
373         while not self.closed:
374             while not self.reader:
375                 if not self._next_packet(timeout):
376                     return False
377             try:
378                 self.buf = self.reader.next()
379                 return True
380             except StopIteration:
381                 self.reader = None
382         return False
383
384     def _read_parts(self, ix_fn):
385         while self._load_buf(None):
386             assert(self.buf is not None)
387             i = ix_fn(self.buf)
388             if i is None or i == len(self.buf):
389                 yv = self.buf
390                 self.buf = None
391             else:
392                 yv = self.buf[:i]
393                 self.buf = self.buf[i:]
394             yield yv
395             if i is not None:
396                 break
397
398     def _readline(self):
399         def find_eol(buf):
400             try:
401                 return buf.index('\n')+1
402             except ValueError:
403                 return None
404         return ''.join(self._read_parts(find_eol))
405
406     def _read(self, size):
407         csize = [size]
408         def until_size(buf): # Closes on csize
409             if len(buf) < csize[0]:
410                 csize[0] -= len(buf)
411                 return None
412             else:
413                 return csize[0]
414         return ''.join(self._read_parts(until_size))
415
416     def has_input(self):
417         return self._load_buf(0)
418
419
420 def linereader(f):
421     """Generate a list of input lines from 'f' without terminating newlines."""
422     while 1:
423         line = f.readline()
424         if not line:
425             break
426         yield line[:-1]
427
428
429 def chunkyreader(f, count = None):
430     """Generate a list of chunks of data read from 'f'.
431
432     If count is None, read until EOF is reached.
433
434     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
435     reached while reading, raise IOError.
436     """
437     if count != None:
438         while count > 0:
439             b = f.read(min(count, 65536))
440             if not b:
441                 raise IOError('EOF with %d bytes remaining' % count)
442             yield b
443             count -= len(b)
444     else:
445         while 1:
446             b = f.read(65536)
447             if not b: break
448             yield b
449
450
451 def slashappend(s):
452     """Append "/" to 's' if it doesn't aleady end in "/"."""
453     if s and not s.endswith('/'):
454         return s + '/'
455     else:
456         return s
457
458
459 def _mmap_do(f, sz, flags, prot, close):
460     if not sz:
461         st = os.fstat(f.fileno())
462         sz = st.st_size
463     if not sz:
464         # trying to open a zero-length map gives an error, but an empty
465         # string has all the same behaviour of a zero-length map, ie. it has
466         # no elements :)
467         return ''
468     map = mmap.mmap(f.fileno(), sz, flags, prot)
469     if close:
470         f.close()  # map will persist beyond file close
471     return map
472
473
474 def mmap_read(f, sz = 0, close=True):
475     """Create a read-only memory mapped region on file 'f'.
476     If sz is 0, the region will cover the entire file.
477     """
478     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
479
480
481 def mmap_readwrite(f, sz = 0, close=True):
482     """Create a read-write memory mapped region on file 'f'.
483     If sz is 0, the region will cover the entire file.
484     """
485     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
486                     close)
487
488
489 def mmap_readwrite_private(f, sz = 0, close=True):
490     """Create a read-write memory mapped region on file 'f'.
491     If sz is 0, the region will cover the entire file.
492     The map is private, which means the changes are never flushed back to the
493     file.
494     """
495     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ|mmap.PROT_WRITE,
496                     close)
497
498
499 def parse_num(s):
500     """Parse data size information into a float number.
501
502     Here are some examples of conversions:
503         199.2k means 203981 bytes
504         1GB means 1073741824 bytes
505         2.1 tb means 2199023255552 bytes
506     """
507     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
508     if not g:
509         raise ValueError("can't parse %r as a number" % s)
510     (val, unit) = g.groups()
511     num = float(val)
512     unit = unit.lower()
513     if unit in ['t', 'tb']:
514         mult = 1024*1024*1024*1024
515     elif unit in ['g', 'gb']:
516         mult = 1024*1024*1024
517     elif unit in ['m', 'mb']:
518         mult = 1024*1024
519     elif unit in ['k', 'kb']:
520         mult = 1024
521     elif unit in ['', 'b']:
522         mult = 1
523     else:
524         raise ValueError("invalid unit %r in number %r" % (unit, s))
525     return int(num*mult)
526
527
528 def count(l):
529     """Count the number of elements in an iterator. (consumes the iterator)"""
530     return reduce(lambda x,y: x+1, l)
531
532
533 saved_errors = []
534 def add_error(e):
535     """Append an error message to the list of saved errors.
536
537     Once processing is able to stop and output the errors, the saved errors are
538     accessible in the module variable helpers.saved_errors.
539     """
540     saved_errors.append(e)
541     log('%-70s\n' % e)
542
543
544 def clear_errors():
545     global saved_errors
546     saved_errors = []
547
548
549 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
550 def progress(s):
551     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
552     if istty:
553         log(s)
554
555
556 def handle_ctrl_c():
557     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
558
559     The new exception handler will make sure that bup will exit without an ugly
560     stacktrace when Ctrl-C is hit.
561     """
562     oldhook = sys.excepthook
563     def newhook(exctype, value, traceback):
564         if exctype == KeyboardInterrupt:
565             log('Interrupted.\n')
566         else:
567             return oldhook(exctype, value, traceback)
568     sys.excepthook = newhook
569
570
571 def columnate(l, prefix):
572     """Format elements of 'l' in columns with 'prefix' leading each line.
573
574     The number of columns is determined automatically based on the string
575     lengths.
576     """
577     if not l:
578         return ""
579     l = l[:]
580     clen = max(len(s) for s in l)
581     ncols = (tty_width() - len(prefix)) / (clen + 2)
582     if ncols <= 1:
583         ncols = 1
584         clen = 0
585     cols = []
586     while len(l) % ncols:
587         l.append('')
588     rows = len(l)/ncols
589     for s in range(0, len(l), rows):
590         cols.append(l[s:s+rows])
591     out = ''
592     for row in zip(*cols):
593         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
594     return out
595
596
597 def parse_date_or_fatal(str, fatal):
598     """Parses the given date or calls Option.fatal().
599     For now we expect a string that contains a float."""
600     try:
601         date = atof(str)
602     except ValueError, e:
603         raise fatal('invalid date format (should be a float): %r' % e)
604     else:
605         return date
606
607
608 def strip_path(prefix, path):
609     """Strips a given prefix from a path.
610
611     First both paths are normalized.
612
613     Raises an Exception if no prefix is given.
614     """
615     if prefix == None:
616         raise Exception('no path given')
617
618     normalized_prefix = os.path.realpath(prefix)
619     debug2("normalized_prefix: %s\n" % normalized_prefix)
620     normalized_path = os.path.realpath(path)
621     debug2("normalized_path: %s\n" % normalized_path)
622     if normalized_path.startswith(normalized_prefix):
623         return normalized_path[len(normalized_prefix):]
624     else:
625         return path
626
627
628 def strip_base_path(path, base_paths):
629     """Strips the base path from a given path.
630
631
632     Determines the base path for the given string and then strips it
633     using strip_path().
634     Iterates over all base_paths from long to short, to prevent that
635     a too short base_path is removed.
636     """
637     normalized_path = os.path.realpath(path)
638     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
639     for bp in sorted_base_paths:
640         if normalized_path.startswith(os.path.realpath(bp)):
641             return strip_path(bp, normalized_path)
642     return path
643
644
645 def graft_path(graft_points, path):
646     normalized_path = os.path.realpath(path)
647     for graft_point in graft_points:
648         old_prefix, new_prefix = graft_point
649         if normalized_path.startswith(old_prefix):
650             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
651     return normalized_path
652
653
654 # hashlib is only available in python 2.5 or higher, but the 'sha' module
655 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
656 # python 2.4 and above without any stupid warnings, so let's try using hashlib
657 # first, and downgrade if it fails.
658 try:
659     import hashlib
660 except ImportError:
661     import sha
662     Sha1 = sha.sha
663 else:
664     Sha1 = hashlib.sha1
665
666
667 def version_date():
668     """Format bup's version date string for output."""
669     return _version.DATE.split(' ')[0]
670
671
672 def version_commit():
673     """Get the commit hash of bup's current version."""
674     return _version.COMMIT
675
676
677 def version_tag():
678     """Format bup's version tag (the official version number).
679
680     When generated from a commit other than one pointed to with a tag, the
681     returned string will be "unknown-" followed by the first seven positions of
682     the commit hash.
683     """
684     names = _version.NAMES.strip()
685     assert(names[0] == '(')
686     assert(names[-1] == ')')
687     names = names[1:-1]
688     l = [n.strip() for n in names.split(',')]
689     for n in l:
690         if n.startswith('tag: bup-'):
691             return n[9:]
692     return 'unknown-%s' % _version.COMMIT[:7]