]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Add docstrings to lib/bup/helpers.py
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
3 from bup import _version
4
5
6 # Write (blockingly) to sockets that may or may not be in blocking mode.
7 # We need this because our stderr is sometimes eaten by subprocesses
8 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
9 # leading to race conditions.  Ick.  We'll do it the hard way.
10 def _hard_write(fd, buf):
11     while buf:
12         (r,w,x) = select.select([], [fd], [], None)
13         if not w:
14             raise IOError('select(fd) returned without being writable')
15         try:
16             sz = os.write(fd, buf)
17         except OSError, e:
18             if e.errno != errno.EAGAIN:
19                 raise
20         assert(sz >= 0)
21         buf = buf[sz:]
22
23 def log(s):
24     """Print a log message to stderr."""
25     sys.stdout.flush()
26     _hard_write(sys.stderr.fileno(), s)
27
28
29 def mkdirp(d):
30     """Recursively create directories on path 'd'.
31
32     Unlike os.makedirs(), it doesn't raise an exception if the last element of
33     the path already exists.
34     """
35     try:
36         os.makedirs(d)
37     except OSError, e:
38         if e.errno == errno.EEXIST:
39             pass
40         else:
41             raise
42
43
44 def next(it):
45     """Get the next item from an iterator, None if we reached the end."""
46     try:
47         return it.next()
48     except StopIteration:
49         return None
50
51
52 def unlink(f):
53     """Delete a file at path 'f' if it currently exists.
54
55     Unlike os.unlink(), does not throw an exception if the file didn't already
56     exist.
57     """
58     try:
59         os.unlink(f)
60     except OSError, e:
61         if e.errno == errno.ENOENT:
62             pass  # it doesn't exist, that's what you asked for
63
64
65 def readpipe(argv):
66     """Run a subprocess and return its output."""
67     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
68     r = p.stdout.read()
69     p.wait()
70     return r
71
72
73 def realpath(p):
74     """Get the absolute path of a file.
75
76     Behaves like os.path.realpath, but doesn't follow a symlink for the last
77     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
78     will follow symlinks in p's directory)
79     """
80     try:
81         st = os.lstat(p)
82     except OSError:
83         st = None
84     if st and stat.S_ISLNK(st.st_mode):
85         (dir, name) = os.path.split(p)
86         dir = os.path.realpath(dir)
87         out = os.path.join(dir, name)
88     else:
89         out = os.path.realpath(p)
90     #log('realpathing:%r,%r\n' % (p, out))
91     return out
92
93
94 _username = None
95 def username():
96     """Get the user's login name."""
97     global _username
98     if not _username:
99         uid = os.getuid()
100         try:
101             _username = pwd.getpwuid(uid)[0]
102         except KeyError:
103             _username = 'user%d' % uid
104     return _username
105
106
107 _userfullname = None
108 def userfullname():
109     """Get the user's full name."""
110     global _userfullname
111     if not _userfullname:
112         uid = os.getuid()
113         try:
114             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
115         except KeyError:
116             _userfullname = 'user%d' % uid
117     return _userfullname
118
119
120 _hostname = None
121 def hostname():
122     """Get the FQDN of this machine."""
123     global _hostname
124     if not _hostname:
125         _hostname = socket.getfqdn()
126     return _hostname
127
128
129 class NotOk(Exception):
130     pass
131
132 class Conn:
133     """A helper class for bup's client-server protocol."""
134     def __init__(self, inp, outp):
135         self.inp = inp
136         self.outp = outp
137
138     def read(self, size):
139         """Read 'size' bytes from input stream."""
140         self.outp.flush()
141         return self.inp.read(size)
142
143     def readline(self):
144         """Read from input stream until a newline is found."""
145         self.outp.flush()
146         return self.inp.readline()
147
148     def write(self, data):
149         """Write 'data' to output stream."""
150         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
151         self.outp.write(data)
152
153     def has_input(self):
154         """Return true if input stream is readable."""
155         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
156         if rl:
157             assert(rl[0] == self.inp.fileno())
158             return True
159         else:
160             return None
161
162     def ok(self):
163         """Indicate end of output from last sent command."""
164         self.write('\nok\n')
165
166     def error(self, s):
167         """Indicate server error to the client."""
168         s = re.sub(r'\s+', ' ', str(s))
169         self.write('\nerror %s\n' % s)
170
171     def _check_ok(self, onempty):
172         self.outp.flush()
173         rl = ''
174         for rl in linereader(self.inp):
175             #log('%d got line: %r\n' % (os.getpid(), rl))
176             if not rl:  # empty line
177                 continue
178             elif rl == 'ok':
179                 return None
180             elif rl.startswith('error '):
181                 #log('client: error: %s\n' % rl[6:])
182                 return NotOk(rl[6:])
183             else:
184                 onempty(rl)
185         raise Exception('server exited unexpectedly; see errors above')
186
187     def drain_and_check_ok(self):
188         """Remove all data for the current command from input stream."""
189         def onempty(rl):
190             pass
191         return self._check_ok(onempty)
192
193     def check_ok(self):
194         """Verify that server action completed successfully."""
195         def onempty(rl):
196             raise Exception('expected "ok", got %r' % rl)
197         return self._check_ok(onempty)
198
199
200 def linereader(f):
201     """Generate a list of input lines from 'f' without terminating newlines."""
202     while 1:
203         line = f.readline()
204         if not line:
205             break
206         yield line[:-1]
207
208
209 def chunkyreader(f, count = None):
210     """Generate a list of chunks of data read from 'f'.
211
212     If count is None, read until EOF is reached.
213
214     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
215     reached while reading, raise IOError.
216     """
217     if count != None:
218         while count > 0:
219             b = f.read(min(count, 65536))
220             if not b:
221                 raise IOError('EOF with %d bytes remaining' % count)
222             yield b
223             count -= len(b)
224     else:
225         while 1:
226             b = f.read(65536)
227             if not b: break
228             yield b
229
230
231 def slashappend(s):
232     """Append "/" to 's' if it doesn't aleady end in "/"."""
233     if s and not s.endswith('/'):
234         return s + '/'
235     else:
236         return s
237
238
239 def _mmap_do(f, sz, flags, prot):
240     if not sz:
241         st = os.fstat(f.fileno())
242         sz = st.st_size
243     map = mmap.mmap(f.fileno(), sz, flags, prot)
244     f.close()  # map will persist beyond file close
245     return map
246
247
248 def mmap_read(f, sz = 0):
249     """Create a read-only memory mapped region on file 'f'.
250
251     If sz is 0, the region will cover the entire file.
252     """
253     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
254
255
256 def mmap_readwrite(f, sz = 0):
257     """Create a read-write memory mapped region on file 'f'.
258
259     If sz is 0, the region will cover the entire file.
260     """
261     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
262
263
264 def parse_num(s):
265     """Parse data size information into a float number.
266
267     Here are some examples of conversions:
268         199.2k means 203981 bytes
269         1GB means 1073741824 bytes
270         2.1 tb means 2199023255552 bytes
271     """
272     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
273     if not g:
274         raise ValueError("can't parse %r as a number" % s)
275     (val, unit) = g.groups()
276     num = float(val)
277     unit = unit.lower()
278     if unit in ['t', 'tb']:
279         mult = 1024*1024*1024*1024
280     elif unit in ['g', 'gb']:
281         mult = 1024*1024*1024
282     elif unit in ['m', 'mb']:
283         mult = 1024*1024
284     elif unit in ['k', 'kb']:
285         mult = 1024
286     elif unit in ['', 'b']:
287         mult = 1
288     else:
289         raise ValueError("invalid unit %r in number %r" % (unit, s))
290     return int(num*mult)
291
292
293 def count(l):
294     """Count the number of elements in an iterator. (consumes the iterator)"""
295     return reduce(lambda x,y: x+1, l)
296
297
298 def atoi(s):
299     """Convert the string 's' to an integer. Return 0 if s is not a number."""
300     try:
301         return int(s or '0')
302     except ValueError:
303         return 0
304
305
306 saved_errors = []
307 def add_error(e):
308     """Append an error message to the list of saved errors.
309
310     Once processing is able to stop and output the errors, the saved errors are
311     accessible in the module variable helpers.saved_errors.
312     """
313     saved_errors.append(e)
314     log('%-70s\n' % e)
315
316 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
317 def progress(s):
318     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
319     if istty:
320         log(s)
321
322
323 def handle_ctrl_c():
324     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
325
326     The new exception handler will make sure that bup will exit without an ugly
327     stacktrace when Ctrl-C is hit.
328     """
329     oldhook = sys.excepthook
330     def newhook(exctype, value, traceback):
331         if exctype == KeyboardInterrupt:
332             log('Interrupted.\n')
333         else:
334             return oldhook(exctype, value, traceback)
335     sys.excepthook = newhook
336
337
338 def columnate(l, prefix):
339     """Format elements of 'l' in columns with 'prefix' leading each line.
340
341     The number of columns is determined automatically based on the string
342     lengths.
343     """
344     l = l[:]
345     clen = max(len(s) for s in l)
346     ncols = (78 - len(prefix)) / (clen + 2)
347     if ncols <= 1:
348         ncols = 1
349         clen = 0
350     cols = []
351     while len(l) % ncols:
352         l.append('')
353     rows = len(l)/ncols
354     for s in range(0, len(l), rows):
355         cols.append(l[s:s+rows])
356     out = ''
357     for row in zip(*cols):
358         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
359     return out
360
361
362 # hashlib is only available in python 2.5 or higher, but the 'sha' module
363 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
364 # python 2.4 and above without any stupid warnings, so let's try using hashlib
365 # first, and downgrade if it fails.
366 try:
367     import hashlib
368 except ImportError:
369     import sha
370     Sha1 = sha.sha
371 else:
372     Sha1 = hashlib.sha1
373
374
375 def version_date():
376     """Format bup's version date string for output."""
377     return _version.DATE.split(' ')[0]
378
379 def version_commit():
380     """Get the commit hash of bup's current version."""
381     return _version.COMMIT
382
383 def version_tag():
384     """Format bup's version tag (the official version number).
385
386     When generated from a commit other than one pointed to with a tag, the
387     returned string will be "unknown-" followed by the first seven positions of
388     the commit hash.
389     """
390     names = _version.NAMES.strip()
391     assert(names[0] == '(')
392     assert(names[-1] == ')')
393     names = names[1:-1]
394     l = [n.strip() for n in names.split(',')]
395     for n in l:
396         if n.startswith('tag: bup-'):
397             return n[9:]
398     return 'unknown-%s' % _version.COMMIT[:7]