Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created October 3, 2025 00:09
Show Gist options
  • Select an option

  • Save mlazos/2c1fdbd0d7cf5beefdf41f90c5cea3ed to your computer and use it in GitHub Desktop.

Select an option

Save mlazos/2c1fdbd0d7cf5beefdf41f90c5cea3ed to your computer and use it in GitHub Desktop.
def test_get_current_stream_return(self):
def fn(x, s):
with s:
s0 = torch.cuda.current_stream()
return x, s0
s_inp = torch.Stream(device="cuda")
inp = (torch.ones(2, 2) + 1, s_inp)
fn_opt = torch.compile(fn, fullgraph=True)
_, s0 = fn_opt(*inp)
_, s1 = fn_opt(*inp)
self.assertEqual(s_inp, s0)
self.assertEqual(s0, s1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment