]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
mmap: Make closing source file optional
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator
5 from bup import _version
6
7 # This function should really be in helpers, not in bup.options.  But we
8 # want options.py to be standalone so people can include it in other projects.
9 from bup.options import _tty_width
10 tty_width = _tty_width
11
12
13 def atoi(s):
14     """Convert the string 's' to an integer. Return 0 if s is not a number."""
15     try:
16         return int(s or '0')
17     except ValueError:
18         return 0
19
20
21 def atof(s):
22     """Convert the string 's' to a float. Return 0 if s is not a number."""
23     try:
24         return float(s or '0')
25     except ValueError:
26         return 0
27
28
29 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
30
31
32 # Write (blockingly) to sockets that may or may not be in blocking mode.
33 # We need this because our stderr is sometimes eaten by subprocesses
34 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
35 # leading to race conditions.  Ick.  We'll do it the hard way.
36 def _hard_write(fd, buf):
37     while buf:
38         (r,w,x) = select.select([], [fd], [], None)
39         if not w:
40             raise IOError('select(fd) returned without being writable')
41         try:
42             sz = os.write(fd, buf)
43         except OSError, e:
44             if e.errno != errno.EAGAIN:
45                 raise
46         assert(sz >= 0)
47         buf = buf[sz:]
48
49 def log(s):
50     """Print a log message to stderr."""
51     sys.stdout.flush()
52     _hard_write(sys.stderr.fileno(), s)
53
54
55 def debug1(s):
56     if buglvl >= 1:
57         log(s)
58
59
60 def debug2(s):
61     if buglvl >= 2:
62         log(s)
63
64
65 def mkdirp(d, mode=None):
66     """Recursively create directories on path 'd'.
67
68     Unlike os.makedirs(), it doesn't raise an exception if the last element of
69     the path already exists.
70     """
71     try:
72         if mode:
73             os.makedirs(d, mode)
74         else:
75             os.makedirs(d)
76     except OSError, e:
77         if e.errno == errno.EEXIST:
78             pass
79         else:
80             raise
81
82
83 def next(it):
84     """Get the next item from an iterator, None if we reached the end."""
85     try:
86         return it.next()
87     except StopIteration:
88         return None
89
90
91 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
92     if key:
93         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
94     else:
95         samekey = operator.eq
96     count = 0
97     total = sum(len(it) for it in iters)
98     iters = (iter(it) for it in iters)
99     heap = ((next(it),it) for it in iters)
100     heap = [(e,it) for e,it in heap if e]
101
102     heapq.heapify(heap)
103     pe = None
104     while heap:
105         if not count % pfreq:
106             pfunc(count, total)
107         e, it = heap[0]
108         if not samekey(e, pe):
109             pe = e
110             yield e
111         count += 1
112         try:
113             e = it.next() # Don't use next() function, it's too expensive
114         except StopIteration:
115             heapq.heappop(heap) # remove current
116         else:
117             heapq.heapreplace(heap, (e, it)) # shift current to new location
118     pfinal(count, total)
119
120
121 def unlink(f):
122     """Delete a file at path 'f' if it currently exists.
123
124     Unlike os.unlink(), does not throw an exception if the file didn't already
125     exist.
126     """
127     try:
128         os.unlink(f)
129     except OSError, e:
130         if e.errno == errno.ENOENT:
131             pass  # it doesn't exist, that's what you asked for
132
133
134 def readpipe(argv):
135     """Run a subprocess and return its output."""
136     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
137     r = p.stdout.read()
138     p.wait()
139     return r
140
141
142 def realpath(p):
143     """Get the absolute path of a file.
144
145     Behaves like os.path.realpath, but doesn't follow a symlink for the last
146     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
147     will follow symlinks in p's directory)
148     """
149     try:
150         st = os.lstat(p)
151     except OSError:
152         st = None
153     if st and stat.S_ISLNK(st.st_mode):
154         (dir, name) = os.path.split(p)
155         dir = os.path.realpath(dir)
156         out = os.path.join(dir, name)
157     else:
158         out = os.path.realpath(p)
159     #log('realpathing:%r,%r\n' % (p, out))
160     return out
161
162
163 _username = None
164 def username():
165     """Get the user's login name."""
166     global _username
167     if not _username:
168         uid = os.getuid()
169         try:
170             _username = pwd.getpwuid(uid)[0]
171         except KeyError:
172             _username = 'user%d' % uid
173     return _username
174
175
176 _userfullname = None
177 def userfullname():
178     """Get the user's full name."""
179     global _userfullname
180     if not _userfullname:
181         uid = os.getuid()
182         try:
183             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
184         except KeyError:
185             _userfullname = 'user%d' % uid
186     return _userfullname
187
188
189 _hostname = None
190 def hostname():
191     """Get the FQDN of this machine."""
192     global _hostname
193     if not _hostname:
194         _hostname = socket.getfqdn()
195     return _hostname
196
197
198 _resource_path = None
199 def resource_path(subdir=''):
200     global _resource_path
201     if not _resource_path:
202         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
203     return os.path.join(_resource_path, subdir)
204
205
206 class NotOk(Exception):
207     pass
208
209
210 class BaseConn:
211     def __init__(self, outp):
212         self.outp = outp
213
214     def close(self):
215         while self._read(65536): pass
216
217     def read(self, size):
218         """Read 'size' bytes from input stream."""
219         self.outp.flush()
220         return self._read(size)
221
222     def readline(self):
223         """Read from input stream until a newline is found."""
224         self.outp.flush()
225         return self._readline()
226
227     def write(self, data):
228         """Write 'data' to output stream."""
229         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
230         self.outp.write(data)
231
232     def has_input(self):
233         """Return true if input stream is readable."""
234         raise NotImplemented("Subclasses must implement has_input")
235
236     def ok(self):
237         """Indicate end of output from last sent command."""
238         self.write('\nok\n')
239
240     def error(self, s):
241         """Indicate server error to the client."""
242         s = re.sub(r'\s+', ' ', str(s))
243         self.write('\nerror %s\n' % s)
244
245     def _check_ok(self, onempty):
246         self.outp.flush()
247         rl = ''
248         for rl in linereader(self):
249             #log('%d got line: %r\n' % (os.getpid(), rl))
250             if not rl:  # empty line
251                 continue
252             elif rl == 'ok':
253                 return None
254             elif rl.startswith('error '):
255                 #log('client: error: %s\n' % rl[6:])
256                 return NotOk(rl[6:])
257             else:
258                 onempty(rl)
259         raise Exception('server exited unexpectedly; see errors above')
260
261     def drain_and_check_ok(self):
262         """Remove all data for the current command from input stream."""
263         def onempty(rl):
264             pass
265         return self._check_ok(onempty)
266
267     def check_ok(self):
268         """Verify that server action completed successfully."""
269         def onempty(rl):
270             raise Exception('expected "ok", got %r' % rl)
271         return self._check_ok(onempty)
272
273
274 class Conn(BaseConn):
275     def __init__(self, inp, outp):
276         BaseConn.__init__(self, outp)
277         self.inp = inp
278
279     def _read(self, size):
280         return self.inp.read(size)
281
282     def _readline(self):
283         return self.inp.readline()
284
285     def has_input(self):
286         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
287         if rl:
288             assert(rl[0] == self.inp.fileno())
289             return True
290         else:
291             return None
292
293
294 def checked_reader(fd, n):
295     while n > 0:
296         rl, _, _ = select.select([fd], [], [])
297         assert(rl[0] == fd)
298         buf = os.read(fd, n)
299         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
300         yield buf
301         n -= len(buf)
302
303
304 MAX_PACKET = 128 * 1024
305 def mux(p, outfd, outr, errr):
306     try:
307         fds = [outr, errr]
308         while p.poll() is None:
309             rl, _, _ = select.select(fds, [], [])
310             for fd in rl:
311                 if fd == outr:
312                     buf = os.read(outr, MAX_PACKET)
313                     if not buf: break
314                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
315                 elif fd == errr:
316                     buf = os.read(errr, 1024)
317                     if not buf: break
318                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
319     finally:
320         os.write(outfd, struct.pack('!IB', 0, 3))
321
322
323 class DemuxConn(BaseConn):
324     """A helper class for bup's client-server protocol."""
325     def __init__(self, infd, outp):
326         BaseConn.__init__(self, outp)
327         # Anything that comes through before the sync string was not
328         # multiplexed and can be assumed to be debug/log before mux init.
329         tail = ''
330         while tail != 'BUPMUX':
331             b = os.read(infd, (len(tail) < 6) and (6-len(tail)) or 1)
332             if not b:
333                 raise IOError('demux: unexpected EOF during initialization')
334             tail += b
335             sys.stderr.write(tail[:-6])  # pre-mux log messages
336             tail = tail[-6:]
337         self.infd = infd
338         self.reader = None
339         self.buf = None
340         self.closed = False
341
342     def write(self, data):
343         self._load_buf(0)
344         BaseConn.write(self, data)
345
346     def _next_packet(self, timeout):
347         if self.closed: return False
348         rl, wl, xl = select.select([self.infd], [], [], timeout)
349         if not rl: return False
350         assert(rl[0] == self.infd)
351         ns = ''.join(checked_reader(self.infd, 5))
352         n, fdw = struct.unpack('!IB', ns)
353         assert(n <= MAX_PACKET)
354         if fdw == 1:
355             self.reader = checked_reader(self.infd, n)
356         elif fdw == 2:
357             for buf in checked_reader(self.infd, n):
358                 sys.stderr.write(buf)
359         elif fdw == 3:
360             self.closed = True
361             debug2("DemuxConn: marked closed\n")
362         return True
363
364     def _load_buf(self, timeout):
365         if self.buf is not None:
366             return True
367         while not self.closed:
368             while not self.reader:
369                 if not self._next_packet(timeout):
370                     return False
371             try:
372                 self.buf = self.reader.next()
373                 return True
374             except StopIteration:
375                 self.reader = None
376         return False
377
378     def _read_parts(self, ix_fn):
379         while self._load_buf(None):
380             assert(self.buf is not None)
381             i = ix_fn(self.buf)
382             if i is None or i == len(self.buf):
383                 yv = self.buf
384                 self.buf = None
385             else:
386                 yv = self.buf[:i]
387                 self.buf = self.buf[i:]
388             yield yv
389             if i is not None:
390                 break
391
392     def _readline(self):
393         def find_eol(buf):
394             try:
395                 return buf.index('\n')+1
396             except ValueError:
397                 return None
398         return ''.join(self._read_parts(find_eol))
399
400     def _read(self, size):
401         csize = [size]
402         def until_size(buf): # Closes on csize
403             if len(buf) < csize[0]:
404                 csize[0] -= len(buf)
405                 return None
406             else:
407                 return csize[0]
408         return ''.join(self._read_parts(until_size))
409
410     def has_input(self):
411         return self._load_buf(0)
412
413
414 def linereader(f):
415     """Generate a list of input lines from 'f' without terminating newlines."""
416     while 1:
417         line = f.readline()
418         if not line:
419             break
420         yield line[:-1]
421
422
423 def chunkyreader(f, count = None):
424     """Generate a list of chunks of data read from 'f'.
425
426     If count is None, read until EOF is reached.
427
428     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
429     reached while reading, raise IOError.
430     """
431     if count != None:
432         while count > 0:
433             b = f.read(min(count, 65536))
434             if not b:
435                 raise IOError('EOF with %d bytes remaining' % count)
436             yield b
437             count -= len(b)
438     else:
439         while 1:
440             b = f.read(65536)
441             if not b: break
442             yield b
443
444
445 def slashappend(s):
446     """Append "/" to 's' if it doesn't aleady end in "/"."""
447     if s and not s.endswith('/'):
448         return s + '/'
449     else:
450         return s
451
452
453 def _mmap_do(f, sz, flags, prot, close):
454     if not sz:
455         st = os.fstat(f.fileno())
456         sz = st.st_size
457     if not sz:
458         # trying to open a zero-length map gives an error, but an empty
459         # string has all the same behaviour of a zero-length map, ie. it has
460         # no elements :)
461         return ''
462     map = mmap.mmap(f.fileno(), sz, flags, prot)
463     if close:
464         f.close()  # map will persist beyond file close
465     return map
466
467
468 def mmap_read(f, sz = 0, close=True):
469     """Create a read-only memory mapped region on file 'f'.
470
471     If sz is 0, the region will cover the entire file.
472     """
473     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ, close)
474
475
476 def mmap_readwrite(f, sz = 0, close=True):
477     """Create a read-write memory mapped region on file 'f'.
478
479     If sz is 0, the region will cover the entire file.
480     """
481     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE,
482                     close)
483
484
485 def parse_num(s):
486     """Parse data size information into a float number.
487
488     Here are some examples of conversions:
489         199.2k means 203981 bytes
490         1GB means 1073741824 bytes
491         2.1 tb means 2199023255552 bytes
492     """
493     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
494     if not g:
495         raise ValueError("can't parse %r as a number" % s)
496     (val, unit) = g.groups()
497     num = float(val)
498     unit = unit.lower()
499     if unit in ['t', 'tb']:
500         mult = 1024*1024*1024*1024
501     elif unit in ['g', 'gb']:
502         mult = 1024*1024*1024
503     elif unit in ['m', 'mb']:
504         mult = 1024*1024
505     elif unit in ['k', 'kb']:
506         mult = 1024
507     elif unit in ['', 'b']:
508         mult = 1
509     else:
510         raise ValueError("invalid unit %r in number %r" % (unit, s))
511     return int(num*mult)
512
513
514 def count(l):
515     """Count the number of elements in an iterator. (consumes the iterator)"""
516     return reduce(lambda x,y: x+1, l)
517
518
519 saved_errors = []
520 def add_error(e):
521     """Append an error message to the list of saved errors.
522
523     Once processing is able to stop and output the errors, the saved errors are
524     accessible in the module variable helpers.saved_errors.
525     """
526     saved_errors.append(e)
527     log('%-70s\n' % e)
528
529
530 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
531 def progress(s):
532     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
533     if istty:
534         log(s)
535
536
537 def handle_ctrl_c():
538     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
539
540     The new exception handler will make sure that bup will exit without an ugly
541     stacktrace when Ctrl-C is hit.
542     """
543     oldhook = sys.excepthook
544     def newhook(exctype, value, traceback):
545         if exctype == KeyboardInterrupt:
546             log('Interrupted.\n')
547         else:
548             return oldhook(exctype, value, traceback)
549     sys.excepthook = newhook
550
551
552 def columnate(l, prefix):
553     """Format elements of 'l' in columns with 'prefix' leading each line.
554
555     The number of columns is determined automatically based on the string
556     lengths.
557     """
558     if not l:
559         return ""
560     l = l[:]
561     clen = max(len(s) for s in l)
562     ncols = (tty_width() - len(prefix)) / (clen + 2)
563     if ncols <= 1:
564         ncols = 1
565         clen = 0
566     cols = []
567     while len(l) % ncols:
568         l.append('')
569     rows = len(l)/ncols
570     for s in range(0, len(l), rows):
571         cols.append(l[s:s+rows])
572     out = ''
573     for row in zip(*cols):
574         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
575     return out
576
577
578 def parse_date_or_fatal(str, fatal):
579     """Parses the given date or calls Option.fatal().
580     For now we expect a string that contains a float."""
581     try:
582         date = atof(str)
583     except ValueError, e:
584         raise fatal('invalid date format (should be a float): %r' % e)
585     else:
586         return date
587
588
589 def strip_path(prefix, path):
590     """Strips a given prefix from a path.
591
592     First both paths are normalized.
593
594     Raises an Exception if no prefix is given.
595     """
596     if prefix == None:
597         raise Exception('no path given')
598
599     normalized_prefix = os.path.realpath(prefix)
600     debug2("normalized_prefix: %s\n" % normalized_prefix)
601     normalized_path = os.path.realpath(path)
602     debug2("normalized_path: %s\n" % normalized_path)
603     if normalized_path.startswith(normalized_prefix):
604         return normalized_path[len(normalized_prefix):]
605     else:
606         return path
607
608
609 def strip_base_path(path, base_paths):
610     """Strips the base path from a given path.
611
612
613     Determines the base path for the given string and then strips it
614     using strip_path().
615     Iterates over all base_paths from long to short, to prevent that
616     a too short base_path is removed.
617     """
618     normalized_path = os.path.realpath(path)
619     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
620     for bp in sorted_base_paths:
621         if normalized_path.startswith(os.path.realpath(bp)):
622             return strip_path(bp, normalized_path)
623     return path
624
625
626 def graft_path(graft_points, path):
627     normalized_path = os.path.realpath(path)
628     for graft_point in graft_points:
629         old_prefix, new_prefix = graft_point
630         if normalized_path.startswith(old_prefix):
631             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
632     return normalized_path
633
634
635 # hashlib is only available in python 2.5 or higher, but the 'sha' module
636 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
637 # python 2.4 and above without any stupid warnings, so let's try using hashlib
638 # first, and downgrade if it fails.
639 try:
640     import hashlib
641 except ImportError:
642     import sha
643     Sha1 = sha.sha
644 else:
645     Sha1 = hashlib.sha1
646
647
648 def version_date():
649     """Format bup's version date string for output."""
650     return _version.DATE.split(' ')[0]
651
652
653 def version_commit():
654     """Get the commit hash of bup's current version."""
655     return _version.COMMIT
656
657
658 def version_tag():
659     """Format bup's version tag (the official version number).
660
661     When generated from a commit other than one pointed to with a tag, the
662     returned string will be "unknown-" followed by the first seven positions of
663     the commit hash.
664     """
665     names = _version.NAMES.strip()
666     assert(names[0] == '(')
667     assert(names[-1] == ')')
668     names = names[1:-1]
669     l = [n.strip() for n in names.split(',')]
670     for n in l:
671         if n.startswith('tag: bup-'):
672             return n[9:]
673     return 'unknown-%s' % _version.COMMIT[:7]