]> arthur.barton.de Git - bup.git/blob - lib/bup/repo.py
Replace LocalRepo/RemoteRepo __del__ with context management
[bup.git] / lib / bup / repo.py
1
2 from __future__ import absolute_import
3 from os.path import realpath
4 from functools import partial
5
6 from bup import client, git, vfs
7 from bup.compat import pending_raise
8
9
10 _next_repo_id = 0
11 _repo_ids = {}
12
13 def _repo_id(key):
14     global _next_repo_id, _repo_ids
15     repo_id = _repo_ids.get(key)
16     if repo_id:
17         return repo_id
18     next_id = _next_repo_id = _next_repo_id + 1
19     _repo_ids[key] = next_id
20     return next_id
21
22 class LocalRepo:
23     def __init__(self, repo_dir=None):
24         self.repo_dir = realpath(repo_dir or git.repo())
25         self._cp = git.cp(self.repo_dir)
26         self.update_ref = partial(git.update_ref, repo_dir=self.repo_dir)
27         self.rev_list = partial(git.rev_list, repo_dir=self.repo_dir)
28         self._id = _repo_id(self.repo_dir)
29
30     def close(self):
31         pass
32
33     def __enter__(self):
34         return self
35
36     def __exit__(self, type, value, traceback):
37         with pending_raise(value, rethrow=False):
38             self.close()
39
40     def id(self):
41         """Return an identifier that differs from any other repository that
42         doesn't share the same repository-specific information
43         (e.g. refs, tags, etc.)."""
44         return self._id
45
46     def is_remote(self):
47         return False
48
49     def new_packwriter(self, compression_level=1,
50                        max_pack_size=None, max_pack_objects=None):
51         return git.PackWriter(repo_dir=self.repo_dir,
52                               compression_level=compression_level,
53                               max_pack_size=max_pack_size,
54                               max_pack_objects=max_pack_objects)
55
56     def cat(self, ref):
57         """If ref does not exist, yield (None, None, None).  Otherwise yield
58         (oidx, type, size), and then all of the data associated with
59         ref.
60
61         """
62         it = self._cp.get(ref)
63         oidx, typ, size = info = next(it)
64         yield info
65         if oidx:
66             for data in it:
67                 yield data
68         assert not next(it, None)
69
70     def join(self, ref):
71         return self._cp.join(ref)
72
73     def refs(self, patterns=None, limit_to_heads=False, limit_to_tags=False):
74         for ref in git.list_refs(patterns=patterns,
75                                  limit_to_heads=limit_to_heads,
76                                  limit_to_tags=limit_to_tags,
77                                  repo_dir=self.repo_dir):
78             yield ref
79
80     ## Of course, the vfs better not call this...
81     def resolve(self, path, parent=None, want_meta=True, follow=True):
82         ## FIXME: mode_only=?
83         return vfs.resolve(self, path,
84                            parent=parent, want_meta=want_meta, follow=follow)
85
86
87 class RemoteRepo:
88     def __init__(self, address):
89         self.address = address
90         self.client = client.Client(address)
91         self.new_packwriter = self.client.new_packwriter
92         self.update_ref = self.client.update_ref
93         self.rev_list = self.client.rev_list
94         self._id = _repo_id(self.address)
95
96     def close(self):
97         if self.client:
98             self.client.close()
99             self.client = None
100
101     def __enter__(self):
102         return self
103
104     def __exit__(self, type, value, traceback):
105         with pending_raise(value, rethrow=False):
106             self.close()
107
108     def id(self):
109         """Return an identifier that differs from any other repository that
110         doesn't share the same repository-specific information
111         (e.g. refs, tags, etc.)."""
112         return self._id
113
114     def is_remote(self):
115         return True
116
117     def cat(self, ref):
118         """If ref does not exist, yield (None, None, None).  Otherwise yield
119         (oidx, type, size), and then all of the data associated with
120         ref.
121
122         """
123         # Yield all the data here so that we don't finish the
124         # cat_batch iterator (triggering its cleanup) until all of the
125         # data has been read.  Otherwise we'd be out of sync with the
126         # server.
127         items = self.client.cat_batch((ref,))
128         oidx, typ, size, it = info = next(items)
129         yield info[:-1]
130         if oidx:
131             for data in it:
132                 yield data
133         assert not next(items, None)
134
135     def join(self, ref):
136         return self.client.join(ref)
137
138     def refs(self, patterns=None, limit_to_heads=False, limit_to_tags=False):
139         for ref in self.client.refs(patterns=patterns,
140                                     limit_to_heads=limit_to_heads,
141                                     limit_to_tags=limit_to_tags):
142             yield ref
143
144     def resolve(self, path, parent=None, want_meta=True, follow=True):
145         ## FIXME: mode_only=?
146         return self.client.resolve(path, parent=parent, want_meta=want_meta,
147                                    follow=follow)