Skip to content

Instantly share code, notes, and snippets.

@ivanfioravanti
Created September 27, 2025 15:59
Show Gist options
  • Select an option

  • Save ivanfioravanti/598543d1fd112c5fa38536e8168b34b7 to your computer and use it in GitHub Desktop.

Select an option

Save ivanfioravanti/598543d1fd112c5fa38536e8168b34b7 to your computer and use it in GitHub Desktop.
batch_generate_response_multi_batches.py
"""Run four batched generations with varying sampling settings."""
import argparse
import mlx.core as mx
from mlx_lm import batch_generate, load
from mlx_lm.sample_utils import make_sampler
def main(model_name: str, max_tokens: int = 64) -> None:
mx.random.seed(42)
model, tokenizer = load(path_or_hf_repo=model_name)
print(f"Using model: {model_name}")
base_prompts = [
"What is the chemical formula of water? In few words.",
"What is the distance between the Earth and the Moon? In few words.",
"What is the speed of light? In few words.",
"How tall is Mount Everest? In few words.",
]
prompts = [
tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
)
for prompt in base_prompts
]
sampler_configs = [
{"temperature": 1.0, "top_k": 60, "min_p": 0.0, "top_p": 0.8},
{"temperature": 2.0, "top_k": 70},
{"temperature": 3.0, "top_k": 80},
]
for idx, config in enumerate(sampler_configs, start=1):
print(
f"===== Batch {idx}: temp={config['temperature']}, top_k={config['top_k']}"
)
sampler = make_sampler(
temp=config["temperature"],
top_k=config["top_k"],
)
result = batch_generate(
model,
tokenizer,
prompts,
max_tokens=max_tokens,
sampler=sampler,
verbose=False,
)
for prompt_text, completion in zip(base_prompts, result.texts):
print(f"Prompt: {prompt_text}\n{completion}")
print("-" * 20)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model",
default="mlx-community/Llama-3.2-3B-Instruct-4bit",
help="Model to load",
)
parser.add_argument(
"--max-tokens",
type=int,
default=64,
help="Maximum tokens to generate per response",
)
args = parser.parse_args()
main(args.model, max_tokens=args.max_tokens)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment