Skip to content

Instantly share code, notes, and snippets.

@jinyangustc
Created April 19, 2025 22:14
Show Gist options
  • Select an option

  • Save jinyangustc/28fedfb71cd8c3b2e78c68931b8de3e6 to your computer and use it in GitHub Desktop.

Select an option

Save jinyangustc/28fedfb71cd8c3b2e78c68931b8de3e6 to your computer and use it in GitHub Desktop.
Python Sum Types
import time
import traceback
from typing import Literal, final
class String:
"""
A variant containing a string
"""
class ListOfInts:
"""
A variant containing a list of ints
"""
@final
class MySumType:
def __init__(
self,
adt: tuple[String, str] | tuple[ListOfInts, list[int]],
) -> None:
self.adt = adt
def consume(err_data: str):
x1 = MySumType((String(), 'hello'))
x2 = MySumType((ListOfInts(), [1, 2, 3]))
x3 = MySumType((ListOfInts(), err_data))
for x in [x1, x2, x3]:
match x.adt:
case (String(), data):
print(data.upper())
case (ListOfInts(), data):
print(sum(data))
if __name__ == '__main__':
try:
consume('lsp will show error')
except TypeError as e:
traceback.print_exc()
print('---')
# --- benchmark ---
data = ['hello', 'world', [1, 2, 3], list(range(1000))]
sum_type_data: list[MySumType] = []
tuple_data: list[
tuple[Literal['str'], str] | tuple[Literal['list_of_ints'], list[int]]
] = []
for x in data:
if isinstance(x, str):
sum_type_data.append(MySumType((String(), x)))
tuple_data.append(('str', x))
else:
sum_type_data.append(MySumType((ListOfInts(), x)))
tuple_data.append(('list_of_ints', x))
max_iter = 1_000_000
counter = 0
start_time = time.perf_counter()
for i in range(max_iter):
for x in sum_type_data:
match x.adt:
case (String(), xx):
counter += len(xx)
case (ListOfInts(), xx):
counter += sum(xx)
end_time = time.perf_counter()
print(
f'match with wrapper class: {end_time - start_time:.4f} seconds for {max_iter} iterations'
)
counter = 0
start_time = time.perf_counter()
for i in range(max_iter):
for x in tuple_data:
match x:
case ('str', xx):
counter += len(xx)
case ('list_of_ints', xx):
counter += sum(xx)
end_time = time.perf_counter()
print(
f'match without wrapper class: {end_time - start_time:.4f} seconds for {max_iter} iterations'
)
counter = 0
start_time = time.perf_counter()
for i in range(max_iter):
for x in data:
if isinstance(x, str):
counter += len(x)
else:
counter += sum(x)
end_time = time.perf_counter()
print(
f'if-else on native values: {end_time - start_time:.4f} seconds for {max_iter} iterations'
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment