Created
June 27, 2025 09:07
-
-
Save dhbrojas/ab42ab774a2b9e461c69b8233f5f7e12 to your computer and use it in GitHub Desktop.
HuggingFace Compatible Attention Mask
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class AttentionMask: | |
| """ | |
| A (Batch, 1, Queries, Keys & Values) attention mask for attention between queries and keys/values. | |
| The mask is "additive" or "inversed" meaning it is a tensor of floating point values | |
| that can be added to the attention scores before the softmax operation. | |
| >>> 0 = Unmasked | |
| >>> dtype.min = Masked | |
| """ | |
| def __init__(self, value: Tensor): | |
| assert value.dim() == 4, f"expect {value.dim()} == 4, (B, T, K, V)" | |
| self.value = value | |
| @staticmethod | |
| def causal( | |
| batch: int, | |
| queries: int, | |
| keys_and_values: int, | |
| *, | |
| dtype: torch.dtype = torch.bfloat16, | |
| device: torch.device | None = None, | |
| ) -> "AttentionMask": | |
| assert keys_and_values - queries >= 0, f"expect {keys_and_values} - {queries} >= 0" | |
| mask = torch.zeros( | |
| (queries, keys_and_values), | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| mask[:, (keys_and_values - queries) :] = torch.triu( | |
| torch.full((queries, queries), finfo(dtype).min, dtype=dtype, device=device), diagonal=1 | |
| ) | |
| return AttentionMask(mask.unsqueeze(0).expand(batch, 1, queries, keys_and_values)) | |
| @staticmethod | |
| def full( | |
| batch: int, | |
| queries: int, | |
| keys_and_values: int, | |
| *, | |
| dtype: torch.dtype = torch.bfloat16, | |
| device: torch.device | None = None, | |
| ) -> "AttentionMask": | |
| return AttentionMask(torch.zeros((batch, 1, queries, keys_and_values), dtype=dtype, device=device)) | |
| @staticmethod | |
| def document( | |
| batch: int, | |
| queries: int, | |
| keys_and_values: int, | |
| docs: Tensor, | |
| *, | |
| causal: bool = False, | |
| dtype: torch.dtype = torch.bfloat16, | |
| device: torch.device | None = None, | |
| ) -> "AttentionMask": | |
| assert keys_and_values >= queries | |
| assert docs.dim() == 2 | |
| assert docs.size(1) == queries, f"expect {docs.size(1)} == {queries}" | |
| if causal: | |
| mask = AttentionMask.causal(batch, queries, keys_and_values, dtype=dtype, device=device) | |
| else: | |
| mask = AttentionMask.full(batch, queries, keys_and_values, dtype=dtype, device=device) | |
| cache_size = keys_and_values - queries | |
| query_doc_ids = docs[:, cache_size:] | |
| key_doc_ids = docs | |
| # Expand dimensions so they can be compared: | |
| query_doc_ids = query_doc_ids.unsqueeze(2) | |
| key_doc_ids = key_doc_ids.unsqueeze(1) | |
| same_document = query_doc_ids == key_doc_ids | |
| mask.value = torch.where( | |
| same_document.unsqueeze(1), mask.value, torch.full_like(mask.value, torch.finfo(dtype).min) | |
| ) | |
| return mask |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment