Skip to content

Commit

Permalink
Merge pull request #916 from pranayasinghcsmpl/hf_cli4
Browse files Browse the repository at this point in the history
Add Huggingface Integration
  • Loading branch information
sarthakpati authored Oct 16, 2024
2 parents bcffff8 + c245a6d commit 301c188
Show file tree
Hide file tree
Showing 10 changed files with 1,051 additions and 12 deletions.
142 changes: 142 additions & 0 deletions GANDLF/cli/huggingface_hub_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from huggingface_hub import HfApi, snapshot_download, ModelCardData, ModelCard
from typing import List, Union
from GANDLF import version
from pathlib import Path
from GANDLF.utils import get_git_hash
import re


def validate_model_card(file_path: str):
"""
Validate that the required fields in the model card are not null, empty, or set to 'REQUIRED_FOR_GANDLF'.
The fields must contain valid alphabetic or alphanumeric values.
Args:
file_path (str): The path to the Markdown file to validate.
Raises:
AssertionError: If any required field is missing, empty, null, or contains 'REQUIRED_FOR_GANDLF'.
"""
# Read the Markdown file
path = Path(file_path)
with path.open("r") as file:
template_str = file.read()

# Define required fields and their regex patterns to capture the values
patterns = {
"Developed by": re.compile(
r'\*\*Developed by:\*\*\s*\{\{\s*developers\s*\|\s*default\("(.+?)",\s*true\)\s*\}\}',
re.MULTILINE,
),
"License": re.compile(
r'\*\*License:\*\*\s*\{\{\s*license\s*\|\s*default\("(.+?)",\s*true\)\s*\}\}',
re.MULTILINE,
),
"Primary Organization": re.compile(
r'\*\*Primary Organization:\*\*\s*\{\{\s*primary_organization\s*\|\s*default\("(.+?)",\s*true\)\s*\}\}',
re.MULTILINE,
),
"Commercial use policy": re.compile(
r'\*\*Commercial use policy:\*\*\s*\{\{\s*commercial_use\s*\|\s*default\("(.+?)",\s*true\)\s*\}\}',
re.MULTILINE,
),
}

# Iterate through the required fields and validate
for field, pattern in patterns.items():
match = pattern.search(template_str)

# Ensure the field is present and does not contain 'REQUIRED_FOR_GANDLF'
assert match, f"Field '{field}' is missing or not found in the file."

extract_value = match.group(1)

# Get the field value
value = (
re.search(r"\[([^\]]+)\]", extract_value).group(1)
if re.search(r"\[([^\]]+)\]", extract_value)
else None
)

# Ensure the field is not set to 'REQUIRED_FOR_GANDLF' or empty
assert (
value != "REQUIRED_FOR_GANDLF"
), f"The value for '{field}' is set to the default placeholder '[REQUIRED_FOR_GANDLF]'. It must be a valid value."
assert value, f"The value for '{field}' is empty or null."

# Ensure the value contains only alphabetic or alphanumeric characters
assert re.match(
r"^[a-zA-Z0-9]+$", value
), f"The value for '{field}' must be alphabetic or alphanumeric, but got: '{value}'"

print(
"All required fields are valid, non-empty, properly filled, and do not contain '[REQUIRED_FOR_GANDLF]'."
)

# Example usage
return template_str


def push_to_model_hub(
repo_id: str,
folder_path: str,
hf_template: str,
path_in_repo: Union[str, None] = None,
commit_message: Union[str, None] = None,
commit_description: Union[str, None] = None,
token: Union[str, None] = None,
repo_type: Union[str, None] = None,
revision: Union[str, None] = None,
allow_patterns: Union[List[str], str, None] = None,
ignore_patterns: Union[List[str], str, None] = None,
delete_patterns: Union[List[str], str, None] = None,
):
api = HfApi(token=token)

try:
repo_id = api.create_repo(repo_id).repo_id
except Exception as e:
print(f"Error: {e}")

tags = ["v" + version]

git_hash = get_git_hash()

if not git_hash == "None":
tags += [git_hash]

readme_template = validate_model_card(hf_template)

card_data = ModelCardData(library_name="GaNDLF", tags=tags)
card = ModelCard.from_template(card_data, template_str=readme_template)

card.save(Path(folder_path, "README.md"))

api.upload_folder(
repo_id=repo_id,
folder_path=folder_path,
repo_type="model",
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
delete_patterns=delete_patterns,
)
print("Model Sucessfully Uploded")


def download_from_hub(
repo_id: str,
revision: Union[str, None] = None,
cache_dir: Union[str, None] = None,
local_dir: Union[str, None] = None,
force_download: bool = False,
token: Union[str, None] = None,
):
snapshot_download(
repo_id=repo_id,
revision=revision,
cache_dir=cache_dir,
local_dir=local_dir,
force_download=force_download,
token=token,
)
156 changes: 156 additions & 0 deletions GANDLF/entrypoints/hf_hub_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import click
from GANDLF.entrypoints import append_copyright_to_help
from GANDLF.cli.huggingface_hub_handler import push_to_model_hub, download_from_hub
from pathlib import Path

huggingfaceDir_ = Path(__file__).parent.absolute()

huggingfaceDir = huggingfaceDir_.parent

# Huggingface template by default Path for the Model Deployment
huggingface_file_path = huggingfaceDir / "hugging_face.md"


@click.command()
@click.option(
"--upload/--download",
"-u/-d",
required=True,
help="Upload or download to/from a Huggingface Repo",
)
@click.option(
"--repo-id",
"-rid",
required=True,
help="Downloading/Uploading: A user or an organization name and a repo name separated by a /",
)
@click.option(
"--token",
"-tk",
help="Downloading/Uploading: A token to be used for the download/upload",
)
@click.option(
"--revision",
"-rv",
help="Downloading/Uploading: git revision id which can be a branch name, a tag, or a commit hash",
)
@click.option(
"--cache-dir",
"-cdir",
help="Downloading: path to the folder where cached files are stored",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
)
@click.option(
"--local-dir",
"-ldir",
help="Downloading: if provided, the downloaded file will be placed under this directory",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
)
@click.option(
"--force-download",
"-fd",
is_flag=True,
help="Downloading: Whether the file should be downloaded even if it already exists in the local cache",
)
@click.option(
"--folder-path",
"-fp",
help="Uploading: Path to the folder to upload on the local file system",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
)
@click.option(
"--path-in-repo",
"-pir",
help="Uploading: Relative path of the directory in the repo. Will default to the root folder of the repository",
)
@click.option(
"--commit-message",
"-cr",
help='Uploading: The summary / title / first line of the generated commit. Defaults to: f"Upload {path_in_repo} with huggingface_hub"',
)
@click.option(
"--commit-description",
"-cd",
help="Uploading: The description of the generated commit",
)
@click.option(
"--repo-type",
"-rt",
help='Uploading: Set to "dataset" or "space" if uploading to a dataset or space, "model" if uploading to a model. Default is model',
)
@click.option(
"--allow-patterns",
"-ap",
help="Uploading: If provided, only files matching at least one pattern are uploaded.",
)
@click.option(
"--ignore-patterns",
"-ip",
help="Uploading: If provided, files matching any of the patterns are not uploaded.",
)
@click.option(
"--delete-patterns",
"-dp",
help="Uploading: If provided, remote files matching any of the patterns will be deleted from the repo while committing new files. This is useful if you don't know which files have already been uploaded.",
)
@click.option(
"--hf-template",
"-hft",
help="Adding the template path for the model card it is Required during Uploaing a model",
default=huggingface_file_path,
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@append_copyright_to_help
def new_way(
upload: bool,
repo_id: str,
token: str,
hf_template: str,
revision: str,
cache_dir: str,
local_dir: str,
force_download: bool,
folder_path: str,
path_in_repo: str,
commit_message: str,
commit_description: str,
repo_type: str,
allow_patterns: str,
ignore_patterns: str,
delete_patterns: str,
):
# """Manages model transfers to and from the Hugging Face Hub"""
# """Manages model transfers to and from the Hugging Face Hub"""

# # Ensure the hf_template is being passed and loaded correctly
# template_path = Path(hf_template)

# # Check if file exists and is readable
# if not template_path.exists():
# raise FileNotFoundError(f"Model card template file '{hf_template}' not found.")

# with template_path.open('r') as f:
# hf_template = f.read()

# # Debug print the content to ensure it's being read
# print(f"Template content: {type(hf_template)}...") # Print the first 100 chars as a preview

if upload:
push_to_model_hub(
repo_id,
folder_path,
hf_template,
path_in_repo,
commit_message,
commit_description,
token,
repo_type,
revision,
allow_patterns,
ignore_patterns,
delete_patterns,
)
else:
download_from_hub(
repo_id, revision, cache_dir, local_dir, force_download, token
)
2 changes: 2 additions & 0 deletions GANDLF/entrypoints/subcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from GANDLF.entrypoints.generate_metrics import new_way as generate_metrics_command
from GANDLF.entrypoints.debug_info import new_way as debug_info_command
from GANDLF.entrypoints.split_csv import new_way as split_csv_command
from GANDLF.entrypoints.hf_hub_integration import new_way as hf_command


cli_subcommands = {
Expand All @@ -29,4 +30,5 @@
"generate-metrics": generate_metrics_command,
"debug-info": debug_info_command,
"split-csv": split_csv_command,
"hf": hf_command,
}
Loading

0 comments on commit 301c188

Please sign in to comment.