]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Adds --graft option to bup save.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2
3 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
4 from bup import _version
5
6 # This function should really be in helpers, not in bup.options.  But we
7 # want options.py to be standalone so people can include it in other projects.
8 from bup.options import _tty_width
9 tty_width = _tty_width
10
11
12 def atoi(s):
13     """Convert the string 's' to an integer. Return 0 if s is not a number."""
14     try:
15         return int(s or '0')
16     except ValueError:
17         return 0
18
19
20 def atof(s):
21     """Convert the string 's' to a float. Return 0 if s is not a number."""
22     try:
23         return float(s or '0')
24     except ValueError:
25         return 0
26
27
28 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
29
30
31 # Write (blockingly) to sockets that may or may not be in blocking mode.
32 # We need this because our stderr is sometimes eaten by subprocesses
33 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
34 # leading to race conditions.  Ick.  We'll do it the hard way.
35 def _hard_write(fd, buf):
36     while buf:
37         (r,w,x) = select.select([], [fd], [], None)
38         if not w:
39             raise IOError('select(fd) returned without being writable')
40         try:
41             sz = os.write(fd, buf)
42         except OSError, e:
43             if e.errno != errno.EAGAIN:
44                 raise
45         assert(sz >= 0)
46         buf = buf[sz:]
47
48 def log(s):
49     """Print a log message to stderr."""
50     sys.stdout.flush()
51     _hard_write(sys.stderr.fileno(), s)
52
53
54 def debug1(s):
55     if buglvl >= 1:
56         log(s)
57
58
59 def debug2(s):
60     if buglvl >= 2:
61         log(s)
62
63
64 def mkdirp(d, mode=None):
65     """Recursively create directories on path 'd'.
66
67     Unlike os.makedirs(), it doesn't raise an exception if the last element of
68     the path already exists.
69     """
70     try:
71         if mode:
72             os.makedirs(d, mode)
73         else:
74             os.makedirs(d)
75     except OSError, e:
76         if e.errno == errno.EEXIST:
77             pass
78         else:
79             raise
80
81
82 def next(it):
83     """Get the next item from an iterator, None if we reached the end."""
84     try:
85         return it.next()
86     except StopIteration:
87         return None
88
89
90 def unlink(f):
91     """Delete a file at path 'f' if it currently exists.
92
93     Unlike os.unlink(), does not throw an exception if the file didn't already
94     exist.
95     """
96     try:
97         os.unlink(f)
98     except OSError, e:
99         if e.errno == errno.ENOENT:
100             pass  # it doesn't exist, that's what you asked for
101
102
103 def readpipe(argv):
104     """Run a subprocess and return its output."""
105     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
106     r = p.stdout.read()
107     p.wait()
108     return r
109
110
111 def realpath(p):
112     """Get the absolute path of a file.
113
114     Behaves like os.path.realpath, but doesn't follow a symlink for the last
115     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
116     will follow symlinks in p's directory)
117     """
118     try:
119         st = os.lstat(p)
120     except OSError:
121         st = None
122     if st and stat.S_ISLNK(st.st_mode):
123         (dir, name) = os.path.split(p)
124         dir = os.path.realpath(dir)
125         out = os.path.join(dir, name)
126     else:
127         out = os.path.realpath(p)
128     #log('realpathing:%r,%r\n' % (p, out))
129     return out
130
131
132 _username = None
133 def username():
134     """Get the user's login name."""
135     global _username
136     if not _username:
137         uid = os.getuid()
138         try:
139             _username = pwd.getpwuid(uid)[0]
140         except KeyError:
141             _username = 'user%d' % uid
142     return _username
143
144
145 _userfullname = None
146 def userfullname():
147     """Get the user's full name."""
148     global _userfullname
149     if not _userfullname:
150         uid = os.getuid()
151         try:
152             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
153         except KeyError:
154             _userfullname = 'user%d' % uid
155     return _userfullname
156
157
158 _hostname = None
159 def hostname():
160     """Get the FQDN of this machine."""
161     global _hostname
162     if not _hostname:
163         _hostname = socket.getfqdn()
164     return _hostname
165
166
167 _resource_path = None
168 def resource_path(subdir=''):
169     global _resource_path
170     if not _resource_path:
171         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
172     return os.path.join(_resource_path, subdir)
173
174 class NotOk(Exception):
175     pass
176
177 class Conn:
178     """A helper class for bup's client-server protocol."""
179     def __init__(self, inp, outp):
180         self.inp = inp
181         self.outp = outp
182
183     def read(self, size):
184         """Read 'size' bytes from input stream."""
185         self.outp.flush()
186         return self.inp.read(size)
187
188     def readline(self):
189         """Read from input stream until a newline is found."""
190         self.outp.flush()
191         return self.inp.readline()
192
193     def write(self, data):
194         """Write 'data' to output stream."""
195         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
196         self.outp.write(data)
197
198     def has_input(self):
199         """Return true if input stream is readable."""
200         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
201         if rl:
202             assert(rl[0] == self.inp.fileno())
203             return True
204         else:
205             return None
206
207     def ok(self):
208         """Indicate end of output from last sent command."""
209         self.write('\nok\n')
210
211     def error(self, s):
212         """Indicate server error to the client."""
213         s = re.sub(r'\s+', ' ', str(s))
214         self.write('\nerror %s\n' % s)
215
216     def _check_ok(self, onempty):
217         self.outp.flush()
218         rl = ''
219         for rl in linereader(self.inp):
220             #log('%d got line: %r\n' % (os.getpid(), rl))
221             if not rl:  # empty line
222                 continue
223             elif rl == 'ok':
224                 return None
225             elif rl.startswith('error '):
226                 #log('client: error: %s\n' % rl[6:])
227                 return NotOk(rl[6:])
228             else:
229                 onempty(rl)
230         raise Exception('server exited unexpectedly; see errors above')
231
232     def drain_and_check_ok(self):
233         """Remove all data for the current command from input stream."""
234         def onempty(rl):
235             pass
236         return self._check_ok(onempty)
237
238     def check_ok(self):
239         """Verify that server action completed successfully."""
240         def onempty(rl):
241             raise Exception('expected "ok", got %r' % rl)
242         return self._check_ok(onempty)
243
244
245 def linereader(f):
246     """Generate a list of input lines from 'f' without terminating newlines."""
247     while 1:
248         line = f.readline()
249         if not line:
250             break
251         yield line[:-1]
252
253
254 def chunkyreader(f, count = None):
255     """Generate a list of chunks of data read from 'f'.
256
257     If count is None, read until EOF is reached.
258
259     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
260     reached while reading, raise IOError.
261     """
262     if count != None:
263         while count > 0:
264             b = f.read(min(count, 65536))
265             if not b:
266                 raise IOError('EOF with %d bytes remaining' % count)
267             yield b
268             count -= len(b)
269     else:
270         while 1:
271             b = f.read(65536)
272             if not b: break
273             yield b
274
275
276 def slashappend(s):
277     """Append "/" to 's' if it doesn't aleady end in "/"."""
278     if s and not s.endswith('/'):
279         return s + '/'
280     else:
281         return s
282
283
284 def _mmap_do(f, sz, flags, prot):
285     if not sz:
286         st = os.fstat(f.fileno())
287         sz = st.st_size
288     if not sz:
289         # trying to open a zero-length map gives an error, but an empty
290         # string has all the same behaviour of a zero-length map, ie. it has
291         # no elements :)
292         return ''
293     map = mmap.mmap(f.fileno(), sz, flags, prot)
294     f.close()  # map will persist beyond file close
295     return map
296
297
298 def mmap_read(f, sz = 0):
299     """Create a read-only memory mapped region on file 'f'.
300
301     If sz is 0, the region will cover the entire file.
302     """
303     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
304
305
306 def mmap_readwrite(f, sz = 0):
307     """Create a read-write memory mapped region on file 'f'.
308
309     If sz is 0, the region will cover the entire file.
310     """
311     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
312
313
314 def parse_num(s):
315     """Parse data size information into a float number.
316
317     Here are some examples of conversions:
318         199.2k means 203981 bytes
319         1GB means 1073741824 bytes
320         2.1 tb means 2199023255552 bytes
321     """
322     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
323     if not g:
324         raise ValueError("can't parse %r as a number" % s)
325     (val, unit) = g.groups()
326     num = float(val)
327     unit = unit.lower()
328     if unit in ['t', 'tb']:
329         mult = 1024*1024*1024*1024
330     elif unit in ['g', 'gb']:
331         mult = 1024*1024*1024
332     elif unit in ['m', 'mb']:
333         mult = 1024*1024
334     elif unit in ['k', 'kb']:
335         mult = 1024
336     elif unit in ['', 'b']:
337         mult = 1
338     else:
339         raise ValueError("invalid unit %r in number %r" % (unit, s))
340     return int(num*mult)
341
342
343 def count(l):
344     """Count the number of elements in an iterator. (consumes the iterator)"""
345     return reduce(lambda x,y: x+1, l)
346
347
348 saved_errors = []
349 def add_error(e):
350     """Append an error message to the list of saved errors.
351
352     Once processing is able to stop and output the errors, the saved errors are
353     accessible in the module variable helpers.saved_errors.
354     """
355     saved_errors.append(e)
356     log('%-70s\n' % e)
357
358 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
359 def progress(s):
360     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
361     if istty:
362         log(s)
363
364
365 def handle_ctrl_c():
366     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
367
368     The new exception handler will make sure that bup will exit without an ugly
369     stacktrace when Ctrl-C is hit.
370     """
371     oldhook = sys.excepthook
372     def newhook(exctype, value, traceback):
373         if exctype == KeyboardInterrupt:
374             log('Interrupted.\n')
375         else:
376             return oldhook(exctype, value, traceback)
377     sys.excepthook = newhook
378
379
380 def columnate(l, prefix):
381     """Format elements of 'l' in columns with 'prefix' leading each line.
382
383     The number of columns is determined automatically based on the string
384     lengths.
385     """
386     if not l:
387         return ""
388     l = l[:]
389     clen = max(len(s) for s in l)
390     ncols = (tty_width() - len(prefix)) / (clen + 2)
391     if ncols <= 1:
392         ncols = 1
393         clen = 0
394     cols = []
395     while len(l) % ncols:
396         l.append('')
397     rows = len(l)/ncols
398     for s in range(0, len(l), rows):
399         cols.append(l[s:s+rows])
400     out = ''
401     for row in zip(*cols):
402         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
403     return out
404
405 def parse_date_or_fatal(str, fatal):
406     """Parses the given date or calls Option.fatal().
407     For now we expect a string that contains a float."""
408     try:
409         date = atof(str)
410     except ValueError, e:
411         raise fatal('invalid date format (should be a float): %r' % e)
412     else:
413         return date
414
415 def strip_path(prefix, path):
416     """Strips a given prefix from a path.
417
418     First both paths are normalized.
419
420     Raises an Exception if no prefix is given.
421     """
422     if prefix == None:
423         raise Exception('no path given')
424
425     normalized_prefix = realpath(prefix)
426     debug2("normalized_prefix: %s\n" % normalized_prefix)
427     normalized_path = realpath(path)
428     debug2("normalized_path: %s\n" % normalized_path)
429     if normalized_path.startswith(normalized_prefix):
430         return normalized_path[len(normalized_prefix):]
431     else:
432         return path
433
434 def strip_base_path(path, base_paths):
435     """Strips the base path from a given path.
436
437
438     Determines the base path for the given string and the strips it
439     using strip_path().
440     Iterates over all base_paths from long to short, to prevent that
441     a too short base_path is removed.
442     """
443     sorted_base_paths = sorted(base_paths, key=len, reverse=True)
444     for bp in sorted_base_paths:
445         if path.startswith(realpath(bp)):
446             return strip_path(bp, path)
447     return path
448
449 def graft_path(graft_points, path):
450     normalized_path = realpath(path)
451     for graft_point in graft_points:
452         old_prefix, new_prefix = graft_point
453         if normalized_path.startswith(old_prefix):
454             return re.sub(r'^' + old_prefix, new_prefix, normalized_path)
455     return normalized_path
456
457
458 # hashlib is only available in python 2.5 or higher, but the 'sha' module
459 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
460 # python 2.4 and above without any stupid warnings, so let's try using hashlib
461 # first, and downgrade if it fails.
462 try:
463     import hashlib
464 except ImportError:
465     import sha
466     Sha1 = sha.sha
467 else:
468     Sha1 = hashlib.sha1
469
470
471 def version_date():
472     """Format bup's version date string for output."""
473     return _version.DATE.split(' ')[0]
474
475 def version_commit():
476     """Get the commit hash of bup's current version."""
477     return _version.COMMIT
478
479 def version_tag():
480     """Format bup's version tag (the official version number).
481
482     When generated from a commit other than one pointed to with a tag, the
483     returned string will be "unknown-" followed by the first seven positions of
484     the commit hash.
485     """
486     names = _version.NAMES.strip()
487     assert(names[0] == '(')
488     assert(names[-1] == ')')
489     names = names[1:-1]
490     l = [n.strip() for n in names.split(',')]
491     for n in l:
492         if n.startswith('tag: bup-'):
493             return n[9:]
494     return 'unknown-%s' % _version.COMMIT[:7]