Created
October 21, 2025 23:20
-
-
Save richardliaw/93d9398d26544cf801e04e3663acfc45 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ray | |
| # Initialize Ray | |
| ray.init()#runtime_env={"env_vars": {"HF_TOKEN": "..."}}) | |
| # Create a sample dataset with prompts | |
| data = [ | |
| {"prompt": "What is the capital of France?"}, | |
| {"prompt": "Explain quantum computing in one sentence."}, | |
| {"prompt": "Write a haiku about programming."}, | |
| {"prompt": "What are the benefits of using Ray for distributed computing?"}, | |
| ] | |
| # Create a Ray Dataset | |
| ds = ray.data.from_items(data) | |
| print("Original dataset:") | |
| ds.show(limit=4) | |
| # Configure vLLM processor for batch inference | |
| from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor | |
| config = vLLMEngineProcessorConfig( | |
| model_source="unsloth/Llama-3.2-1B-Instruct", | |
| placement_group_config={ | |
| "bundles": [{"CPU": 1, "TPU": 1, "GPU": 0}] | |
| }, | |
| concurrency=1, | |
| batch_size=16, | |
| ) | |
| # Build the LLM processor | |
| processor = build_llm_processor( | |
| config, | |
| preprocess=lambda row: dict( | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant"}, | |
| {"role": "user", "content": f"{row['prompt']}"}, | |
| ], | |
| sampling_params=dict( | |
| temperature=0.3, | |
| max_tokens=20, | |
| detokenize=False, | |
| ), | |
| ), | |
| postprocess=lambda row: dict( | |
| resp=row["generated_text"], | |
| **row, # This will return all the original columns in the dataset. | |
| ), | |
| ) | |
| # Run batch inference | |
| print("\nRunning batch inference...") | |
| result_ds = processor(ds) | |
| result_ds = result_ds.materialize() | |
| # Display results | |
| print("\nResults:") | |
| for i, result in enumerate(result_ds.take_all()): | |
| print(f"\n--- Result {i+1} ---") | |
| print(f"Prompt: {result['prompt']}") | |
| print(f"Generated: {result['generated_text']}") | |
| # You can also save results to various formats | |
| print("\nSaving results to JSON...") | |
| result_ds.write_json("./outputs", try_create_dir=True) | |
| print("Results saved to ./outputs directory") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment