]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
f911b046151123c8c31d85c6b41d9dae7dd97e7e
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
3 from bup import _version
4
5 # This function should really be in helpers, not in bup.options.  But we
6 # want options.py to be standalone so people can include it in other projects.
7 from bup.options import _tty_width
8 tty_width = _tty_width
9
10
11 def atoi(s):
12     """Convert the string 's' to an integer. Return 0 if s is not a number."""
13     try:
14         return int(s or '0')
15     except ValueError:
16         return 0
17
18
19 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
20
21
22 # Write (blockingly) to sockets that may or may not be in blocking mode.
23 # We need this because our stderr is sometimes eaten by subprocesses
24 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
25 # leading to race conditions.  Ick.  We'll do it the hard way.
26 def _hard_write(fd, buf):
27     while buf:
28         (r,w,x) = select.select([], [fd], [], None)
29         if not w:
30             raise IOError('select(fd) returned without being writable')
31         try:
32             sz = os.write(fd, buf)
33         except OSError, e:
34             if e.errno != errno.EAGAIN:
35                 raise
36         assert(sz >= 0)
37         buf = buf[sz:]
38
39 def log(s):
40     """Print a log message to stderr."""
41     sys.stdout.flush()
42     _hard_write(sys.stderr.fileno(), s)
43
44
45 def debug1(s):
46     if buglvl >= 1:
47         log(s)
48
49
50 def debug2(s):
51     if buglvl >= 2:
52         log(s)
53
54
55 def mkdirp(d, mode=None):
56     """Recursively create directories on path 'd'.
57
58     Unlike os.makedirs(), it doesn't raise an exception if the last element of
59     the path already exists.
60     """
61     try:
62         if mode:
63             os.makedirs(d, mode)
64         else:
65             os.makedirs(d)
66     except OSError, e:
67         if e.errno == errno.EEXIST:
68             pass
69         else:
70             raise
71
72
73 def next(it):
74     """Get the next item from an iterator, None if we reached the end."""
75     try:
76         return it.next()
77     except StopIteration:
78         return None
79
80
81 def unlink(f):
82     """Delete a file at path 'f' if it currently exists.
83
84     Unlike os.unlink(), does not throw an exception if the file didn't already
85     exist.
86     """
87     try:
88         os.unlink(f)
89     except OSError, e:
90         if e.errno == errno.ENOENT:
91             pass  # it doesn't exist, that's what you asked for
92
93
94 def readpipe(argv):
95     """Run a subprocess and return its output."""
96     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
97     r = p.stdout.read()
98     p.wait()
99     return r
100
101
102 def realpath(p):
103     """Get the absolute path of a file.
104
105     Behaves like os.path.realpath, but doesn't follow a symlink for the last
106     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
107     will follow symlinks in p's directory)
108     """
109     try:
110         st = os.lstat(p)
111     except OSError:
112         st = None
113     if st and stat.S_ISLNK(st.st_mode):
114         (dir, name) = os.path.split(p)
115         dir = os.path.realpath(dir)
116         out = os.path.join(dir, name)
117     else:
118         out = os.path.realpath(p)
119     #log('realpathing:%r,%r\n' % (p, out))
120     return out
121
122
123 _username = None
124 def username():
125     """Get the user's login name."""
126     global _username
127     if not _username:
128         uid = os.getuid()
129         try:
130             _username = pwd.getpwuid(uid)[0]
131         except KeyError:
132             _username = 'user%d' % uid
133     return _username
134
135
136 _userfullname = None
137 def userfullname():
138     """Get the user's full name."""
139     global _userfullname
140     if not _userfullname:
141         uid = os.getuid()
142         try:
143             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
144         except KeyError:
145             _userfullname = 'user%d' % uid
146     return _userfullname
147
148
149 _hostname = None
150 def hostname():
151     """Get the FQDN of this machine."""
152     global _hostname
153     if not _hostname:
154         _hostname = socket.getfqdn()
155     return _hostname
156
157
158 _resource_path = None
159 def resource_path(subdir=''):
160     global _resource_path
161     if not _resource_path:
162         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
163     return os.path.join(_resource_path, subdir)
164
165 class NotOk(Exception):
166     pass
167
168 class Conn:
169     """A helper class for bup's client-server protocol."""
170     def __init__(self, inp, outp):
171         self.inp = inp
172         self.outp = outp
173
174     def read(self, size):
175         """Read 'size' bytes from input stream."""
176         self.outp.flush()
177         return self.inp.read(size)
178
179     def readline(self):
180         """Read from input stream until a newline is found."""
181         self.outp.flush()
182         return self.inp.readline()
183
184     def write(self, data):
185         """Write 'data' to output stream."""
186         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
187         self.outp.write(data)
188
189     def has_input(self):
190         """Return true if input stream is readable."""
191         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
192         if rl:
193             assert(rl[0] == self.inp.fileno())
194             return True
195         else:
196             return None
197
198     def ok(self):
199         """Indicate end of output from last sent command."""
200         self.write('\nok\n')
201
202     def error(self, s):
203         """Indicate server error to the client."""
204         s = re.sub(r'\s+', ' ', str(s))
205         self.write('\nerror %s\n' % s)
206
207     def _check_ok(self, onempty):
208         self.outp.flush()
209         rl = ''
210         for rl in linereader(self.inp):
211             #log('%d got line: %r\n' % (os.getpid(), rl))
212             if not rl:  # empty line
213                 continue
214             elif rl == 'ok':
215                 return None
216             elif rl.startswith('error '):
217                 #log('client: error: %s\n' % rl[6:])
218                 return NotOk(rl[6:])
219             else:
220                 onempty(rl)
221         raise Exception('server exited unexpectedly; see errors above')
222
223     def drain_and_check_ok(self):
224         """Remove all data for the current command from input stream."""
225         def onempty(rl):
226             pass
227         return self._check_ok(onempty)
228
229     def check_ok(self):
230         """Verify that server action completed successfully."""
231         def onempty(rl):
232             raise Exception('expected "ok", got %r' % rl)
233         return self._check_ok(onempty)
234
235
236 def linereader(f):
237     """Generate a list of input lines from 'f' without terminating newlines."""
238     while 1:
239         line = f.readline()
240         if not line:
241             break
242         yield line[:-1]
243
244
245 def chunkyreader(f, count = None):
246     """Generate a list of chunks of data read from 'f'.
247
248     If count is None, read until EOF is reached.
249
250     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
251     reached while reading, raise IOError.
252     """
253     if count != None:
254         while count > 0:
255             b = f.read(min(count, 65536))
256             if not b:
257                 raise IOError('EOF with %d bytes remaining' % count)
258             yield b
259             count -= len(b)
260     else:
261         while 1:
262             b = f.read(65536)
263             if not b: break
264             yield b
265
266
267 def slashappend(s):
268     """Append "/" to 's' if it doesn't aleady end in "/"."""
269     if s and not s.endswith('/'):
270         return s + '/'
271     else:
272         return s
273
274
275 def _mmap_do(f, sz, flags, prot):
276     if not sz:
277         st = os.fstat(f.fileno())
278         sz = st.st_size
279     map = mmap.mmap(f.fileno(), sz, flags, prot)
280     f.close()  # map will persist beyond file close
281     return map
282
283
284 def mmap_read(f, sz = 0):
285     """Create a read-only memory mapped region on file 'f'.
286
287     If sz is 0, the region will cover the entire file.
288     """
289     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
290
291
292 def mmap_readwrite(f, sz = 0):
293     """Create a read-write memory mapped region on file 'f'.
294
295     If sz is 0, the region will cover the entire file.
296     """
297     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
298
299
300 def parse_num(s):
301     """Parse data size information into a float number.
302
303     Here are some examples of conversions:
304         199.2k means 203981 bytes
305         1GB means 1073741824 bytes
306         2.1 tb means 2199023255552 bytes
307     """
308     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
309     if not g:
310         raise ValueError("can't parse %r as a number" % s)
311     (val, unit) = g.groups()
312     num = float(val)
313     unit = unit.lower()
314     if unit in ['t', 'tb']:
315         mult = 1024*1024*1024*1024
316     elif unit in ['g', 'gb']:
317         mult = 1024*1024*1024
318     elif unit in ['m', 'mb']:
319         mult = 1024*1024
320     elif unit in ['k', 'kb']:
321         mult = 1024
322     elif unit in ['', 'b']:
323         mult = 1
324     else:
325         raise ValueError("invalid unit %r in number %r" % (unit, s))
326     return int(num*mult)
327
328
329 def count(l):
330     """Count the number of elements in an iterator. (consumes the iterator)"""
331     return reduce(lambda x,y: x+1, l)
332
333
334 saved_errors = []
335 def add_error(e):
336     """Append an error message to the list of saved errors.
337
338     Once processing is able to stop and output the errors, the saved errors are
339     accessible in the module variable helpers.saved_errors.
340     """
341     saved_errors.append(e)
342     log('%-70s\n' % e)
343
344 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
345 def progress(s):
346     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
347     if istty:
348         log(s)
349
350
351 def handle_ctrl_c():
352     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
353
354     The new exception handler will make sure that bup will exit without an ugly
355     stacktrace when Ctrl-C is hit.
356     """
357     oldhook = sys.excepthook
358     def newhook(exctype, value, traceback):
359         if exctype == KeyboardInterrupt:
360             log('Interrupted.\n')
361         else:
362             return oldhook(exctype, value, traceback)
363     sys.excepthook = newhook
364
365
366 def columnate(l, prefix):
367     """Format elements of 'l' in columns with 'prefix' leading each line.
368
369     The number of columns is determined automatically based on the string
370     lengths.
371     """
372     if not l:
373         return ""
374     l = l[:]
375     clen = max(len(s) for s in l)
376     ncols = (tty_width() - len(prefix)) / (clen + 2)
377     if ncols <= 1:
378         ncols = 1
379         clen = 0
380     cols = []
381     while len(l) % ncols:
382         l.append('')
383     rows = len(l)/ncols
384     for s in range(0, len(l), rows):
385         cols.append(l[s:s+rows])
386     out = ''
387     for row in zip(*cols):
388         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
389     return out
390
391
392 # hashlib is only available in python 2.5 or higher, but the 'sha' module
393 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
394 # python 2.4 and above without any stupid warnings, so let's try using hashlib
395 # first, and downgrade if it fails.
396 try:
397     import hashlib
398 except ImportError:
399     import sha
400     Sha1 = sha.sha
401 else:
402     Sha1 = hashlib.sha1
403
404
405 def version_date():
406     """Format bup's version date string for output."""
407     return _version.DATE.split(' ')[0]
408
409 def version_commit():
410     """Get the commit hash of bup's current version."""
411     return _version.COMMIT
412
413 def version_tag():
414     """Format bup's version tag (the official version number).
415
416     When generated from a commit other than one pointed to with a tag, the
417     returned string will be "unknown-" followed by the first seven positions of
418     the commit hash.
419     """
420     names = _version.NAMES.strip()
421     assert(names[0] == '(')
422     assert(names[-1] == ')')
423     names = names[1:-1]
424     l = [n.strip() for n in names.split(',')]
425     for n in l:
426         if n.startswith('tag: bup-'):
427             return n[9:]
428     return 'unknown-%s' % _version.COMMIT[:7]