]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Merge branches 'gf/ls', 'gf/tag', 'zz/import-rsnapshot' and 'bl/selfindex'
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
3 from bup import _version
4
5 # This function should really be in helpers, not in bup.options.  But we
6 # want options.py to be standalone so people can include it in other projects.
7 from bup.options import _tty_width
8 tty_width = _tty_width
9
10
11 def atoi(s):
12     """Convert the string 's' to an integer. Return 0 if s is not a number."""
13     try:
14         return int(s or '0')
15     except ValueError:
16         return 0
17
18
19 def atof(s):
20     """Convert the string 's' to a float. Return 0 if s is not a number."""
21     try:
22         return float(s or '0')
23     except ValueError:
24         return 0
25
26
27 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
28
29
30 # Write (blockingly) to sockets that may or may not be in blocking mode.
31 # We need this because our stderr is sometimes eaten by subprocesses
32 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
33 # leading to race conditions.  Ick.  We'll do it the hard way.
34 def _hard_write(fd, buf):
35     while buf:
36         (r,w,x) = select.select([], [fd], [], None)
37         if not w:
38             raise IOError('select(fd) returned without being writable')
39         try:
40             sz = os.write(fd, buf)
41         except OSError, e:
42             if e.errno != errno.EAGAIN:
43                 raise
44         assert(sz >= 0)
45         buf = buf[sz:]
46
47 def log(s):
48     """Print a log message to stderr."""
49     sys.stdout.flush()
50     _hard_write(sys.stderr.fileno(), s)
51
52
53 def debug1(s):
54     if buglvl >= 1:
55         log(s)
56
57
58 def debug2(s):
59     if buglvl >= 2:
60         log(s)
61
62
63 def mkdirp(d, mode=None):
64     """Recursively create directories on path 'd'.
65
66     Unlike os.makedirs(), it doesn't raise an exception if the last element of
67     the path already exists.
68     """
69     try:
70         if mode:
71             os.makedirs(d, mode)
72         else:
73             os.makedirs(d)
74     except OSError, e:
75         if e.errno == errno.EEXIST:
76             pass
77         else:
78             raise
79
80
81 def next(it):
82     """Get the next item from an iterator, None if we reached the end."""
83     try:
84         return it.next()
85     except StopIteration:
86         return None
87
88
89 def unlink(f):
90     """Delete a file at path 'f' if it currently exists.
91
92     Unlike os.unlink(), does not throw an exception if the file didn't already
93     exist.
94     """
95     try:
96         os.unlink(f)
97     except OSError, e:
98         if e.errno == errno.ENOENT:
99             pass  # it doesn't exist, that's what you asked for
100
101
102 def readpipe(argv):
103     """Run a subprocess and return its output."""
104     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
105     r = p.stdout.read()
106     p.wait()
107     return r
108
109
110 def realpath(p):
111     """Get the absolute path of a file.
112
113     Behaves like os.path.realpath, but doesn't follow a symlink for the last
114     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
115     will follow symlinks in p's directory)
116     """
117     try:
118         st = os.lstat(p)
119     except OSError:
120         st = None
121     if st and stat.S_ISLNK(st.st_mode):
122         (dir, name) = os.path.split(p)
123         dir = os.path.realpath(dir)
124         out = os.path.join(dir, name)
125     else:
126         out = os.path.realpath(p)
127     #log('realpathing:%r,%r\n' % (p, out))
128     return out
129
130
131 _username = None
132 def username():
133     """Get the user's login name."""
134     global _username
135     if not _username:
136         uid = os.getuid()
137         try:
138             _username = pwd.getpwuid(uid)[0]
139         except KeyError:
140             _username = 'user%d' % uid
141     return _username
142
143
144 _userfullname = None
145 def userfullname():
146     """Get the user's full name."""
147     global _userfullname
148     if not _userfullname:
149         uid = os.getuid()
150         try:
151             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
152         except KeyError:
153             _userfullname = 'user%d' % uid
154     return _userfullname
155
156
157 _hostname = None
158 def hostname():
159     """Get the FQDN of this machine."""
160     global _hostname
161     if not _hostname:
162         _hostname = socket.getfqdn()
163     return _hostname
164
165
166 _resource_path = None
167 def resource_path(subdir=''):
168     global _resource_path
169     if not _resource_path:
170         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
171     return os.path.join(_resource_path, subdir)
172
173 class NotOk(Exception):
174     pass
175
176 class Conn:
177     """A helper class for bup's client-server protocol."""
178     def __init__(self, inp, outp):
179         self.inp = inp
180         self.outp = outp
181
182     def read(self, size):
183         """Read 'size' bytes from input stream."""
184         self.outp.flush()
185         return self.inp.read(size)
186
187     def readline(self):
188         """Read from input stream until a newline is found."""
189         self.outp.flush()
190         return self.inp.readline()
191
192     def write(self, data):
193         """Write 'data' to output stream."""
194         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
195         self.outp.write(data)
196
197     def has_input(self):
198         """Return true if input stream is readable."""
199         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
200         if rl:
201             assert(rl[0] == self.inp.fileno())
202             return True
203         else:
204             return None
205
206     def ok(self):
207         """Indicate end of output from last sent command."""
208         self.write('\nok\n')
209
210     def error(self, s):
211         """Indicate server error to the client."""
212         s = re.sub(r'\s+', ' ', str(s))
213         self.write('\nerror %s\n' % s)
214
215     def _check_ok(self, onempty):
216         self.outp.flush()
217         rl = ''
218         for rl in linereader(self.inp):
219             #log('%d got line: %r\n' % (os.getpid(), rl))
220             if not rl:  # empty line
221                 continue
222             elif rl == 'ok':
223                 return None
224             elif rl.startswith('error '):
225                 #log('client: error: %s\n' % rl[6:])
226                 return NotOk(rl[6:])
227             else:
228                 onempty(rl)
229         raise Exception('server exited unexpectedly; see errors above')
230
231     def drain_and_check_ok(self):
232         """Remove all data for the current command from input stream."""
233         def onempty(rl):
234             pass
235         return self._check_ok(onempty)
236
237     def check_ok(self):
238         """Verify that server action completed successfully."""
239         def onempty(rl):
240             raise Exception('expected "ok", got %r' % rl)
241         return self._check_ok(onempty)
242
243
244 def linereader(f):
245     """Generate a list of input lines from 'f' without terminating newlines."""
246     while 1:
247         line = f.readline()
248         if not line:
249             break
250         yield line[:-1]
251
252
253 def chunkyreader(f, count = None):
254     """Generate a list of chunks of data read from 'f'.
255
256     If count is None, read until EOF is reached.
257
258     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
259     reached while reading, raise IOError.
260     """
261     if count != None:
262         while count > 0:
263             b = f.read(min(count, 65536))
264             if not b:
265                 raise IOError('EOF with %d bytes remaining' % count)
266             yield b
267             count -= len(b)
268     else:
269         while 1:
270             b = f.read(65536)
271             if not b: break
272             yield b
273
274
275 def slashappend(s):
276     """Append "/" to 's' if it doesn't aleady end in "/"."""
277     if s and not s.endswith('/'):
278         return s + '/'
279     else:
280         return s
281
282
283 def _mmap_do(f, sz, flags, prot):
284     if not sz:
285         st = os.fstat(f.fileno())
286         sz = st.st_size
287     if not sz:
288         # trying to open a zero-length map gives an error, but an empty
289         # string has all the same behaviour of a zero-length map, ie. it has
290         # no elements :)
291         return ''
292     map = mmap.mmap(f.fileno(), sz, flags, prot)
293     f.close()  # map will persist beyond file close
294     return map
295
296
297 def mmap_read(f, sz = 0):
298     """Create a read-only memory mapped region on file 'f'.
299
300     If sz is 0, the region will cover the entire file.
301     """
302     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
303
304
305 def mmap_readwrite(f, sz = 0):
306     """Create a read-write memory mapped region on file 'f'.
307
308     If sz is 0, the region will cover the entire file.
309     """
310     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
311
312
313 def parse_num(s):
314     """Parse data size information into a float number.
315
316     Here are some examples of conversions:
317         199.2k means 203981 bytes
318         1GB means 1073741824 bytes
319         2.1 tb means 2199023255552 bytes
320     """
321     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
322     if not g:
323         raise ValueError("can't parse %r as a number" % s)
324     (val, unit) = g.groups()
325     num = float(val)
326     unit = unit.lower()
327     if unit in ['t', 'tb']:
328         mult = 1024*1024*1024*1024
329     elif unit in ['g', 'gb']:
330         mult = 1024*1024*1024
331     elif unit in ['m', 'mb']:
332         mult = 1024*1024
333     elif unit in ['k', 'kb']:
334         mult = 1024
335     elif unit in ['', 'b']:
336         mult = 1
337     else:
338         raise ValueError("invalid unit %r in number %r" % (unit, s))
339     return int(num*mult)
340
341
342 def count(l):
343     """Count the number of elements in an iterator. (consumes the iterator)"""
344     return reduce(lambda x,y: x+1, l)
345
346
347 saved_errors = []
348 def add_error(e):
349     """Append an error message to the list of saved errors.
350
351     Once processing is able to stop and output the errors, the saved errors are
352     accessible in the module variable helpers.saved_errors.
353     """
354     saved_errors.append(e)
355     log('%-70s\n' % e)
356
357 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
358 def progress(s):
359     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
360     if istty:
361         log(s)
362
363
364 def handle_ctrl_c():
365     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
366
367     The new exception handler will make sure that bup will exit without an ugly
368     stacktrace when Ctrl-C is hit.
369     """
370     oldhook = sys.excepthook
371     def newhook(exctype, value, traceback):
372         if exctype == KeyboardInterrupt:
373             log('Interrupted.\n')
374         else:
375             return oldhook(exctype, value, traceback)
376     sys.excepthook = newhook
377
378
379 def columnate(l, prefix):
380     """Format elements of 'l' in columns with 'prefix' leading each line.
381
382     The number of columns is determined automatically based on the string
383     lengths.
384     """
385     if not l:
386         return ""
387     l = l[:]
388     clen = max(len(s) for s in l)
389     ncols = (tty_width() - len(prefix)) / (clen + 2)
390     if ncols <= 1:
391         ncols = 1
392         clen = 0
393     cols = []
394     while len(l) % ncols:
395         l.append('')
396     rows = len(l)/ncols
397     for s in range(0, len(l), rows):
398         cols.append(l[s:s+rows])
399     out = ''
400     for row in zip(*cols):
401         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
402     return out
403
404 def parse_date_or_fatal(str, fatal):
405     """Parses the given date or calls Option.fatal().
406     For now we expect a string that contains a float."""
407     try:
408         date = atof(str)
409     except ValueError, e:
410         raise fatal('invalid date format (should be a float): %r' % e)
411     else:
412         return date
413
414 def strip_path(prefix, path):
415     """Strips a given prefix from a path.
416
417     First both paths are normalized.
418
419     Raises an Exception if no prefix is given.
420     """
421     if prefix == None:
422         raise Exception('no path given')
423
424     normalized_prefix = realpath(prefix)
425     print "normalized_prefix: " + normalized_prefix
426     normalized_path = realpath(path)
427     print "normalized_path: " + normalized_path
428     if normalized_path.startswith(normalized_prefix):
429         return normalized_path[len(normalized_prefix):]
430     else:
431         return path
432
433 def strip_base_path(path, base_paths):
434     """Strips the base path from a given path.
435
436     Determines the base path for the given string and the strips it
437     using strip_path().
438     Iterates over all base_paths from long to short, to prevent that
439     a too short base_path is removed.
440     """
441     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
442     for bp in sorted_base_paths:
443         if path.startswith(realpath(bp)):
444             return strip_path(bp, path)
445     return path
446
447
448 # hashlib is only available in python 2.5 or higher, but the 'sha' module
449 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
450 # python 2.4 and above without any stupid warnings, so let's try using hashlib
451 # first, and downgrade if it fails.
452 try:
453     import hashlib
454 except ImportError:
455     import sha
456     Sha1 = sha.sha
457 else:
458     Sha1 = hashlib.sha1
459
460
461 def version_date():
462     """Format bup's version date string for output."""
463     return _version.DATE.split(' ')[0]
464
465 def version_commit():
466     """Get the commit hash of bup's current version."""
467     return _version.COMMIT
468
469 def version_tag():
470     """Format bup's version tag (the official version number).
471
472     When generated from a commit other than one pointed to with a tag, the
473     returned string will be "unknown-" followed by the first seven positions of
474     the commit hash.
475     """
476     names = _version.NAMES.strip()
477     assert(names[0] == '(')
478     assert(names[-1] == ')')
479     names = names[1:-1]
480     l = [n.strip() for n in names.split(',')]
481     for n in l:
482         if n.startswith('tag: bup-'):
483             return n[9:]
484     return 'unknown-%s' % _version.COMMIT[:7]