Skip to content

Instantly share code, notes, and snippets.

@dhbrojas
Created June 27, 2025 09:07
Show Gist options
  • Select an option

  • Save dhbrojas/ab42ab774a2b9e461c69b8233f5f7e12 to your computer and use it in GitHub Desktop.

Select an option

Save dhbrojas/ab42ab774a2b9e461c69b8233f5f7e12 to your computer and use it in GitHub Desktop.
HuggingFace Compatible Attention Mask
class AttentionMask:
"""
A (Batch, 1, Queries, Keys & Values) attention mask for attention between queries and keys/values.
The mask is "additive" or "inversed" meaning it is a tensor of floating point values
that can be added to the attention scores before the softmax operation.
>>> 0 = Unmasked
>>> dtype.min = Masked
"""
def __init__(self, value: Tensor):
assert value.dim() == 4, f"expect {value.dim()} == 4, (B, T, K, V)"
self.value = value
@staticmethod
def causal(
batch: int,
queries: int,
keys_and_values: int,
*,
dtype: torch.dtype = torch.bfloat16,
device: torch.device | None = None,
) -> "AttentionMask":
assert keys_and_values - queries >= 0, f"expect {keys_and_values} - {queries} >= 0"
mask = torch.zeros(
(queries, keys_and_values),
dtype=dtype,
device=device,
)
mask[:, (keys_and_values - queries) :] = torch.triu(
torch.full((queries, queries), finfo(dtype).min, dtype=dtype, device=device), diagonal=1
)
return AttentionMask(mask.unsqueeze(0).expand(batch, 1, queries, keys_and_values))
@staticmethod
def full(
batch: int,
queries: int,
keys_and_values: int,
*,
dtype: torch.dtype = torch.bfloat16,
device: torch.device | None = None,
) -> "AttentionMask":
return AttentionMask(torch.zeros((batch, 1, queries, keys_and_values), dtype=dtype, device=device))
@staticmethod
def document(
batch: int,
queries: int,
keys_and_values: int,
docs: Tensor,
*,
causal: bool = False,
dtype: torch.dtype = torch.bfloat16,
device: torch.device | None = None,
) -> "AttentionMask":
assert keys_and_values >= queries
assert docs.dim() == 2
assert docs.size(1) == queries, f"expect {docs.size(1)} == {queries}"
if causal:
mask = AttentionMask.causal(batch, queries, keys_and_values, dtype=dtype, device=device)
else:
mask = AttentionMask.full(batch, queries, keys_and_values, dtype=dtype, device=device)
cache_size = keys_and_values - queries
query_doc_ids = docs[:, cache_size:]
key_doc_ids = docs
# Expand dimensions so they can be compared:
query_doc_ids = query_doc_ids.unsqueeze(2)
key_doc_ids = key_doc_ids.unsqueeze(1)
same_document = query_doc_ids == key_doc_ids
mask.value = torch.where(
same_document.unsqueeze(1), mask.value, torch.full_like(mask.value, torch.finfo(dtype).min)
)
return mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment