Last active
September 28, 2022 08:32
-
-
Save sachdevkartik/4bb35299b9fb7b718d5c79bf1c672b74 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Optional | |
| from torchvision import transforms | |
| from PIL import Image | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| from torch.utils.data import DataLoader, Dataset | |
| def get_transform_train( | |
| upsample_size: int, final_size: int, channels: Optional[int] = 1 | |
| ): | |
| """Trainset transformation | |
| Args: | |
| upsample_size (int): intermediate upsampling size | |
| final_size (int): final size of image to network | |
| channels (Optional[int], optional): number of channels of final image to network. Defaults to 1. | |
| Returns: | |
| Compose: transforms.Compose | |
| Example: | |
| >>> get_transform_test(387, 224, 1) | |
| """ | |
| transform_train = A.Compose( | |
| [ | |
| A.HorizontalFlip(p=0.25), | |
| A.VerticalFlip(p=0.25), | |
| A.Resize(final_size, final_size, p=1.0), | |
| ToTensorV2(), | |
| ] | |
| ) | |
| return transform_train | |
| def get_transform_test(final_size: int, channels: Optional[int] = 1): | |
| """Testset transformation | |
| Args: | |
| final_size (int): final size of image to network | |
| channels (int, optional): number of channels of final image. Defaults to 1. | |
| Returns: | |
| Compose: transforms.Compose | |
| Example: | |
| >>> get_transform_test(224, 1) | |
| """ | |
| transform_test = A.Compose( | |
| [A.Resize(final_size, final_size, p=1.0), ToTensorV2()] | |
| ) | |
| return transform_test |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment