]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Add helpers.detect_fakeroot() and use it in relevant metadata tests.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
3 from bup import _version
4 import bup._helpers as _helpers
5
6 # This function should really be in helpers, not in bup.options.  But we
7 # want options.py to be standalone so people can include it in other projects.
8 from bup.options import _tty_width
9 tty_width = _tty_width
10
11
12 def atoi(s):
13     """Convert the string 's' to an integer. Return 0 if s is not a number."""
14     try:
15         return int(s or '0')
16     except ValueError:
17         return 0
18
19
20 def atof(s):
21     """Convert the string 's' to a float. Return 0 if s is not a number."""
22     try:
23         return float(s or '0')
24     except ValueError:
25         return 0
26
27
28 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
29
30
31 # Write (blockingly) to sockets that may or may not be in blocking mode.
32 # We need this because our stderr is sometimes eaten by subprocesses
33 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
34 # leading to race conditions.  Ick.  We'll do it the hard way.
35 def _hard_write(fd, buf):
36     while buf:
37         (r,w,x) = select.select([], [fd], [], None)
38         if not w:
39             raise IOError('select(fd) returned without being writable')
40         try:
41             sz = os.write(fd, buf)
42         except OSError, e:
43             if e.errno != errno.EAGAIN:
44                 raise
45         assert(sz >= 0)
46         buf = buf[sz:]
47
48 def log(s):
49     """Print a log message to stderr."""
50     sys.stdout.flush()
51     _hard_write(sys.stderr.fileno(), s)
52
53
54 def debug1(s):
55     if buglvl >= 1:
56         log(s)
57
58
59 def debug2(s):
60     if buglvl >= 2:
61         log(s)
62
63
64 def mkdirp(d, mode=None):
65     """Recursively create directories on path 'd'.
66
67     Unlike os.makedirs(), it doesn't raise an exception if the last element of
68     the path already exists.
69     """
70     try:
71         if mode:
72             os.makedirs(d, mode)
73         else:
74             os.makedirs(d)
75     except OSError, e:
76         if e.errno == errno.EEXIST:
77             pass
78         else:
79             raise
80
81
82 def next(it):
83     """Get the next item from an iterator, None if we reached the end."""
84     try:
85         return it.next()
86     except StopIteration:
87         return None
88
89
90 def unlink(f):
91     """Delete a file at path 'f' if it currently exists.
92
93     Unlike os.unlink(), does not throw an exception if the file didn't already
94     exist.
95     """
96     try:
97         os.unlink(f)
98     except OSError, e:
99         if e.errno == errno.ENOENT:
100             pass  # it doesn't exist, that's what you asked for
101
102
103 def readpipe(argv):
104     """Run a subprocess and return its output."""
105     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
106     r = p.stdout.read()
107     p.wait()
108     return r
109
110
111 def realpath(p):
112     """Get the absolute path of a file.
113
114     Behaves like os.path.realpath, but doesn't follow a symlink for the last
115     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
116     will follow symlinks in p's directory)
117     """
118     try:
119         st = os.lstat(p)
120     except OSError:
121         st = None
122     if st and stat.S_ISLNK(st.st_mode):
123         (dir, name) = os.path.split(p)
124         dir = os.path.realpath(dir)
125         out = os.path.join(dir, name)
126     else:
127         out = os.path.realpath(p)
128     #log('realpathing:%r,%r\n' % (p, out))
129     return out
130
131
132 def detect_fakeroot():
133     "Return True if we appear to be running under fakeroot."
134     return os.getenv("FAKEROOTKEY") != None
135
136
137 _username = None
138 def username():
139     """Get the user's login name."""
140     global _username
141     if not _username:
142         uid = os.getuid()
143         try:
144             _username = pwd.getpwuid(uid)[0]
145         except KeyError:
146             _username = 'user%d' % uid
147     return _username
148
149
150 _userfullname = None
151 def userfullname():
152     """Get the user's full name."""
153     global _userfullname
154     if not _userfullname:
155         uid = os.getuid()
156         try:
157             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
158         except KeyError:
159             _userfullname = 'user%d' % uid
160     return _userfullname
161
162
163 _hostname = None
164 def hostname():
165     """Get the FQDN of this machine."""
166     global _hostname
167     if not _hostname:
168         _hostname = socket.getfqdn()
169     return _hostname
170
171
172 _resource_path = None
173 def resource_path(subdir=''):
174     global _resource_path
175     if not _resource_path:
176         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
177     return os.path.join(_resource_path, subdir)
178
179 class NotOk(Exception):
180     pass
181
182 class Conn:
183     """A helper class for bup's client-server protocol."""
184     def __init__(self, inp, outp):
185         self.inp = inp
186         self.outp = outp
187
188     def read(self, size):
189         """Read 'size' bytes from input stream."""
190         self.outp.flush()
191         return self.inp.read(size)
192
193     def readline(self):
194         """Read from input stream until a newline is found."""
195         self.outp.flush()
196         return self.inp.readline()
197
198     def write(self, data):
199         """Write 'data' to output stream."""
200         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
201         self.outp.write(data)
202
203     def has_input(self):
204         """Return true if input stream is readable."""
205         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
206         if rl:
207             assert(rl[0] == self.inp.fileno())
208             return True
209         else:
210             return None
211
212     def ok(self):
213         """Indicate end of output from last sent command."""
214         self.write('\nok\n')
215
216     def error(self, s):
217         """Indicate server error to the client."""
218         s = re.sub(r'\s+', ' ', str(s))
219         self.write('\nerror %s\n' % s)
220
221     def _check_ok(self, onempty):
222         self.outp.flush()
223         rl = ''
224         for rl in linereader(self.inp):
225             #log('%d got line: %r\n' % (os.getpid(), rl))
226             if not rl:  # empty line
227                 continue
228             elif rl == 'ok':
229                 return None
230             elif rl.startswith('error '):
231                 #log('client: error: %s\n' % rl[6:])
232                 return NotOk(rl[6:])
233             else:
234                 onempty(rl)
235         raise Exception('server exited unexpectedly; see errors above')
236
237     def drain_and_check_ok(self):
238         """Remove all data for the current command from input stream."""
239         def onempty(rl):
240             pass
241         return self._check_ok(onempty)
242
243     def check_ok(self):
244         """Verify that server action completed successfully."""
245         def onempty(rl):
246             raise Exception('expected "ok", got %r' % rl)
247         return self._check_ok(onempty)
248
249
250 def linereader(f):
251     """Generate a list of input lines from 'f' without terminating newlines."""
252     while 1:
253         line = f.readline()
254         if not line:
255             break
256         yield line[:-1]
257
258
259 def chunkyreader(f, count = None):
260     """Generate a list of chunks of data read from 'f'.
261
262     If count is None, read until EOF is reached.
263
264     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
265     reached while reading, raise IOError.
266     """
267     if count != None:
268         while count > 0:
269             b = f.read(min(count, 65536))
270             if not b:
271                 raise IOError('EOF with %d bytes remaining' % count)
272             yield b
273             count -= len(b)
274     else:
275         while 1:
276             b = f.read(65536)
277             if not b: break
278             yield b
279
280
281 def slashappend(s):
282     """Append "/" to 's' if it doesn't aleady end in "/"."""
283     if s and not s.endswith('/'):
284         return s + '/'
285     else:
286         return s
287
288
289 def _mmap_do(f, sz, flags, prot):
290     if not sz:
291         st = os.fstat(f.fileno())
292         sz = st.st_size
293     map = mmap.mmap(f.fileno(), sz, flags, prot)
294     f.close()  # map will persist beyond file close
295     return map
296
297
298 def mmap_read(f, sz = 0):
299     """Create a read-only memory mapped region on file 'f'.
300
301     If sz is 0, the region will cover the entire file.
302     """
303     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
304
305
306 def mmap_readwrite(f, sz = 0):
307     """Create a read-write memory mapped region on file 'f'.
308
309     If sz is 0, the region will cover the entire file.
310     """
311     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
312
313
314 def parse_num(s):
315     """Parse data size information into a float number.
316
317     Here are some examples of conversions:
318         199.2k means 203981 bytes
319         1GB means 1073741824 bytes
320         2.1 tb means 2199023255552 bytes
321     """
322     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
323     if not g:
324         raise ValueError("can't parse %r as a number" % s)
325     (val, unit) = g.groups()
326     num = float(val)
327     unit = unit.lower()
328     if unit in ['t', 'tb']:
329         mult = 1024*1024*1024*1024
330     elif unit in ['g', 'gb']:
331         mult = 1024*1024*1024
332     elif unit in ['m', 'mb']:
333         mult = 1024*1024
334     elif unit in ['k', 'kb']:
335         mult = 1024
336     elif unit in ['', 'b']:
337         mult = 1
338     else:
339         raise ValueError("invalid unit %r in number %r" % (unit, s))
340     return int(num*mult)
341
342
343 def count(l):
344     """Count the number of elements in an iterator. (consumes the iterator)"""
345     return reduce(lambda x,y: x+1, l)
346
347
348 saved_errors = []
349 def add_error(e):
350     """Append an error message to the list of saved errors.
351
352     Once processing is able to stop and output the errors, the saved errors are
353     accessible in the module variable helpers.saved_errors.
354     """
355     saved_errors.append(e)
356     log('%-70s\n' % e)
357
358 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
359 def progress(s):
360     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
361     if istty:
362         log(s)
363
364
365 def handle_ctrl_c():
366     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
367
368     The new exception handler will make sure that bup will exit without an ugly
369     stacktrace when Ctrl-C is hit.
370     """
371     oldhook = sys.excepthook
372     def newhook(exctype, value, traceback):
373         if exctype == KeyboardInterrupt:
374             log('Interrupted.\n')
375         else:
376             return oldhook(exctype, value, traceback)
377     sys.excepthook = newhook
378
379
380 def columnate(l, prefix):
381     """Format elements of 'l' in columns with 'prefix' leading each line.
382
383     The number of columns is determined automatically based on the string
384     lengths.
385     """
386     if not l:
387         return ""
388     l = l[:]
389     clen = max(len(s) for s in l)
390     ncols = (tty_width() - len(prefix)) / (clen + 2)
391     if ncols <= 1:
392         ncols = 1
393         clen = 0
394     cols = []
395     while len(l) % ncols:
396         l.append('')
397     rows = len(l)/ncols
398     for s in range(0, len(l), rows):
399         cols.append(l[s:s+rows])
400     out = ''
401     for row in zip(*cols):
402         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
403     return out
404
405 def parse_date_or_fatal(str, fatal):
406     """Parses the given date or calls Option.fatal().
407     For now we expect a string that contains a float."""
408     try:
409         date = atof(str)
410     except ValueError, e:
411         raise fatal('invalid date format (should be a float): %r' % e)
412     else:
413         return date
414
415
416 def lutime(path, times):
417     if _helpers.utimensat:
418         atime = times[0]
419         mtime = times[1]
420         return _helpers.utimensat(_helpers.AT_FDCWD, path, (atime, mtime),
421                                   _helpers.AT_SYMLINK_NOFOLLOW)
422     else:
423         return None
424
425
426 def utime(path, times):
427     atime = times[0]
428     mtime = times[1]
429     if _helpers.utimensat:
430         return _helpers.utimensat(_helpers.AT_FDCWD, path, (atime, mtime),
431                                   0)
432     else:
433         os.utime(path, (atime[0] + atime[1] / 10e9,
434                         mtime[0] + mtime[1] / 10e9))
435
436
437 class stat_result():
438     pass
439
440
441 def lstat(path):
442     result = stat_result()
443     if _helpers.lstat:
444         st = _helpers.lstat(path)
445         (result.st_mode,
446          result.st_ino,
447          result.st_dev,
448          result.st_nlink,
449          result.st_uid,
450          result.st_gid,
451          result.st_rdev,
452          result.st_size,
453          result.st_atime,
454          result.st_mtime,
455          result.st_ctime) = st
456     else:
457         st = os.lstat(path)
458         result.st_mode = st.st_mode
459         result.st_ino = st.st_ino
460         result.st_dev = st.st_dev
461         result.st_nlink = st.st_nlink
462         result.st_uid = st.st_uid
463         result.st_gid = st.st_gid
464         result.st_rdev = st.st_rdev
465         result.st_size = st.st_size
466         result.st_atime = (math.trunc(st.st_atime),
467                            math.trunc(math.fmod(st.st_atime, 1) * 10**9))
468         result.st_mtime = (math.trunc(st.st_mtime),
469                            math.trunc(math.fmod(st.st_mtime, 1) * 10**9))
470         result.st_ctime = (math.trunc(st.st_ctime),
471                            math.trunc(math.fmod(st.st_ctime, 1) * 10**9))
472     return result
473
474
475 # hashlib is only available in python 2.5 or higher, but the 'sha' module
476 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
477 # python 2.4 and above without any stupid warnings, so let's try using hashlib
478 # first, and downgrade if it fails.
479 try:
480     import hashlib
481 except ImportError:
482     import sha
483     Sha1 = sha.sha
484 else:
485     Sha1 = hashlib.sha1
486
487
488 def version_date():
489     """Format bup's version date string for output."""
490     return _version.DATE.split(' ')[0]
491
492 def version_commit():
493     """Get the commit hash of bup's current version."""
494     return _version.COMMIT
495
496 def version_tag():
497     """Format bup's version tag (the official version number).
498
499     When generated from a commit other than one pointed to with a tag, the
500     returned string will be "unknown-" followed by the first seven positions of
501     the commit hash.
502     """
503     names = _version.NAMES.strip()
504     assert(names[0] == '(')
505     assert(names[-1] == ')')
506     names = names[1:-1]
507     l = [n.strip() for n in names.split(',')]
508     for n in l:
509         if n.startswith('tag: bup-'):
510             return n[9:]
511     return 'unknown-%s' % _version.COMMIT[:7]