3 # Copyright 2009 Facebook
5 # Licensed under the Apache License, Version 2.0 (the "License"); you may
6 # not use this file except in compliance with the License. You may obtain
7 # a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14 # License for the specific language governing permissions and limitations
17 """A lightweight wrapper around MySQLdb."""
20 import MySQLdb.constants
21 import MySQLdb.converters
22 import MySQLdb.cursors
26 class Connection(object):
27 """A lightweight wrapper around MySQLdb DB-API connections.
29 The main value we provide is wrapping rows in a dict/object so that
30 columns can be accessed by name. Typical usage:
32 db = database.Connection("localhost", "mydatabase")
33 for article in db.query("SELECT * FROM articles"):
36 Cursors are hidden by the implementation, but other than that, the methods
37 are very similar to the DB-API.
39 We explicitly set the timezone to UTC and the character encoding to
40 UTF-8 on all connections to avoid time zone and encoding errors.
42 def __init__(self, host, database, user=None, password=None):
44 self.database = database
46 args = dict(conv=CONVERSIONS, use_unicode=True, charset="utf8",
47 db=database, init_command='SET time_zone = "+0:00"',
48 sql_mode="TRADITIONAL")
51 if password is not None:
52 args["passwd"] = password
54 # We accept a path to a MySQL socket file or a host(:port) string
56 args["unix_socket"] = host
59 pair = host.split(":")
61 args["host"] = pair[0]
62 args["port"] = int(pair[1])
72 logging.error("Cannot connect to MySQL on %s", self.host,
79 """Closes this database connection."""
80 if getattr(self, "_db", None) is not None:
85 """Closes the existing database connection and re-opens it."""
87 self._db = MySQLdb.connect(**self._db_args)
88 self._db.autocommit(True)
90 def iter(self, query, *parameters):
91 """Returns an iterator for the given query and parameters."""
92 if self._db is None: self.reconnect()
93 cursor = MySQLdb.cursors.SSCursor(self._db)
95 self._execute(cursor, query, parameters)
96 column_names = [d[0] for d in cursor.description]
98 yield Row(zip(column_names, row))
102 def query(self, query, *parameters):
103 """Returns a row list for the given query and parameters."""
104 cursor = self._cursor()
106 self._execute(cursor, query, parameters)
107 column_names = [d[0] for d in cursor.description]
108 return [Row(itertools.izip(column_names, row)) for row in cursor]
112 def get(self, query, *parameters):
113 """Returns the first row returned for the given query."""
114 rows = self.query(query, *parameters)
118 raise Exception("Multiple rows returned for Database.get() query")
122 def execute(self, query, *parameters):
123 """Executes the given query, returning the lastrowid from the query."""
124 cursor = self._cursor()
126 self._execute(cursor, query, parameters)
127 return cursor.lastrowid
131 def executemany(self, query, parameters):
132 """Executes the given query against all the given param sequences.
134 We return the lastrowid from the query.
136 cursor = self._cursor()
138 cursor.executemany(query, parameters)
139 return cursor.lastrowid
144 if self._db is None: self.reconnect()
145 return self._db.cursor()
147 def _execute(self, cursor, query, parameters):
149 return cursor.execute(query, parameters)
150 except OperationalError:
151 logging.error("Error connecting to MySQL on %s", self.host)
157 """A dict that allows for object-like property access syntax."""
158 def __getattr__(self, name):
162 raise AttributeError(name)
165 # Fix the access conversions to properly recognize unicode/binary
166 FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
167 FLAG = MySQLdb.constants.FLAG
168 CONVERSIONS = copy.deepcopy(MySQLdb.converters.conversions)
170 field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]
171 if 'VARCHAR' in vars(FIELD_TYPE):
172 field_types.append(FIELD_TYPE.VARCHAR)
174 for field_type in field_types:
175 CONVERSIONS[field_type].insert(0, (FLAG.BINARY, str))
178 # Alias some common MySQL exceptions
179 IntegrityError = MySQLdb.IntegrityError
180 OperationalError = MySQLdb.OperationalError