Skip to content

Instantly share code, notes, and snippets.

@lebrice
Last active October 4, 2022 17:32
Show Gist options
  • Select an option

  • Save lebrice/f3ff6274ecfd7eb96f0584d1195e8ab4 to your computer and use it in GitHub Desktop.

Select an option

Save lebrice/f3ff6274ecfd7eb96f0584d1195e8ab4 to your computer and use it in GitHub Desktop.
Consolidating the cache on the Mila cluster
"""Sets up a user cache directory for commonly used libraries, while reusing shared cache entries.
Use this to avoid having to download files to the $HOME directory, as well as to remove
duplicated downloads and free up space in your $HOME and $SCRATCH directories.
The user cache directory should be writeable, and doesn't need to be empty.
This command adds symlinks to (some of) the files contained in the *shared* cache directory to this
user cache directory.
The shared cache directory should be readable (e.g. a directory containing frequently-downloaded
weights/checkpoints, managed by the IT/IDT Team at Mila).
TODO:
This command also sets the environment variables via a block in the `$HOME/.bashrc` file, so that
these libraries look in the specified user cache for these files.
"""
from __future__ import annotations
import logging
import os
import shutil
from dataclasses import dataclass
from logging import getLogger as get_logger
from pathlib import Path
from typing import Literal
from tqdm import tqdm
logger = get_logger(__name__)
SCRATCH = Path.home() / "scratch"
# TODO: Might be a good idea to actually leave it at $HOME/.cache, and raise a warning,
# since that wouldn require the users to set any environment variables.
DEFAULT_USER_CACHE_DIR = SCRATCH / ".cache"
# TODO: Change to an actual IDT-approved, read-only directory. Using another dir in my scratch for now.
# SHARED_CACHE_DIR = Path("/network/shared_cache")
SHARED_CACHE_DIR = SCRATCH / "shared_cache"
def setup_cache(user_cache_dir: Path, shared_cache_dir: Path) -> None:
"""Set up the user cache directory.
1. If the `user_cache_dir` directory doesn't exist, creates it.
2. Removes broken symlinks in the user cache directory if they point to files in
`shared_cache_dir` that don't exist anymore.
3. For every file in the shared cache dir, creates a (symbolic?) link to it in the
`user_cache_dir`.
"""
if not user_cache_dir.exists():
user_cache_dir.mkdir(parents=True, exist_ok=False)
if not user_cache_dir.is_dir():
raise RuntimeError(f"cache_dir is not a directory: {user_cache_dir}")
if not shared_cache_dir.is_dir():
raise RuntimeError(
f"The shared cache directory {shared_cache_dir} doesn't exist, or isn't a directory! "
)
delete_broken_symlinks_to_shared_cache(user_cache_dir, shared_cache_dir)
create_links(user_cache_dir, shared_cache_dir)
set_environment_variables(user_cache_dir)
def set_environment_variables(user_cache_dir: Path):
"""Set the relevant environment variables for each library so they start to use the new cache
dir.
"""
# TODO: These changes won't persist. We probably need to add a block of code in .bashrc
os.environ["TORCH_HOME"] = str(user_cache_dir / "torch")
os.environ["HF_HOME"] = str(user_cache_dir / "huggingface")
os.environ["TRANSFORMERS_CACHE"] = str(
user_cache_dir / "huggingface" / "transformers"
)
def is_child(path: Path, parent: Path) -> bool:
"""Return True if the path is under the parent directory."""
if path == parent:
return False
try:
path.relative_to(parent)
return True
except ValueError:
return False
def delete_broken_symlinks_to_shared_cache(
user_cache_dir: Path, shared_cache_dir: Path
):
"""Delete all symlinks in the user cache directory that point to files that don't exist anymore
in the shared cache directory. """
for file in user_cache_dir.rglob("*"):
if file.is_symlink():
target = file.resolve()
if is_child(target, shared_cache_dir) and not target.exists():
logger.debug(f"Removing broken symlink: {file}")
if file.is_dir():
file.rmdir()
else:
file.unlink()
def create_links(user_cache_dir: Path, shared_cache_dir: Path):
"""Create symlinks to the shared cache directory in the user cache directory. """
# For every file in the shared cache dir, create a (symbolic?) link to it in the user cache dir
pbar = tqdm()
def _copy_fn(src: str, dst: str) -> None:
# NOTE: This also overwrites the files in the user directory with symlinks to the same files in
# the shared directory. We might not necessarily want to do that.
# For instance, we might want to do a checksum or something first, to check that they have
# exactly the same contents.
src_path = Path(src)
dst_path = Path(dst)
rel_d = dst_path.relative_to(user_cache_dir)
rel_s = src_path.relative_to(shared_cache_dir)
if dst_path.exists():
if dst_path.is_symlink():
# From a previous run.
return
# Replace "real" files with symlinks.
dst_path.unlink()
# print(f"Linking {rel_s}")
pbar.set_description(f"Linking {rel_s}")
pbar.update(1)
os.symlink(src, dst) # Create symlinks instead of copying.
shutil.copytree(
shared_cache_dir,
user_cache_dir,
symlinks=True,
copy_function=_copy_fn,
dirs_exist_ok=True,
)
@dataclass
class Options:
""" Options for the setup_cache command. """
user_cache_dir: Path = DEFAULT_USER_CACHE_DIR
"""The user cache directory. Should probably be in $SCRATCH (not $HOME!) """
shared_cache_dir: Path = SHARED_CACHE_DIR
""" The path to the shared cache directory.
This defaults to the path of the shared cache setup by the IDT team on the Mila cluster.
"""
framework_subdirectory: str = "all"
"""The name of a subdirectory of `shared_cache_dir` to link, or 'all' to create symlinks for
every file in `shared_cache_dir`. Defaults to 'all'.
"""
def __post_init__(self):
if self.framework_subdirectory != "all":
available_subdirectories = [p.name for p in self.shared_cache_dir.iterdir()]
if self.framework_subdirectory not in available_subdirectories:
raise ValueError(
f"The framework subdirectory '{self.framework_subdirectory}' does not exist in "
f"{self.shared_cache_dir}. \n"
f"Frameworks/subdirectories available in the shared cache: {available_subdirectories}"
)
self.user_cache_dir = self.user_cache_dir / self.framework_subdirectory
self.shared_cache_dir = self.shared_cache_dir / self.framework_subdirectory
def main():
from simple_parsing import ArgumentParser
parser = ArgumentParser(description=__doc__)
parser.add_arguments(Options, dest="options")
args = parser.parse_args()
options: Options = args.options
setup_cache(options.user_cache_dir, options.shared_cache_dir)
if __name__ == "__main__":
main()
from __future__ import annotations
import os
from pathlib import Path
from typing import Callable
import pytest
import torchvision
from torchvision.models import resnet18
from milatools.setup_cache import create_links
def create_dummy_dir_tree(
parent_dir: Path, files: list[str], mode: int | None = None
) -> None:
parent_dir.mkdir()
paths = [parent_dir / file for file in files]
for path in paths:
path.parent.mkdir(exist_ok=True, parents=True)
path.write_text(f"Hello, this is the content of {path}")
if mode is not None:
path.chmod(mode)
def test_create_links(tmp_path: Path):
# Create /foo/bar/baz.txt, call setup_cache, and see if it works.
user_dir = tmp_path / "user"
shared_dir = tmp_path / "shared"
user_files = ["foo.txt", "bar/baz.txt"]
shared_files = ["shared.txt"]
# Create some dummy directories.
create_dummy_dir_tree(user_dir, user_files)
create_dummy_dir_tree(shared_dir, shared_files)
user_paths = [user_dir / p for p in user_files]
shared_paths = [shared_dir / p for p in shared_files]
user_file_contents = {p: p.read_text() for p in user_paths}
shared_file_contents = {p: p.read_text() for p in shared_paths}
create_links(user_dir, shared_dir)
for user_file in user_files:
user_path = user_dir / user_file
assert user_path.exists()
if user_file not in shared_files:
# User paths and contents stay the same.
assert not user_path.is_symlink()
assert user_path.read_text() == user_file_contents[user_path]
else:
# The 'real' path should be replaced with a symbolic link!
assert user_path.is_symlink()
assert user_path.read_text() == shared_file_contents[user_path]
for shared_file in shared_files:
new_symlink_location = user_dir / shared_file
assert new_symlink_location.exists()
target = shared_dir / shared_file
assert new_symlink_location.is_symlink()
assert new_symlink_location.resolve() == target
assert new_symlink_location.read_text() == shared_file_contents[target]
@pytest.fixture
def empty_shared_cache_dir(tmp_path_factory, monkeypatch):
""" Fake shared cache directory. """
shared_cache_dir: Path = tmp_path_factory.mktemp("shared_empty_cache_dir")
shared_cache_dir.chmod(0o444)
monkeypatch.setitem(os.environ, "TORCH_HOME", str(shared_cache_dir))
yield shared_cache_dir
@pytest.fixture(scope="session")
def polulated_shared_cache_dir(tmp_path_factory, monkeypatch):
shared_cache_dir: Path = tmp_path_factory.mktemp("shared_cache_dir")
monkeypatch.setitem(os.environ, "TORCH_HOME", str(shared_cache_dir.absolute()))
resnet18(pretrained=True, progress=True)
checkpoint_file = shared_cache_dir / "hub" / "checkpoints" / "resnet18-f37072fd.pth"
assert checkpoint_file.exists()
# TODO: Double-check that this works.
mode = checkpoint_file.stat()
shared_cache_dir_mode = shared_cache_dir.stat()
checkpoint_file.chmod(0o444)
shared_cache_dir.chmod(0o444)
monkeypatch.setitem(os.environ, "TORCH_HOME", str(shared_cache_dir))
yield
# NOTE: This might not actually be necessary. My concern was that pytest wouldn't be able to
# delete these if I changed the permission here.
checkpoint_file.chmod(mode.st_mode)
shared_cache_dir.chmod(shared_cache_dir_mode.st_mode)
library_function_to_created_files = {
resnet18: [Path("hub/checkpoints/resnet18-f37072fd.pth")],
}
def get_all_files_in_dir(dir_path: Path) -> list[Path]:
return [p.relative_to(dir_path) for p in dir_path.glob("**/*") if not p.is_dir()]
class TestTorchvision:
""" TODO: Add more tests specific to each library. """
def test_cant_write_to_shared_cache_dir(
self, empty_shared_cache_dir: Path, monkeypatch
):
monkeypatch.setitem(os.environ, "TORCH_HOME", str(empty_shared_cache_dir))
assert len(list(empty_shared_cache_dir.iterdir())) == 0
with pytest.raises(IOError, match="Permission denied"):
resnet18(pretrained=True, progress=True)
def test_changing_torch_home_works(self, tmp_path: Path, monkeypatch):
"""Test that changing the TORCH_HOME environment variable changes where the weights are saved. """
monkeypatch.setitem(os.environ, "TORCH_HOME", str(tmp_path.absolute()))
assert len(list(tmp_path.iterdir())) == 0
resnet18(pretrained=True, progress=True)
files_after = list(tmp_path.iterdir())
assert len(files_after) == 1
# Basically, we want to check that this created a `hub/`
assert get_all_files_in_dir(tmp_path) == [
Path("hub") / "checkpoints" / "resnet18-f37072fd.pth"
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment