Skip to content

Instantly share code, notes, and snippets.

@smothiki
Last active January 22, 2026 01:11
Show Gist options
  • Select an option

  • Save smothiki/bd6474446b4490464cd4f6412780b90b to your computer and use it in GitHub Desktop.

Select an option

Save smothiki/bd6474446b4490464cd4f6412780b90b to your computer and use it in GitHub Desktop.
airgap script windows version
#!/usr/bin/env python3
import argparse
import json
import os
import subprocess
import sys
import logging
import yaml
from typing import List, Dict, Optional, Tuple
# courtesy : https://stackoverflow.com/questions/43765849/pyyaml-load-and-dump-yaml-file-and-preserve-tags-customtag
class SafeUnknownConstructor(yaml.constructor.SafeConstructor):
def __init__(self):
yaml.constructor.SafeConstructor.__init__(self)
def construct_undefined(self, node):
data = getattr(self, 'construct_' + node.id)(node)
datatype = type(data)
wraptype = type('TagWrap_'+datatype.__name__, (datatype,), {})
wrapdata = wraptype(data)
wrapdata.tag = lambda: None
wrapdata.datatype = lambda: None
setattr(wrapdata, "wrapTag", node.tag)
setattr(wrapdata, "wrapType", datatype)
return wrapdata
class SafeUnknownLoader(SafeUnknownConstructor, yaml.loader.SafeLoader):
def __init__(self, stream):
SafeUnknownConstructor.__init__(self)
yaml.loader.SafeLoader.__init__(self, stream)
class SafeUnknownRepresenter(yaml.representer.SafeRepresenter):
def represent_data(self, wrapdata):
tag = False
if type(wrapdata).__name__.startswith('TagWrap_'):
datatype = getattr(wrapdata, "wrapType")
tag = getattr(wrapdata, "wrapTag")
data = datatype(wrapdata)
else:
data = wrapdata
node = super(SafeUnknownRepresenter, self).represent_data(data)
if tag:
node.tag = tag
return node
class SafeUnknownDumper(SafeUnknownRepresenter, yaml.dumper.SafeDumper):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
SafeUnknownRepresenter.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
yaml.dumper.SafeDumper.__init__(self, stream,
default_style=default_style,
default_flow_style=default_flow_style,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
encoding=encoding,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
sort_keys=sort_keys)
MySafeLoader = SafeUnknownLoader
yaml.constructor.SafeConstructor.add_constructor(None, SafeUnknownConstructor.construct_undefined)
class ModelParser:
def __init__(self, yaml_file_path: str):
"""Initialize the parser with a YAML file path."""
self.yaml_file_path = yaml_file_path
self.models_data = self._load_yaml()
def _load_yaml(self) -> Dict:
"""Load and parse the YAML file."""
try:
with open(self.yaml_file_path, "r") as file:
yaml_data = file.read()
return yaml.load(yaml_data,Loader=MySafeLoader)
except FileNotFoundError:
print(f"Error: File {self.yaml_file_path} not found.")
return {}
except yaml.YAMLError as e:
print(f"Error parsing YAML: {e}")
return {}
def get_all_models(self) -> List[Dict[str, str]]:
"""Fetch all models with their basic information."""
models = []
if 'models' in self.models_data:
for model in self.models_data['models']:
model_info = {
'name': model.get('name', ''),
'displayName': model.get('displayName', ''),
'modelHubID': model.get('modelHubID', ''),
'category': model.get('category', ''),
'description': model.get('description', '')
}
models.append(model_info)
return models
def get_model_variant_ids(self, model_name: str) -> List[str]:
"""
Fetch all model variant IDs for a given model name.
Args:
model_name: The name of the model to search for
Returns:
List of variant IDs for the specified model
"""
variant_ids = []
if 'models' in self.models_data:
for model in self.models_data['models']:
if model.get('name', '').lower() == model_name.lower():
if 'modelVariants' in model:
for variant in model['modelVariants']:
variant_id = variant.get('variantId', '')
if variant_id:
variant_ids.append(variant_id)
break
return variant_ids
def get_optimization_profile_ids(self, model_name: str, variant_id: str = None) -> List[str]:
"""
Fetch all optimization profile IDs for a given model name and optional variant ID.
Args:
model_name: The name of the model
variant_id: Optional variant ID. If None, gets profiles from all variants
Returns:
List of optimization profile IDs
"""
profile_ids = []
if 'models' in self.models_data:
for model in self.models_data['models']:
if model.get('name', '').lower() == model_name.lower():
if 'modelVariants' in model:
for variant in model['modelVariants']:
# If variant_id is specified, only process that variant
if variant_id and variant.get('variantId', '').lower() != variant_id.lower():
continue
if 'optimizationProfiles' in variant:
for profile in variant['optimizationProfiles']:
profile_id = profile.get('profileId', '')
if profile_id:
profile_ids.append(profile_id)
break
return profile_ids
def get_detailed_model_info(self, model_name: str) -> Optional[Dict]:
"""
Get detailed information about a specific model including all variants and profiles.
Args:
model_name: The name of the model
Returns:
Dictionary containing detailed model information
"""
if 'models' in self.models_data:
for model in self.models_data['models']:
if model.get('name', '').lower() == model_name.lower():
return model
return None
def extract_last_part(input_string):
# Split the string by '/' and return the last part
return input_string.split('/')[-1]
def run_ngc_info_command(ID):
logging.info(f"Start Fetching model info from NGC: {ID}")
cmd = ["ngc", "registry", "model", "info", ID, "--format_type", "json"]
# print the command without exposing Cloudera's key
logging.info(cmd)
# Set up environment with API keys
# env = os.environ.copy()
# env["NGC_CLI_API_KEY"] = os.environ.get("NGC_API_KEY")
# env["NGC_CLI_ORG"] = os.environ.get("NGC_CLI_ORG")
# Run the command and collect output
try:
result = subprocess.run(
cmd,
capture_output=True,
text=False,
check=True
)
logging.info(f"Finish Fetching NGC model repo {ID}")
# Parse the JSON output
try:
metadata_map = json.loads(result.stdout)
return metadata_map, None
except json.JSONDecodeError as e:
logging.error(f"Error while Unmarshalling the NGC info command to map: {str(e)}")
return None, e
except subprocess.CalledProcessError as e:
logging.error(f"Error Fetching model repo: {ID}")
logging.error(f"Error: {str(e)}")
return None, e
def get_ngc_model_info(model_id, tag):
model_metadata_map, err = run_ngc_info_command(model_id)
if err:
logging.error(f"Error Fetching model repo: {model_id}")
logging.error(f"Error: {str(err)}")
return ""
versionMetadataMap, err = run_ngc_info_command(model_id + ":" + tag)
if err:
logging.error(f"Error Fetching model repo: {model_id}")
logging.error(f"Error: {str(err)}")
return ""
v1 = versionMetadataMap.get("totalSizeInBytes", "")
v2 = model_metadata_map.get("versionId", "")
if v2 != "":
model_metadata_map["versionId"] = v2
return model_metadata_map
def load_ngc_spec(spec_file):
with open(spec_file, "r") as file:
yaml_data = file.read()
yaml_data = yaml.load(yaml_data,Loader=MySafeLoader)
return yaml_data
def execute_nim_download_command(repo_id, folder_location, ngc_spec, profile_sha, version):
"""
Execute nim cli download command to download model files.
Args:
repo_id (str): Repository ID
"""
model_name = repo_id.split(":")[0]
count = model_name.count('/')
if count != 2:
raise ValueError(f"Expected 3 '/' characters, but found {count} in model name")
# Get the absolute path of the ngc_spec folder
ngc_spec_abs = os.path.dirname(ngc_spec)
manifest_path = f"{ngc_spec_abs}/manifests/{version}/{model_name}.yaml"
cmd = [
"nimcli", "download", "--profiles", profile_sha, "--manifest-file",
manifest_path, "--model-cache-path", folder_location
]
print(cmd)
try:
# output = subprocess.check_output(cmd, env=env, stderr=subprocess.STDOUT)
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logging.error(f"Error download model repo: {repo_id}")
logging.error(f"Error: {str(e)}")
return folder_location, e
def execute_ngc_download_command(repo_id, folder_location, files=None):
"""
Execute NGC download command to download model files.
Args:
repo_id (str): Repository ID
folder_location (str): Destination folder
files (list, optional): List of files to download. If None, download the entire repo.
Returns:
tuple: (folder_location, error)
"""
# Set environment variables
# env = os.environ.copy()
# env["NGC_CLI_API_KEY"] = os.environ.get("NGC_API_KEY")
# env["NGC_CLI_ORG"] = os.environ.get("NGC_CLI_ORG")
if files and len(files) > 0:
# Download specific files
for file in files:
cmd = [
"ngc", "registry", "model", "download-version",
repo_id, "--file", file, "--dest", folder_location,
"--format_type", "json"
]
logging.info(cmd)
try:
# output = subprocess.check_output(cmd,env=env, stderr=subprocess.STDOUT)
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logging.error(f"Error download model repo: {repo_id}")
logging.error(f"Error: {str(e)}")
logging.error(f"Command output: {e.output.decode()}")
return folder_location, e
else:
# Download entire repository
cmd = [
"ngc", "registry", "model", "download-version",
repo_id, "--dest", folder_location, "--format_type", "json"
]
logging.info(cmd)
try:
# output = subprocess.check_output(cmd, env=env, stderr=subprocess.STDOUT)
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logging.error(f"Error download model repo: {repo_id}")
logging.error(f"Error: {str(e)}")
logging.error(f"Command output: {e.output.decode()}")
return folder_location, Exception(e.output.decode())
return folder_location, None
def extract_profile_components(data, target_profile_id):
"""
Extract repoID and src files from components for a given profileID.
Args:
yaml_data (str): The YAML data as a string
target_profile_id (str): The profileID to search for
Returns:
dict: A dictionary containing:
- 'found' (bool): Whether the profile was found
- 'components' (list): List of dictionaries with 'repo_id' and 'files' for each component
"""
result = {
'found': False,
'components': [],
'ngcMetadata': None,
}
# Iterate through models
for model in data.get('models', []):
# Iterate through variants
for variant in model.get('modelVariants', []):
# Iterate through profiles
for profile in variant.get('optimizationProfiles', []):
# Check if this is the target profileID
if profile.get('profileId') == target_profile_id:
result['found'] = True
result['ngcMetadata'] = profile.get('ngcMetadata', None)
# Look for ngcMetadata which contains the workspace components
for sha_key, metadata in profile.get('ngcMetadata', {}).items():
if 'workspace' in metadata and 'components' in metadata['workspace']:
for component in metadata['workspace']['components']:
component_info = {
'destination': component.get('dst', ''),
'repo_id': component.get('src', {}).get('repo_id', '')
}
# Extract files if they exist
files = component.get('src', {}).get('files', [])
if files:
# Handle both direct strings and dictionaries with name tags
component_info['files'] = []
for f in files:
if isinstance(f, str):
component_info['files'].append(f)
elif isinstance(f, dict) and f.get('!name'):
component_info['files'].append(f.get('!name'))
result['components'].append(component_info)
return result # Return once the profile is found
return result # Return not found if we get here
def show_help():
"""Display help information and exit."""
help_text = """
Description: Fetches information about a model from the Hugging Face Hub and optionally downloads it,
or uploads model artifacts to cloud storage.
Examples:
python script.py gpt2
python script.py --token YOUR_HF_TOKEN --repo-id facebook/bart-large
python script.py --download --path ~/my_models --repo-id bert-base-uncased
python script.py --repo-type dataset --download --repo-id mnist
python script.py --cloud aws --src /path/to/models/ --dst s3://bucket/path --recursive
python script.py --cloud pvc --src /path/to/model/hf/meta/llama3.1 --dst s3://bucket/secured-models/hf/meta/llama3.1
python script.py --cloud azure --src /path/to/model/hf/meta/llama3.1 --account cloudera-customer1 --container data --dst modelregistry/secured-models/hf/meta/llama3.1
"""
print(help_text)
sys.exit(1)
def check_requirements(download_model, cloud):
"""Check if the required tools are installed."""
if download_model:
try:
subprocess.run(["hf", "version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except (subprocess.SubprocessError, FileNotFoundError):
print("Error: huggingface-cli is required for downloading but not installed.")
print("Please install it using pip:")
print(" pip install huggingface_hub")
return False
try:
subprocess.run(["ngc", "version", "info"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except (subprocess.SubprocessError, FileNotFoundError):
print("Error: NGC CLI is required for downloading but not installed.")
print("Please install it ngc cli : https://org.ngc.nvidia.com/setup/installers/cli")
return False
if cloud:
if cloud == "aws" or cloud == "pvc":
try:
subprocess.run(["aws", "--version"], check=True, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except (subprocess.SubprocessError, FileNotFoundError):
print(f"Error: aws-cli is required for uploading to {cloud} but not installed.")
print("Please install it using pip:")
print(" pip install awscli")
return False
elif cloud == "azure":
try:
subprocess.run(["az", "--version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except (subprocess.SubprocessError, FileNotFoundError):
print("Error: Azure CLI is required for uploading to Azure Blob Storage but not installed.")
print("Please install it using pip:")
print(" pip install azure-cli")
return False
else:
print(f"Unsupported cloud provider: {cloud}")
return False
return True
def validate_repo_type(repo_type):
"""Validate the repository type."""
valid_types = ["hf", "ngc"]
if repo_type not in valid_types:
print(f"Error: Invalid repository type '{repo_type}'")
print(f"Valid types are: {', '.join(valid_types)}")
return False
return True
def download_repo_hf(repo_id, token, download_path):
"""Download a Hugging Face repository using huggingface-cli."""
print(f"Downloading repository: {repo_id} to {download_path}")
# Create necessary directories
os.makedirs(os.path.join(download_path, "hf", repo_id, "artifacts"), exist_ok=True)
# Download the repository
print("Starting download with huggingface-cli...")
try:
cmd = ["hf", "download", repo_id, "--local-dir", f"{download_path}/hf/{repo_id}/artifacts"]
if token:
cmd.extend(["--token", token])
subprocess.run(cmd, check=True)
print(f"Download completed successfully for {repo_id}")
return True
except subprocess.SubprocessError:
print(f"Error: Failed to download {repo_id}")
return False
def canusenimcli(metadata):
"""
Parse YAML string and extract release information.
Args:
yaml_str (str): YAML string containing release data
Returns:
dict: Dictionary containing parsed release information
"""
try:
# Parse the YAML string
# Extract information from the first entry in the YAML
# (In this case, there's only one entry with a hash as the key)
ngcmetadata = metadata['ngcMetadata']
print(ngcmetadata)
hash_key = next(iter(ngcmetadata))
release_info = ngcmetadata[hash_key]
print(release_info)
# Check if the release key is present
if "release" not in release_info:
return False
if "model" in release_info:
if release_info['model'] == "nvidia/nemoretriever-parse":
return True
if "release" in release_info:
version = release_info["release"]
checkversion = "1.3.0"
from packaging.version import Version
return Version(version) >= Version(checkversion)
except Exception as e:
print("error Failed to parse YAML:", e)
return False
def download_repo(repo_id, token, download_path, repo_type, metadata, ngc_spec):
"""Download a repository using huggingface-cli."""
print(f"start downloading {repo_type} repository: {repo_id}")
if repo_type == 'hf':
return download_repo_hf(repo_id, token, download_path)
elif repo_type == 'ngc':
ngcPrefix = 'ngc://'
if repo_id.startswith(ngcPrefix):
repo_id = repo_id[len(ngcPrefix):]
repo_path=extract_last_part(repo_id)
download_path = os.path.join(download_path, "ngc", repo_path, "artifacts")
os.makedirs(download_path, exist_ok=True)
# Implement NGC repository downloading if needed
nimcli=canusenimcli(metadata)
ngcmetadata = metadata['ngcMetadata']
profile_sha = next(iter(ngcmetadata))
spec = load_ngc_spec(ngc_spec)
version = ''
if 'registryVersion' in spec:
version = spec['registryVersion']
if nimcli:
model_name = repo_id.split(":")[0]
model_name_parts = model_name.split("/")
if len(model_name_parts)!=3:
raise NameError("Model name should have three parts "+model_name)
execute_nim_download_command(repo_id, download_path, ngc_spec, profile_sha, version)
else:
for component in metadata['components']:
print(f"Repo ID: {component['repo_id']}")
print(f"Destination: {component['destination']}")
execute_ngc_download_command(component['repo_id'], download_path, component.get('files'))
print("Finish downloading artifacts")
def get_repo_info_hf(repo_id, token):
"""Get huggingface repository metadata and save it to a file."""
import requests
# Prepare headers for API request
headers = {"Accept": "application/json"}
if token:
headers["Authorization"] = f"Bearer {token}"
# Make the API request
url = f"https://huggingface.co/api/models/{repo_id}"
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raise exception for HTTP errors
# Parse JSON and save to file
return response.json()
except (requests.RequestException, json.JSONDecodeError) as e:
print(f"Failed to fetch huggingface repo metadata for the model: {str(e)}")
return None
def get_repo_info_ngc(repo_id, spec_file):
"""Get NGC repository metadata and save it to a file."""
# Implement NGC repository metadata fetching if needed
spec = load_ngc_spec(spec_file)
ngcMetadata = extract_profile_components(spec, repo_id)
repo_id=repo_id.split(':')
modelMetadata = get_ngc_model_info(repo_id[0], repo_id[1])
return ngcMetadata, modelMetadata
def get_repo_info(repo_id, token, repo_type, download_path, ngc_spec):
"""Get repository metadata and save it to a file."""
print(f"Fetching information for {repo_type} repository: {repo_id}")
print(f"Download path: {download_path}")
if repo_type == 'hf':
metadata_path = os.path.join(download_path, "hf", repo_id, "metadata")
metadata = get_repo_info_hf(repo_id, token)
os.makedirs(metadata_path, exist_ok=True)
output_file = os.path.join(metadata_path, "metadata.json")
with open(output_file, 'w') as f:
json.dump(metadata, f, indent=2)
elif repo_type == 'ngc':
repo_path = extract_last_part(repo_id)
metadata_path = os.path.join(download_path, "ngc", repo_path, "metadata")
metadata, modelmetadata = get_repo_info_ngc(repo_id, ngc_spec)
metadataToFile = metadata['ngcMetadata']
os.makedirs(metadata_path, exist_ok=True)
output_file = os.path.join(metadata_path, "metadata.yaml")
with open(output_file, 'w') as f:
yaml.dump(metadataToFile, f, default_flow_style=False, sort_keys=False, allow_unicode=True, Dumper=SafeUnknownDumper)
if modelmetadata:
output_file = os.path.join(metadata_path, "modelmetadata.json")
with open(output_file, 'w') as f:
json.dump(modelmetadata, f, indent=2)
print("finish downloading metadata file")
# Implement NGC repository metadata fetching if needed
print(f"Saved metadata to {output_file}")
return metadata
def upload_to_cloud(src, dst, cloud, token=None, recursive=False, endpoint=None,
insecure=False, ca_bundle=None, account=None, container=None,
repo_id=None, repo_type=None):
"""Upload files to cloud storage."""
print(f"Start uploading {repo_type} repository: {repo_id}")
try:
if repo_type == "ngc":
repo_id = extract_last_part(repo_id)
src = os.path.join(src, repo_type, repo_id)
dst = dst +"/"+ repo_type +"/"+ repo_id
if cloud == "aws":
cmd = ["aws", "s3", "cp", src, f"{dst}/", "--recursive"]
print(" ".join(cmd))
subprocess.run(cmd, check=True, shell=True)
elif cloud == "azure":
cmd = [
"az", "storage", "blob", "upload-batch",
"--account-name", account,
"--destination", container,
"--destination-path", dst,
"--source", src
]
if token:
cmd.extend(["--sas-token", token])
subprocess.run(cmd, check=True)
elif cloud == "pvc":
cmd = ["aws", "s3"]
if endpoint:
cmd.extend(["--endpoint", endpoint])
if insecure:
cmd.append("--no-verify-ssl")
elif ca_bundle:
cmd.extend(["--ca-bundle", ca_bundle])
cmd.extend(["cp", src, dst, "--recursive"])
subprocess.run(cmd, check=True)
else:
print(f"Unsupported cloud provider: {cloud}")
return False
print(f"Uploaded: {src} -> {dst}")
print(f"Finish uploading {repo_type} repository: {repo_id} to {cloud}")
return True
except subprocess.SubprocessError as e:
print(f"Error during upload: {str(e)}")
return False
def print_models(models: List[Dict[str, str]], title: str = "Models"):
"""Print models in a formatted way."""
print(f"\n=== {title.upper()} ===")
if not models:
print("No models found.")
return
for i, model in enumerate(models, 1):
print(f"{i}. {model['name']}")
print(f" Display Name: {model['displayName']}")
print(f" Category: {model['category']}")
print(f" Hub ID: {model['modelHubID']}")
if model['description']:
desc = model['description'][:100] + "..." if len(model['description']) > 100 else model['description']
print(f" Description: {desc}")
print()
def print_list(items: List[str], title: str):
"""Print a list of items in a formatted way."""
print(f"\n=== {title.upper()} ===")
if not items:
print(f"No {title.lower()} found.")
return
for i, item in enumerate(items, 1):
print(f"{i}. {item}")
def print_detailed_model(model_data: Dict, model_name: str):
"""Print detailed model information."""
print(f"\n=== DETAILED INFO FOR '{model_name}' ===")
if not model_data:
print(f"Model '{model_name}' not found.")
return
print(f"Name: {model_data.get('name', 'N/A')}")
print(f"Display Name: {model_data.get('displayName', 'N/A')}")
print(f"Model Hub ID: {model_data.get('modelHubID', 'N/A')}")
print(f"Category: {model_data.get('category', 'N/A')}")
print(f"Type: {model_data.get('type', 'N/A')}")
print(f"Description: {model_data.get('description', 'N/A')}")
print(f"License: {model_data.get('license', 'N/A')}")
if 'labels' in model_data:
print(f"Labels: {', '.join(model_data['labels'])}")
if 'modelVariants' in model_data:
print(f"\nModel Variants ({len(model_data['modelVariants'])}):")
for i, variant in enumerate(model_data['modelVariants'], 1):
print(f" {i}. {variant.get('variantId', 'N/A')}")
if 'optimizationProfiles' in variant:
print(f" Optimization Profiles ({len(variant['optimizationProfiles'])}):")
for j, profile in enumerate(variant['optimizationProfiles'], 1):
profile_id = profile.get('profileId', 'N/A')
display_name = profile.get('displayName', 'N/A')
framework = profile.get('framework', 'N/A')
print(f" {j}. {profile_id}")
print(f" Display Name: {display_name}")
print(f" Framework: {framework}")
def main():
parser = argparse.ArgumentParser(description="Hugging Face model management script")
# Help option
# Download options
parser.add_argument("-t", "--token", default="", help="Token for authentication")
parser.add_argument("-j", "--json", action="store_true", help="Output raw JSON")
parser.add_argument("-do", "--download", action="store_true", help="Download the model repository")
parser.add_argument("-p", "--path", default="./models", help="Path to download model files")
parser.add_argument("-ri", "--repo-id", help="Repository ID to download")
parser.add_argument("-rt", "--repo-type", default="hf", help="Repository type (default: hf)")
# parser.add_argument("-sha", "--profile-sha", help="Sha of the p rofile of the repoID")
# Upload options
parser.add_argument("-c", "--cloud", default="aws", help="Cloud provider (aws, gcp, azure, pvc)")
parser.add_argument("-s", "--src", help="Source directory for upload")
parser.add_argument("-d", "--dst", help="Destination path in object storage")
parser.add_argument("-r", "--recursive", action="store_true", help="Recursively upload folders")
parser.add_argument("-e", "--endpoint", help="S3 gateway endpoint (Private cloud only)")
parser.add_argument("-i", "--insecure", action="store_true", help="Allow insecure SSL connections")
parser.add_argument("-ca", "--ca-bundle", help="Path to custom CA bundle file")
parser.add_argument("-ac", "--account", help="Account for Azure uploads")
parser.add_argument("-cn", "--container", help="Container name for Azure uploads")
parser.add_argument("-ns", "--ngc-spec", help="NGC spec folder path")
parser.add_argument('-m', '--model-name', help='Name of the model to query')
parser.add_argument('-vid', '--variant-id', help='Variant ID (used with model name for specific queries)')
parser.add_argument('--list-all', action='store_true',help='List all available models')
parser.add_argument('--list-variants', action='store_true',help='List all variant IDs for the specified model')
parser.add_argument('--list-profiles', action='store_true',help='List all optimization profile IDs for the specified model (and variant if provided)')
args = parser.parse_args()
# # Show help if requested or no arguments provided
# if args.help or len(sys.argv) == 1:
# show_help()
# Check requirements
if not check_requirements(args.download, args.cloud if args.src else None):
sys.exit(1)
# Handle different command combinations
if args.list_all:
ngc_spec_file = args.ngc_spec
parser = ModelParser(ngc_spec_file)
models = parser.get_all_models()
print_models(models, "All Models")
return
elif args.model_name:
ngc_spec_file = args.ngc_spec
parser = ModelParser(ngc_spec_file)
if args.list_variants:
variants = parser.get_model_variant_ids(args.model_name)
print_list(variants, f"Variants for '{args.model_name}'")
elif args.list_profiles:
profiles = parser.get_optimization_profile_ids(args.model_name, args.variant_id)
if args.variant_id:
title = f"Optimization Profiles for '{args.model_name}' variant '{args.variant_id}'"
else:
title = f"Optimization Profiles for '{args.model_name}'"
print_list(profiles, title)
else:
# Default: show basic info about the model
model_data = parser.get_detailed_model_info(args.model_name)
if model_data:
model_info = {
'name': model_data.get('name', ''),
'displayName': model_data.get('displayName', ''),
'modelHubID': model_data.get('modelHubID', ''),
'category': model_data.get('category', ''),
'description': model_data.get('description', '')
}
print_models([model_info], f"Model '{args.model_name}'")
else:
print(f"Model '{args.model_name}' not found.")
return
# Handle download use case
if args.download:
if not args.repo_id:
print("Error: --repo-id is required for download")
sys.exit(1)
if not validate_repo_type(args.repo_type):
sys.exit(1)
# Get repository info and download
metadata = get_repo_info(args.repo_id, args.token, args.repo_type, args.path, args.ngc_spec)
if metadata is not None:
download_repo(args.repo_id, args.token, args.path, args.repo_type, metadata, args.ngc_spec)
else:
print("Error: Failed to get repository metadata")
sys.exit(1)
sys.exit(0)
# Handle upload use case
if args.src and args.dst:
if not upload_to_cloud(
args.src, args.dst, args.cloud, args.token, args.recursive,
args.endpoint, args.insecure, args.ca_bundle, args.account, args.container,
args.repo_id, args.repo_type
):
sys.exit(1)
print("Upload completed.")
else:
print("Error: Missing required parameters.")
show_help()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment