Skip to content

Commit

Permalink
postgres-specific code removed (#123)
Browse files Browse the repository at this point in the history
* postgres-specific code removed

* add unit tests

* remove debug print

* qa

* fix test

* better handling of wdir

* qa

---------

Co-authored-by: Mattia Almansi <[email protected]>
  • Loading branch information
alex75 and malmans2 authored Jul 31, 2024
1 parent aaa38c2 commit b84ca42
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 36 deletions.
83 changes: 48 additions & 35 deletions cacholote/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _decode_kwargs(**kwargs: Any) -> dict[str, Any]:
def _cached_sessionmaker(
url: str, **kwargs: Any
) -> sa.orm.sessionmaker[sa.orm.Session]:
engine = sa.create_engine(url, **_decode_kwargs(**kwargs))
engine = init_database(url, **_decode_kwargs(**kwargs))
Base.metadata.create_all(engine)
return sa.orm.sessionmaker(engine)

Expand All @@ -120,44 +120,57 @@ def cached_sessionmaker(url: str, **kwargs: Any) -> sa.orm.sessionmaker[sa.orm.S
return _cached_sessionmaker(url, **_encode_kwargs(**kwargs))


def init_database(connection_string: str, force: bool = False) -> sa.engine.Engine:
def init_database(
connection_string: str, force: bool = False, **kwargs: Any
) -> sa.engine.Engine:
"""
Make sure the db located at URI `connection_string` exists updated and return the engine object.
:param connection_string: something like 'postgresql://user:password@netloc:port/dbname'
:param force: if True, drop the database structure and build again from scratch
Parameters
----------
connection_string: str
Something like 'postgresql://user:password@netloc:port/dbname'
force: bool
if True, drop the database structure and build again from scratch
kwargs: Any
Keyword arguments for create_engine
Returns
-------
engine: Engine
"""
engine = sa.create_engine(connection_string)
engine = sa.create_engine(connection_string, **kwargs)
migration_directory = os.path.abspath(os.path.join(__file__, ".."))
os.chdir(migration_directory)
alembic_config_path = os.path.join(migration_directory, "alembic.ini")
alembic_cfg = alembic.config.Config(alembic_config_path)
for option in ["drivername", "username", "password", "host", "port", "database"]:
value = getattr(engine.url, option)
if value is None:
value = ""
alembic_cfg.set_main_option(option, str(value))
if not sqlalchemy_utils.database_exists(engine.url):
sqlalchemy_utils.create_database(engine.url)
# cleanup and create the schema
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
alembic.command.stamp(alembic_cfg, "head")
else:
# check the structure is empty or incomplete
query = sa.text(
"SELECT table_name FROM information_schema.tables WHERE table_schema='public'"
)
conn = engine.connect()
if "cache_entries" not in conn.execute(query).scalars().all():
with utils.change_working_dir(migration_directory):
alembic_config_path = os.path.join(migration_directory, "alembic.ini")
alembic_cfg = alembic.config.Config(alembic_config_path)
for option in [
"drivername",
"username",
"password",
"host",
"port",
"database",
]:
value = getattr(engine.url, option)
if value is None:
value = ""
alembic_cfg.set_main_option(option, str(value))
if not sqlalchemy_utils.database_exists(engine.url):
sqlalchemy_utils.create_database(engine.url)
# cleanup and create the schema
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
alembic.command.stamp(alembic_cfg, "head")
elif "cache_entries" not in sa.inspect(engine).get_table_names():
# db structure is empty or incomplete
force = True
conn.close()
if force:
# cleanup and create the schema
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
alembic.command.stamp(alembic_cfg, "head")
else:
# update db structure
alembic.command.upgrade(alembic_cfg, "head")
if force:
# cleanup and create the schema
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
alembic.command.stamp(alembic_cfg, "head")
else:
# update db structure
alembic.command.upgrade(alembic_cfg, "head")
return engine
14 changes: 13 additions & 1 deletion cacholote/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
# limitations under the License.import hashlib
from __future__ import annotations

import contextlib
import dataclasses
import datetime
import functools
import hashlib
import io
import os
import time
import warnings
from types import TracebackType
from typing import Any
from typing import Any, Iterator

import fsspec

Expand Down Expand Up @@ -129,3 +131,13 @@ def __exit__(
def utcnow() -> datetime.datetime:
"""See https://discuss.python.org/t/deprecating-utcnow-and-utcfromtimestamp/26221."""
return datetime.datetime.now(tz=datetime.timezone.utc)


@contextlib.contextmanager
def change_working_dir(working_dir: str) -> Iterator[str]:
old_dir = os.getcwd()
os.chdir(working_dir)
try:
yield os.getcwd()
finally:
os.chdir(old_dir)
8 changes: 8 additions & 0 deletions tests/test_02_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
import pathlib

import fsspec
Expand Down Expand Up @@ -29,3 +30,10 @@ def test_copy_buffered_file(tmp_path: pathlib.Path) -> None:
with open(src, "rb") as f_src, open(dst, "wb") as f_dst:
utils.copy_buffered_file(f_src, f_dst)
assert open(src, "rb").read() == open(dst, "rb").read() == b"test"


def test_change_working_dir(tmp_path: pathlib.Path) -> None:
old_cwd = os.getcwd()
with utils.change_working_dir(str(tmp_path)) as actual:
assert actual == os.getcwd() == str(tmp_path.resolve())
assert os.getcwd() == old_cwd
22 changes: 22 additions & 0 deletions tests/test_70_alembic.py

Large diffs are not rendered by default.

0 comments on commit b84ca42

Please sign in to comment.