]> arthur.barton.de Git - bup.git/blob - lib/bup/helpers.py
Add helpers.fstat and _helpers._have_ns_fs_timestamps.
[bup.git] / lib / bup / helpers.py
1 """Helper functions and classes for bup."""
2 import sys, os, pwd, subprocess, errno, socket, select, mmap, stat, re
3 from bup import _version
4 import bup._helpers as _helpers
5
6 # This function should really be in helpers, not in bup.options.  But we
7 # want options.py to be standalone so people can include it in other projects.
8 from bup.options import _tty_width
9 tty_width = _tty_width
10
11
12 def atoi(s):
13     """Convert the string 's' to an integer. Return 0 if s is not a number."""
14     try:
15         return int(s or '0')
16     except ValueError:
17         return 0
18
19
20 def atof(s):
21     """Convert the string 's' to a float. Return 0 if s is not a number."""
22     try:
23         return float(s or '0')
24     except ValueError:
25         return 0
26
27
28 buglvl = atoi(os.environ.get('BUP_DEBUG', 0))
29
30
31 # Write (blockingly) to sockets that may or may not be in blocking mode.
32 # We need this because our stderr is sometimes eaten by subprocesses
33 # (probably ssh) that sometimes make it nonblocking, if only temporarily,
34 # leading to race conditions.  Ick.  We'll do it the hard way.
35 def _hard_write(fd, buf):
36     while buf:
37         (r,w,x) = select.select([], [fd], [], None)
38         if not w:
39             raise IOError('select(fd) returned without being writable')
40         try:
41             sz = os.write(fd, buf)
42         except OSError, e:
43             if e.errno != errno.EAGAIN:
44                 raise
45         assert(sz >= 0)
46         buf = buf[sz:]
47
48 def log(s):
49     """Print a log message to stderr."""
50     sys.stdout.flush()
51     _hard_write(sys.stderr.fileno(), s)
52
53
54 def debug1(s):
55     if buglvl >= 1:
56         log(s)
57
58
59 def debug2(s):
60     if buglvl >= 2:
61         log(s)
62
63
64 def mkdirp(d, mode=None):
65     """Recursively create directories on path 'd'.
66
67     Unlike os.makedirs(), it doesn't raise an exception if the last element of
68     the path already exists.
69     """
70     try:
71         if mode:
72             os.makedirs(d, mode)
73         else:
74             os.makedirs(d)
75     except OSError, e:
76         if e.errno == errno.EEXIST:
77             pass
78         else:
79             raise
80
81
82 def next(it):
83     """Get the next item from an iterator, None if we reached the end."""
84     try:
85         return it.next()
86     except StopIteration:
87         return None
88
89
90 def unlink(f):
91     """Delete a file at path 'f' if it currently exists.
92
93     Unlike os.unlink(), does not throw an exception if the file didn't already
94     exist.
95     """
96     try:
97         os.unlink(f)
98     except OSError, e:
99         if e.errno == errno.ENOENT:
100             pass  # it doesn't exist, that's what you asked for
101
102
103 def readpipe(argv):
104     """Run a subprocess and return its output."""
105     p = subprocess.Popen(argv, stdout=subprocess.PIPE)
106     r = p.stdout.read()
107     p.wait()
108     return r
109
110
111 def realpath(p):
112     """Get the absolute path of a file.
113
114     Behaves like os.path.realpath, but doesn't follow a symlink for the last
115     element. (ie. if 'p' itself is a symlink, this one won't follow it, but it
116     will follow symlinks in p's directory)
117     """
118     try:
119         st = os.lstat(p)
120     except OSError:
121         st = None
122     if st and stat.S_ISLNK(st.st_mode):
123         (dir, name) = os.path.split(p)
124         dir = os.path.realpath(dir)
125         out = os.path.join(dir, name)
126     else:
127         out = os.path.realpath(p)
128     #log('realpathing:%r,%r\n' % (p, out))
129     return out
130
131
132 def detect_fakeroot():
133     "Return True if we appear to be running under fakeroot."
134     return os.getenv("FAKEROOTKEY") != None
135
136
137 _username = None
138 def username():
139     """Get the user's login name."""
140     global _username
141     if not _username:
142         uid = os.getuid()
143         try:
144             _username = pwd.getpwuid(uid)[0]
145         except KeyError:
146             _username = 'user%d' % uid
147     return _username
148
149
150 _userfullname = None
151 def userfullname():
152     """Get the user's full name."""
153     global _userfullname
154     if not _userfullname:
155         uid = os.getuid()
156         try:
157             _userfullname = pwd.getpwuid(uid)[4].split(',')[0]
158         except KeyError:
159             _userfullname = 'user%d' % uid
160     return _userfullname
161
162
163 _hostname = None
164 def hostname():
165     """Get the FQDN of this machine."""
166     global _hostname
167     if not _hostname:
168         _hostname = socket.getfqdn()
169     return _hostname
170
171
172 _resource_path = None
173 def resource_path(subdir=''):
174     global _resource_path
175     if not _resource_path:
176         _resource_path = os.environ.get('BUP_RESOURCE_PATH') or '.'
177     return os.path.join(_resource_path, subdir)
178
179 class NotOk(Exception):
180     pass
181
182 class Conn:
183     """A helper class for bup's client-server protocol."""
184     def __init__(self, inp, outp):
185         self.inp = inp
186         self.outp = outp
187
188     def read(self, size):
189         """Read 'size' bytes from input stream."""
190         self.outp.flush()
191         return self.inp.read(size)
192
193     def readline(self):
194         """Read from input stream until a newline is found."""
195         self.outp.flush()
196         return self.inp.readline()
197
198     def write(self, data):
199         """Write 'data' to output stream."""
200         #log('%d writing: %d bytes\n' % (os.getpid(), len(data)))
201         self.outp.write(data)
202
203     def has_input(self):
204         """Return true if input stream is readable."""
205         [rl, wl, xl] = select.select([self.inp.fileno()], [], [], 0)
206         if rl:
207             assert(rl[0] == self.inp.fileno())
208             return True
209         else:
210             return None
211
212     def ok(self):
213         """Indicate end of output from last sent command."""
214         self.write('\nok\n')
215
216     def error(self, s):
217         """Indicate server error to the client."""
218         s = re.sub(r'\s+', ' ', str(s))
219         self.write('\nerror %s\n' % s)
220
221     def _check_ok(self, onempty):
222         self.outp.flush()
223         rl = ''
224         for rl in linereader(self.inp):
225             #log('%d got line: %r\n' % (os.getpid(), rl))
226             if not rl:  # empty line
227                 continue
228             elif rl == 'ok':
229                 return None
230             elif rl.startswith('error '):
231                 #log('client: error: %s\n' % rl[6:])
232                 return NotOk(rl[6:])
233             else:
234                 onempty(rl)
235         raise Exception('server exited unexpectedly; see errors above')
236
237     def drain_and_check_ok(self):
238         """Remove all data for the current command from input stream."""
239         def onempty(rl):
240             pass
241         return self._check_ok(onempty)
242
243     def check_ok(self):
244         """Verify that server action completed successfully."""
245         def onempty(rl):
246             raise Exception('expected "ok", got %r' % rl)
247         return self._check_ok(onempty)
248
249
250 def linereader(f):
251     """Generate a list of input lines from 'f' without terminating newlines."""
252     while 1:
253         line = f.readline()
254         if not line:
255             break
256         yield line[:-1]
257
258
259 def chunkyreader(f, count = None):
260     """Generate a list of chunks of data read from 'f'.
261
262     If count is None, read until EOF is reached.
263
264     If count is a positive integer, read 'count' bytes from 'f'. If EOF is
265     reached while reading, raise IOError.
266     """
267     if count != None:
268         while count > 0:
269             b = f.read(min(count, 65536))
270             if not b:
271                 raise IOError('EOF with %d bytes remaining' % count)
272             yield b
273             count -= len(b)
274     else:
275         while 1:
276             b = f.read(65536)
277             if not b: break
278             yield b
279
280
281 def slashappend(s):
282     """Append "/" to 's' if it doesn't aleady end in "/"."""
283     if s and not s.endswith('/'):
284         return s + '/'
285     else:
286         return s
287
288
289 def _mmap_do(f, sz, flags, prot):
290     if not sz:
291         st = os.fstat(f.fileno())
292         sz = st.st_size
293     map = mmap.mmap(f.fileno(), sz, flags, prot)
294     f.close()  # map will persist beyond file close
295     return map
296
297
298 def mmap_read(f, sz = 0):
299     """Create a read-only memory mapped region on file 'f'.
300
301     If sz is 0, the region will cover the entire file.
302     """
303     return _mmap_do(f, sz, mmap.MAP_PRIVATE, mmap.PROT_READ)
304
305
306 def mmap_readwrite(f, sz = 0):
307     """Create a read-write memory mapped region on file 'f'.
308
309     If sz is 0, the region will cover the entire file.
310     """
311     return _mmap_do(f, sz, mmap.MAP_SHARED, mmap.PROT_READ|mmap.PROT_WRITE)
312
313
314 def parse_num(s):
315     """Parse data size information into a float number.
316
317     Here are some examples of conversions:
318         199.2k means 203981 bytes
319         1GB means 1073741824 bytes
320         2.1 tb means 2199023255552 bytes
321     """
322     g = re.match(r'([-+\d.e]+)\s*(\w*)', str(s))
323     if not g:
324         raise ValueError("can't parse %r as a number" % s)
325     (val, unit) = g.groups()
326     num = float(val)
327     unit = unit.lower()
328     if unit in ['t', 'tb']:
329         mult = 1024*1024*1024*1024
330     elif unit in ['g', 'gb']:
331         mult = 1024*1024*1024
332     elif unit in ['m', 'mb']:
333         mult = 1024*1024
334     elif unit in ['k', 'kb']:
335         mult = 1024
336     elif unit in ['', 'b']:
337         mult = 1
338     else:
339         raise ValueError("invalid unit %r in number %r" % (unit, s))
340     return int(num*mult)
341
342
343 def count(l):
344     """Count the number of elements in an iterator. (consumes the iterator)"""
345     return reduce(lambda x,y: x+1, l)
346
347
348 saved_errors = []
349 def add_error(e):
350     """Append an error message to the list of saved errors.
351
352     Once processing is able to stop and output the errors, the saved errors are
353     accessible in the module variable helpers.saved_errors.
354     """
355     saved_errors.append(e)
356     log('%-70s\n' % e)
357
358 istty = os.isatty(2) or atoi(os.environ.get('BUP_FORCE_TTY'))
359 def progress(s):
360     """Calls log(s) if stderr is a TTY.  Does nothing otherwise."""
361     if istty:
362         log(s)
363
364
365 def handle_ctrl_c():
366     """Replace the default exception handler for KeyboardInterrupt (Ctrl-C).
367
368     The new exception handler will make sure that bup will exit without an ugly
369     stacktrace when Ctrl-C is hit.
370     """
371     oldhook = sys.excepthook
372     def newhook(exctype, value, traceback):
373         if exctype == KeyboardInterrupt:
374             log('Interrupted.\n')
375         else:
376             return oldhook(exctype, value, traceback)
377     sys.excepthook = newhook
378
379
380 def columnate(l, prefix):
381     """Format elements of 'l' in columns with 'prefix' leading each line.
382
383     The number of columns is determined automatically based on the string
384     lengths.
385     """
386     if not l:
387         return ""
388     l = l[:]
389     clen = max(len(s) for s in l)
390     ncols = (tty_width() - len(prefix)) / (clen + 2)
391     if ncols <= 1:
392         ncols = 1
393         clen = 0
394     cols = []
395     while len(l) % ncols:
396         l.append('')
397     rows = len(l)/ncols
398     for s in range(0, len(l), rows):
399         cols.append(l[s:s+rows])
400     out = ''
401     for row in zip(*cols):
402         out += prefix + ''.join(('%-*s' % (clen+2, s)) for s in row) + '\n'
403     return out
404
405 def parse_date_or_fatal(str, fatal):
406     """Parses the given date or calls Option.fatal().
407     For now we expect a string that contains a float."""
408     try:
409         date = atof(str)
410     except ValueError, e:
411         raise fatal('invalid date format (should be a float): %r' % e)
412     else:
413         return date
414
415
416 class FSTime():
417     # Class to represent filesystem timestamps.  Use integer
418     # nanoseconds on platforms where we have the higher resolution
419     # lstat.  Use the native python stat representation (floating
420     # point seconds) otherwise.
421
422     def __cmp__(self, x):
423         return self._value.__cmp__(x._value)
424
425     def to_timespec(self):
426         """Return (s, ns) where ns is always non-negative
427         and t = s + ns / 10e8""" # metadata record rep (and libc rep)
428         s_ns = self.secs_nsecs()
429         if s_ns[0] > 0 or s_ns[1] >= 0:
430             return s_ns
431         return (s_ns[0] - 1, 10**9 + s_ns[1]) # ns is negative
432
433     if _helpers.lstat: # Use integer nanoseconds.
434
435         @staticmethod
436         def from_secs(secs):
437             ts = FSTime()
438             ts._value = int(secs * 10**9)
439             return ts
440
441         @staticmethod
442         def from_timespec(timespec):
443             ts = FSTime()
444             ts._value = timespec[0] * 10**9 + timespec[1]
445             return ts
446
447         @staticmethod
448         def from_stat_time(stat_time):
449             return FSTime.from_timespec(stat_time)
450
451         def approx_secs(self):
452             return self._value / 10e8;
453
454         def secs_nsecs(self):
455             "Return a (s, ns) pair: -1.5s -> (-1, -10**9 / 2)."
456             if self._value >= 0:
457                 return (self._value / 10**9, self._value % 10**9)
458             abs_val = -self._value
459             return (- (abs_val / 10**9), - (abs_val % 10**9))
460
461     else: # Use python default floating-point seconds.
462
463         @staticmethod
464         def from_secs(secs):
465             ts = FSTime()
466             ts._value = secs
467             return ts
468
469         @staticmethod
470         def from_timespec(timespec):
471             ts = FSTime()
472             ts._value = timespec[0] + (timespec[1] / 10e8)
473             return ts
474
475         @staticmethod
476         def from_stat_time(stat_time):
477             ts = FSTime()
478             ts._value = stat_time
479             return ts
480
481         def approx_secs(self):
482             return self._value
483
484         def secs_nsecs(self):
485             "Return a (s, ns) pair: -1.5s -> (-1, -5**9)."
486             x = math.modf(self._value)
487             return (x[1], x[0] * 10**9)
488
489
490 def lutime(path, times):
491     if _helpers.utimensat:
492         atime = times[0].to_timespec()
493         mtime = times[1].to_timespec()
494         return _helpers.utimensat(_helpers.AT_FDCWD, path, (atime, mtime),
495                                   _helpers.AT_SYMLINK_NOFOLLOW)
496     else:
497         return None
498
499
500 def utime(path, times):
501     if _helpers.utimensat:
502         atime = times[0].to_timespec()
503         mtime = times[1].to_timespec()
504         return _helpers.utimensat(_helpers.AT_FDCWD, path, (atime, mtime), 0)
505     else:
506         atime = times[0].approx_secs()
507         mtime = times[1].approx_secs()
508         os.utime(path, (atime, mtime))
509
510
511 class stat_result():
512
513     @staticmethod
514     def from_stat_rep(st):
515         result = stat_result()
516         if _helpers._have_ns_fs_timestamps:
517             (result.st_mode,
518              result.st_ino,
519              result.st_dev,
520              result.st_nlink,
521              result.st_uid,
522              result.st_gid,
523              result.st_rdev,
524              result.st_size,
525              atime,
526              mtime,
527              ctime) = st
528         else:
529             result.st_mode = st.st_mode
530             result.st_ino = st.st_ino
531             result.st_dev = st.st_dev
532             result.st_nlink = st.st_nlink
533             result.st_uid = st.st_uid
534             result.st_gid = st.st_gid
535             result.st_rdev = st.st_rdev
536             result.st_size = st.st_size
537             atime = FSTime.from_stat_time(st.st_atime)
538             mtime = FSTime.from_stat_time(st.st_mtime)
539             ctime = FSTime.from_stat_time(st.st_ctime)
540         result.st_atime = FSTime.from_stat_time(atime)
541         result.st_mtime = FSTime.from_stat_time(mtime)
542         result.st_ctime = FSTime.from_stat_time(ctime)
543         return result
544
545
546 def fstat(path):
547     if _helpers.fstat:
548         st = _helpers.fstat(path)
549     else:
550         st = os.fstat(path)
551     return stat_result.from_stat_rep(st)
552
553
554 def lstat(path):
555     if _helpers.lstat:
556         st = _helpers.lstat(path)
557     else:
558         st = os.lstat(path)
559     return stat_result.from_stat_rep(st)
560
561
562 # hashlib is only available in python 2.5 or higher, but the 'sha' module
563 # produces a DeprecationWarning in python 2.6 or higher.  We want to support
564 # python 2.4 and above without any stupid warnings, so let's try using hashlib
565 # first, and downgrade if it fails.
566 try:
567     import hashlib
568 except ImportError:
569     import sha
570     Sha1 = sha.sha
571 else:
572     Sha1 = hashlib.sha1
573
574
575 def version_date():
576     """Format bup's version date string for output."""
577     return _version.DATE.split(' ')[0]
578
579 def version_commit():
580     """Get the commit hash of bup's current version."""
581     return _version.COMMIT
582
583 def version_tag():
584     """Format bup's version tag (the official version number).
585
586     When generated from a commit other than one pointed to with a tag, the
587     returned string will be "unknown-" followed by the first seven positions of
588     the commit hash.
589     """
590     names = _version.NAMES.strip()
591     assert(names[0] == '(')
592     assert(names[-1] == ')')
593     names = names[1:-1]
594     l = [n.strip() for n in names.split(',')]
595     for n in l:
596         if n.startswith('tag: bup-'):
597             return n[9:]
598     return 'unknown-%s' % _version.COMMIT[:7]