root/GratiaWeb/db.py

Revision 1268, 10.9 kB (checked in by brian, 2 years ago)

Various framework updates.

Line 
1
2 import cherrypy, re, types
3
4 class IncorrectSchemaError(Exception):
5     pass
6
7 class SqlTranslator(object):
8
9     def __init__(self):
10         self.default_sql = {}
11         self.dialect_sql = {}
12         self.schema = {}
13
14     def register_sql(self, query_name, sql_str, dialect = None):
15         if dialect == None:
16             if query_name in self.default_sql:
17                 raise ValueError("Query name `%s` already registered." % \
18                     query_name)
19             self.default_sql[query_name] = sql_str
20         else:
21             if not self.dialect_sql.has_key(dialect):
22                 self.dialect_sql[dialect] = {}
23             if query_name in self.dialect_sql[dialect]:
24                 raise ValueError("Query name `%s` already registered for " \
25                     "dialect `%s`." % (query_name, dialect))
26             self.dialect_sql[dialect][query_name] = sql_str
27
28     def register_schema(self, name, schema, version, dialect=None):
29         self.schema[dialect] = self.schema.get(dialect, {})
30         self.schema[dialect][name] = self.schema[dialect].get(name, {})
31         if isinstance(schema, types.StringType):
32             schema = [schema]
33         self.schema[dialect][name][version] = schema
34
35     def get_schema(self, name, version, dialect=None):
36         schema = self.schema.get(dialect, {}).get(name, {}).get(version, None)
37         if schema == None:
38             schema = self.schema.get(None, {}).get(name, {}).get(version, None)
39         if schema == None:
40             raise ValueError("Schema name %s, version %s, dialect %s not " \
41                              "found!" % (name, version, dialect))
42         return schema
43
44     def get_sql(self, query_name, dialect=None):
45         if dialect == None:
46             return self.default_sql[query_name]
47         elif not (dialect in self.dialect_sql):
48             return self.default_sql[query_name]
49         elif query_name in self.dialect_sql[dialect]:
50             return self.dialect_sql[dialect][query_name]
51         else:
52             return self.default_sql[query_name]
53
54 translator = SqlTranslator()
55
56 def get_config():
57     if cherrypy.request.config:
58         return cherrypy.request.config
59     else:
60         return cherrypy.config
61
62
63 def find_db():
64     conn_dict = get_config().get('db.connections', {})
65     conn_str = None
66     for key, val in conn_dict.items():
67         if re.search(key, cherrypy.request.path_info):
68             conn_str = val
69     if conn_str == None:
70         conn_str = get_config().get('db.connection', None)
71     if conn_str == None:
72         raise ValueError("DB Connection string not provided")
73     cherrypy.request.db = db.get_db(conn_str)
74
75
76 class DBHandler(object):
77
78     def __init__(self):
79         self.db_conns = {}
80
81     def __getitem__(self, attr_name):
82         db = self.find_db(attr_name)
83         return db
84
85     def __getattr__(self, attr_name):
86         db = self.find_db()
87         return getattr(db, attr_name)
88
89     def find_db(self, module=None):
90         if module != None:
91             prefix = 'db.' + module + '.'
92         else:
93             prefix = 'db.'
94         conn_dict = get_config().get(prefix + 'connections', {})
95         conn_str = None
96         for key, val in conn_dict.items():
97             if re.search(key, cherrypy.request.path_info):
98                 conn_str = val
99         if conn_str == None:
100             conn_str = get_config().get(prefix + 'connection', None)
101         if conn_str == None:
102             try:
103                 return self.find_db()
104             except:
105                 raise ValueError("DB Connection string not provided")
106         if conn_str in self.db_conns:
107             return self.db_conns[conn_str]
108         self.db_conns[conn_str] = DBConn(conn_str)
109         return self.db_conns[conn_str]
110  
111 db = DBHandler()
112
113 cherrypy.tools.find_db = cherrypy.Tool('before_handler', db.find_db)
114
115 class DBConn(object):
116
117     conn_re = re.compile('(\w*)://(.*)')
118
119     def __init__(self, conn_str):
120         if conn_str == None:
121             raise ValueError("`db.connection` config parameter not passed!")
122         m = self.conn_re.match(conn_str)
123         if not m:
124             raise ValueError("Unable to parse connection string `%s`" % \
125                 conn_str)
126         proto, conn_str = m.groups()
127         self.dialect = proto
128         self.conn_str = conn_str
129         if proto == 'mysql':
130             mysql_re = re.compile('(.*)/(.*)(\?.*)')
131             m = mysql_re.match(conn_str)
132             if not m:
133                 raise ValueError("String `%s` does not match format for MySQL"\
134                     % conn_str)
135             host, dbname, params = m.groups()
136             self.execute = self._mysql_execute
137             self.connect = self._mysql_connect
138             self.module = __import__('MySQLdb')
139         if proto == 'sqlite':
140             self.execute = self._sqlite_execute
141             self.connect = self._sqlite_connect
142             try:
143                 import pysqlite2.dbapi2 as sqlite
144             except ImportError:
145                 try:
146                     import sqlite3 as sqlite
147                 except ImportError:
148                     import sqlite
149             self.module = sqlite
150             self._get_conn = self._sqlite_get_conn
151         if proto == 'oracle':
152             raise NotImplementedError()
153         self.connect(conn_str)
154         self._schema_vers = {}
155         self.execute_sql = self.execute
156
157     def register_sql(self, query_name, sql_str, dialect=None):
158         translator.register_sql(query_name, sql_str, dialect)
159
160     def _mysql_connect(self, conn_str):
161         mysql_re = re.compile('(.*)/(.*)(\?.*)')
162         m = mysql_re.match(conn_str)
163         if not m:
164             raise ValueError("String `%s` does not match format for MySQL"\
165                              % conn_str)
166         host, dbname, params = m.groups()
167         conn_args = {}
168         if host.find(':') >= 0:
169             host, port = host.split(':')
170             conn_args['port'] = int(port)
171         conn_args['host'] = host
172         conn_args['db'] = dbname
173         if len(params) > 0:
174             params = params[1:]
175             kws = params.split('&')
176             for kw in kws:
177                 key, val = kw.split('=')
178                 conn_args[key] = val
179         if 'password' in conn_args:
180             conn_args['passwd'] = conn_args['password']
181             del conn_args['password']
182         self._conn = self.module.connect(**conn_args)
183
184     def _sqlite_connect(self, conn_str):
185         cherrypy.thread_data._conn = self.module.connect(conn_str)
186
187     def _sqlite_get_conn(self):
188         if hasattr(cherrypy.thread_data, '_conn'):
189             return cherrypy.thread_data._conn
190         self.connect(self.conn_str)
191         return cherrypy.thread_data._conn
192
193     def _get_conn(self):
194         return self._conn
195
196     def _get_cursor(self):
197         return self._get_conn().cursor()
198
199     def execute(self, query_name, bind_vars):
200         raise NotImplementedError()
201
202     def _mysql_execute( self, query_name, sql_vars ):
203         my_string = translator.get_sql(query_name, self.dialect)
204         sql_vars = dict( sql_vars )
205         placement_dict = {}
206         for var_name in sql_vars.keys():
207             var_string = ':' + var_name
208             placement = my_string.find( var_string )
209             var_string_len = len(var_string)
210             while placement >= 0:
211                 placement_dict[placement] = var_name
212                 my_string = my_string[:placement] + '%s' + \
213                     my_string[placement+var_string_len:]
214                 placement = my_string.find( var_string )
215         places = placement_dict.keys(); places.sort()
216         my_tuple = ()
217         for place in places:
218             my_tuple += (sql_vars[placement_dict[place]],)
219         curs = self._get_cursor()
220         curs.arraysize = 500
221         curs.execute( my_string, my_tuple )
222         return curs
223
224     def _sqlite_execute( self, query_name, sql_vars ):
225         my_string = translator.get_sql(query_name, self.dialect)
226         sql_vars = dict( sql_vars )
227         placement_dict = {}
228         for var_name in sql_vars.keys():
229             var_string = ':' + var_name
230             placement = my_string.find( var_string )
231             var_string_len = len(var_string)
232             while placement >= 0:
233                 placement_dict[placement] = var_name
234                 my_string = my_string[:placement] + '?' + \
235                     my_string[placement+var_string_len:]
236                 placement = my_string.find( var_string )
237         places = placement_dict.keys(); places.sort()
238         my_tuple = ()
239         for place in places:
240             my_tuple += (sql_vars[placement_dict[place]],)
241         curs = self._get_cursor()
242         curs.arraysize = 500
243         #print my_string, my_tuple
244         curs.execute( my_string, my_tuple )
245         return curs
246
247     def verify_schema(self, name, version, create=False):
248         if name in self._schema_vers and self._schema_vers[name] == version:
249             return True
250         try:
251             curs = db.execute("get_schema", {'name': name})
252         except Exception, e:
253             if str(e).find('no such table') >= 0:
254                 self.create_schema('schema_verify', '0.1', dialect=self.dialect)
255                 curs = db.execute("get_schema", {'name': name})
256             else:
257                 raise
258         rows = curs.fetchall()
259         has_schema = False
260         has_vers = ''
261         for row in rows:
262             if row[0] == version:
263                 has_schema = True
264                 break
265         if not has_schema:
266             if not create:
267                 if has_vers != '':
268                     raise IncorrectSchemaError("Schema for %s is currently %s;"\
269                         " required version is %s." % (name, has_vers, version))
270                 else:
271                     raise IncorrectSchemaError("No schema for module %s found!"\
272                         % name)
273             else:
274                 self.create_schema(name, version, dialect=self.dialect)
275         self._schema_vers[name] = version
276         return True
277
278     def create_schema(self, name, version, dialect=None):
279         schema = self.get_schema(name, version, dialect)
280         curs = self._get_cursor()
281         for statement in schema:
282             print statement
283             curs.execute(statement)
284         self.execute_sql('update_schema', {'name': name, 'version': version})
285         self.commit()
286
287     def register_schema(self, name, schema, version, dialect=None):
288         translator.register_schema(name, schema, version, dialect)
289
290     def get_schema(self, name, version, dialect=None):
291         return translator.get_schema(name, version, dialect)
292
293     def commit(self):
294         self._get_conn().commit()
295
296 verify_schema = ["""
297 CREATE TABLE Schema (
298   NAME varchar(255),
299   VERSION varchar(255)
300 )
301 """, """
302 CREATE UNIQUE INDEX Schema_unique on Schema (NAME, VERSION)
303 """]
304 translator.register_schema('schema_verify', verify_schema, '0.1', 'sqlite')
305
306 update_schema_sql = """
307 INSERT OR REPLACE INTO Schema(NAME, VERSION) VALUES
308     (:name, :version)
309 """
310 translator.register_sql('update_schema', update_schema_sql, 'sqlite')
311
312 get_schema_sql = """
313 SELECT VERSION from Schema where NAME=:name
314 """
315 translator.register_sql('get_schema', get_schema_sql, 'sqlite')
316
Note: See TracBrowser for help on using the browser.