Skip to content

Instantly share code, notes, and snippets.

@rajbot
Created March 3, 2026 06:08
Show Gist options
  • Select an option

  • Save rajbot/9a931e3ab135ecd68e5a94979887e03b to your computer and use it in GitHub Desktop.

Select an option

Save rajbot/9a931e3ab135ecd68e5a94979887e03b to your computer and use it in GitHub Desktop.
Compare local ollama vs mlx inference
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
import asyncio
import traceback
import time
ollama_model = init_chat_model("llama3.1:8b", model_provider="ollama")
import mlx_lm
_original_generate = mlx_lm.generate
def _patched_generate(*args, **kwargs):
kwargs.pop("formatter", None)
return _original_generate(*args, **kwargs)
mlx_lm.generate = _patched_generate
from langchain_community.llms.mlx_pipeline import MLXPipeline
from langchain_community.chat_models.mlx import ChatMLX
from langchain_core.messages import HumanMessage
llm = MLXPipeline.from_model_id(
"mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
)
mlx_model = ChatMLX(llm=llm)
ollama_agent = create_agent(ollama_model)
mlx_agent = create_agent(mlx_model)
messages = [HumanMessage(content="How many feet do snails have?")]
#warm up the models
response = ollama_agent.invoke({"messages": messages})
print(response)
response = mlx_agent.invoke({"messages": messages})
print(response)
print("warm up complete")
total_ollama_time = 0
total_mlx_time = 0
ollama_times = []
mlx_times = []
iterations = 10 #number of iterations to run
for i in range(iterations):
print(f"iteration {i}")
start_time = time.time()
response = ollama_agent.invoke({"messages": messages})
end_time = time.time()
ollama_time = end_time - start_time
total_ollama_time += ollama_time
ollama_times.append(ollama_time)
print(f"ollama time: {ollama_time}")
start_time = time.time()
response = mlx_agent.invoke({"messages": messages})
end_time = time.time()
mlx_time = end_time - start_time
total_mlx_time += mlx_time
mlx_times.append(mlx_time)
print(f"mlx time: {mlx_time}")
import statistics
ollama_std = statistics.stdev(ollama_times)
mlx_std = statistics.stdev(mlx_times)
print(f"total ollama time: {total_ollama_time}")
print(f"total mlx time: {total_mlx_time}")
print(f"average ollama time: {total_ollama_time / iterations}")
print(f"average mlx time: {total_mlx_time / iterations}")
print(f"ollama std dev: {ollama_std}")
print(f"mlx std dev: {mlx_std}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment