Skip to content

Instantly share code, notes, and snippets.

@finetunej
finetunej / to_hf_weights.py
Created August 17, 2021 12:24
For converting trained gpt-j checkpoints into a pytorch Hugging Face format.
####
# run with 'help' arg for usage.
####
"""
python3.8 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
def apply_reshard(pytree_params_in, pytree_params_out, shards_in, shards_out):
def override_dtype(x):
if x.dtype == np.dtype('V2'):
x.dtype = jnp.bfloat16
return x
def is_leaf(x):
return type(x) == np.ndarray