]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Add DemuxConn and `bup mux` for client-server
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re, struct
4 import heapq, operator
5 from bup import _version
6
7 # This function should really be in helpers, not in bup.options.  But we
8 # want options.py to be standalone so people can include it in other projects.
9 from bup.options import _tty_width
10 tty_width = _tty_width
11
12
13 def atoi(s):
14     """Convert the string 's' to an integer. Return 0 if s is not a number."""
15     try:
16         return int(s or '0')
17     except ValueError:
18         return 0
19
20
21 def atof(s):
22     """Convert the string 's' to a float. Return 0 if s is not a number."""
23     try:
24         return float(s or '0')
25     except ValueError:
26         return 0
27
28
29 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
30
31
32 # Write (blockingly) to sockets that may or may not be in blocking mode.
33 # We need this because our stderr is sometimes eaten by subprocesses
34 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
35 # leading to race conditions.  Ick.  We'll do it the hard way.
36 def _hard_write(fd, buf):
37     while buf:
38         (r,w,x) = select.select([], [fd], [], None)
39         if not w:
40             raise IOError('select(fd) returned without being writable')
41         try:
42             sz = os.write(fd, buf)
43         except OSError, e:
44             if e.errno != errno.EAGAIN:
45                 raise
46         assert(sz >= 0)
47         buf = buf[sz:]
48
49 def log(s):
50     """Print a log message to stderr."""
51     sys.stdout.flush()
52     _hard_write(sys.stderr.fileno(), s)
53
54
55 def debug1(s):
56     if buglvl >= 1:
57         log(s)
58
59
60 def debug2(s):
61     if buglvl >= 2:
62         log(s)
63
64
65 def mkdirp(d, mode=None):
66     """Recursively create directories on path 'd'.
67
68     Unlike os.makedirs(), it doesn't raise an exception if the last element of
69     the path already exists.
70     """
71     try:
72         if mode:
73             os.makedirs(d, mode)
74         else:
75             os.makedirs(d)
76     except OSError, e:
77         if e.errno == errno.EEXIST:
78             pass
79         else:
80             raise
81
82
83 def next(it):
84     """Get the next item from an iterator, None if we reached the end."""
85     try:
86         return it.next()
87     except StopIteration:
88         return None
89
90
91 def merge_iter(iters, pfreq, pfunc, pfinal, key=None):
92     if key:
93         samekey = lambda e, pe: getattr(e, key) == getattr(pe, key, None)
94     else:
95         samekey = operator.eq
96     count = 0
97     total = sum(len(it) for it in iters)
98     iters = (iter(it) for it in iters)
99     heap = ((next(it),it) for it in iters)
100     heap = [(e,it) for e,it in heap if e]
101
102     heapq.heapify(heap)
103     pe = None
104     while heap:
105         if not count % pfreq:
106             pfunc(count, total)
107         e, it = heap[0]
108         if not samekey(e, pe):
109             pe = e
110             yield e
111         count += 1
112         try:
113             e = it.next() # Don't use next() function, it's too expensive
114         except StopIteration:
115             heapq.heappop(heap) # remove current
116         else:
117             heapq.heapreplace(heap, (e, it)) # shift current to new location
118     pfinal(count, total)
119
120
121 def unlink(f):
122     """Delete a file at path 'f' if it currently exists.
123
124     Unlike os.unlink(), does not throw an exception if the file didn't already
125     exist.
126     """
127     try:
128         os.unlink(f)
129     except OSError, e:
130         if e.errno == errno.ENOENT:
131             pass  # it doesn't exist, that's what you asked for
132
133
134 def readpipe(argv):
135     """Run a subprocess and return its output."""
136     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
137     r = p.stdout.read()
138     p.wait()
139     return r
140
141
142 def realpath(p):
143     """Get the absolute path of a file.
144
145     Behaves like os.path.realpath, but doesn't follow a symlink for the last
146     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
147     will follow symlinks in p's directory)
148     """
149     try:
150         st = os.lstat(p)
151     except OSError:
152         st = None
153     if st and stat.S_ISLNK(st.st_mode):
154         (dir, name) = os.path.split(p)
155         dir = os.path.realpath(dir)
156         out = os.path.join(dir, name)
157     else:
158         out = os.path.realpath(p)
159     #log('realpathing:%r,%r\n' % (p, out))
160     return out
161
162
163 _username = None
164 def username():
165     """Get the user's login name."""
166     global _username
167     if not _username:
168         uid = os.getuid()
169         try:
170             _username = pwd.getpwuid(uid)[0]
171         except KeyError:
172             _username = 'user%d' % uid
173     return _username
174
175
176 _userfullname = None
177 def userfullname():
178     """Get the user's full name."""
179     global _userfullname
180     if not _userfullname:
181         uid = os.getuid()
182         try:
183             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
184         except KeyError:
185             _userfullname = 'user%d' % uid
186     return _userfullname
187
188
189 _hostname = None
190 def hostname():
191     """Get the FQDN of this machine."""
192     global _hostname
193     if not _hostname:
194         _hostname = socket.getfqdn()
195     return _hostname
196
197
198 _resource_path = None
199 def resource_path(subdir=''):
200     global _resource_path
201     if not _resource_path:
202         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
203     return os.path.join(_resource_path, subdir)
204
205 class NotOk(Exception):
206     pass
207
208 class BaseConn:
209     def __init__(self, outp):
210         self.outp = outp
211
212     def close(self):
213         while self._read(65536): pass
214
215     def read(self, size):
216         """Read 'size' bytes from input stream."""
217         self.outp.flush()
218         return self._read(size)
219
220     def readline(self):
221         """Read from input stream until a newline is found."""
222         self.outp.flush()
223         return self._readline()
224
225     def write(self, data):
226         """Write 'data' to output stream."""
227         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
228         self.outp.write(data)
229
230     def has_input(self):
231         """Return true if input stream is readable."""
232         raise NotImplemented("Subclasses must implement has_input")
233
234     def ok(self):
235         """Indicate end of output from last sent command."""
236         self.write('\nok\n')
237
238     def error(self, s):
239         """Indicate server error to the client."""
240         s = re.sub(r'\s+', ' ', str(s))
241         self.write('\nerror %s\n' % s)
242
243     def _check_ok(self, onempty):
244         self.outp.flush()
245         rl = ''
246         for rl in linereader(self):
247             #log('%d got line: %r\n' % (os.getpid(), rl))
248             if not rl:  # empty line
249                 continue
250             elif rl == 'ok':
251                 return None
252             elif rl.startswith('error '):
253                 #log('client: error: %s\n' % rl[6:])
254                 return NotOk(rl[6:])
255             else:
256                 onempty(rl)
257         raise Exception('server exited unexpectedly; see errors above')
258
259     def drain_and_check_ok(self):
260         """Remove all data for the current command from input stream."""
261         def onempty(rl):
262             pass
263         return self._check_ok(onempty)
264
265     def check_ok(self):
266         """Verify that server action completed successfully."""
267         def onempty(rl):
268             raise Exception('expected "ok", got %r' % rl)
269         return self._check_ok(onempty)
270
271 class Conn(BaseConn):
272     def __init__(self, inp, outp):
273         BaseConn.__init__(self, outp)
274         self.inp = inp
275
276     def _read(self, size):
277         return self.inp.read(size)
278
279     def _readline(self):
280         return self.inp.readline()
281
282     def has_input(self):
283         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
284         if rl:
285             assert(rl[0] == self.inp.fileno())
286             return True
287         else:
288             return None
289
290 def checked_reader(fd, n):
291     while n > 0:
292         rl, _, _ = select.select([fd], [], [])
293         assert(rl[0] == fd)
294         buf = os.read(fd, n)
295         if not buf: raise Exception("Unexpected EOF reading %d more bytes" % n)
296         yield buf
297         n -= len(buf)
298
299 MAX_PACKET = 128 * 1024
300 def mux(p, outfd, outr, errr):
301     try:
302         fds = [outr, errr]
303         while p.poll() is None:
304             rl, _, _ = select.select(fds, [], [])
305             for fd in rl:
306                 if fd == outr:
307                     buf = os.read(outr, MAX_PACKET)
308                     if not buf: break
309                     os.write(outfd, struct.pack('!IB', len(buf), 1) + buf)
310                 elif fd == errr:
311                     buf = os.read(errr, 1024)
312                     if not buf: break
313                     os.write(outfd, struct.pack('!IB', len(buf), 2) + buf)
314     finally:
315         os.write(outfd, struct.pack('!IB', 0, 3))
316
317 class DemuxConn(BaseConn):
318     """A helper class for bup's client-server protocol."""
319     def __init__(self, infd, outp):
320         BaseConn.__init__(self, outp)
321         # Anything that comes through before the sync string was not
322         # multiplexed and can be assumed to be debug/log before mux init.
323         tail = ''
324         while tail != 'BUPMUX':
325             tail += os.read(infd, 1024)
326             buf = tail[:-6]
327             tail = tail[-6:]
328             sys.stderr.write(buf)
329         self.infd = infd
330         self.reader = None
331         self.buf = None
332         self.closed = False
333
334     def write(self, data):
335         self._load_buf(0)
336         BaseConn.write(self, data)
337
338     def _next_packet(self, timeout):
339         if self.closed: return False
340         rl, wl, xl = select.select([self.infd], [], [], timeout)
341         if not rl: return False
342         assert(rl[0] == self.infd)
343         ns = ''.join(checked_reader(self.infd, 5))
344         n, fdw = struct.unpack('!IB', ns)
345         assert(n<=MAX_PACKET)
346         if fdw == 1:
347             self.reader = checked_reader(self.infd, n)
348         elif fdw == 2:
349             for buf in checked_reader(self.infd, n):
350                 sys.stderr.write(buf)
351         elif fdw == 3:
352             self.closed = True
353             debug2("DemuxConn: marked closed\n")
354         return True
355
356     def _load_buf(self, timeout):
357         if self.buf is not None:
358             return True
359         while not self.closed:
360             while not self.reader:
361                 if not self._next_packet(timeout):
362                     return False
363             try:
364                 self.buf = self.reader.next()
365                 return True
366             except StopIteration:
367                 self.reader = None
368         return False
369
370     def _read_parts(self, ix_fn):
371         while self._load_buf(None):
372             assert(self.buf is not None)
373             i = ix_fn(self.buf)
374             if i is None or i == len(self.buf):
375                 yv = self.buf
376                 self.buf = None
377             else:
378                 yv = self.buf[:i]
379                 self.buf = self.buf[i:]
380             yield yv
381             if i is not None:
382                 break
383
384     def _readline(self):
385         def find_eol(buf):
386             try:
387                 return buf.index('\n')+1
388             except ValueError:
389                 return None
390         return ''.join(self._read_parts(find_eol))
391
392     def _read(self, size):
393         csize = [size]
394         def until_size(buf): # Closes on csize
395             if len(buf) < csize[0]:
396                 csize[0] -= len(buf)
397                 return None
398             else:
399                 return csize[0]
400         return ''.join(self._read_parts(until_size))
401
402     def has_input(self):
403         return self._load_buf(0)
404
405 def linereader(f):
406     """Generate a list of input lines from 'f' without terminating newlines."""
407     while 1:
408         line = f.readline()
409         if not line:
410             break
411         yield line[:-1]
412
413
414 def chunkyreader(f, count = None):
415     """Generate a list of chunks of data read from 'f'.
416
417     If count is None, read until EOF is reached.
418
419     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
420     reached while reading, raise IOError.
421     """
422     if count != None:
423         while count > 0:
424             b = f.read(min(count, 65536))
425             if not b:
426                 raise IOError('EOF with %d bytes remaining' % count)
427             yield b
428             count -= len(b)
429     else:
430         while 1:
431             b = f.read(65536)
432             if not b: break
433             yield b
434
435
436 def slashappend(s):
437     """Append "/" to 's' if it doesn't aleady end in "/"."""
438     if s and not s.endswith('/'):
439         return s + '/'
440     else:
441         return s
442
443
444 def _mmap_do(f, sz, flags, prot):
445     if not sz:
446         st = os.fstat(f.fileno())
447         sz = st.st_size
448     if not sz:
449         # trying to open a zero-length map gives an error, but an empty
450         # string has all the same behaviour of a zero-length map, ie. it has
451         # no elements :)
452         return ''
453     map = mmap.mmap(f.fileno(), sz, flags, prot)
454     f.close()  # map will persist beyond file close
455     return map
456
457
458 def mmap_read(f, sz = 0):
459     """Create a read-only memory mapped region on file 'f'.
460
461     If sz is 0, the region will cover the entire file.
462     """
463     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
464
465
466 def mmap_readwrite(f, sz = 0):
467     """Create a read-write memory mapped region on file 'f'.
468
469     If sz is 0, the region will cover the entire file.
470     """
471     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
472
473
474 def parse_num(s):
475     """Parse data size information into a float number.
476
477     Here are some examples of conversions:
478         199.2k means 203981 bytes
479         1GB means 1073741824 bytes
480         2.1 tb means 2199023255552 bytes
481     """
482     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
483     if not g:
484         raise ValueError("can't parse %r as a number" % s)
485     (val, unit) = g.groups()
486     num = float(val)
487     unit = unit.lower()
488     if unit in ['t', 'tb']:
489         mult = 1024*1024*1024*1024
490     elif unit in ['g', 'gb']:
491         mult = 1024*1024*1024
492     elif unit in ['m', 'mb']:
493         mult = 1024*1024
494     elif unit in ['k', 'kb']:
495         mult = 1024
496     elif unit in ['', 'b']:
497         mult = 1
498     else:
499         raise ValueError("invalid unit %r in number %r" % (unit, s))
500     return int(num*mult)
501
502
503 def count(l):
504     """Count the number of elements in an iterator. (consumes the iterator)"""
505     return reduce(lambda x,y: x+1, l)
506
507
508 saved_errors = []
509 def add_error(e):
510     """Append an error message to the list of saved errors.
511
512     Once processing is able to stop and output the errors, the saved errors are
513     accessible in the module variable helpers.saved_errors.
514     """
515     saved_errors.append(e)
516     log('%-70s\n' % e)
517
518 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
519 def progress(s):
520     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
521     if istty:
522         log(s)
523
524
525 def handle_ctrl_c():
526     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
527
528     The new exception handler will make sure that bup will exit without an ugly
529     stacktrace when Ctrl-C is hit.
530     """
531     oldhook = sys.excepthook
532     def newhook(exctype, value, traceback):
533         if exctype == KeyboardInterrupt:
534             log('Interrupted.\n')
535         else:
536             return oldhook(exctype, value, traceback)
537     sys.excepthook = newhook
538
539
540 def columnate(l, prefix):
541     """Format elements of 'l' in columns with 'prefix' leading each line.
542
543     The number of columns is determined automatically based on the string
544     lengths.
545     """
546     if not l:
547         return ""
548     l = l[:]
549     clen = max(len(s) for s in l)
550     ncols = (tty_width() - len(prefix)) / (clen + 2)
551     if ncols <= 1:
552         ncols = 1
553         clen = 0
554     cols = []
555     while len(l) % ncols:
556         l.append('')
557     rows = len(l)/ncols
558     for s in range(0, len(l), rows):
559         cols.append(l[s:s+rows])
560     out = ''
561     for row in zip(*cols):
562         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
563     return out
564
565 def parse_date_or_fatal(str, fatal):
566     """Parses the given date or calls Option.fatal().
567     For now we expect a string that contains a float."""
568     try:
569         date = atof(str)
570     except ValueError, e:
571         raise fatal('invalid date format (should be a float): %r' % e)
572     else:
573         return date
574
575 def strip_path(prefix, path):
576     """Strips a given prefix from a path.
577
578     First both paths are normalized.
579
580     Raises an Exception if no prefix is given.
581     """
582     if prefix == None:
583         raise Exception('no path given')
584
585     normalized_prefix = os.path.realpath(prefix)
586     debug2("normalized_prefix: %s\n" % normalized_prefix)
587     normalized_path = os.path.realpath(path)
588     debug2("normalized_path: %s\n" % normalized_path)
589     if normalized_path.startswith(normalized_prefix):
590         return normalized_path[len(normalized_prefix):]
591     else:
592         return path
593
594 def strip_base_path(path, base_paths):
595     """Strips the base path from a given path.
596
597
598     Determines the base path for the given string and then strips it
599     using strip_path().
600     Iterates over all base_paths from long to short, to prevent that
601     a too short base_path is removed.
602     """
603     normalized_path = os.path.realpath(path)
604     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
605     for bp in sorted_base_paths:
606         if normalized_path.startswith(os.path.realpath(bp)):
607             return strip_path(bp, normalized_path)
608     return path
609
610 def graft_path(graft_points, path):
611     normalized_path = os.path.realpath(path)
612     for graft_point in graft_points:
613         old_prefix, new_prefix = graft_point
614         if normalized_path.startswith(old_prefix):
615             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
616     return normalized_path
617
618
619 # hashlib is only available in python 2.5 or higher, but the 'sha' module
620 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
621 # python 2.4 and above without any stupid warnings, so let's try using hashlib
622 # first, and downgrade if it fails.
623 try:
624     import hashlib
625 except ImportError:
626     import sha
627     Sha1 = sha.sha
628 else:
629     Sha1 = hashlib.sha1
630
631
632 def version_date():
633     """Format bup's version date string for output."""
634     return _version.DATE.split(' ')[0]
635
636 def version_commit():
637     """Get the commit hash of bup's current version."""
638     return _version.COMMIT
639
640 def version_tag():
641     """Format bup's version tag (the official version number).
642
643     When generated from a commit other than one pointed to with a tag, the
644     returned string will be "unknown-" followed by the first seven positions of
645     the commit hash.
646     """
647     names = _version.NAMES.strip()
648     assert(names[0] == '(')
649     assert(names[-1] == ')')
650     names = names[1:-1]
651     l = [n.strip() for n in names.split(',')]
652     for n in l:
653         if n.startswith('tag: bup-'):
654             return n[9:]
655     return 'unknown-%s' % _version.COMMIT[:7]