Created
September 10, 2025 12:56
-
-
Save b1tg/417c333e740bdfd42abd0d2ac275d10f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Compare JAX dtype promotion against tinygrad, including FP8 pairs. | |
| - Shows JAX result_type(a,b) vs tinygrad.least_upper_dtype(a,b) | |
| - Covers fp8e4m3/fp8e5m2 with {fp8, f16, bf16, f32, f64} and int64/uint64 combos | |
| Usage: | |
| python3 jax_promo.py | |
| Notes: | |
| - Requires jax and ml_dtypes installed for JAX side. | |
| - This script does not execute arithmetic; it only queries dtype promotion. | |
| """ | |
| from __future__ import annotations | |
| from typing import Dict, List, Tuple | |
| # tinygrad side | |
| from tinygrad import dtypes as TG | |
| from tinygrad.dtype import least_upper_dtype | |
| # jax side (optional) | |
| try: | |
| import jax | |
| import jax.numpy as jnp | |
| import ml_dtypes as ml | |
| HAS_JAX = True | |
| except Exception as e: # pragma: no cover | |
| HAS_JAX = False | |
| JAX_IMPORT_ERR = e | |
| # logical dtype keys we want to test | |
| DT_KEYS: List[str] = [ | |
| "fp8e4m3", "fp8e5m2", | |
| "float16", "bfloat16", "float32", "float64", | |
| "int64", "uint64", | |
| ] | |
| # mapping: logical key -> tinygrad dtype | |
| TG_MAP: Dict[str, TG.DType] = { | |
| "fp8e4m3": TG.fp8e4m3, | |
| "fp8e5m2": TG.fp8e5m2, | |
| "float16": TG.float16, | |
| "bfloat16": TG.bfloat16, | |
| "float32": TG.float32, | |
| "float64": TG.float64, | |
| "int64": TG.int64, | |
| "uint64": TG.uint64, | |
| } | |
| # mapping: logical key -> jax dtype (using ml_dtypes for fp8) | |
| JX_MAP: Dict[str, object] = {} | |
| if HAS_JAX: | |
| JX_MAP = { | |
| "fp8e4m3": getattr(ml, "float8_e4m3fn"), | |
| "fp8e5m2": getattr(ml, "float8_e5m2"), | |
| "float16": jnp.float16, | |
| "bfloat16": jnp.bfloat16, | |
| "float32": jnp.float32, | |
| "float64": jnp.float64, | |
| "int64": jnp.int64, | |
| "uint64": jnp.uint64, | |
| } | |
| def tinygrad_name(dt: TG.DType) -> str: | |
| # __repr__ prints like "dtypes.float16"; trim prefix for compactness | |
| s = repr(dt) | |
| return s.replace("dtypes.", "") | |
| def jax_name(dt: object) -> str: | |
| # ml_dtypes and jnp dtypes have .name or str representations | |
| try: | |
| n = getattr(dt, "name", None) | |
| if n: return str(n) | |
| except Exception: | |
| pass | |
| return str(dt) | |
| def jax_promo(a: object, b: object) -> str: | |
| if not HAS_JAX: | |
| return f"<no jax: {type(JAX_IMPORT_ERR).__name__}: {JAX_IMPORT_ERR}>" | |
| try: | |
| out = jnp.result_type(a, b) | |
| return jax_name(out) | |
| except Exception as e: | |
| return f"<error: {type(e).__name__}>" | |
| return f"<error: {type(e).__name__}: {e}>" | |
| def tinygrad_promo(a: TG.DType, b: TG.DType) -> str: | |
| try: | |
| return tinygrad_name(least_upper_dtype(a, b)) | |
| except Exception as e: | |
| return f"<error: {type(e).__name__}: {e}>" | |
| def pairs_to_test(keys: List[str]) -> List[Tuple[str, str]]: | |
| pairs: List[Tuple[str, str]] = [] | |
| # test all ordered pairs (including same) | |
| for i, ka in enumerate(keys): | |
| for kb in keys[i:]: | |
| pairs.append((ka, kb)) | |
| # also include cross-variant fp8 pair explicitly first for readability | |
| if ("fp8e4m3", "fp8e5m2") not in pairs: | |
| pairs.insert(0, ("fp8e4m3", "fp8e5m2")) | |
| return pairs | |
| def main() -> int: | |
| print("JAX available:", HAS_JAX) | |
| if not HAS_JAX: | |
| print("Hint: pip install -U jax ml_dtypes # CPU-only") | |
| pairs = pairs_to_test(DT_KEYS) | |
| # header | |
| print("\nComparing dtype promotion (JAX vs tinygrad):\n") | |
| # print(f"{'A':<14} {'B':<14} {'JAX':<24} {'tinygrad':<24}") | |
| print(f"{'A':<14} {'B':<14} {'tinygrad':<24} {'jax':<24}") | |
| print("-" * 80) | |
| for ka, kb in pairs: | |
| tg_a, tg_b = TG_MAP[ka], TG_MAP[kb] | |
| jx_a = JX_MAP.get(ka, ka) | |
| jx_b = JX_MAP.get(kb, kb) | |
| jx_out = jax_promo(jx_a, jx_b) | |
| tg_out = tinygrad_promo(tg_a, tg_b) | |
| # print(f"{ka:<14} {kb:<14} {jx_out:<24} {tg_out:<24}") | |
| print(f"{ka:<14} {kb:<14} {tg_out:<24} {jx_out:<24}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| % python jax_promo.py | |
| JAX available: True | |
| Comparing dtype promotion (JAX vs tinygrad): | |
| A B tinygrad jax | |
| -------------------------------------------------------------------------------- | |
| fp8e4m3 fp8e4m3 fp8e4m3 float8_e4m3fn | |
| fp8e4m3 fp8e5m2 half <error: TypePromotionError> | |
| fp8e4m3 float16 half <error: TypePromotionError> | |
| fp8e4m3 bfloat16 bfloat16 <error: TypePromotionError> | |
| fp8e4m3 float32 float <error: TypePromotionError> | |
| fp8e4m3 float64 double <error: TypePromotionError> | |
| fp8e4m3 int64 half float8_e4m3fn | |
| fp8e4m3 uint64 half float8_e4m3fn | |
| fp8e5m2 fp8e5m2 fp8e5m2 float8_e5m2 | |
| fp8e5m2 float16 half <error: TypePromotionError> | |
| fp8e5m2 bfloat16 bfloat16 <error: TypePromotionError> | |
| fp8e5m2 float32 float <error: TypePromotionError> | |
| fp8e5m2 float64 double <error: TypePromotionError> | |
| fp8e5m2 int64 half float8_e5m2 | |
| fp8e5m2 uint64 half float8_e5m2 | |
| float16 float16 half float16 | |
| float16 bfloat16 float float32 | |
| float16 float32 float float32 | |
| float16 float64 double float32 | |
| float16 int64 half float16 | |
| float16 uint64 half float16 | |
| bfloat16 bfloat16 bfloat16 bfloat16 | |
| bfloat16 float32 float float32 | |
| bfloat16 float64 double float32 | |
| bfloat16 int64 bfloat16 bfloat16 | |
| bfloat16 uint64 bfloat16 bfloat16 | |
| float32 float32 float float32 | |
| float32 float64 double float32 | |
| float32 int64 float float32 | |
| float32 uint64 float float32 | |
| float64 float64 double float32 | |
| float64 int64 double float32 | |
| float64 uint64 double float32 | |
| int64 int64 long int32 | |
| int64 uint64 half float32 | |
| uint64 uint64 ulong uint32 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment