]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Add (sec, ns) timestamps and extended stat, lstat, utime, and lutime.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
3 from bup import _version
4 import bup._helpers as _helpers
5
6 # This function should really be in helpers, not in bup.options.  But we
7 # want options.py to be standalone so people can include it in other projects.
8 from bup.options import _tty_width
9 tty_width = _tty_width
10
11
12 def atoi(s):
13     """Convert the string 's' to an integer. Return 0 if s is not a number."""
14     try:
15         return int(s or '0')
16     except ValueError:
17         return 0
18
19
20 def atof(s):
21     """Convert the string 's' to a float. Return 0 if s is not a number."""
22     try:
23         return float(s or '0')
24     except ValueError:
25         return 0
26
27
28 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
29
30
31 # Write (blockingly) to sockets that may or may not be in blocking mode.
32 # We need this because our stderr is sometimes eaten by subprocesses
33 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
34 # leading to race conditions.  Ick.  We'll do it the hard way.
35 def _hard_write(fd, buf):
36     while buf:
37         (r,w,x) = select.select([], [fd], [], None)
38         if not w:
39             raise IOError('select(fd) returned without being writable')
40         try:
41             sz = os.write(fd, buf)
42         except OSError, e:
43             if e.errno != errno.EAGAIN:
44                 raise
45         assert(sz >= 0)
46         buf = buf[sz:]
47
48 def log(s):
49     """Print a log message to stderr."""
50     sys.stdout.flush()
51     _hard_write(sys.stderr.fileno(), s)
52
53
54 def debug1(s):
55     if buglvl >= 1:
56         log(s)
57
58
59 def debug2(s):
60     if buglvl >= 2:
61         log(s)
62
63
64 def mkdirp(d, mode=None):
65     """Recursively create directories on path 'd'.
66
67     Unlike os.makedirs(), it doesn't raise an exception if the last element of
68     the path already exists.
69     """
70     try:
71         if mode:
72             os.makedirs(d, mode)
73         else:
74             os.makedirs(d)
75     except OSError, e:
76         if e.errno == errno.EEXIST:
77             pass
78         else:
79             raise
80
81
82 def next(it):
83     """Get the next item from an iterator, None if we reached the end."""
84     try:
85         return it.next()
86     except StopIteration:
87         return None
88
89
90 def unlink(f):
91     """Delete a file at path 'f' if it currently exists.
92
93     Unlike os.unlink(), does not throw an exception if the file didn't already
94     exist.
95     """
96     try:
97         os.unlink(f)
98     except OSError, e:
99         if e.errno == errno.ENOENT:
100             pass  # it doesn't exist, that's what you asked for
101
102
103 def readpipe(argv):
104     """Run a subprocess and return its output."""
105     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
106     r = p.stdout.read()
107     p.wait()
108     return r
109
110
111 def realpath(p):
112     """Get the absolute path of a file.
113
114     Behaves like os.path.realpath, but doesn't follow a symlink for the last
115     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
116     will follow symlinks in p's directory)
117     """
118     try:
119         st = os.lstat(p)
120     except OSError:
121         st = None
122     if st and stat.S_ISLNK(st.st_mode):
123         (dir, name) = os.path.split(p)
124         dir = os.path.realpath(dir)
125         out = os.path.join(dir, name)
126     else:
127         out = os.path.realpath(p)
128     #log('realpathing:%r,%r\n' % (p, out))
129     return out
130
131
132 _username = None
133 def username():
134     """Get the user's login name."""
135     global _username
136     if not _username:
137         uid = os.getuid()
138         try:
139             _username = pwd.getpwuid(uid)[0]
140         except KeyError:
141             _username = 'user%d' % uid
142     return _username
143
144
145 _userfullname = None
146 def userfullname():
147     """Get the user's full name."""
148     global _userfullname
149     if not _userfullname:
150         uid = os.getuid()
151         try:
152             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
153         except KeyError:
154             _userfullname = 'user%d' % uid
155     return _userfullname
156
157
158 _hostname = None
159 def hostname():
160     """Get the FQDN of this machine."""
161     global _hostname
162     if not _hostname:
163         _hostname = socket.getfqdn()
164     return _hostname
165
166
167 _resource_path = None
168 def resource_path(subdir=''):
169     global _resource_path
170     if not _resource_path:
171         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
172     return os.path.join(_resource_path, subdir)
173
174 class NotOk(Exception):
175     pass
176
177 class Conn:
178     """A helper class for bup's client-server protocol."""
179     def __init__(self, inp, outp):
180         self.inp = inp
181         self.outp = outp
182
183     def read(self, size):
184         """Read 'size' bytes from input stream."""
185         self.outp.flush()
186         return self.inp.read(size)
187
188     def readline(self):
189         """Read from input stream until a newline is found."""
190         self.outp.flush()
191         return self.inp.readline()
192
193     def write(self, data):
194         """Write 'data' to output stream."""
195         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
196         self.outp.write(data)
197
198     def has_input(self):
199         """Return true if input stream is readable."""
200         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
201         if rl:
202             assert(rl[0] == self.inp.fileno())
203             return True
204         else:
205             return None
206
207     def ok(self):
208         """Indicate end of output from last sent command."""
209         self.write('\nok\n')
210
211     def error(self, s):
212         """Indicate server error to the client."""
213         s = re.sub(r'\s+', ' ', str(s))
214         self.write('\nerror %s\n' % s)
215
216     def _check_ok(self, onempty):
217         self.outp.flush()
218         rl = ''
219         for rl in linereader(self.inp):
220             #log('%d got line: %r\n' % (os.getpid(), rl))
221             if not rl:  # empty line
222                 continue
223             elif rl == 'ok':
224                 return None
225             elif rl.startswith('error '):
226                 #log('client: error: %s\n' % rl[6:])
227                 return NotOk(rl[6:])
228             else:
229                 onempty(rl)
230         raise Exception('server exited unexpectedly; see errors above')
231
232     def drain_and_check_ok(self):
233         """Remove all data for the current command from input stream."""
234         def onempty(rl):
235             pass
236         return self._check_ok(onempty)
237
238     def check_ok(self):
239         """Verify that server action completed successfully."""
240         def onempty(rl):
241             raise Exception('expected "ok", got %r' % rl)
242         return self._check_ok(onempty)
243
244
245 def linereader(f):
246     """Generate a list of input lines from 'f' without terminating newlines."""
247     while 1:
248         line = f.readline()
249         if not line:
250             break
251         yield line[:-1]
252
253
254 def chunkyreader(f, count = None):
255     """Generate a list of chunks of data read from 'f'.
256
257     If count is None, read until EOF is reached.
258
259     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
260     reached while reading, raise IOError.
261     """
262     if count != None:
263         while count > 0:
264             b = f.read(min(count, 65536))
265             if not b:
266                 raise IOError('EOF with %d bytes remaining' % count)
267             yield b
268             count -= len(b)
269     else:
270         while 1:
271             b = f.read(65536)
272             if not b: break
273             yield b
274
275
276 def slashappend(s):
277     """Append "/" to 's' if it doesn't aleady end in "/"."""
278     if s and not s.endswith('/'):
279         return s + '/'
280     else:
281         return s
282
283
284 def _mmap_do(f, sz, flags, prot):
285     if not sz:
286         st = os.fstat(f.fileno())
287         sz = st.st_size
288     map = mmap.mmap(f.fileno(), sz, flags, prot)
289     f.close()  # map will persist beyond file close
290     return map
291
292
293 def mmap_read(f, sz = 0):
294     """Create a read-only memory mapped region on file 'f'.
295
296     If sz is 0, the region will cover the entire file.
297     """
298     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
299
300
301 def mmap_readwrite(f, sz = 0):
302     """Create a read-write memory mapped region on file 'f'.
303
304     If sz is 0, the region will cover the entire file.
305     """
306     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
307
308
309 def parse_num(s):
310     """Parse data size information into a float number.
311
312     Here are some examples of conversions:
313         199.2k means 203981 bytes
314         1GB means 1073741824 bytes
315         2.1 tb means 2199023255552 bytes
316     """
317     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
318     if not g:
319         raise ValueError("can't parse %r as a number" % s)
320     (val, unit) = g.groups()
321     num = float(val)
322     unit = unit.lower()
323     if unit in ['t', 'tb']:
324         mult = 1024*1024*1024*1024
325     elif unit in ['g', 'gb']:
326         mult = 1024*1024*1024
327     elif unit in ['m', 'mb']:
328         mult = 1024*1024
329     elif unit in ['k', 'kb']:
330         mult = 1024
331     elif unit in ['', 'b']:
332         mult = 1
333     else:
334         raise ValueError("invalid unit %r in number %r" % (unit, s))
335     return int(num*mult)
336
337
338 def count(l):
339     """Count the number of elements in an iterator. (consumes the iterator)"""
340     return reduce(lambda x,y: x+1, l)
341
342
343 saved_errors = []
344 def add_error(e):
345     """Append an error message to the list of saved errors.
346
347     Once processing is able to stop and output the errors, the saved errors are
348     accessible in the module variable helpers.saved_errors.
349     """
350     saved_errors.append(e)
351     log('%-70s\n' % e)
352
353 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
354 def progress(s):
355     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
356     if istty:
357         log(s)
358
359
360 def handle_ctrl_c():
361     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
362
363     The new exception handler will make sure that bup will exit without an ugly
364     stacktrace when Ctrl-C is hit.
365     """
366     oldhook = sys.excepthook
367     def newhook(exctype, value, traceback):
368         if exctype == KeyboardInterrupt:
369             log('Interrupted.\n')
370         else:
371             return oldhook(exctype, value, traceback)
372     sys.excepthook = newhook
373
374
375 def columnate(l, prefix):
376     """Format elements of 'l' in columns with 'prefix' leading each line.
377
378     The number of columns is determined automatically based on the string
379     lengths.
380     """
381     if not l:
382         return ""
383     l = l[:]
384     clen = max(len(s) for s in l)
385     ncols = (tty_width() - len(prefix)) / (clen + 2)
386     if ncols <= 1:
387         ncols = 1
388         clen = 0
389     cols = []
390     while len(l) % ncols:
391         l.append('')
392     rows = len(l)/ncols
393     for s in range(0, len(l), rows):
394         cols.append(l[s:s+rows])
395     out = ''
396     for row in zip(*cols):
397         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
398     return out
399
400 def parse_date_or_fatal(str, fatal):
401     """Parses the given date or calls Option.fatal().
402     For now we expect a string that contains a float."""
403     try:
404         date = atof(str)
405     except ValueError, e:
406         raise fatal('invalid date format (should be a float): %r' % e)
407     else:
408         return date
409
410
411 def lutime(path, times):
412     if _helpers.utimensat:
413         atime = times[0]
414         mtime = times[1]
415         return _helpers.utimensat(_helpers.AT_FDCWD, path, (atime, mtime),
416                                   _helpers.AT_SYMLINK_NOFOLLOW)
417     else:
418         return None
419
420
421 def utime(path, times):
422     atime = times[0]
423     mtime = times[1]
424     if _helpers.utimensat:
425         return _helpers.utimensat(_helpers.AT_FDCWD, path, (atime, mtime),
426                                   0)
427     else:
428         os.utime(path, (atime[0] + atime[1] / 10e9,
429                         mtime[0] + mtime[1] / 10e9))
430
431
432 class stat_result():
433     pass
434
435
436 def lstat(path):
437     result = stat_result()
438     if _helpers.lstat:
439         st = _helpers.lstat(path)
440         (result.st_mode,
441          result.st_ino,
442          result.st_dev,
443          result.st_nlink,
444          result.st_uid,
445          result.st_gid,
446          result.st_rdev,
447          result.st_size,
448          result.st_atime,
449          result.st_mtime,
450          result.st_ctime) = st
451     else:
452         st = os.lstat(path)
453         result.st_mode = st.st_mode
454         result.st_ino = st.st_ino
455         result.st_dev = st.st_dev
456         result.st_nlink = st.st_nlink
457         result.st_uid = st.st_uid
458         result.st_gid = st.st_gid
459         result.st_rdev = st.st_rdev
460         result.st_size = st.st_size
461         result.st_atime = (math.trunc(st.st_atime),
462                            math.trunc(math.fmod(st.st_atime, 1) * 10**9))
463         result.st_mtime = (math.trunc(st.st_mtime),
464                            math.trunc(math.fmod(st.st_mtime, 1) * 10**9))
465         result.st_ctime = (math.trunc(st.st_ctime),
466                            math.trunc(math.fmod(st.st_ctime, 1) * 10**9))
467     return result
468
469
470 # hashlib is only available in python 2.5 or higher, but the 'sha' module
471 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
472 # python 2.4 and above without any stupid warnings, so let's try using hashlib
473 # first, and downgrade if it fails.
474 try:
475     import hashlib
476 except ImportError:
477     import sha
478     Sha1 = sha.sha
479 else:
480     Sha1 = hashlib.sha1
481
482
483 def version_date():
484     """Format bup's version date string for output."""
485     return _version.DATE.split(' ')[0]
486
487 def version_commit():
488     """Get the commit hash of bup's current version."""
489     return _version.COMMIT
490
491 def version_tag():
492     """Format bup's version tag (the official version number).
493
494     When generated from a commit other than one pointed to with a tag, the
495     returned string will be "unknown-" followed by the first seven positions of
496     the commit hash.
497     """
498     names = _version.NAMES.strip()
499     assert(names[0] == '(')
500     assert(names[-1] == ')')
501     names = names[1:-1]
502     l = [n.strip() for n in names.split(',')]
503     for n in l:
504         if n.startswith('tag: bup-'):
505             return n[9:]
506     return 'unknown-%s' % _version.COMMIT[:7]