Last active
January 22, 2026 01:11
-
-
Save smothiki/bd6474446b4490464cd4f6412780b90b to your computer and use it in GitHub Desktop.
airgap script windows version
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| import logging | |
| import yaml | |
| from typing import List, Dict, Optional, Tuple | |
| # courtesy : https://stackoverflow.com/questions/43765849/pyyaml-load-and-dump-yaml-file-and-preserve-tags-customtag | |
| class SafeUnknownConstructor(yaml.constructor.SafeConstructor): | |
| def __init__(self): | |
| yaml.constructor.SafeConstructor.__init__(self) | |
| def construct_undefined(self, node): | |
| data = getattr(self, 'construct_' + node.id)(node) | |
| datatype = type(data) | |
| wraptype = type('TagWrap_'+datatype.__name__, (datatype,), {}) | |
| wrapdata = wraptype(data) | |
| wrapdata.tag = lambda: None | |
| wrapdata.datatype = lambda: None | |
| setattr(wrapdata, "wrapTag", node.tag) | |
| setattr(wrapdata, "wrapType", datatype) | |
| return wrapdata | |
| class SafeUnknownLoader(SafeUnknownConstructor, yaml.loader.SafeLoader): | |
| def __init__(self, stream): | |
| SafeUnknownConstructor.__init__(self) | |
| yaml.loader.SafeLoader.__init__(self, stream) | |
| class SafeUnknownRepresenter(yaml.representer.SafeRepresenter): | |
| def represent_data(self, wrapdata): | |
| tag = False | |
| if type(wrapdata).__name__.startswith('TagWrap_'): | |
| datatype = getattr(wrapdata, "wrapType") | |
| tag = getattr(wrapdata, "wrapTag") | |
| data = datatype(wrapdata) | |
| else: | |
| data = wrapdata | |
| node = super(SafeUnknownRepresenter, self).represent_data(data) | |
| if tag: | |
| node.tag = tag | |
| return node | |
| class SafeUnknownDumper(SafeUnknownRepresenter, yaml.dumper.SafeDumper): | |
| def __init__(self, stream, | |
| default_style=None, default_flow_style=False, | |
| canonical=None, indent=None, width=None, | |
| allow_unicode=None, line_break=None, | |
| encoding=None, explicit_start=None, explicit_end=None, | |
| version=None, tags=None, sort_keys=True): | |
| SafeUnknownRepresenter.__init__(self, default_style=default_style, | |
| default_flow_style=default_flow_style, sort_keys=sort_keys) | |
| yaml.dumper.SafeDumper.__init__(self, stream, | |
| default_style=default_style, | |
| default_flow_style=default_flow_style, | |
| canonical=canonical, | |
| indent=indent, | |
| width=width, | |
| allow_unicode=allow_unicode, | |
| line_break=line_break, | |
| encoding=encoding, | |
| explicit_start=explicit_start, | |
| explicit_end=explicit_end, | |
| version=version, | |
| tags=tags, | |
| sort_keys=sort_keys) | |
| MySafeLoader = SafeUnknownLoader | |
| yaml.constructor.SafeConstructor.add_constructor(None, SafeUnknownConstructor.construct_undefined) | |
| class ModelParser: | |
| def __init__(self, yaml_file_path: str): | |
| """Initialize the parser with a YAML file path.""" | |
| self.yaml_file_path = yaml_file_path | |
| self.models_data = self._load_yaml() | |
| def _load_yaml(self) -> Dict: | |
| """Load and parse the YAML file.""" | |
| try: | |
| with open(self.yaml_file_path, "r") as file: | |
| yaml_data = file.read() | |
| return yaml.load(yaml_data,Loader=MySafeLoader) | |
| except FileNotFoundError: | |
| print(f"Error: File {self.yaml_file_path} not found.") | |
| return {} | |
| except yaml.YAMLError as e: | |
| print(f"Error parsing YAML: {e}") | |
| return {} | |
| def get_all_models(self) -> List[Dict[str, str]]: | |
| """Fetch all models with their basic information.""" | |
| models = [] | |
| if 'models' in self.models_data: | |
| for model in self.models_data['models']: | |
| model_info = { | |
| 'name': model.get('name', ''), | |
| 'displayName': model.get('displayName', ''), | |
| 'modelHubID': model.get('modelHubID', ''), | |
| 'category': model.get('category', ''), | |
| 'description': model.get('description', '') | |
| } | |
| models.append(model_info) | |
| return models | |
| def get_model_variant_ids(self, model_name: str) -> List[str]: | |
| """ | |
| Fetch all model variant IDs for a given model name. | |
| Args: | |
| model_name: The name of the model to search for | |
| Returns: | |
| List of variant IDs for the specified model | |
| """ | |
| variant_ids = [] | |
| if 'models' in self.models_data: | |
| for model in self.models_data['models']: | |
| if model.get('name', '').lower() == model_name.lower(): | |
| if 'modelVariants' in model: | |
| for variant in model['modelVariants']: | |
| variant_id = variant.get('variantId', '') | |
| if variant_id: | |
| variant_ids.append(variant_id) | |
| break | |
| return variant_ids | |
| def get_optimization_profile_ids(self, model_name: str, variant_id: str = None) -> List[str]: | |
| """ | |
| Fetch all optimization profile IDs for a given model name and optional variant ID. | |
| Args: | |
| model_name: The name of the model | |
| variant_id: Optional variant ID. If None, gets profiles from all variants | |
| Returns: | |
| List of optimization profile IDs | |
| """ | |
| profile_ids = [] | |
| if 'models' in self.models_data: | |
| for model in self.models_data['models']: | |
| if model.get('name', '').lower() == model_name.lower(): | |
| if 'modelVariants' in model: | |
| for variant in model['modelVariants']: | |
| # If variant_id is specified, only process that variant | |
| if variant_id and variant.get('variantId', '').lower() != variant_id.lower(): | |
| continue | |
| if 'optimizationProfiles' in variant: | |
| for profile in variant['optimizationProfiles']: | |
| profile_id = profile.get('profileId', '') | |
| if profile_id: | |
| profile_ids.append(profile_id) | |
| break | |
| return profile_ids | |
| def get_detailed_model_info(self, model_name: str) -> Optional[Dict]: | |
| """ | |
| Get detailed information about a specific model including all variants and profiles. | |
| Args: | |
| model_name: The name of the model | |
| Returns: | |
| Dictionary containing detailed model information | |
| """ | |
| if 'models' in self.models_data: | |
| for model in self.models_data['models']: | |
| if model.get('name', '').lower() == model_name.lower(): | |
| return model | |
| return None | |
| def extract_last_part(input_string): | |
| # Split the string by '/' and return the last part | |
| return input_string.split('/')[-1] | |
| def run_ngc_info_command(ID): | |
| logging.info(f"Start Fetching model info from NGC: {ID}") | |
| cmd = ["ngc", "registry", "model", "info", ID, "--format_type", "json"] | |
| # print the command without exposing Cloudera's key | |
| logging.info(cmd) | |
| # Set up environment with API keys | |
| # env = os.environ.copy() | |
| # env["NGC_CLI_API_KEY"] = os.environ.get("NGC_API_KEY") | |
| # env["NGC_CLI_ORG"] = os.environ.get("NGC_CLI_ORG") | |
| # Run the command and collect output | |
| try: | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=False, | |
| check=True | |
| ) | |
| logging.info(f"Finish Fetching NGC model repo {ID}") | |
| # Parse the JSON output | |
| try: | |
| metadata_map = json.loads(result.stdout) | |
| return metadata_map, None | |
| except json.JSONDecodeError as e: | |
| logging.error(f"Error while Unmarshalling the NGC info command to map: {str(e)}") | |
| return None, e | |
| except subprocess.CalledProcessError as e: | |
| logging.error(f"Error Fetching model repo: {ID}") | |
| logging.error(f"Error: {str(e)}") | |
| return None, e | |
| def get_ngc_model_info(model_id, tag): | |
| model_metadata_map, err = run_ngc_info_command(model_id) | |
| if err: | |
| logging.error(f"Error Fetching model repo: {model_id}") | |
| logging.error(f"Error: {str(err)}") | |
| return "" | |
| versionMetadataMap, err = run_ngc_info_command(model_id + ":" + tag) | |
| if err: | |
| logging.error(f"Error Fetching model repo: {model_id}") | |
| logging.error(f"Error: {str(err)}") | |
| return "" | |
| v1 = versionMetadataMap.get("totalSizeInBytes", "") | |
| v2 = model_metadata_map.get("versionId", "") | |
| if v2 != "": | |
| model_metadata_map["versionId"] = v2 | |
| return model_metadata_map | |
| def load_ngc_spec(spec_file): | |
| with open(spec_file, "r") as file: | |
| yaml_data = file.read() | |
| yaml_data = yaml.load(yaml_data,Loader=MySafeLoader) | |
| return yaml_data | |
| def execute_nim_download_command(repo_id, folder_location, ngc_spec, profile_sha, version): | |
| """ | |
| Execute nim cli download command to download model files. | |
| Args: | |
| repo_id (str): Repository ID | |
| """ | |
| model_name = repo_id.split(":")[0] | |
| count = model_name.count('/') | |
| if count != 2: | |
| raise ValueError(f"Expected 3 '/' characters, but found {count} in model name") | |
| # Get the absolute path of the ngc_spec folder | |
| ngc_spec_abs = os.path.dirname(ngc_spec) | |
| manifest_path = f"{ngc_spec_abs}/manifests/{version}/{model_name}.yaml" | |
| cmd = [ | |
| "nimcli", "download", "--profiles", profile_sha, "--manifest-file", | |
| manifest_path, "--model-cache-path", folder_location | |
| ] | |
| print(cmd) | |
| try: | |
| # output = subprocess.check_output(cmd, env=env, stderr=subprocess.STDOUT) | |
| subprocess.run(cmd, check=True) | |
| except subprocess.CalledProcessError as e: | |
| logging.error(f"Error download model repo: {repo_id}") | |
| logging.error(f"Error: {str(e)}") | |
| return folder_location, e | |
| def execute_ngc_download_command(repo_id, folder_location, files=None): | |
| """ | |
| Execute NGC download command to download model files. | |
| Args: | |
| repo_id (str): Repository ID | |
| folder_location (str): Destination folder | |
| files (list, optional): List of files to download. If None, download the entire repo. | |
| Returns: | |
| tuple: (folder_location, error) | |
| """ | |
| # Set environment variables | |
| # env = os.environ.copy() | |
| # env["NGC_CLI_API_KEY"] = os.environ.get("NGC_API_KEY") | |
| # env["NGC_CLI_ORG"] = os.environ.get("NGC_CLI_ORG") | |
| if files and len(files) > 0: | |
| # Download specific files | |
| for file in files: | |
| cmd = [ | |
| "ngc", "registry", "model", "download-version", | |
| repo_id, "--file", file, "--dest", folder_location, | |
| "--format_type", "json" | |
| ] | |
| logging.info(cmd) | |
| try: | |
| # output = subprocess.check_output(cmd,env=env, stderr=subprocess.STDOUT) | |
| subprocess.run(cmd, check=True) | |
| except subprocess.CalledProcessError as e: | |
| logging.error(f"Error download model repo: {repo_id}") | |
| logging.error(f"Error: {str(e)}") | |
| logging.error(f"Command output: {e.output.decode()}") | |
| return folder_location, e | |
| else: | |
| # Download entire repository | |
| cmd = [ | |
| "ngc", "registry", "model", "download-version", | |
| repo_id, "--dest", folder_location, "--format_type", "json" | |
| ] | |
| logging.info(cmd) | |
| try: | |
| # output = subprocess.check_output(cmd, env=env, stderr=subprocess.STDOUT) | |
| subprocess.run(cmd, check=True) | |
| except subprocess.CalledProcessError as e: | |
| logging.error(f"Error download model repo: {repo_id}") | |
| logging.error(f"Error: {str(e)}") | |
| logging.error(f"Command output: {e.output.decode()}") | |
| return folder_location, Exception(e.output.decode()) | |
| return folder_location, None | |
| def extract_profile_components(data, target_profile_id): | |
| """ | |
| Extract repoID and src files from components for a given profileID. | |
| Args: | |
| yaml_data (str): The YAML data as a string | |
| target_profile_id (str): The profileID to search for | |
| Returns: | |
| dict: A dictionary containing: | |
| - 'found' (bool): Whether the profile was found | |
| - 'components' (list): List of dictionaries with 'repo_id' and 'files' for each component | |
| """ | |
| result = { | |
| 'found': False, | |
| 'components': [], | |
| 'ngcMetadata': None, | |
| } | |
| # Iterate through models | |
| for model in data.get('models', []): | |
| # Iterate through variants | |
| for variant in model.get('modelVariants', []): | |
| # Iterate through profiles | |
| for profile in variant.get('optimizationProfiles', []): | |
| # Check if this is the target profileID | |
| if profile.get('profileId') == target_profile_id: | |
| result['found'] = True | |
| result['ngcMetadata'] = profile.get('ngcMetadata', None) | |
| # Look for ngcMetadata which contains the workspace components | |
| for sha_key, metadata in profile.get('ngcMetadata', {}).items(): | |
| if 'workspace' in metadata and 'components' in metadata['workspace']: | |
| for component in metadata['workspace']['components']: | |
| component_info = { | |
| 'destination': component.get('dst', ''), | |
| 'repo_id': component.get('src', {}).get('repo_id', '') | |
| } | |
| # Extract files if they exist | |
| files = component.get('src', {}).get('files', []) | |
| if files: | |
| # Handle both direct strings and dictionaries with name tags | |
| component_info['files'] = [] | |
| for f in files: | |
| if isinstance(f, str): | |
| component_info['files'].append(f) | |
| elif isinstance(f, dict) and f.get('!name'): | |
| component_info['files'].append(f.get('!name')) | |
| result['components'].append(component_info) | |
| return result # Return once the profile is found | |
| return result # Return not found if we get here | |
| def show_help(): | |
| """Display help information and exit.""" | |
| help_text = """ | |
| Description: Fetches information about a model from the Hugging Face Hub and optionally downloads it, | |
| or uploads model artifacts to cloud storage. | |
| Examples: | |
| python script.py gpt2 | |
| python script.py --token YOUR_HF_TOKEN --repo-id facebook/bart-large | |
| python script.py --download --path ~/my_models --repo-id bert-base-uncased | |
| python script.py --repo-type dataset --download --repo-id mnist | |
| python script.py --cloud aws --src /path/to/models/ --dst s3://bucket/path --recursive | |
| python script.py --cloud pvc --src /path/to/model/hf/meta/llama3.1 --dst s3://bucket/secured-models/hf/meta/llama3.1 | |
| python script.py --cloud azure --src /path/to/model/hf/meta/llama3.1 --account cloudera-customer1 --container data --dst modelregistry/secured-models/hf/meta/llama3.1 | |
| """ | |
| print(help_text) | |
| sys.exit(1) | |
| def check_requirements(download_model, cloud): | |
| """Check if the required tools are installed.""" | |
| if download_model: | |
| try: | |
| subprocess.run(["hf", "version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| except (subprocess.SubprocessError, FileNotFoundError): | |
| print("Error: huggingface-cli is required for downloading but not installed.") | |
| print("Please install it using pip:") | |
| print(" pip install huggingface_hub") | |
| return False | |
| try: | |
| subprocess.run(["ngc", "version", "info"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| except (subprocess.SubprocessError, FileNotFoundError): | |
| print("Error: NGC CLI is required for downloading but not installed.") | |
| print("Please install it ngc cli : https://org.ngc.nvidia.com/setup/installers/cli") | |
| return False | |
| if cloud: | |
| if cloud == "aws" or cloud == "pvc": | |
| try: | |
| subprocess.run(["aws", "--version"], check=True, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| except (subprocess.SubprocessError, FileNotFoundError): | |
| print(f"Error: aws-cli is required for uploading to {cloud} but not installed.") | |
| print("Please install it using pip:") | |
| print(" pip install awscli") | |
| return False | |
| elif cloud == "azure": | |
| try: | |
| subprocess.run(["az", "--version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| except (subprocess.SubprocessError, FileNotFoundError): | |
| print("Error: Azure CLI is required for uploading to Azure Blob Storage but not installed.") | |
| print("Please install it using pip:") | |
| print(" pip install azure-cli") | |
| return False | |
| else: | |
| print(f"Unsupported cloud provider: {cloud}") | |
| return False | |
| return True | |
| def validate_repo_type(repo_type): | |
| """Validate the repository type.""" | |
| valid_types = ["hf", "ngc"] | |
| if repo_type not in valid_types: | |
| print(f"Error: Invalid repository type '{repo_type}'") | |
| print(f"Valid types are: {', '.join(valid_types)}") | |
| return False | |
| return True | |
| def download_repo_hf(repo_id, token, download_path): | |
| """Download a Hugging Face repository using huggingface-cli.""" | |
| print(f"Downloading repository: {repo_id} to {download_path}") | |
| # Create necessary directories | |
| os.makedirs(os.path.join(download_path, "hf", repo_id, "artifacts"), exist_ok=True) | |
| # Download the repository | |
| print("Starting download with huggingface-cli...") | |
| try: | |
| cmd = ["hf", "download", repo_id, "--local-dir", f"{download_path}/hf/{repo_id}/artifacts"] | |
| if token: | |
| cmd.extend(["--token", token]) | |
| subprocess.run(cmd, check=True) | |
| print(f"Download completed successfully for {repo_id}") | |
| return True | |
| except subprocess.SubprocessError: | |
| print(f"Error: Failed to download {repo_id}") | |
| return False | |
| def canusenimcli(metadata): | |
| """ | |
| Parse YAML string and extract release information. | |
| Args: | |
| yaml_str (str): YAML string containing release data | |
| Returns: | |
| dict: Dictionary containing parsed release information | |
| """ | |
| try: | |
| # Parse the YAML string | |
| # Extract information from the first entry in the YAML | |
| # (In this case, there's only one entry with a hash as the key) | |
| ngcmetadata = metadata['ngcMetadata'] | |
| print(ngcmetadata) | |
| hash_key = next(iter(ngcmetadata)) | |
| release_info = ngcmetadata[hash_key] | |
| print(release_info) | |
| # Check if the release key is present | |
| if "release" not in release_info: | |
| return False | |
| if "model" in release_info: | |
| if release_info['model'] == "nvidia/nemoretriever-parse": | |
| return True | |
| if "release" in release_info: | |
| version = release_info["release"] | |
| checkversion = "1.3.0" | |
| from packaging.version import Version | |
| return Version(version) >= Version(checkversion) | |
| except Exception as e: | |
| print("error Failed to parse YAML:", e) | |
| return False | |
| def download_repo(repo_id, token, download_path, repo_type, metadata, ngc_spec): | |
| """Download a repository using huggingface-cli.""" | |
| print(f"start downloading {repo_type} repository: {repo_id}") | |
| if repo_type == 'hf': | |
| return download_repo_hf(repo_id, token, download_path) | |
| elif repo_type == 'ngc': | |
| ngcPrefix = 'ngc://' | |
| if repo_id.startswith(ngcPrefix): | |
| repo_id = repo_id[len(ngcPrefix):] | |
| repo_path=extract_last_part(repo_id) | |
| download_path = os.path.join(download_path, "ngc", repo_path, "artifacts") | |
| os.makedirs(download_path, exist_ok=True) | |
| # Implement NGC repository downloading if needed | |
| nimcli=canusenimcli(metadata) | |
| ngcmetadata = metadata['ngcMetadata'] | |
| profile_sha = next(iter(ngcmetadata)) | |
| spec = load_ngc_spec(ngc_spec) | |
| version = '' | |
| if 'registryVersion' in spec: | |
| version = spec['registryVersion'] | |
| if nimcli: | |
| model_name = repo_id.split(":")[0] | |
| model_name_parts = model_name.split("/") | |
| if len(model_name_parts)!=3: | |
| raise NameError("Model name should have three parts "+model_name) | |
| execute_nim_download_command(repo_id, download_path, ngc_spec, profile_sha, version) | |
| else: | |
| for component in metadata['components']: | |
| print(f"Repo ID: {component['repo_id']}") | |
| print(f"Destination: {component['destination']}") | |
| execute_ngc_download_command(component['repo_id'], download_path, component.get('files')) | |
| print("Finish downloading artifacts") | |
| def get_repo_info_hf(repo_id, token): | |
| """Get huggingface repository metadata and save it to a file.""" | |
| import requests | |
| # Prepare headers for API request | |
| headers = {"Accept": "application/json"} | |
| if token: | |
| headers["Authorization"] = f"Bearer {token}" | |
| # Make the API request | |
| url = f"https://huggingface.co/api/models/{repo_id}" | |
| try: | |
| response = requests.get(url, headers=headers) | |
| response.raise_for_status() # Raise exception for HTTP errors | |
| # Parse JSON and save to file | |
| return response.json() | |
| except (requests.RequestException, json.JSONDecodeError) as e: | |
| print(f"Failed to fetch huggingface repo metadata for the model: {str(e)}") | |
| return None | |
| def get_repo_info_ngc(repo_id, spec_file): | |
| """Get NGC repository metadata and save it to a file.""" | |
| # Implement NGC repository metadata fetching if needed | |
| spec = load_ngc_spec(spec_file) | |
| ngcMetadata = extract_profile_components(spec, repo_id) | |
| repo_id=repo_id.split(':') | |
| modelMetadata = get_ngc_model_info(repo_id[0], repo_id[1]) | |
| return ngcMetadata, modelMetadata | |
| def get_repo_info(repo_id, token, repo_type, download_path, ngc_spec): | |
| """Get repository metadata and save it to a file.""" | |
| print(f"Fetching information for {repo_type} repository: {repo_id}") | |
| print(f"Download path: {download_path}") | |
| if repo_type == 'hf': | |
| metadata_path = os.path.join(download_path, "hf", repo_id, "metadata") | |
| metadata = get_repo_info_hf(repo_id, token) | |
| os.makedirs(metadata_path, exist_ok=True) | |
| output_file = os.path.join(metadata_path, "metadata.json") | |
| with open(output_file, 'w') as f: | |
| json.dump(metadata, f, indent=2) | |
| elif repo_type == 'ngc': | |
| repo_path = extract_last_part(repo_id) | |
| metadata_path = os.path.join(download_path, "ngc", repo_path, "metadata") | |
| metadata, modelmetadata = get_repo_info_ngc(repo_id, ngc_spec) | |
| metadataToFile = metadata['ngcMetadata'] | |
| os.makedirs(metadata_path, exist_ok=True) | |
| output_file = os.path.join(metadata_path, "metadata.yaml") | |
| with open(output_file, 'w') as f: | |
| yaml.dump(metadataToFile, f, default_flow_style=False, sort_keys=False, allow_unicode=True, Dumper=SafeUnknownDumper) | |
| if modelmetadata: | |
| output_file = os.path.join(metadata_path, "modelmetadata.json") | |
| with open(output_file, 'w') as f: | |
| json.dump(modelmetadata, f, indent=2) | |
| print("finish downloading metadata file") | |
| # Implement NGC repository metadata fetching if needed | |
| print(f"Saved metadata to {output_file}") | |
| return metadata | |
| def upload_to_cloud(src, dst, cloud, token=None, recursive=False, endpoint=None, | |
| insecure=False, ca_bundle=None, account=None, container=None, | |
| repo_id=None, repo_type=None): | |
| """Upload files to cloud storage.""" | |
| print(f"Start uploading {repo_type} repository: {repo_id}") | |
| try: | |
| if repo_type == "ngc": | |
| repo_id = extract_last_part(repo_id) | |
| src = os.path.join(src, repo_type, repo_id) | |
| dst = dst +"/"+ repo_type +"/"+ repo_id | |
| if cloud == "aws": | |
| cmd = ["aws", "s3", "cp", src, f"{dst}/", "--recursive"] | |
| print(" ".join(cmd)) | |
| subprocess.run(cmd, check=True, shell=True) | |
| elif cloud == "azure": | |
| cmd = [ | |
| "az", "storage", "blob", "upload-batch", | |
| "--account-name", account, | |
| "--destination", container, | |
| "--destination-path", dst, | |
| "--source", src | |
| ] | |
| if token: | |
| cmd.extend(["--sas-token", token]) | |
| subprocess.run(cmd, check=True) | |
| elif cloud == "pvc": | |
| cmd = ["aws", "s3"] | |
| if endpoint: | |
| cmd.extend(["--endpoint", endpoint]) | |
| if insecure: | |
| cmd.append("--no-verify-ssl") | |
| elif ca_bundle: | |
| cmd.extend(["--ca-bundle", ca_bundle]) | |
| cmd.extend(["cp", src, dst, "--recursive"]) | |
| subprocess.run(cmd, check=True) | |
| else: | |
| print(f"Unsupported cloud provider: {cloud}") | |
| return False | |
| print(f"Uploaded: {src} -> {dst}") | |
| print(f"Finish uploading {repo_type} repository: {repo_id} to {cloud}") | |
| return True | |
| except subprocess.SubprocessError as e: | |
| print(f"Error during upload: {str(e)}") | |
| return False | |
| def print_models(models: List[Dict[str, str]], title: str = "Models"): | |
| """Print models in a formatted way.""" | |
| print(f"\n=== {title.upper()} ===") | |
| if not models: | |
| print("No models found.") | |
| return | |
| for i, model in enumerate(models, 1): | |
| print(f"{i}. {model['name']}") | |
| print(f" Display Name: {model['displayName']}") | |
| print(f" Category: {model['category']}") | |
| print(f" Hub ID: {model['modelHubID']}") | |
| if model['description']: | |
| desc = model['description'][:100] + "..." if len(model['description']) > 100 else model['description'] | |
| print(f" Description: {desc}") | |
| print() | |
| def print_list(items: List[str], title: str): | |
| """Print a list of items in a formatted way.""" | |
| print(f"\n=== {title.upper()} ===") | |
| if not items: | |
| print(f"No {title.lower()} found.") | |
| return | |
| for i, item in enumerate(items, 1): | |
| print(f"{i}. {item}") | |
| def print_detailed_model(model_data: Dict, model_name: str): | |
| """Print detailed model information.""" | |
| print(f"\n=== DETAILED INFO FOR '{model_name}' ===") | |
| if not model_data: | |
| print(f"Model '{model_name}' not found.") | |
| return | |
| print(f"Name: {model_data.get('name', 'N/A')}") | |
| print(f"Display Name: {model_data.get('displayName', 'N/A')}") | |
| print(f"Model Hub ID: {model_data.get('modelHubID', 'N/A')}") | |
| print(f"Category: {model_data.get('category', 'N/A')}") | |
| print(f"Type: {model_data.get('type', 'N/A')}") | |
| print(f"Description: {model_data.get('description', 'N/A')}") | |
| print(f"License: {model_data.get('license', 'N/A')}") | |
| if 'labels' in model_data: | |
| print(f"Labels: {', '.join(model_data['labels'])}") | |
| if 'modelVariants' in model_data: | |
| print(f"\nModel Variants ({len(model_data['modelVariants'])}):") | |
| for i, variant in enumerate(model_data['modelVariants'], 1): | |
| print(f" {i}. {variant.get('variantId', 'N/A')}") | |
| if 'optimizationProfiles' in variant: | |
| print(f" Optimization Profiles ({len(variant['optimizationProfiles'])}):") | |
| for j, profile in enumerate(variant['optimizationProfiles'], 1): | |
| profile_id = profile.get('profileId', 'N/A') | |
| display_name = profile.get('displayName', 'N/A') | |
| framework = profile.get('framework', 'N/A') | |
| print(f" {j}. {profile_id}") | |
| print(f" Display Name: {display_name}") | |
| print(f" Framework: {framework}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Hugging Face model management script") | |
| # Help option | |
| # Download options | |
| parser.add_argument("-t", "--token", default="", help="Token for authentication") | |
| parser.add_argument("-j", "--json", action="store_true", help="Output raw JSON") | |
| parser.add_argument("-do", "--download", action="store_true", help="Download the model repository") | |
| parser.add_argument("-p", "--path", default="./models", help="Path to download model files") | |
| parser.add_argument("-ri", "--repo-id", help="Repository ID to download") | |
| parser.add_argument("-rt", "--repo-type", default="hf", help="Repository type (default: hf)") | |
| # parser.add_argument("-sha", "--profile-sha", help="Sha of the p rofile of the repoID") | |
| # Upload options | |
| parser.add_argument("-c", "--cloud", default="aws", help="Cloud provider (aws, gcp, azure, pvc)") | |
| parser.add_argument("-s", "--src", help="Source directory for upload") | |
| parser.add_argument("-d", "--dst", help="Destination path in object storage") | |
| parser.add_argument("-r", "--recursive", action="store_true", help="Recursively upload folders") | |
| parser.add_argument("-e", "--endpoint", help="S3 gateway endpoint (Private cloud only)") | |
| parser.add_argument("-i", "--insecure", action="store_true", help="Allow insecure SSL connections") | |
| parser.add_argument("-ca", "--ca-bundle", help="Path to custom CA bundle file") | |
| parser.add_argument("-ac", "--account", help="Account for Azure uploads") | |
| parser.add_argument("-cn", "--container", help="Container name for Azure uploads") | |
| parser.add_argument("-ns", "--ngc-spec", help="NGC spec folder path") | |
| parser.add_argument('-m', '--model-name', help='Name of the model to query') | |
| parser.add_argument('-vid', '--variant-id', help='Variant ID (used with model name for specific queries)') | |
| parser.add_argument('--list-all', action='store_true',help='List all available models') | |
| parser.add_argument('--list-variants', action='store_true',help='List all variant IDs for the specified model') | |
| parser.add_argument('--list-profiles', action='store_true',help='List all optimization profile IDs for the specified model (and variant if provided)') | |
| args = parser.parse_args() | |
| # # Show help if requested or no arguments provided | |
| # if args.help or len(sys.argv) == 1: | |
| # show_help() | |
| # Check requirements | |
| if not check_requirements(args.download, args.cloud if args.src else None): | |
| sys.exit(1) | |
| # Handle different command combinations | |
| if args.list_all: | |
| ngc_spec_file = args.ngc_spec | |
| parser = ModelParser(ngc_spec_file) | |
| models = parser.get_all_models() | |
| print_models(models, "All Models") | |
| return | |
| elif args.model_name: | |
| ngc_spec_file = args.ngc_spec | |
| parser = ModelParser(ngc_spec_file) | |
| if args.list_variants: | |
| variants = parser.get_model_variant_ids(args.model_name) | |
| print_list(variants, f"Variants for '{args.model_name}'") | |
| elif args.list_profiles: | |
| profiles = parser.get_optimization_profile_ids(args.model_name, args.variant_id) | |
| if args.variant_id: | |
| title = f"Optimization Profiles for '{args.model_name}' variant '{args.variant_id}'" | |
| else: | |
| title = f"Optimization Profiles for '{args.model_name}'" | |
| print_list(profiles, title) | |
| else: | |
| # Default: show basic info about the model | |
| model_data = parser.get_detailed_model_info(args.model_name) | |
| if model_data: | |
| model_info = { | |
| 'name': model_data.get('name', ''), | |
| 'displayName': model_data.get('displayName', ''), | |
| 'modelHubID': model_data.get('modelHubID', ''), | |
| 'category': model_data.get('category', ''), | |
| 'description': model_data.get('description', '') | |
| } | |
| print_models([model_info], f"Model '{args.model_name}'") | |
| else: | |
| print(f"Model '{args.model_name}' not found.") | |
| return | |
| # Handle download use case | |
| if args.download: | |
| if not args.repo_id: | |
| print("Error: --repo-id is required for download") | |
| sys.exit(1) | |
| if not validate_repo_type(args.repo_type): | |
| sys.exit(1) | |
| # Get repository info and download | |
| metadata = get_repo_info(args.repo_id, args.token, args.repo_type, args.path, args.ngc_spec) | |
| if metadata is not None: | |
| download_repo(args.repo_id, args.token, args.path, args.repo_type, metadata, args.ngc_spec) | |
| else: | |
| print("Error: Failed to get repository metadata") | |
| sys.exit(1) | |
| sys.exit(0) | |
| # Handle upload use case | |
| if args.src and args.dst: | |
| if not upload_to_cloud( | |
| args.src, args.dst, args.cloud, args.token, args.recursive, | |
| args.endpoint, args.insecure, args.ca_bundle, args.account, args.container, | |
| args.repo_id, args.repo_type | |
| ): | |
| sys.exit(1) | |
| print("Upload completed.") | |
| else: | |
| print("Error: Missing required parameters.") | |
| show_help() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment