Last active
October 4, 2022 17:32
-
-
Save lebrice/f3ff6274ecfd7eb96f0584d1195e8ab4 to your computer and use it in GitHub Desktop.
Consolidating the cache on the Mila cluster
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Sets up a user cache directory for commonly used libraries, while reusing shared cache entries. | |
| Use this to avoid having to download files to the $HOME directory, as well as to remove | |
| duplicated downloads and free up space in your $HOME and $SCRATCH directories. | |
| The user cache directory should be writeable, and doesn't need to be empty. | |
| This command adds symlinks to (some of) the files contained in the *shared* cache directory to this | |
| user cache directory. | |
| The shared cache directory should be readable (e.g. a directory containing frequently-downloaded | |
| weights/checkpoints, managed by the IT/IDT Team at Mila). | |
| TODO: | |
| This command also sets the environment variables via a block in the `$HOME/.bashrc` file, so that | |
| these libraries look in the specified user cache for these files. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import shutil | |
| from dataclasses import dataclass | |
| from logging import getLogger as get_logger | |
| from pathlib import Path | |
| from typing import Literal | |
| from tqdm import tqdm | |
| logger = get_logger(__name__) | |
| SCRATCH = Path.home() / "scratch" | |
| # TODO: Might be a good idea to actually leave it at $HOME/.cache, and raise a warning, | |
| # since that wouldn require the users to set any environment variables. | |
| DEFAULT_USER_CACHE_DIR = SCRATCH / ".cache" | |
| # TODO: Change to an actual IDT-approved, read-only directory. Using another dir in my scratch for now. | |
| # SHARED_CACHE_DIR = Path("/network/shared_cache") | |
| SHARED_CACHE_DIR = SCRATCH / "shared_cache" | |
| def setup_cache(user_cache_dir: Path, shared_cache_dir: Path) -> None: | |
| """Set up the user cache directory. | |
| 1. If the `user_cache_dir` directory doesn't exist, creates it. | |
| 2. Removes broken symlinks in the user cache directory if they point to files in | |
| `shared_cache_dir` that don't exist anymore. | |
| 3. For every file in the shared cache dir, creates a (symbolic?) link to it in the | |
| `user_cache_dir`. | |
| """ | |
| if not user_cache_dir.exists(): | |
| user_cache_dir.mkdir(parents=True, exist_ok=False) | |
| if not user_cache_dir.is_dir(): | |
| raise RuntimeError(f"cache_dir is not a directory: {user_cache_dir}") | |
| if not shared_cache_dir.is_dir(): | |
| raise RuntimeError( | |
| f"The shared cache directory {shared_cache_dir} doesn't exist, or isn't a directory! " | |
| ) | |
| delete_broken_symlinks_to_shared_cache(user_cache_dir, shared_cache_dir) | |
| create_links(user_cache_dir, shared_cache_dir) | |
| set_environment_variables(user_cache_dir) | |
| def set_environment_variables(user_cache_dir: Path): | |
| """Set the relevant environment variables for each library so they start to use the new cache | |
| dir. | |
| """ | |
| # TODO: These changes won't persist. We probably need to add a block of code in .bashrc | |
| os.environ["TORCH_HOME"] = str(user_cache_dir / "torch") | |
| os.environ["HF_HOME"] = str(user_cache_dir / "huggingface") | |
| os.environ["TRANSFORMERS_CACHE"] = str( | |
| user_cache_dir / "huggingface" / "transformers" | |
| ) | |
| def is_child(path: Path, parent: Path) -> bool: | |
| """Return True if the path is under the parent directory.""" | |
| if path == parent: | |
| return False | |
| try: | |
| path.relative_to(parent) | |
| return True | |
| except ValueError: | |
| return False | |
| def delete_broken_symlinks_to_shared_cache( | |
| user_cache_dir: Path, shared_cache_dir: Path | |
| ): | |
| """Delete all symlinks in the user cache directory that point to files that don't exist anymore | |
| in the shared cache directory. """ | |
| for file in user_cache_dir.rglob("*"): | |
| if file.is_symlink(): | |
| target = file.resolve() | |
| if is_child(target, shared_cache_dir) and not target.exists(): | |
| logger.debug(f"Removing broken symlink: {file}") | |
| if file.is_dir(): | |
| file.rmdir() | |
| else: | |
| file.unlink() | |
| def create_links(user_cache_dir: Path, shared_cache_dir: Path): | |
| """Create symlinks to the shared cache directory in the user cache directory. """ | |
| # For every file in the shared cache dir, create a (symbolic?) link to it in the user cache dir | |
| pbar = tqdm() | |
| def _copy_fn(src: str, dst: str) -> None: | |
| # NOTE: This also overwrites the files in the user directory with symlinks to the same files in | |
| # the shared directory. We might not necessarily want to do that. | |
| # For instance, we might want to do a checksum or something first, to check that they have | |
| # exactly the same contents. | |
| src_path = Path(src) | |
| dst_path = Path(dst) | |
| rel_d = dst_path.relative_to(user_cache_dir) | |
| rel_s = src_path.relative_to(shared_cache_dir) | |
| if dst_path.exists(): | |
| if dst_path.is_symlink(): | |
| # From a previous run. | |
| return | |
| # Replace "real" files with symlinks. | |
| dst_path.unlink() | |
| # print(f"Linking {rel_s}") | |
| pbar.set_description(f"Linking {rel_s}") | |
| pbar.update(1) | |
| os.symlink(src, dst) # Create symlinks instead of copying. | |
| shutil.copytree( | |
| shared_cache_dir, | |
| user_cache_dir, | |
| symlinks=True, | |
| copy_function=_copy_fn, | |
| dirs_exist_ok=True, | |
| ) | |
| @dataclass | |
| class Options: | |
| """ Options for the setup_cache command. """ | |
| user_cache_dir: Path = DEFAULT_USER_CACHE_DIR | |
| """The user cache directory. Should probably be in $SCRATCH (not $HOME!) """ | |
| shared_cache_dir: Path = SHARED_CACHE_DIR | |
| """ The path to the shared cache directory. | |
| This defaults to the path of the shared cache setup by the IDT team on the Mila cluster. | |
| """ | |
| framework_subdirectory: str = "all" | |
| """The name of a subdirectory of `shared_cache_dir` to link, or 'all' to create symlinks for | |
| every file in `shared_cache_dir`. Defaults to 'all'. | |
| """ | |
| def __post_init__(self): | |
| if self.framework_subdirectory != "all": | |
| available_subdirectories = [p.name for p in self.shared_cache_dir.iterdir()] | |
| if self.framework_subdirectory not in available_subdirectories: | |
| raise ValueError( | |
| f"The framework subdirectory '{self.framework_subdirectory}' does not exist in " | |
| f"{self.shared_cache_dir}. \n" | |
| f"Frameworks/subdirectories available in the shared cache: {available_subdirectories}" | |
| ) | |
| self.user_cache_dir = self.user_cache_dir / self.framework_subdirectory | |
| self.shared_cache_dir = self.shared_cache_dir / self.framework_subdirectory | |
| def main(): | |
| from simple_parsing import ArgumentParser | |
| parser = ArgumentParser(description=__doc__) | |
| parser.add_arguments(Options, dest="options") | |
| args = parser.parse_args() | |
| options: Options = args.options | |
| setup_cache(options.user_cache_dir, options.shared_cache_dir) | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Callable | |
| import pytest | |
| import torchvision | |
| from torchvision.models import resnet18 | |
| from milatools.setup_cache import create_links | |
| def create_dummy_dir_tree( | |
| parent_dir: Path, files: list[str], mode: int | None = None | |
| ) -> None: | |
| parent_dir.mkdir() | |
| paths = [parent_dir / file for file in files] | |
| for path in paths: | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| path.write_text(f"Hello, this is the content of {path}") | |
| if mode is not None: | |
| path.chmod(mode) | |
| def test_create_links(tmp_path: Path): | |
| # Create /foo/bar/baz.txt, call setup_cache, and see if it works. | |
| user_dir = tmp_path / "user" | |
| shared_dir = tmp_path / "shared" | |
| user_files = ["foo.txt", "bar/baz.txt"] | |
| shared_files = ["shared.txt"] | |
| # Create some dummy directories. | |
| create_dummy_dir_tree(user_dir, user_files) | |
| create_dummy_dir_tree(shared_dir, shared_files) | |
| user_paths = [user_dir / p for p in user_files] | |
| shared_paths = [shared_dir / p for p in shared_files] | |
| user_file_contents = {p: p.read_text() for p in user_paths} | |
| shared_file_contents = {p: p.read_text() for p in shared_paths} | |
| create_links(user_dir, shared_dir) | |
| for user_file in user_files: | |
| user_path = user_dir / user_file | |
| assert user_path.exists() | |
| if user_file not in shared_files: | |
| # User paths and contents stay the same. | |
| assert not user_path.is_symlink() | |
| assert user_path.read_text() == user_file_contents[user_path] | |
| else: | |
| # The 'real' path should be replaced with a symbolic link! | |
| assert user_path.is_symlink() | |
| assert user_path.read_text() == shared_file_contents[user_path] | |
| for shared_file in shared_files: | |
| new_symlink_location = user_dir / shared_file | |
| assert new_symlink_location.exists() | |
| target = shared_dir / shared_file | |
| assert new_symlink_location.is_symlink() | |
| assert new_symlink_location.resolve() == target | |
| assert new_symlink_location.read_text() == shared_file_contents[target] | |
| @pytest.fixture | |
| def empty_shared_cache_dir(tmp_path_factory, monkeypatch): | |
| """ Fake shared cache directory. """ | |
| shared_cache_dir: Path = tmp_path_factory.mktemp("shared_empty_cache_dir") | |
| shared_cache_dir.chmod(0o444) | |
| monkeypatch.setitem(os.environ, "TORCH_HOME", str(shared_cache_dir)) | |
| yield shared_cache_dir | |
| @pytest.fixture(scope="session") | |
| def polulated_shared_cache_dir(tmp_path_factory, monkeypatch): | |
| shared_cache_dir: Path = tmp_path_factory.mktemp("shared_cache_dir") | |
| monkeypatch.setitem(os.environ, "TORCH_HOME", str(shared_cache_dir.absolute())) | |
| resnet18(pretrained=True, progress=True) | |
| checkpoint_file = shared_cache_dir / "hub" / "checkpoints" / "resnet18-f37072fd.pth" | |
| assert checkpoint_file.exists() | |
| # TODO: Double-check that this works. | |
| mode = checkpoint_file.stat() | |
| shared_cache_dir_mode = shared_cache_dir.stat() | |
| checkpoint_file.chmod(0o444) | |
| shared_cache_dir.chmod(0o444) | |
| monkeypatch.setitem(os.environ, "TORCH_HOME", str(shared_cache_dir)) | |
| yield | |
| # NOTE: This might not actually be necessary. My concern was that pytest wouldn't be able to | |
| # delete these if I changed the permission here. | |
| checkpoint_file.chmod(mode.st_mode) | |
| shared_cache_dir.chmod(shared_cache_dir_mode.st_mode) | |
| library_function_to_created_files = { | |
| resnet18: [Path("hub/checkpoints/resnet18-f37072fd.pth")], | |
| } | |
| def get_all_files_in_dir(dir_path: Path) -> list[Path]: | |
| return [p.relative_to(dir_path) for p in dir_path.glob("**/*") if not p.is_dir()] | |
| class TestTorchvision: | |
| """ TODO: Add more tests specific to each library. """ | |
| def test_cant_write_to_shared_cache_dir( | |
| self, empty_shared_cache_dir: Path, monkeypatch | |
| ): | |
| monkeypatch.setitem(os.environ, "TORCH_HOME", str(empty_shared_cache_dir)) | |
| assert len(list(empty_shared_cache_dir.iterdir())) == 0 | |
| with pytest.raises(IOError, match="Permission denied"): | |
| resnet18(pretrained=True, progress=True) | |
| def test_changing_torch_home_works(self, tmp_path: Path, monkeypatch): | |
| """Test that changing the TORCH_HOME environment variable changes where the weights are saved. """ | |
| monkeypatch.setitem(os.environ, "TORCH_HOME", str(tmp_path.absolute())) | |
| assert len(list(tmp_path.iterdir())) == 0 | |
| resnet18(pretrained=True, progress=True) | |
| files_after = list(tmp_path.iterdir()) | |
| assert len(files_after) == 1 | |
| # Basically, we want to check that this created a `hub/` | |
| assert get_all_files_in_dir(tmp_path) == [ | |
| Path("hub") / "checkpoints" / "resnet18-f37072fd.pth" | |
| ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment