diff --git a/kdoorweb/kdoorweb/db.py b/kdoorweb/kdoorweb/db.py index 8a614e1..aac736d 100644 --- a/kdoorweb/kdoorweb/db.py +++ b/kdoorweb/kdoorweb/db.py @@ -51,7 +51,10 @@ class DB: @staticmethod def create_db(dbfile): - db = sqlite3.connect(dbfile) + if type(dbfile) is DB: + db = dbfile.db + else: + db = sqlite3.connect(dbfile) db.executescript(""" create table versions ( version integer, diff --git a/kdoorweb/tests/test_db.py b/kdoorweb/tests/test_db.py index cf213db..cbfb067 100644 --- a/kdoorweb/tests/test_db.py +++ b/kdoorweb/tests/test_db.py @@ -4,9 +4,15 @@ from kdoorweb.db import DB class TestDB(unittest.TestCase): + def setUp(self) -> None: + self.db = DB() + def test_create_db_in_memory(self): DB.create_db(dbfile=":memory:") + def test_create_db_in_connection(self): + DB.create_db(dbfile=self.db) + if __name__ == '__main__': unittest.main()