Created
September 27, 2025 15:59
-
-
Save ivanfioravanti/598543d1fd112c5fa38536e8168b34b7 to your computer and use it in GitHub Desktop.
batch_generate_response_multi_batches.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Run four batched generations with varying sampling settings.""" | |
| import argparse | |
| import mlx.core as mx | |
| from mlx_lm import batch_generate, load | |
| from mlx_lm.sample_utils import make_sampler | |
| def main(model_name: str, max_tokens: int = 64) -> None: | |
| mx.random.seed(42) | |
| model, tokenizer = load(path_or_hf_repo=model_name) | |
| print(f"Using model: {model_name}") | |
| base_prompts = [ | |
| "What is the chemical formula of water? In few words.", | |
| "What is the distance between the Earth and the Moon? In few words.", | |
| "What is the speed of light? In few words.", | |
| "How tall is Mount Everest? In few words.", | |
| ] | |
| prompts = [ | |
| tokenizer.apply_chat_template( | |
| [{"role": "user", "content": prompt}], | |
| add_generation_prompt=True, | |
| ) | |
| for prompt in base_prompts | |
| ] | |
| sampler_configs = [ | |
| {"temperature": 1.0, "top_k": 60, "min_p": 0.0, "top_p": 0.8}, | |
| {"temperature": 2.0, "top_k": 70}, | |
| {"temperature": 3.0, "top_k": 80}, | |
| ] | |
| for idx, config in enumerate(sampler_configs, start=1): | |
| print( | |
| f"===== Batch {idx}: temp={config['temperature']}, top_k={config['top_k']}" | |
| ) | |
| sampler = make_sampler( | |
| temp=config["temperature"], | |
| top_k=config["top_k"], | |
| ) | |
| result = batch_generate( | |
| model, | |
| tokenizer, | |
| prompts, | |
| max_tokens=max_tokens, | |
| sampler=sampler, | |
| verbose=False, | |
| ) | |
| for prompt_text, completion in zip(base_prompts, result.texts): | |
| print(f"Prompt: {prompt_text}\n{completion}") | |
| print("-" * 20) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument( | |
| "--model", | |
| default="mlx-community/Llama-3.2-3B-Instruct-4bit", | |
| help="Model to load", | |
| ) | |
| parser.add_argument( | |
| "--max-tokens", | |
| type=int, | |
| default=64, | |
| help="Maximum tokens to generate per response", | |
| ) | |
| args = parser.parse_args() | |
| main(args.model, max_tokens=args.max_tokens) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment