Skip to content

Instantly share code, notes, and snippets.

@vanbasten23
Created November 14, 2025 06:24
Show Gist options
  • Select an option

  • Save vanbasten23/7618a394435a024594fd58de57fa337d to your computer and use it in GitHub Desktop.

Select an option

Save vanbasten23/7618a394435a024594fd58de57fa337d to your computer and use it in GitHub Desktop.
import jax
from jax import export
import jax.numpy as jnp
import pickle
import time
import statistics
with open("/home/xiowei_google_com/old_exports.pkl", "rb") as f:
data = pickle.load(f)
exported = export.deserialize(data)
with open("/home/xiowei_google_com/old_weights.pkl", "rb") as f:
weights = pickle.load(f)
positions=jnp.arange(16)
key = jax.random.key(0)
hidden_states=jax.random.normal(key, (16, 1536), dtype=jnp.bfloat16)
# Time it
# all_time=[]
# for _ in range(20):
# start = time.perf_counter_ns()
# exported.call(weights, (positions, hidden_states)).block_until_ready()
# end = time.perf_counter_ns()
# all_time.append(end-start)
#
# print("Running old jax finished in [ns] ", statistics.mean(all_time[5:])) # 4338680
# Profile it
profile_path='/home/xiowei_google_com/myprofiles'
for i in range(20):
if i == 5:
jax.profiler.start_trace(profile_path)
exported.call(weights, (positions, hidden_states)).block_until_ready()
jax.profiler.stop_trace()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment