]> arthur.barton.de Git - bup.git/blobdiff - lib/bup/cmd/split.py
Remove Client __del__ in favor of context management
[bup.git] / lib / bup / cmd / split.py
index 87ad88752ab854603578d1d4495a0a14164205e2..1ffe44d3cdd7f4fc4919f7fa0634ef1048e195d8 100755 (executable)
@@ -4,8 +4,8 @@ from binascii import hexlify
 import sys, time
 
 from bup import compat, hashsplit, git, options, client
-from bup.compat import argv_bytes, environ
-from bup.helpers import (add_error, handle_ctrl_c, hostname, log, parse_num,
+from bup.compat import argv_bytes, environ, nullcontext
+from bup.helpers import (add_error, hostname, log, parse_num,
                          qprogress, reprogress, saved_errors,
                          valid_save_name,
                          parse_date_or_fatal)
@@ -41,9 +41,26 @@ bwlimit=   maximum bytes/sec to transmit to server
 #,compress=  set compression level to # (0-9, 9 is highest) [1]
 """
 
-def main(argv):
+
+class NoOpPackWriter:
+    def __init__(self):
+        pass
+    def __enter__(self):
+        return self
+    def __exit__(self, type, value, traceback):
+        return None  # since close() does nothing
+    def close(self):
+        return None
+    def new_blob(self, content):
+        return git.calc_hash(b'blob', content)
+    def new_tree(self, shalist):
+        return git.calc_hash(b'tree', git.tree_encode(shalist))
+
+def opts_from_cmdline(argv):
     o = options.Options(optspec)
     opt, flags, extra = o.parse_bytes(argv[1:])
+    opt.sources = extra
+
     if opt.name: opt.name = argv_bytes(opt.name)
     if opt.remote: opt.remote = argv_bytes(opt.remote)
     if opt.verbose is None: opt.verbose = 0
@@ -59,29 +76,32 @@ def main(argv):
         o.fatal('-b is incompatible with -t, -c, -n')
     if extra and opt.git_ids:
         o.fatal("don't provide filenames when using --git-ids")
-
     if opt.verbose >= 2:
         git.verbose = opt.verbose - 1
         opt.bench = 1
-
-    max_pack_size = None
     if opt.max_pack_size:
-        max_pack_size = parse_num(opt.max_pack_size)
-    max_pack_objects = None
+        opt.max_pack_size = parse_num(opt.max_pack_size)
     if opt.max_pack_objects:
-        max_pack_objects = parse_num(opt.max_pack_objects)
-
+        opt.max_pack_objects = parse_num(opt.max_pack_objects)
     if opt.fanout:
-        hashsplit.fanout = parse_num(opt.fanout)
-    if opt.blobs:
-        hashsplit.fanout = 0
+        opt.fanout = parse_num(opt.fanout)
     if opt.bwlimit:
-        client.bwlimit = parse_num(opt.bwlimit)
+        opt.bwlimit = parse_num(opt.bwlimit)
     if opt.date:
-        date = parse_date_or_fatal(opt.date, o.fatal)
+        opt.date = parse_date_or_fatal(opt.date, o.fatal)
     else:
-        date = time.time()
+        opt.date = time.time()
+
+    opt.is_reverse = environ.get(b'BUP_SERVER_REVERSE')
+    if opt.is_reverse and opt.remote:
+        o.fatal("don't use -r in reverse mode; it's automatic")
+
+    if opt.name and not valid_save_name(opt.name):
+        o.fatal("'%r' is not a valid branch name." % opt.name)
+
+    return opt
 
+def split(opt, files, parent, out, pack_writer):
     # Hack around lack of nonlocal vars in python 2
     total_bytes = [0]
     def prog(filenum, nbytes):
@@ -92,34 +112,74 @@ def main(argv):
         else:
             qprogress('Splitting: %d kbytes\r' % (total_bytes[0] // 1024))
 
+    new_blob = pack_writer.new_blob
+    new_tree = pack_writer.new_tree
+    if opt.blobs:
+        shalist = hashsplit.split_to_blobs(new_blob, files,
+                                           keep_boundaries=opt.keep_boundaries,
+                                           progress=prog)
+        for sha, size, level in shalist:
+            out.write(hexlify(sha) + b'\n')
+            reprogress()
+    elif opt.tree or opt.commit or opt.name:
+        if opt.name: # insert dummy_name which may be used as a restore target
+            mode, sha = \
+                hashsplit.split_to_blob_or_tree(new_blob, new_tree, files,
+                                                keep_boundaries=opt.keep_boundaries,
+                                                progress=prog)
+            splitfile_name = git.mangle_name(b'data', hashsplit.GIT_MODE_FILE, mode)
+            shalist = [(mode, splitfile_name, sha)]
+        else:
+            shalist = \
+                hashsplit.split_to_shalist(new_blob, new_tree, files,
+                                           keep_boundaries=opt.keep_boundaries,
+                                           progress=prog)
+        tree = new_tree(shalist)
+    else:
+        last = 0
+        it = hashsplit.hashsplit_iter(files,
+                                      keep_boundaries=opt.keep_boundaries,
+                                      progress=prog)
+        for blob, level in it:
+            hashsplit.total_split += len(blob)
+            if opt.copy:
+                sys.stdout.write(str(blob))
+            megs = hashsplit.total_split // 1024 // 1024
+            if not opt.quiet and last != megs:
+                last = megs
 
-    is_reverse = environ.get(b'BUP_SERVER_REVERSE')
-    if is_reverse and opt.remote:
-        o.fatal("don't use -r in reverse mode; it's automatic")
-    start_time = time.time()
+    if opt.verbose:
+        log('\n')
+    if opt.tree:
+        out.write(hexlify(tree) + b'\n')
 
-    if opt.name and not valid_save_name(opt.name):
-        o.fatal("'%r' is not a valid branch name." % opt.name)
-    refname = opt.name and b'refs/heads/%s' % opt.name or None
+    commit = None
+    if opt.commit or opt.name:
+        msg = b'bup split\n\nGenerated by command:\n%r\n' % compat.get_argvb()
+        userline = b'%s <%s@%s>' % (userfullname(), username(), hostname())
+        commit = pack_writer.new_commit(tree, parent, userline, opt.date,
+                                        None, userline, opt.date, None, msg)
+        if opt.commit:
+            out.write(hexlify(commit) + b'\n')
 
-    if opt.noop or opt.copy:
-        cli = pack_writer = oldref = None
-    elif opt.remote or is_reverse:
-        git.check_repo_or_die()
-        cli = client.Client(opt.remote)
-        oldref = refname and cli.read_ref(refname) or None
-        pack_writer = cli.new_packwriter(compression_level=opt.compress,
-                                         max_pack_size=max_pack_size,
-                                         max_pack_objects=max_pack_objects)
-    else:
-        git.check_repo_or_die()
-        cli = None
-        oldref = refname and git.read_ref(refname) or None
-        pack_writer = git.PackWriter(compression_level=opt.compress,
-                                     max_pack_size=max_pack_size,
-                                     max_pack_objects=max_pack_objects)
+    return commit
 
-    input = byte_stream(sys.stdin)
+def main(argv):
+    opt = opts_from_cmdline(argv)
+    if opt.verbose >= 2:
+        git.verbose = opt.verbose - 1
+    if opt.fanout:
+        hashsplit.fanout = opt.fanout
+    if opt.blobs:
+        hashsplit.fanout = 0
+    if opt.bwlimit:
+        client.bwlimit = opt.bwlimit
+
+    start_time = time.time()
+
+    sys.stdout.flush()
+    out = byte_stream(sys.stdout)
+    stdin = byte_stream(sys.stdin)
 
     if opt.git_ids:
         # the input is actually a series of git object ids that we should retrieve
@@ -139,7 +199,7 @@ def main(argv):
                 return v or b''
         def read_ids():
             while 1:
-                line = input.readline()
+                line = stdin.readline()
                 if not line:
                     break
                 if line:
@@ -154,76 +214,50 @@ def main(argv):
         files = read_ids()
     else:
         # the input either comes from a series of files or from stdin.
-        files = extra and (open(argv_bytes(fn), 'rb') for fn in extra) or [input]
+        if opt.sources:
+            files = (open(argv_bytes(fn), 'rb') for fn in opt.sources)
+        else:
+            files = [stdin]
 
-    if pack_writer:
-        new_blob = pack_writer.new_blob
-        new_tree = pack_writer.new_tree
-    elif opt.blobs or opt.tree:
-        # --noop mode
-        new_blob = lambda content: git.calc_hash(b'blob', content)
-        new_tree = lambda shalist: git.calc_hash(b'tree', git.tree_encode(shalist))
+    writing = not (opt.noop or opt.copy)
+    remote_dest = opt.remote or opt.is_reverse
 
-    sys.stdout.flush()
-    out = byte_stream(sys.stdout)
+    if writing:
+        git.check_repo_or_die()
 
-    if opt.blobs:
-        shalist = hashsplit.split_to_blobs(new_blob, files,
-                                           keep_boundaries=opt.keep_boundaries,
-                                           progress=prog)
-        for (sha, size, level) in shalist:
-            out.write(hexlify(sha) + b'\n')
-            reprogress()
-    elif opt.tree or opt.commit or opt.name:
-        if opt.name: # insert dummy_name which may be used as a restore target
-            mode, sha = \
-                hashsplit.split_to_blob_or_tree(new_blob, new_tree, files,
-                                                keep_boundaries=opt.keep_boundaries,
-                                                progress=prog)
-            splitfile_name = git.mangle_name(b'data', hashsplit.GIT_MODE_FILE, mode)
-            shalist = [(mode, splitfile_name, sha)]
-        else:
-            shalist = hashsplit.split_to_shalist(
-                          new_blob, new_tree, files,
-                          keep_boundaries=opt.keep_boundaries, progress=prog)
-        tree = new_tree(shalist)
+    if remote_dest and writing:
+        cli = repo = client.Client(opt.remote)
     else:
-        last = 0
-        it = hashsplit.hashsplit_iter(files,
-                                      keep_boundaries=opt.keep_boundaries,
-                                      progress=prog)
-        for (blob, level) in it:
-            hashsplit.total_split += len(blob)
-            if opt.copy:
-                sys.stdout.write(str(blob))
-            megs = hashsplit.total_split // 1024 // 1024
-            if not opt.quiet and last != megs:
-                last = megs
-
-    if opt.verbose:
-        log('\n')
-    if opt.tree:
-        out.write(hexlify(tree) + b'\n')
-    if opt.commit or opt.name:
-        msg = b'bup split\n\nGenerated by command:\n%r\n' % compat.get_argvb()
-        ref = opt.name and (b'refs/heads/%s' % opt.name) or None
-        userline = b'%s <%s@%s>' % (userfullname(), username(), hostname())
-        commit = pack_writer.new_commit(tree, oldref, userline, date, None,
-                                        userline, date, None, msg)
-        if opt.commit:
-            out.write(hexlify(commit) + b'\n')
+        cli = nullcontext()
+        repo = git
 
-    if pack_writer:
-        pack_writer.close()  # must close before we can update the ref
+    # cli creation must be last nontrivial command in each if clause above
+    with cli:
+        if opt.name and writing:
+            refname = opt.name and b'refs/heads/%s' % opt.name
+            oldref = repo.read_ref(refname)
+        else:
+            refname = oldref = None
 
-    if opt.name:
-        if cli:
-            cli.update_ref(refname, commit, oldref)
+        if not writing:
+            pack_writer = NoOpPackWriter()
+        elif not remote_dest:
+            pack_writer = git.PackWriter(compression_level=opt.compress,
+                                         max_pack_size=opt.max_pack_size,
+                                         max_pack_objects=opt.max_pack_objects)
         else:
-            git.update_ref(refname, commit, oldref)
+            pack_writer = cli.new_packwriter(compression_level=opt.compress,
+                                             max_pack_size=opt.max_pack_size,
+                                             max_pack_objects=opt.max_pack_objects)
+
+        commit = split(opt, files, oldref, out, pack_writer)
+
+        if pack_writer:
+            pack_writer.close()
 
-    if cli:
-        cli.close()
+        # pack_writer must be closed before we can update the ref
+        if refname:
+            repo.update_ref(refname, commit, oldref)
 
     secs = time.time() - start_time
     size = hashsplit.total_split