Skip to content

Instantly share code, notes, and snippets.

@b1tg
Created September 10, 2025 12:56
Show Gist options
  • Select an option

  • Save b1tg/417c333e740bdfd42abd0d2ac275d10f to your computer and use it in GitHub Desktop.

Select an option

Save b1tg/417c333e740bdfd42abd0d2ac275d10f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Compare JAX dtype promotion against tinygrad, including FP8 pairs.
- Shows JAX result_type(a,b) vs tinygrad.least_upper_dtype(a,b)
- Covers fp8e4m3/fp8e5m2 with {fp8, f16, bf16, f32, f64} and int64/uint64 combos
Usage:
python3 jax_promo.py
Notes:
- Requires jax and ml_dtypes installed for JAX side.
- This script does not execute arithmetic; it only queries dtype promotion.
"""
from __future__ import annotations
from typing import Dict, List, Tuple
# tinygrad side
from tinygrad import dtypes as TG
from tinygrad.dtype import least_upper_dtype
# jax side (optional)
try:
import jax
import jax.numpy as jnp
import ml_dtypes as ml
HAS_JAX = True
except Exception as e: # pragma: no cover
HAS_JAX = False
JAX_IMPORT_ERR = e
# logical dtype keys we want to test
DT_KEYS: List[str] = [
"fp8e4m3", "fp8e5m2",
"float16", "bfloat16", "float32", "float64",
"int64", "uint64",
]
# mapping: logical key -> tinygrad dtype
TG_MAP: Dict[str, TG.DType] = {
"fp8e4m3": TG.fp8e4m3,
"fp8e5m2": TG.fp8e5m2,
"float16": TG.float16,
"bfloat16": TG.bfloat16,
"float32": TG.float32,
"float64": TG.float64,
"int64": TG.int64,
"uint64": TG.uint64,
}
# mapping: logical key -> jax dtype (using ml_dtypes for fp8)
JX_MAP: Dict[str, object] = {}
if HAS_JAX:
JX_MAP = {
"fp8e4m3": getattr(ml, "float8_e4m3fn"),
"fp8e5m2": getattr(ml, "float8_e5m2"),
"float16": jnp.float16,
"bfloat16": jnp.bfloat16,
"float32": jnp.float32,
"float64": jnp.float64,
"int64": jnp.int64,
"uint64": jnp.uint64,
}
def tinygrad_name(dt: TG.DType) -> str:
# __repr__ prints like "dtypes.float16"; trim prefix for compactness
s = repr(dt)
return s.replace("dtypes.", "")
def jax_name(dt: object) -> str:
# ml_dtypes and jnp dtypes have .name or str representations
try:
n = getattr(dt, "name", None)
if n: return str(n)
except Exception:
pass
return str(dt)
def jax_promo(a: object, b: object) -> str:
if not HAS_JAX:
return f"<no jax: {type(JAX_IMPORT_ERR).__name__}: {JAX_IMPORT_ERR}>"
try:
out = jnp.result_type(a, b)
return jax_name(out)
except Exception as e:
return f"<error: {type(e).__name__}>"
return f"<error: {type(e).__name__}: {e}>"
def tinygrad_promo(a: TG.DType, b: TG.DType) -> str:
try:
return tinygrad_name(least_upper_dtype(a, b))
except Exception as e:
return f"<error: {type(e).__name__}: {e}>"
def pairs_to_test(keys: List[str]) -> List[Tuple[str, str]]:
pairs: List[Tuple[str, str]] = []
# test all ordered pairs (including same)
for i, ka in enumerate(keys):
for kb in keys[i:]:
pairs.append((ka, kb))
# also include cross-variant fp8 pair explicitly first for readability
if ("fp8e4m3", "fp8e5m2") not in pairs:
pairs.insert(0, ("fp8e4m3", "fp8e5m2"))
return pairs
def main() -> int:
print("JAX available:", HAS_JAX)
if not HAS_JAX:
print("Hint: pip install -U jax ml_dtypes # CPU-only")
pairs = pairs_to_test(DT_KEYS)
# header
print("\nComparing dtype promotion (JAX vs tinygrad):\n")
# print(f"{'A':<14} {'B':<14} {'JAX':<24} {'tinygrad':<24}")
print(f"{'A':<14} {'B':<14} {'tinygrad':<24} {'jax':<24}")
print("-" * 80)
for ka, kb in pairs:
tg_a, tg_b = TG_MAP[ka], TG_MAP[kb]
jx_a = JX_MAP.get(ka, ka)
jx_b = JX_MAP.get(kb, kb)
jx_out = jax_promo(jx_a, jx_b)
tg_out = tinygrad_promo(tg_a, tg_b)
# print(f"{ka:<14} {kb:<14} {jx_out:<24} {tg_out:<24}")
print(f"{ka:<14} {kb:<14} {tg_out:<24} {jx_out:<24}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
% python jax_promo.py
JAX available: True
Comparing dtype promotion (JAX vs tinygrad):
A B tinygrad jax
--------------------------------------------------------------------------------
fp8e4m3 fp8e4m3 fp8e4m3 float8_e4m3fn
fp8e4m3 fp8e5m2 half <error: TypePromotionError>
fp8e4m3 float16 half <error: TypePromotionError>
fp8e4m3 bfloat16 bfloat16 <error: TypePromotionError>
fp8e4m3 float32 float <error: TypePromotionError>
fp8e4m3 float64 double <error: TypePromotionError>
fp8e4m3 int64 half float8_e4m3fn
fp8e4m3 uint64 half float8_e4m3fn
fp8e5m2 fp8e5m2 fp8e5m2 float8_e5m2
fp8e5m2 float16 half <error: TypePromotionError>
fp8e5m2 bfloat16 bfloat16 <error: TypePromotionError>
fp8e5m2 float32 float <error: TypePromotionError>
fp8e5m2 float64 double <error: TypePromotionError>
fp8e5m2 int64 half float8_e5m2
fp8e5m2 uint64 half float8_e5m2
float16 float16 half float16
float16 bfloat16 float float32
float16 float32 float float32
float16 float64 double float32
float16 int64 half float16
float16 uint64 half float16
bfloat16 bfloat16 bfloat16 bfloat16
bfloat16 float32 float float32
bfloat16 float64 double float32
bfloat16 int64 bfloat16 bfloat16
bfloat16 uint64 bfloat16 bfloat16
float32 float32 float float32
float32 float64 double float32
float32 int64 float float32
float32 uint64 float float32
float64 float64 double float32
float64 int64 double float32
float64 uint64 double float32
int64 int64 long int32
int64 uint64 half float32
uint64 uint64 ulong uint32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment