]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/client.py
Check that all context managed objects are properly closed
[bup.git] / lib / bup / client.py
index 07b184cdcfe0fbca35cb9c8ad8a65db15f5c37a5..7c284c4af4aa613db536cf5db637382988e1db02 100644 (file)
@@ -3,15 +3,16 @@ from __future__ import print_function
 
 from __future__ import absolute_import
 from binascii import hexlify, unhexlify
 
 from __future__ import absolute_import
 from binascii import hexlify, unhexlify
-import errno, os, re, struct, sys, time, zlib
+import os, re, struct, time, zlib
+import socket
 
 from bup import git, ssh, vfs
 
 from bup import git, ssh, vfs
-from bup.compat import environ, range, reraise
+from bup.compat import environ, pending_raise, range, reraise
 from bup.helpers import (Conn, atomically_replaced_file, chunkyreader, debug1,
                          debug2, linereader, lines_until_sentinel,
 from bup.helpers import (Conn, atomically_replaced_file, chunkyreader, debug1,
                          debug2, linereader, lines_until_sentinel,
-                         mkdirp, progress, qprogress, DemuxConn, atoi)
+                         mkdirp, nullcontext_if_not, progress, qprogress, DemuxConn)
 from bup.io import path_msg
 from bup.io import path_msg
-from bup.vint import read_bvec, read_vuint, write_bvec
+from bup.vint import write_bvec
 
 
 bwlimit = None
 
 
 bwlimit = None
@@ -68,6 +69,7 @@ def parse_remote(remote):
 
 class Client:
     def __init__(self, remote, create=False):
 
 class Client:
     def __init__(self, remote, create=False):
+        self.closed = False
         self._busy = self.conn = None
         self.sock = self.p = self.pout = self.pin = None
         is_reverse = environ.get(b'BUP_SERVER_REVERSE')
         self._busy = self.conn = None
         self.sock = self.p = self.pout = self.pin = None
         is_reverse = environ.get(b'BUP_SERVER_REVERSE')
@@ -100,7 +102,8 @@ class Client:
                     reraise(ClientError('connect: %s' % e))
             elif self.protocol == b'bup':
                 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                     reraise(ClientError('connect: %s' % e))
             elif self.protocol == b'bup':
                 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-                self.sock.connect((self.host, atoi(self.port) or 1982))
+                self.sock.connect((self.host,
+                                   1982 if self.port is None else int(self.port)))
                 self.sockw = self.sock.makefile('wb')
                 self.conn = DemuxConn(self.sock.fileno(), self.sockw)
         self._available_commands = self._get_available_commands()
                 self.sockw = self.sock.makefile('wb')
                 self.conn = DemuxConn(self.sock.fileno(), self.sockw)
         self._available_commands = self._get_available_commands()
@@ -115,16 +118,8 @@ class Client:
             self.check_ok()
         self.sync_indexes()
 
             self.check_ok()
         self.sync_indexes()
 
-    def __del__(self):
-        try:
-            self.close()
-        except IOError as e:
-            if e.errno == errno.EPIPE:
-                pass
-            else:
-                raise
-
     def close(self):
     def close(self):
+        self.closed = True
         if self.conn and not self._busy:
             self.conn.write(b'quit\n')
         if self.pin:
         if self.conn and not self._busy:
             self.conn.write(b'quit\n')
         if self.pin:
@@ -146,6 +141,16 @@ class Client:
         self.conn = None
         self.sock = self.p = self.pin = self.pout = None
 
         self.conn = None
         self.sock = self.p = self.pin = self.pout = None
 
+    def __del__(self):
+        assert self.closed
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, type, value, traceback):
+        with pending_raise(value, rethrow=False):
+            self.close()
+
     def check_ok(self):
         if self.p:
             rv = self.p.poll()
     def check_ok(self):
         if self.p:
             rv = self.p.poll()
@@ -156,15 +161,17 @@ class Client:
             return self.conn.check_ok()
         except Exception as e:
             reraise(ClientError(e))
             return self.conn.check_ok()
         except Exception as e:
             reraise(ClientError(e))
+            # reraise doesn't return
+            return None
 
     def check_busy(self):
         if self._busy:
             raise ClientError('already busy with command %r' % self._busy)
 
     def check_busy(self):
         if self._busy:
             raise ClientError('already busy with command %r' % self._busy)
-        
+
     def ensure_busy(self):
         if not self._busy:
             raise ClientError('expected to be busy, but not busy?!')
     def ensure_busy(self):
         if not self._busy:
             raise ClientError('expected to be busy, but not busy?!')
-        
+
     def _not_busy(self):
         self._busy = None
 
     def _not_busy(self):
         self._busy = None
 
@@ -406,7 +413,7 @@ class Client:
             raise not_ok
         self._not_busy()
 
             raise not_ok
         self._not_busy()
 
-    def rev_list(self, refs, count=None, parse=None, format=None):
+    def rev_list(self, refs, parse=None, format=None):
         """See git.rev_list for the general semantics, but note that with the
         current interface, the parse function must be able to handle
         (consume) any blank lines produced by the format because the
         """See git.rev_list for the general semantics, but note that with the
         current interface, the parse function must be able to handle
         (consume) any blank lines produced by the format because the
@@ -415,7 +422,6 @@ class Client:
 
         """
         self._require_command(b'rev-list')
 
         """
         self._require_command(b'rev-list')
-        assert (count is None) or (isinstance(count, Integral))
         if format:
             assert b'\n' not in format
             assert parse
         if format:
             assert b'\n' not in format
             assert parse
@@ -426,8 +432,6 @@ class Client:
         self._busy = b'rev-list'
         conn = self.conn
         conn.write(b'rev-list\n')
         self._busy = b'rev-list'
         conn = self.conn
         conn.write(b'rev-list\n')
-        if count is not None:
-            conn.write(b'%d' % count)
         conn.write(b'\n')
         if format:
             conn.write(format)
         conn.write(b'\n')
         if format:
             conn.write(format)
@@ -481,7 +485,9 @@ class Client:
         return result
 
 
         return result
 
 
+# FIXME: disentangle this (stop inheriting) from PackWriter
 class PackWriter_Remote(git.PackWriter):
 class PackWriter_Remote(git.PackWriter):
+
     def __init__(self, conn, objcache_maker, suggest_packs,
                  onopen, onclose,
                  ensure_busy,
     def __init__(self, conn, objcache_maker, suggest_packs,
                  onopen, onclose,
                  ensure_busy,
@@ -493,6 +499,7 @@ class PackWriter_Remote(git.PackWriter):
                                 compression_level=compression_level,
                                 max_pack_size=max_pack_size,
                                 max_pack_objects=max_pack_objects)
                                 compression_level=compression_level,
                                 max_pack_size=max_pack_size,
                                 max_pack_objects=max_pack_objects)
+        self.remote_closed = False
         self.file = conn
         self.filename = b'remote socket'
         self.suggest_packs = suggest_packs
         self.file = conn
         self.filename = b'remote socket'
         self.suggest_packs = suggest_packs
@@ -503,25 +510,39 @@ class PackWriter_Remote(git.PackWriter):
         self._bwcount = 0
         self._bwtime = time.time()
 
         self._bwcount = 0
         self._bwtime = time.time()
 
+    # __enter__ and __exit__ are inherited
+
     def _open(self):
         if not self._packopen:
             self.onopen()
             self._packopen = True
 
     def _end(self, run_midx=True):
     def _open(self):
         if not self._packopen:
             self.onopen()
             self._packopen = True
 
     def _end(self, run_midx=True):
+        # Called by other PackWriter methods like breakpoint().
+        # Must not close the connection (self.file)
         assert(run_midx)  # We don't support this via remote yet
         assert(run_midx)  # We don't support this via remote yet
-        if self._packopen and self.file:
+        self.objcache, objcache = None, self.objcache
+        with nullcontext_if_not(objcache):
+            if not (self._packopen and self.file):
+                return None
             self.file.write(b'\0\0\0\0')
             self._packopen = False
             self.onclose() # Unbusy
             self.file.write(b'\0\0\0\0')
             self._packopen = False
             self.onclose() # Unbusy
-            self.objcache = None
+            if objcache is not None:
+                objcache.close()
             return self.suggest_packs() # Returns last idx received
 
     def close(self):
             return self.suggest_packs() # Returns last idx received
 
     def close(self):
+        # Called by inherited __exit__
+        self.remote_closed = True
         id = self._end()
         self.file = None
         return id
 
         id = self._end()
         self.file = None
         return id
 
+    def __del__(self):
+        assert self.remote_closed
+        super(PackWriter_Remote, self).__del__()
+
     def abort(self):
         raise ClientError("don't know how to abort remote pack writing")
 
     def abort(self):
         raise ClientError("don't know how to abort remote pack writing")