Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created September 23, 2025 17:07
Show Gist options
  • Select an option

  • Save mlazos/c77696e2c637359c73f14c7b4bceda1b to your computer and use it in GitHub Desktop.

Select an option

Save mlazos/c77696e2c637359c73f14c7b4bceda1b to your computer and use it in GitHub Desktop.
def llama_shapes():
# batch sizes * seq lengths
BS = [2**i for i in range(4, 17)]
#BS = [2**i for i in range(16, 17)]
# attn: wqkv, wo; ffn: w13, w2
KN = [
(4096, 12288),
(4096, 4096),
(4096, 22016),
(11008, 4096),
(8192, 1280),
(1024, 8192),
(8192, 7168),
(3584, 8192),
(16384, 2304),
(2048, 16384),
(16384, 13312),
(6656, 16384),
]
return [(bs, n, k) for bs, (k, n) in itertools.product(BS, KN)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment