Skip to content

Instantly share code, notes, and snippets.

@sachdevkartik
Last active September 28, 2022 08:32
Show Gist options
  • Select an option

  • Save sachdevkartik/4bb35299b9fb7b718d5c79bf1c672b74 to your computer and use it in GitHub Desktop.

Select an option

Save sachdevkartik/4bb35299b9fb7b718d5c79bf1c672b74 to your computer and use it in GitHub Desktop.
from typing import Optional
from torchvision import transforms
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
def get_transform_train(
upsample_size: int, final_size: int, channels: Optional[int] = 1
):
"""Trainset transformation
Args:
upsample_size (int): intermediate upsampling size
final_size (int): final size of image to network
channels (Optional[int], optional): number of channels of final image to network. Defaults to 1.
Returns:
Compose: transforms.Compose
Example:
>>> get_transform_test(387, 224, 1)
"""
transform_train = A.Compose(
[
A.HorizontalFlip(p=0.25),
A.VerticalFlip(p=0.25),
A.Resize(final_size, final_size, p=1.0),
ToTensorV2(),
]
)
return transform_train
def get_transform_test(final_size: int, channels: Optional[int] = 1):
"""Testset transformation
Args:
final_size (int): final size of image to network
channels (int, optional): number of channels of final image. Defaults to 1.
Returns:
Compose: transforms.Compose
Example:
>>> get_transform_test(224, 1)
"""
transform_test = A.Compose(
[A.Resize(final_size, final_size, p=1.0), ToTensorV2()]
)
return transform_test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment