]> arthur.barton.de Git - bup.git/blob - lib/tornado/database.py
Always publish (l)utimes in helpers when available and fix type conversions.
[bup.git] / lib / tornado / database.py
1 #!/usr/bin/env python
2 #
3 # Copyright 2009 Facebook
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License"); you may
6 # not use this file except in compliance with the License. You may obtain
7 # a copy of the License at
8 #
9 #     http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14 # License for the specific language governing permissions and limitations
15 # under the License.
16
17 """A lightweight wrapper around MySQLdb."""
18
19 import copy
20 import MySQLdb.constants
21 import MySQLdb.converters
22 import MySQLdb.cursors
23 import itertools
24 import logging
25
26 class Connection(object):
27     """A lightweight wrapper around MySQLdb DB-API connections.
28
29     The main value we provide is wrapping rows in a dict/object so that
30     columns can be accessed by name. Typical usage:
31
32         db = database.Connection("localhost", "mydatabase")
33         for article in db.query("SELECT * FROM articles"):
34             print article.title
35
36     Cursors are hidden by the implementation, but other than that, the methods
37     are very similar to the DB-API.
38
39     We explicitly set the timezone to UTC and the character encoding to
40     UTF-8 on all connections to avoid time zone and encoding errors.
41     """
42     def __init__(self, host, database, user=None, password=None):
43         self.host = host
44         self.database = database
45
46         args = dict(conv=CONVERSIONS, use_unicode=True, charset="utf8",
47                     db=database, init_command='SET time_zone = "+0:00"',
48                     sql_mode="TRADITIONAL")
49         if user is not None:
50             args["user"] = user
51         if password is not None:
52             args["passwd"] = password
53
54         # We accept a path to a MySQL socket file or a host(:port) string
55         if "/" in host:
56             args["unix_socket"] = host
57         else:
58             self.socket = None
59             pair = host.split(":")
60             if len(pair) == 2:
61                 args["host"] = pair[0]
62                 args["port"] = int(pair[1])
63             else:
64                 args["host"] = host
65                 args["port"] = 3306
66
67         self._db = None
68         self._db_args = args
69         try:
70             self.reconnect()
71         except:
72             logging.error("Cannot connect to MySQL on %s", self.host,
73                           exc_info=True)
74
75     def __del__(self):
76         self.close()
77
78     def close(self):
79         """Closes this database connection."""
80         if getattr(self, "_db", None) is not None:
81             self._db.close()
82             self._db = None
83
84     def reconnect(self):
85         """Closes the existing database connection and re-opens it."""
86         self.close()
87         self._db = MySQLdb.connect(**self._db_args)
88         self._db.autocommit(True)
89
90     def iter(self, query, *parameters):
91         """Returns an iterator for the given query and parameters."""
92         if self._db is None: self.reconnect()
93         cursor = MySQLdb.cursors.SSCursor(self._db)
94         try:
95             self._execute(cursor, query, parameters)
96             column_names = [d[0] for d in cursor.description]
97             for row in cursor:
98                 yield Row(zip(column_names, row))
99         finally:
100             cursor.close()
101
102     def query(self, query, *parameters):
103         """Returns a row list for the given query and parameters."""
104         cursor = self._cursor()
105         try:
106             self._execute(cursor, query, parameters)
107             column_names = [d[0] for d in cursor.description]
108             return [Row(itertools.izip(column_names, row)) for row in cursor]
109         finally:
110             cursor.close()
111
112     def get(self, query, *parameters):
113         """Returns the first row returned for the given query."""
114         rows = self.query(query, *parameters)
115         if not rows:
116             return None
117         elif len(rows) > 1:
118             raise Exception("Multiple rows returned for Database.get() query")
119         else:
120             return rows[0]
121
122     def execute(self, query, *parameters):
123         """Executes the given query, returning the lastrowid from the query."""
124         cursor = self._cursor()
125         try:
126             self._execute(cursor, query, parameters)
127             return cursor.lastrowid
128         finally:
129             cursor.close()
130
131     def executemany(self, query, parameters):
132         """Executes the given query against all the given param sequences.
133
134         We return the lastrowid from the query.
135         """
136         cursor = self._cursor()
137         try:
138             cursor.executemany(query, parameters)
139             return cursor.lastrowid
140         finally:
141             cursor.close()
142
143     def _cursor(self):
144         if self._db is None: self.reconnect()
145         return self._db.cursor()
146
147     def _execute(self, cursor, query, parameters):
148         try:
149             return cursor.execute(query, parameters)
150         except OperationalError:
151             logging.error("Error connecting to MySQL on %s", self.host)
152             self.close()
153             raise
154
155
156 class Row(dict):
157     """A dict that allows for object-like property access syntax."""
158     def __getattr__(self, name):
159         try:
160             return self[name]
161         except KeyError:
162             raise AttributeError(name)
163
164
165 # Fix the access conversions to properly recognize unicode/binary
166 FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
167 FLAG = MySQLdb.constants.FLAG
168 CONVERSIONS = copy.deepcopy(MySQLdb.converters.conversions)
169
170 field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]
171 if 'VARCHAR' in vars(FIELD_TYPE):
172     field_types.append(FIELD_TYPE.VARCHAR)
173
174 for field_type in field_types:
175     CONVERSIONS[field_type].insert(0, (FLAG.BINARY, str))
176
177
178 # Alias some common MySQL exceptions
179 IntegrityError = MySQLdb.IntegrityError
180 OperationalError = MySQLdb.OperationalError