]> arthur.barton.de Git - bup.git/blob - lib/bup/repo.py
index.Reader.__del__: replace with context management
[bup.git] / lib / bup / repo.py
1
2 from __future__ import absolute_import
3 from os.path import realpath
4 from functools import partial
5
6 from bup import client, git, vfs
7 from bup.compat import pending_raise
8
9 _next_repo_id = 0
10 _repo_ids = {}
11
12 def _repo_id(key):
13     global _next_repo_id, _repo_ids
14     repo_id = _repo_ids.get(key)
15     if repo_id:
16         return repo_id
17     next_id = _next_repo_id = _next_repo_id + 1
18     _repo_ids[key] = next_id
19     return next_id
20
21 class LocalRepo:
22     def __init__(self, repo_dir=None):
23         self.repo_dir = realpath(repo_dir or git.repo())
24         self._cp = git.cp(self.repo_dir)
25         self.update_ref = partial(git.update_ref, repo_dir=self.repo_dir)
26         self.rev_list = partial(git.rev_list, repo_dir=self.repo_dir)
27         self._id = _repo_id(self.repo_dir)
28
29     def close(self):
30         pass
31
32     def __del__(self):
33         self.close()
34
35     def __enter__(self):
36         return self
37
38     def __exit__(self, type, value, traceback):
39         with pending_raise(value, rethrow=False):
40             self.close()
41
42     def id(self):
43         """Return an identifier that differs from any other repository that
44         doesn't share the same repository-specific information
45         (e.g. refs, tags, etc.)."""
46         return self._id
47
48     def is_remote(self):
49         return False
50
51     def new_packwriter(self, compression_level=1,
52                        max_pack_size=None, max_pack_objects=None):
53         return git.PackWriter(repo_dir=self.repo_dir,
54                               compression_level=compression_level,
55                               max_pack_size=max_pack_size,
56                               max_pack_objects=max_pack_objects)
57
58     def cat(self, ref):
59         """If ref does not exist, yield (None, None, None).  Otherwise yield
60         (oidx, type, size), and then all of the data associated with
61         ref.
62
63         """
64         it = self._cp.get(ref)
65         oidx, typ, size = info = next(it)
66         yield info
67         if oidx:
68             for data in it:
69                 yield data
70         assert not next(it, None)
71
72     def join(self, ref):
73         return self._cp.join(ref)
74
75     def refs(self, patterns=None, limit_to_heads=False, limit_to_tags=False):
76         for ref in git.list_refs(patterns=patterns,
77                                  limit_to_heads=limit_to_heads,
78                                  limit_to_tags=limit_to_tags,
79                                  repo_dir=self.repo_dir):
80             yield ref
81
82     ## Of course, the vfs better not call this...
83     def resolve(self, path, parent=None, want_meta=True, follow=True):
84         ## FIXME: mode_only=?
85         return vfs.resolve(self, path,
86                            parent=parent, want_meta=want_meta, follow=follow)
87
88
89 class RemoteRepo:
90     def __init__(self, address):
91         self.address = address
92         self.client = client.Client(address)
93         self.new_packwriter = self.client.new_packwriter
94         self.update_ref = self.client.update_ref
95         self.rev_list = self.client.rev_list
96         self._id = _repo_id(self.address)
97
98     def close(self):
99         if self.client:
100             self.client.close()
101             self.client = None
102
103     def __del__(self):
104         self.close()
105
106     def __enter__(self):
107         return self
108
109     def __exit__(self, type, value, traceback):
110         with pending_raise(value, rethrow=False):
111             self.close()
112
113     def id(self):
114         """Return an identifier that differs from any other repository that
115         doesn't share the same repository-specific information
116         (e.g. refs, tags, etc.)."""
117         return self._id
118
119     def is_remote(self):
120         return True
121
122     def cat(self, ref):
123         """If ref does not exist, yield (None, None, None).  Otherwise yield
124         (oidx, type, size), and then all of the data associated with
125         ref.
126
127         """
128         # Yield all the data here so that we don't finish the
129         # cat_batch iterator (triggering its cleanup) until all of the
130         # data has been read.  Otherwise we'd be out of sync with the
131         # server.
132         items = self.client.cat_batch((ref,))
133         oidx, typ, size, it = info = next(items)
134         yield info[:-1]
135         if oidx:
136             for data in it:
137                 yield data
138         assert not next(items, None)
139
140     def join(self, ref):
141         return self.client.join(ref)
142
143     def refs(self, patterns=None, limit_to_heads=False, limit_to_tags=False):
144         for ref in self.client.refs(patterns=patterns,
145                                     limit_to_heads=limit_to_heads,
146                                     limit_to_tags=limit_to_tags):
147             yield ref
148
149     def resolve(self, path, parent=None, want_meta=True, follow=True):
150         ## FIXME: mode_only=?
151         return self.client.resolve(path, parent=parent, want_meta=want_meta,
152                                    follow=follow)