Created
August 11, 2025 01:47
-
-
Save Shivanshu-Gupta/e36704de82caa01e16d6d57c7e1b1c41 to your computer and use it in GitHub Desktop.
A dataclass with dict- and attribute-style access and lots of additional freebies!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Taken from https://github.com/Shivanshu-Gupta/attr-hyperparams | |
| import attr | |
| import cattr | |
| from collections.abc import MutableMapping | |
| from collections import UserList | |
| from typing import IO, Iterable, Text, Union, get_origin, get_args | |
| from pathlib import Path | |
| from itertools import product | |
| from copy import deepcopy | |
| converter = cattr.Converter() | |
| converter.register_structure_hook(Path, lambda d, t: Path(d)) | |
| converter.register_unstructure_hook( Path, lambda d: str(d) ) # type: ignore | |
| converter.register_structure_hook( | |
| bool, lambda v, _: {"false": False, "true": True}[str(v).lower()] | |
| ) | |
| # converter.register_unstructure_hook( | |
| # bool, lambda v: {False: "false", True: "true"}[v] | |
| # ) | |
| def nest_dict(flat, sep='.'): | |
| def _nest_dict_rec(k, v, out, sep='.'): | |
| k, *rest = k.split(sep, 1) | |
| if rest: | |
| _nest_dict_rec(rest[0], v, out.setdefault(k, {})) | |
| else: | |
| out[k] = v | |
| result = {} | |
| for k, v in flat.items(): | |
| _nest_dict_rec(k, v, result, sep=sep) | |
| return result | |
| def default_value(default): | |
| return attr.ib(default=attr.Factory(lambda: default)) | |
| @attr.s(auto_attribs=True) | |
| class DictDataClass(MutableMapping): | |
| """ | |
| Allow dict-like access to attributes using ``[]`` operator in addition to dot-access. | |
| Easy serisation to/deserilation from nested dict, flattened dict, json and yaml. | |
| """ | |
| def __iter__(self): | |
| return iter(vars(self)) | |
| def __getitem__(self, item): | |
| if '.' in item: | |
| keys = item.split('.') | |
| d = self | |
| for k in keys: | |
| d = d[k] | |
| return d | |
| else: | |
| return getattr(self, item) | |
| def __setitem__(self, key, value): | |
| if '.' in key: | |
| keys = key.split('.') | |
| d = self | |
| for k in keys[:-1]: | |
| d = d[k] | |
| setattr(d, keys[-1], value) | |
| else: | |
| return setattr(self, key, value) | |
| def __delitem__(self, key): | |
| raise NotImplementedError | |
| def __len__(self): | |
| return len(vars(self)) | |
| def to_dict(self) -> dict: | |
| """ Serialize to a nested dict """ | |
| # this doesn't always produce a json serializable dict (eg. with Path objects) | |
| # return attr.asdict(self) | |
| global converter | |
| return converter.unstructure(self) | |
| def to_flattened_dict(self, sep='.', _parent_key='') -> dict: | |
| """ Seralize to a flattened dict using the given separator `sep` """ | |
| # _d = pd.json_normalize(attr.asdict(self), sep=sep).iloc[0].to_dict() | |
| flat_d = {} | |
| for k, v in self.items(): | |
| if _parent_key: k = _parent_key + sep + k | |
| if isinstance(v, DictDataClass): | |
| flat_d.update(v.to_flattened_dict(sep=sep, _parent_key=k)) | |
| else: | |
| flat_d[k] = v | |
| return flat_d | |
| @classmethod | |
| def from_flattened_dict(cls, d: dict, sep='.'): | |
| """ Deserialize from a flattened dict `d` using the given separator `sep` """ | |
| return cls.from_dict(nest_dict(d, sep=sep)) | |
| @classmethod | |
| def from_dict(cls, d: dict): | |
| """ Deserialize from a nested dict `d` """ | |
| global converter | |
| converter = converter.copy() | |
| disambiguators = cls._get_all_disambiguators() | |
| for union_type, func in disambiguators.items(): | |
| converter.register_structure_hook(union_type, lambda o, t, hook=func: converter.structure(o, hook(o, t))) | |
| return converter.structure(d, cls) | |
| @classmethod | |
| def get_disambiguators(cls): | |
| return {} | |
| @classmethod | |
| def _get_all_disambiguators(cls): | |
| disambiguators = cls.get_disambiguators() | |
| for _, t in cls.__annotations__.items(): | |
| try: | |
| if issubclass(t, DictDataClass): | |
| disambiguators.update(t._get_all_disambiguators()) | |
| except TypeError as e: | |
| if str(e) == 'issubclass() arg 1 must be a class': # t is a generic type not a class | |
| if get_origin(t) == Union: | |
| for _t in get_args(t): | |
| if issubclass(_t, DictDataClass): | |
| disambiguators.update(_t._get_all_disambiguators()) | |
| return disambiguators | |
| def to_json(self, fp: IO[str] = None, **kwargs): | |
| """ Serialize to a json file """ | |
| import json | |
| if fp is None: | |
| return json.dumps(self.to_dict(), **kwargs) | |
| else: | |
| json.dump(self.to_dict(), fp, **kwargs) | |
| @classmethod | |
| def from_json(cls, fp): | |
| """ Deserialize from a json file """ | |
| import json | |
| return cls.from_dict(json.load(fp)) | |
| def to_yaml(self, stream: IO[str]): | |
| """ Serialize to a yaml file """ | |
| import yaml | |
| yaml.dump(self.to_dict(), stream) | |
| @classmethod | |
| def from_yaml(cls, stream: Union[bytes, IO[bytes], str, IO[Text]]): | |
| """ Deserialize from a yaml file """ | |
| import yaml | |
| return cls.from_dict(yaml.load(stream, Loader=yaml.FullLoader)) | |
| class Settings(UserList): | |
| def __init__(self, data): | |
| super().__init__(data) | |
| class Parameters(DictDataClass): | |
| """ | |
| Adds the `get_settings` functionality useful to create parameter grids. | |
| - If a non-list type parameter is assigned a list value, it creates an object for each value in | |
| the list with the parameter assigned that value. (class A in the example below) | |
| - Even works recursively so a parameter itself may be assigned a grid! (class B in the example below) | |
| Eg. | |
| ``` | |
| @attr.s(auto_attribs=True) | |
| class A(Parameters): | |
| x: int | |
| a_l = A(x=[1, 2]).get_settings() | |
| print(a_l) | |
| # [A(x=1), A(x=2)] | |
| @attr.s(auto_attribs=True) | |
| class B(Parameters): | |
| a: A | |
| y: str | |
| a_l = A(x=[1, 2]).get_settings() | |
| print(a_l) | |
| # [A(x=1), A(x=2)] | |
| b_l = B(a=A(x=[1, 2]), y=['here', 'there']).get_settings() | |
| print(b_l) | |
| # [B(a=A(x=1), y='here'), B(a=A(x=1), y='there'), B(a=A(x=2), y='here'), B(a=A(x=2), y='there')] | |
| ``` | |
| """ | |
| def is_settings_grid(self): | |
| # return len(self.get_settings()) > 1 | |
| for k in self.keys(): | |
| v = getattr(self, k) | |
| if isinstance(v, Parameters): | |
| if v.is_settings_grid(): | |
| return True | |
| elif isinstance(v, list): | |
| return True | |
| return False | |
| def get_settings(self, key_order=None): | |
| keys = [] | |
| value_lists = [] | |
| key_order = key_order or list(self.keys()) | |
| assert set(key_order) == set(self.keys()), 'key_order must contain all the keys' | |
| for k in key_order: | |
| v = getattr(self, k) | |
| if not isinstance(v, (Parameters, list)): | |
| continue | |
| else: | |
| keys.append(k) | |
| if isinstance(v, Parameters): | |
| value_lists.append(v.get_settings()) | |
| elif isinstance(v, list): | |
| if len(v) == 0: | |
| raise ValueError(f"Empty settings list for {k}") | |
| if isinstance(v[0], Parameters): # Each value in settings list itself extends GridMixin so needs to be explored in the search space | |
| value_lists.append([s for _v in v for s in _v.get_settings()]) | |
| else: | |
| value_lists.append(v) | |
| settings = [] | |
| for values in product(*value_lists): | |
| _setting = attr.evolve(deepcopy(self), **{k: v for k, v in zip(keys, deepcopy(values))}) | |
| settings.append(_setting) | |
| return settings | |
| class InstantiationMixin: | |
| """ | |
| Mixin that enables direct instantiation of object from a `DictDataClass` | |
| containing parameters of a particular class. | |
| Note: If a `DictDataClass` is given this mixin then all the "sub-parameter" | |
| that are also of type `DictDataClass` need to have this mixin to be | |
| recursively instantiated. | |
| """ | |
| def instantiate(self, **kwargs): | |
| """ | |
| Recursively instantiates an instance of the class specified in the | |
| type attribute using parameters from this `DictDataClass` instance | |
| overridden using `kwargs`. | |
| `kwargs` should be a nested dict containing any additional parameters | |
| required to instantiate the class and its constructor arguments. | |
| At the very list it should values of all positional arguments not | |
| specified in the `DictDataClass`. | |
| Example: | |
| ``` | |
| class SimpleTagger: | |
| def __init__(self, embedding_param=50, encoder=None): | |
| self.embedding_param = embedding_param | |
| self.encoder = encoder | |
| @attr.s(auto_attribs=True) | |
| class EncoderParams(Parameters, InstantiationMixin): | |
| type: str = 'torch.nn.LSTM' | |
| hidden_size: int = 100 | |
| num_layers: int = 1 | |
| @attr.s(auto_attribs=True) | |
| class ModelParams(Parameters, InstantiationMixin): | |
| type: type = SimpleTagger | |
| embedding_param: Union[int, str] = 50 | |
| encoder: Optional[EncoderParams] = None | |
| mp = ModelParams(encoder=EncoderParams()) | |
| ``` | |
| For the above since `input_size` is a required positional argument | |
| for `torch.nn.LSTM`, `mp = m.instantiate(encoder={'input_size': 10})` | |
| will work, but not `mp = m.instantiate()`. | |
| """ | |
| if not hasattr(self, 'type'): | |
| raise ValueError('Missing type attribute.') | |
| parameters = deepcopy(vars(self)) | |
| _type = parameters.pop('type') | |
| if isinstance(_type, str) and _type == '': | |
| return None | |
| if _type: | |
| instantiated_attrs = {} | |
| for attr_name, attr_params in parameters.items(): | |
| if isinstance(attr_params, InstantiationMixin): | |
| if attr_name in kwargs: | |
| instantiated_attrs[attr_name] = attr_params.instantiate(**kwargs[attr_name]) | |
| kwargs.pop(attr_name) | |
| else: | |
| instantiated_attrs[attr_name] = attr_params.instantiate() | |
| parameters.update(instantiated_attrs) | |
| parameters.update(kwargs) | |
| if isinstance(_type, str): | |
| _type = _type.split('.') | |
| module_name, class_name = '.'.join(_type[:-1]), _type[-1] | |
| import importlib | |
| module = importlib.import_module(module_name) | |
| _class = getattr(module, class_name) | |
| elif isinstance(_type, type): | |
| _class = _type | |
| instance = _class(**parameters) | |
| return instance |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment