Skip to content

Commit

Permalink
allow getting disk usage from database
Browse files Browse the repository at this point in the history
  • Loading branch information
malmans2 committed Sep 4, 2024
1 parent 43d9e61 commit 8359c93
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 103 deletions.
2 changes: 1 addition & 1 deletion cacholote/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return _decode_and_update(session, cache_entry, settings)
except decode.DecodeError as ex:
warnings.warn(str(ex), UserWarning)
clean._delete_cache_entry(session, cache_entry)
clean._delete_cache_entries(session, cache_entry)

result = func(*args, **kwargs)
cache_entry = database.CacheEntry(
Expand Down
191 changes: 102 additions & 89 deletions cacholote/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import posixpath
from typing import Any, Callable, Literal, Optional

import fsspec
import pydantic
import sqlalchemy as sa
import sqlalchemy.orm
Expand All @@ -35,7 +36,9 @@
)


def _get_files_from_cache_entry(cache_entry: database.CacheEntry) -> dict[str, str]:
def _get_files_from_cache_entry(
cache_entry: database.CacheEntry, key: str | None
) -> dict[str, Any]:
result = cache_entry.result
if not isinstance(result, (list, tuple, set)):
result = [result]
Expand All @@ -48,27 +51,57 @@ def _get_files_from_cache_entry(cache_entry: database.CacheEntry) -> dict[str, s
and obj["callable"] in FILE_RESULT_CALLABLES
):
fs, urlpath = extra_encoders._get_fs_and_urlpath(*obj["args"][:2])
files[fs.unstrip_protocol(urlpath)] = obj["args"][0]["type"]
value = obj["args"][0]
if key is not None:
value = value[key]
files[fs.unstrip_protocol(urlpath)] = value
return files


def _delete_cache_entry(
session: sa.orm.Session, cache_entry: database.CacheEntry
def _remove_files(
fs: fsspec.AbstractFileSystem,
files: list[str],
max_tries: int = 10,
**kwargs: Any,
) -> None:
fs, _ = utils.get_cache_files_fs_dirname()
files_to_delete = _get_files_from_cache_entry(cache_entry)
logger = config.get().logger
assert max_tries >= 1
if not files:
return

config.get().logger.info("deleting files", n_files_to_delete=len(files), **kwargs)

n_tries = 0
while files:
n_tries += 1
try:
fs.rm(files, **kwargs)
return
except FileNotFoundError:
# Another concurrent process might have deleted files
if n_tries >= max_tries:
raise
files = [file for file in files if fs.exists(file)]

# First, delete database entry
logger.info("deleting cache entry", cache_entry=cache_entry)
session.delete(cache_entry)

def _delete_cache_entries(
session: sa.orm.Session, *cache_entries: database.CacheEntry
) -> None:
fs, _ = utils.get_cache_files_fs_dirname()
files_to_delete = []
dirs_to_delete = []
for cache_entry in cache_entries:
session.delete(cache_entry)

files = _get_files_from_cache_entry(cache_entry, key="type")
for file, file_type in files.items():
if file_type == "application/vnd+zarr":
dirs_to_delete.append(file)
else:
files_to_delete.append(file)
database._commit_or_rollback(session)

# Then, delete files
for urlpath, file_type in files_to_delete.items():
if fs.exists(urlpath):
logger.info("deleting cache file", urlpath=urlpath)
fs.rm(urlpath, recursive=file_type == "application/vnd+zarr")
_remove_files(fs, files_to_delete, recursive=False)
_remove_files(fs, dirs_to_delete, recursive=True)


def delete(func_to_del: str | Callable[..., Any], *args: Any, **kwargs: Any) -> None:
Expand All @@ -88,25 +121,25 @@ def delete(func_to_del: str | Callable[..., Any], *args: Any, **kwargs: Any) ->
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(database.CacheEntry.key == hexdigest)
):
_delete_cache_entry(session, cache_entry)
_delete_cache_entries(session, cache_entry)


class _Cleaner:
def __init__(self) -> None:
def __init__(self, depth: int, use_database: bool) -> None:
self.logger = config.get().logger
self.fs, self.dirname = utils.get_cache_files_fs_dirname()

urldir = self.fs.unstrip_protocol(self.dirname)
self.urldir = self.fs.unstrip_protocol(self.dirname)

self.logger.info("getting disk usage")
self.file_sizes: dict[str, int] = collections.defaultdict(int)
for path, size in self.fs.du(self.dirname, total=False).items():
du = self.known_files if use_database else self.fs.du(self.dirname, total=False)
for path, size in du.items():
# Group dirs
urlpath = self.fs.unstrip_protocol(path)
basename, *_ = urlpath.replace(urldir, "", 1).strip("/").split("/")
if basename:
self.file_sizes[posixpath.join(urldir, basename)] += size

parts = urlpath.replace(self.urldir, "", 1).strip("/").split("/")
if parts:
self.file_sizes[posixpath.join(self.urldir, *parts[:depth])] += size
self.disk_usage = sum(self.file_sizes.values())
self.log_disk_usage()

Expand All @@ -121,6 +154,16 @@ def log_disk_usage(self) -> None:
def stop_cleaning(self, maxsize: int) -> bool:
return self.disk_usage <= maxsize

@property
def known_files(self) -> dict[str, int]:
known_files: dict[str, int] = {}
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(sa.select(database.CacheEntry)):
known_files.update(
_get_files_from_cache_entry(cache_entry, key="file:size")
)
return known_files

def get_unknown_files(self, lock_validity_period: float | None) -> set[str]:
self.logger.info("getting unknown files")

Expand All @@ -138,25 +181,15 @@ def get_unknown_files(self, lock_validity_period: float | None) -> set[str]:
locked_files.add(urlpath)
locked_files.add(urlpath.rsplit(".lock", 1)[0])

if unknown_files := (set(self.file_sizes) - locked_files):
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(sa.select(database.CacheEntry)):
for known_file in _get_files_from_cache_entry(cache_entry):
unknown_files.discard(known_file)
if not unknown_files:
break
return unknown_files
return set(self.file_sizes) - locked_files - set(self.known_files)

def delete_unknown_files(
self, lock_validity_period: float | None, recursive: bool
) -> None:
unknown_files = self.get_unknown_files(lock_validity_period)
for urlpath in unknown_files:
self.pop_file_size(urlpath)
self.remove_files(
list(unknown_files),
recursive=recursive,
)
_remove_files(self.fs, list(unknown_files), recursive=recursive)
self.log_disk_usage()

@staticmethod
Expand Down Expand Up @@ -208,30 +241,6 @@ def _get_method_sorters(
sorters.append(database.CacheEntry.expiration)
return sorters

def remove_files(
self,
files: list[str],
max_tries: int = 10,
**kwargs: Any,
) -> None:
assert max_tries >= 1
if not files:
return

self.logger.info("deleting files", n_files_to_delete=len(files), **kwargs)

n_tries = 0
while files:
n_tries += 1
try:
self.fs.rm(files, **kwargs)
return
except FileNotFoundError:
# Another concurrent process might have deleted files
if n_tries >= max_tries:
raise
files = [file for file in files if self.fs.exists(file)]

def delete_cache_files(
self,
maxsize: int,
Expand All @@ -245,37 +254,27 @@ def delete_cache_files(
if self.stop_cleaning(maxsize):
return

files_to_delete = []
dirs_to_delete = []
entries_to_delete = []
self.logger.info("getting cache entries to delete")
n_entries_to_delete = 0
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(*filters).order_by(*sorters)
):
files = _get_files_from_cache_entry(cache_entry)
if files:
n_entries_to_delete += 1
session.delete(cache_entry)

for file, file_type in files.items():
self.pop_file_size(file)
if file_type == "application/vnd+zarr":
dirs_to_delete.append(file)
else:
files_to_delete.append(file)
files = _get_files_from_cache_entry(cache_entry, key="file:size")
if any(file.startswith(self.urldir) for file in files):
entries_to_delete.append(cache_entry)
for file in files:
self.pop_file_size(file)

if self.stop_cleaning(maxsize):
break

if n_entries_to_delete:
if entries_to_delete:
self.logger.info(
"deleting cache entries", n_entries_to_delete=n_entries_to_delete
"deleting cache entries", n_entries_to_delete=len(entries_to_delete)
)
database._commit_or_rollback(session)
_delete_cache_entries(session, *entries_to_delete)

self.remove_files(files_to_delete, recursive=False)
self.remove_files(dirs_to_delete, recursive=True)
self.log_disk_usage()

if not self.stop_cleaning(maxsize):
Expand All @@ -296,6 +295,8 @@ def clean_cache_files(
lock_validity_period: float | None = None,
tags_to_clean: list[str | None] | None = None,
tags_to_keep: list[str | None] | None = None,
depth: int = 1,
use_database: bool = False,
) -> None:
"""Clean cache files.
Expand All @@ -316,8 +317,17 @@ def clean_cache_files(
Tags to clean/keep. If None, delete all cache entries.
To delete/keep untagged entries, add None in the list (e.g., [None, 'tag1', ...]).
tags_to_clean and tags_to_keep are mutually exclusive.
depth: int, default: 1
depth for grouping cache files
use_database: bool, default: False
Whether to infer disk usage from the cacholote database
"""
cleaner = _Cleaner()
if use_database and delete_unknown_files:
raise ValueError(
"'use_database' and 'delete_unknown_files' are mutually exclusive"
)

cleaner = _Cleaner(depth=depth, use_database=use_database)

if delete_unknown_files:
cleaner.delete_unknown_files(lock_validity_period, recursive)
Expand Down Expand Up @@ -350,21 +360,22 @@ def clean_invalid_cache_entries(
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(*filters)
):
_delete_cache_entry(session, cache_entry)
_delete_cache_entries(session, cache_entry)

if try_decode:
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(sa.select(database.CacheEntry)):
try:
decode.loads(cache_entry._result_as_string)
except decode.DecodeError:
_delete_cache_entry(session, cache_entry)
_delete_cache_entries(session, cache_entry)


def expire_cache_entries(
tags: list[str] | None = None,
before: datetime.datetime | None = None,
after: datetime.date | None = None,
delete: bool = False,
) -> int:
now = utils.utcnow()

Expand All @@ -376,12 +387,14 @@ def expire_cache_entries(
if after is not None:
filters.append(database.CacheEntry.created_at > after)

count = 0
with config.get().instantiated_sessionmaker() as session:
for cache_entry in session.scalars(
sa.select(database.CacheEntry).filter(*filters)
):
count += 1
cache_entry.expiration = now
database._commit_or_rollback(session)
return count
cache_entries = list(
session.scalars(sa.select(database.CacheEntry).filter(*filters))
)
if delete:
_delete_cache_entries(session, *cache_entries)
else:
for cache_entry in cache_entries:
cache_entry.expiration = now
database._commit_or_rollback(session)
return len(cache_entries)
Loading

0 comments on commit 8359c93

Please sign in to comment.