Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

postgres-specific code removed #123

Merged
merged 8 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 48 additions & 35 deletions cacholote/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _decode_kwargs(**kwargs: Any) -> dict[str, Any]:
def _cached_sessionmaker(
url: str, **kwargs: Any
) -> sa.orm.sessionmaker[sa.orm.Session]:
engine = sa.create_engine(url, **_decode_kwargs(**kwargs))
engine = init_database(url, **_decode_kwargs(**kwargs))
Base.metadata.create_all(engine)
return sa.orm.sessionmaker(engine)

Expand All @@ -120,44 +120,57 @@ def cached_sessionmaker(url: str, **kwargs: Any) -> sa.orm.sessionmaker[sa.orm.S
return _cached_sessionmaker(url, **_encode_kwargs(**kwargs))


def init_database(connection_string: str, force: bool = False) -> sa.engine.Engine:
def init_database(
connection_string: str, force: bool = False, **kwargs: Any
) -> sa.engine.Engine:
"""
Make sure the db located at URI `connection_string` exists updated and return the engine object.

:param connection_string: something like 'postgresql://user:password@netloc:port/dbname'
:param force: if True, drop the database structure and build again from scratch
Parameters
----------
connection_string: str
Something like 'postgresql://user:password@netloc:port/dbname'
force: bool
if True, drop the database structure and build again from scratch
kwargs: Any
Keyword arguments for create_engine

Returns
-------
engine: Engine
"""
engine = sa.create_engine(connection_string)
engine = sa.create_engine(connection_string, **kwargs)
migration_directory = os.path.abspath(os.path.join(__file__, ".."))
os.chdir(migration_directory)
alembic_config_path = os.path.join(migration_directory, "alembic.ini")
alembic_cfg = alembic.config.Config(alembic_config_path)
for option in ["drivername", "username", "password", "host", "port", "database"]:
value = getattr(engine.url, option)
if value is None:
value = ""
alembic_cfg.set_main_option(option, str(value))
if not sqlalchemy_utils.database_exists(engine.url):
sqlalchemy_utils.create_database(engine.url)
# cleanup and create the schema
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
alembic.command.stamp(alembic_cfg, "head")
else:
# check the structure is empty or incomplete
query = sa.text(
"SELECT table_name FROM information_schema.tables WHERE table_schema='public'"
)
conn = engine.connect()
if "cache_entries" not in conn.execute(query).scalars().all():
with utils.change_working_dir(migration_directory):
alembic_config_path = os.path.join(migration_directory, "alembic.ini")
alembic_cfg = alembic.config.Config(alembic_config_path)
for option in [
"drivername",
"username",
"password",
"host",
"port",
"database",
]:
value = getattr(engine.url, option)
if value is None:
value = ""
alembic_cfg.set_main_option(option, str(value))
if not sqlalchemy_utils.database_exists(engine.url):
sqlalchemy_utils.create_database(engine.url)
# cleanup and create the schema
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
alembic.command.stamp(alembic_cfg, "head")
elif "cache_entries" not in sa.inspect(engine).get_table_names():
# db structure is empty or incomplete
force = True
conn.close()
if force:
# cleanup and create the schema
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
alembic.command.stamp(alembic_cfg, "head")
else:
# update db structure
alembic.command.upgrade(alembic_cfg, "head")
if force:
# cleanup and create the schema
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
alembic.command.stamp(alembic_cfg, "head")
else:
# update db structure
alembic.command.upgrade(alembic_cfg, "head")
return engine
14 changes: 13 additions & 1 deletion cacholote/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
# limitations under the License.import hashlib
from __future__ import annotations

import contextlib
import dataclasses
import datetime
import functools
import hashlib
import io
import os
import time
import warnings
from types import TracebackType
from typing import Any
from typing import Any, Iterator

import fsspec

Expand Down Expand Up @@ -129,3 +131,13 @@ def __exit__(
def utcnow() -> datetime.datetime:
"""See https://discuss.python.org/t/deprecating-utcnow-and-utcfromtimestamp/26221."""
return datetime.datetime.now(tz=datetime.timezone.utc)


@contextlib.contextmanager
def change_working_dir(working_dir: str) -> Iterator[str]:
old_dir = os.getcwd()
os.chdir(working_dir)
try:
yield os.getcwd()
finally:
os.chdir(old_dir)
8 changes: 8 additions & 0 deletions tests/test_02_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
import pathlib

import fsspec
Expand Down Expand Up @@ -29,3 +30,10 @@ def test_copy_buffered_file(tmp_path: pathlib.Path) -> None:
with open(src, "rb") as f_src, open(dst, "wb") as f_dst:
utils.copy_buffered_file(f_src, f_dst)
assert open(src, "rb").read() == open(dst, "rb").read() == b"test"


def test_change_working_dir(tmp_path: pathlib.Path) -> None:
old_cwd = os.getcwd()
with utils.change_working_dir(str(tmp_path)) as actual:
assert actual == os.getcwd() == str(tmp_path.resolve())
assert os.getcwd() == old_cwd
22 changes: 22 additions & 0 deletions tests/test_70_alembic.py

Large diffs are not rendered by default.