mirror of
https://github.com/laurivosandi/certidude
synced 2024-11-04 20:38:12 +00:00
104 lines
3.0 KiB
Python
104 lines
3.0 KiB
Python
|
|
|
|
import click
|
|
import re
|
|
import os
|
|
from urlparse import urlparse
|
|
|
|
SCRIPTS = {}
|
|
|
|
class RelationalMixin(object):
|
|
"""
|
|
Thin wrapper around SQLite and MySQL database connectors
|
|
"""
|
|
|
|
SQL_CREATE_TABLES = ""
|
|
|
|
def __init__(self, uri):
|
|
self.uri = urlparse(uri)
|
|
if self.SQL_CREATE_TABLES and self.SQL_CREATE_TABLES not in SCRIPTS:
|
|
conn = self.sql_connect()
|
|
cur = conn.cursor()
|
|
with open(self.sql_resolve_script(self.SQL_CREATE_TABLES)) as fh:
|
|
click.echo("Executing: %s" % fh.name)
|
|
if self.uri.scheme == "sqlite":
|
|
cur.executescript(fh.read())
|
|
else:
|
|
cur.execute(fh.read(), multi=True)
|
|
conn.commit()
|
|
cur.close()
|
|
conn.close()
|
|
|
|
|
|
def sql_connect(self):
|
|
if self.uri.scheme == "mysql":
|
|
import mysql.connector
|
|
return mysql.connector.connect(
|
|
user=self.uri.username,
|
|
password=self.uri.password,
|
|
host=self.uri.hostname,
|
|
database=self.uri.path[1:])
|
|
elif self.uri.scheme == "sqlite":
|
|
if self.uri.netloc:
|
|
raise ValueError("Malformed database URI %s" % self.uri)
|
|
import sqlite3
|
|
return sqlite3.connect(self.uri.path)
|
|
else:
|
|
raise NotImplementedError("Unsupported database scheme %s, currently only mysql://user:pass@host/database or sqlite:///path/to/database.sqlite is supported" % o.scheme)
|
|
|
|
|
|
def sql_resolve_script(self, filename):
|
|
return os.path.realpath(os.path.join(os.path.dirname(__file__),
|
|
"sql", self.uri.scheme, filename))
|
|
|
|
|
|
def sql_load(self, filename):
|
|
if filename in SCRIPTS:
|
|
return SCRIPTS[filename]
|
|
|
|
fh = open(self.sql_resolve_script(filename))
|
|
click.echo("Caching SQL script: %s" % fh.name)
|
|
buf = re.sub("\s*\n\s*", " ", fh.read())
|
|
SCRIPTS[filename] = buf
|
|
fh.close()
|
|
return buf
|
|
|
|
|
|
def sql_execute(self, script, *args):
|
|
conn = self.sql_connect()
|
|
cursor = conn.cursor()
|
|
click.echo("Executing %s with %s" % (script, args))
|
|
cursor.execute(self.sql_load(script), args)
|
|
rowid = cursor.lastrowid
|
|
conn.commit()
|
|
cursor.close()
|
|
conn.close()
|
|
return rowid
|
|
|
|
|
|
def iterfetch(self, query, *args):
|
|
conn = self.sql_connect()
|
|
cursor = conn.cursor()
|
|
cursor.execute(query, args)
|
|
cols = [j[0] for j in cursor.description]
|
|
def g():
|
|
for row in cursor:
|
|
yield dict(zip(cols, row))
|
|
cursor.close()
|
|
conn.close()
|
|
return tuple(g())
|
|
|
|
|
|
def sql_fetchone(self, query, *args):
|
|
conn = self.sql_connect()
|
|
cursor = conn.cursor()
|
|
cursor.execute(query, args)
|
|
cols = [j[0] for j in cursor.description]
|
|
|
|
for row in cursor:
|
|
r = dict(zip(cols, row))
|
|
cursor.close()
|
|
conn.close()
|
|
return r
|
|
return None
|