-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #916 from pranayasinghcsmpl/hf_cli4
Add Huggingface Integration
- Loading branch information
Showing
10 changed files
with
1,051 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from huggingface_hub import HfApi, snapshot_download, ModelCardData, ModelCard | ||
from typing import List, Union | ||
from GANDLF import version | ||
from pathlib import Path | ||
from GANDLF.utils import get_git_hash | ||
import re | ||
|
||
|
||
def validate_model_card(file_path: str): | ||
""" | ||
Validate that the required fields in the model card are not null, empty, or set to 'REQUIRED_FOR_GANDLF'. | ||
The fields must contain valid alphabetic or alphanumeric values. | ||
Args: | ||
file_path (str): The path to the Markdown file to validate. | ||
Raises: | ||
AssertionError: If any required field is missing, empty, null, or contains 'REQUIRED_FOR_GANDLF'. | ||
""" | ||
# Read the Markdown file | ||
path = Path(file_path) | ||
with path.open("r") as file: | ||
template_str = file.read() | ||
|
||
# Define required fields and their regex patterns to capture the values | ||
patterns = { | ||
"Developed by": re.compile( | ||
r'\*\*Developed by:\*\*\s*\{\{\s*developers\s*\|\s*default\("(.+?)",\s*true\)\s*\}\}', | ||
re.MULTILINE, | ||
), | ||
"License": re.compile( | ||
r'\*\*License:\*\*\s*\{\{\s*license\s*\|\s*default\("(.+?)",\s*true\)\s*\}\}', | ||
re.MULTILINE, | ||
), | ||
"Primary Organization": re.compile( | ||
r'\*\*Primary Organization:\*\*\s*\{\{\s*primary_organization\s*\|\s*default\("(.+?)",\s*true\)\s*\}\}', | ||
re.MULTILINE, | ||
), | ||
"Commercial use policy": re.compile( | ||
r'\*\*Commercial use policy:\*\*\s*\{\{\s*commercial_use\s*\|\s*default\("(.+?)",\s*true\)\s*\}\}', | ||
re.MULTILINE, | ||
), | ||
} | ||
|
||
# Iterate through the required fields and validate | ||
for field, pattern in patterns.items(): | ||
match = pattern.search(template_str) | ||
|
||
# Ensure the field is present and does not contain 'REQUIRED_FOR_GANDLF' | ||
assert match, f"Field '{field}' is missing or not found in the file." | ||
|
||
extract_value = match.group(1) | ||
|
||
# Get the field value | ||
value = ( | ||
re.search(r"\[([^\]]+)\]", extract_value).group(1) | ||
if re.search(r"\[([^\]]+)\]", extract_value) | ||
else None | ||
) | ||
|
||
# Ensure the field is not set to 'REQUIRED_FOR_GANDLF' or empty | ||
assert ( | ||
value != "REQUIRED_FOR_GANDLF" | ||
), f"The value for '{field}' is set to the default placeholder '[REQUIRED_FOR_GANDLF]'. It must be a valid value." | ||
assert value, f"The value for '{field}' is empty or null." | ||
|
||
# Ensure the value contains only alphabetic or alphanumeric characters | ||
assert re.match( | ||
r"^[a-zA-Z0-9]+$", value | ||
), f"The value for '{field}' must be alphabetic or alphanumeric, but got: '{value}'" | ||
|
||
print( | ||
"All required fields are valid, non-empty, properly filled, and do not contain '[REQUIRED_FOR_GANDLF]'." | ||
) | ||
|
||
# Example usage | ||
return template_str | ||
|
||
|
||
def push_to_model_hub( | ||
repo_id: str, | ||
folder_path: str, | ||
hf_template: str, | ||
path_in_repo: Union[str, None] = None, | ||
commit_message: Union[str, None] = None, | ||
commit_description: Union[str, None] = None, | ||
token: Union[str, None] = None, | ||
repo_type: Union[str, None] = None, | ||
revision: Union[str, None] = None, | ||
allow_patterns: Union[List[str], str, None] = None, | ||
ignore_patterns: Union[List[str], str, None] = None, | ||
delete_patterns: Union[List[str], str, None] = None, | ||
): | ||
api = HfApi(token=token) | ||
|
||
try: | ||
repo_id = api.create_repo(repo_id).repo_id | ||
except Exception as e: | ||
print(f"Error: {e}") | ||
|
||
tags = ["v" + version] | ||
|
||
git_hash = get_git_hash() | ||
|
||
if not git_hash == "None": | ||
tags += [git_hash] | ||
|
||
readme_template = validate_model_card(hf_template) | ||
|
||
card_data = ModelCardData(library_name="GaNDLF", tags=tags) | ||
card = ModelCard.from_template(card_data, template_str=readme_template) | ||
|
||
card.save(Path(folder_path, "README.md")) | ||
|
||
api.upload_folder( | ||
repo_id=repo_id, | ||
folder_path=folder_path, | ||
repo_type="model", | ||
revision=revision, | ||
allow_patterns=allow_patterns, | ||
ignore_patterns=ignore_patterns, | ||
delete_patterns=delete_patterns, | ||
) | ||
print("Model Sucessfully Uploded") | ||
|
||
|
||
def download_from_hub( | ||
repo_id: str, | ||
revision: Union[str, None] = None, | ||
cache_dir: Union[str, None] = None, | ||
local_dir: Union[str, None] = None, | ||
force_download: bool = False, | ||
token: Union[str, None] = None, | ||
): | ||
snapshot_download( | ||
repo_id=repo_id, | ||
revision=revision, | ||
cache_dir=cache_dir, | ||
local_dir=local_dir, | ||
force_download=force_download, | ||
token=token, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import click | ||
from GANDLF.entrypoints import append_copyright_to_help | ||
from GANDLF.cli.huggingface_hub_handler import push_to_model_hub, download_from_hub | ||
from pathlib import Path | ||
|
||
huggingfaceDir_ = Path(__file__).parent.absolute() | ||
|
||
huggingfaceDir = huggingfaceDir_.parent | ||
|
||
# Huggingface template by default Path for the Model Deployment | ||
huggingface_file_path = huggingfaceDir / "hugging_face.md" | ||
|
||
|
||
@click.command() | ||
@click.option( | ||
"--upload/--download", | ||
"-u/-d", | ||
required=True, | ||
help="Upload or download to/from a Huggingface Repo", | ||
) | ||
@click.option( | ||
"--repo-id", | ||
"-rid", | ||
required=True, | ||
help="Downloading/Uploading: A user or an organization name and a repo name separated by a /", | ||
) | ||
@click.option( | ||
"--token", | ||
"-tk", | ||
help="Downloading/Uploading: A token to be used for the download/upload", | ||
) | ||
@click.option( | ||
"--revision", | ||
"-rv", | ||
help="Downloading/Uploading: git revision id which can be a branch name, a tag, or a commit hash", | ||
) | ||
@click.option( | ||
"--cache-dir", | ||
"-cdir", | ||
help="Downloading: path to the folder where cached files are stored", | ||
type=click.Path(exists=True, file_okay=False, dir_okay=True), | ||
) | ||
@click.option( | ||
"--local-dir", | ||
"-ldir", | ||
help="Downloading: if provided, the downloaded file will be placed under this directory", | ||
type=click.Path(exists=True, file_okay=False, dir_okay=True), | ||
) | ||
@click.option( | ||
"--force-download", | ||
"-fd", | ||
is_flag=True, | ||
help="Downloading: Whether the file should be downloaded even if it already exists in the local cache", | ||
) | ||
@click.option( | ||
"--folder-path", | ||
"-fp", | ||
help="Uploading: Path to the folder to upload on the local file system", | ||
type=click.Path(exists=True, file_okay=False, dir_okay=True), | ||
) | ||
@click.option( | ||
"--path-in-repo", | ||
"-pir", | ||
help="Uploading: Relative path of the directory in the repo. Will default to the root folder of the repository", | ||
) | ||
@click.option( | ||
"--commit-message", | ||
"-cr", | ||
help='Uploading: The summary / title / first line of the generated commit. Defaults to: f"Upload {path_in_repo} with huggingface_hub"', | ||
) | ||
@click.option( | ||
"--commit-description", | ||
"-cd", | ||
help="Uploading: The description of the generated commit", | ||
) | ||
@click.option( | ||
"--repo-type", | ||
"-rt", | ||
help='Uploading: Set to "dataset" or "space" if uploading to a dataset or space, "model" if uploading to a model. Default is model', | ||
) | ||
@click.option( | ||
"--allow-patterns", | ||
"-ap", | ||
help="Uploading: If provided, only files matching at least one pattern are uploaded.", | ||
) | ||
@click.option( | ||
"--ignore-patterns", | ||
"-ip", | ||
help="Uploading: If provided, files matching any of the patterns are not uploaded.", | ||
) | ||
@click.option( | ||
"--delete-patterns", | ||
"-dp", | ||
help="Uploading: If provided, remote files matching any of the patterns will be deleted from the repo while committing new files. This is useful if you don't know which files have already been uploaded.", | ||
) | ||
@click.option( | ||
"--hf-template", | ||
"-hft", | ||
help="Adding the template path for the model card it is Required during Uploaing a model", | ||
default=huggingface_file_path, | ||
type=click.Path(exists=True, file_okay=True, dir_okay=False), | ||
) | ||
@append_copyright_to_help | ||
def new_way( | ||
upload: bool, | ||
repo_id: str, | ||
token: str, | ||
hf_template: str, | ||
revision: str, | ||
cache_dir: str, | ||
local_dir: str, | ||
force_download: bool, | ||
folder_path: str, | ||
path_in_repo: str, | ||
commit_message: str, | ||
commit_description: str, | ||
repo_type: str, | ||
allow_patterns: str, | ||
ignore_patterns: str, | ||
delete_patterns: str, | ||
): | ||
# """Manages model transfers to and from the Hugging Face Hub""" | ||
# """Manages model transfers to and from the Hugging Face Hub""" | ||
|
||
# # Ensure the hf_template is being passed and loaded correctly | ||
# template_path = Path(hf_template) | ||
|
||
# # Check if file exists and is readable | ||
# if not template_path.exists(): | ||
# raise FileNotFoundError(f"Model card template file '{hf_template}' not found.") | ||
|
||
# with template_path.open('r') as f: | ||
# hf_template = f.read() | ||
|
||
# # Debug print the content to ensure it's being read | ||
# print(f"Template content: {type(hf_template)}...") # Print the first 100 chars as a preview | ||
|
||
if upload: | ||
push_to_model_hub( | ||
repo_id, | ||
folder_path, | ||
hf_template, | ||
path_in_repo, | ||
commit_message, | ||
commit_description, | ||
token, | ||
repo_type, | ||
revision, | ||
allow_patterns, | ||
ignore_patterns, | ||
delete_patterns, | ||
) | ||
else: | ||
download_from_hub( | ||
repo_id, revision, cache_dir, local_dir, force_download, token | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.