Created
August 27, 2021 20:51
-
-
Save bsnacks000/360988929b917b436c73b6b96908d7f7 to your computer and use it in GitHub Desktop.
YearMonth pydantic type
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, List, NamedTuple, Optional | |
| import pydantic | |
| import datetime | |
| import re | |
| _ym_re = re.compile(r'[0-9]{4}-[0-9]{2}') | |
| class YearMonth(NamedTuple): | |
| """A custom YearMonth type. Can validate against YYYY-MM, ('YYYY', 'MM') or (int, int). | |
| Checks that year and month are valid ISO 8601 ranges. | |
| """ | |
| year: int | |
| month: int | |
| @classmethod | |
| def _parse_yearmonth_str(cls, v): | |
| if not _ym_re.match(v): | |
| raise ValueError(f'str format not correct. Must be {_ym_re.pattern}') | |
| return [int(x) for x in v.split('-')] | |
| @classmethod | |
| def _parse_yearmonth_tuple(cls, v): | |
| return (int(x) for x in v) | |
| @classmethod | |
| def _validate_ym(cls, year, month): | |
| if year < 1 or month < 1 or month > 12: | |
| raise ValueError(f'{year}-{month} is out of range') | |
| @classmethod | |
| def __get_validators__(cls): | |
| yield cls.validate | |
| @classmethod | |
| def validate(cls, v): | |
| if isinstance(v, str): | |
| year, month = cls._parse_yearmonth_str(v) | |
| elif isinstance(v, tuple): | |
| year, month = cls._parse_yearmonth_tuple(v) | |
| elif isinstance(v, YearMonth): | |
| year, month = v | |
| cls._validate_ym(year, month) | |
| return cls(year=year, month=month) | |
| def __repr__(self): | |
| return f'YearMonth(year={self.year}, month={self.month})' | |
| # example usage | |
| class MonthlyTimeSeries(pydantic.BaseModel): | |
| """ defines a monthly series """ | |
| obs: List[YearMonth] | |
| @pydantic.validator('obs') | |
| def validate_obs(self, v: List[YearMonth]) -> List[YearMonth]: | |
| if len(v) < 12: | |
| raise ValueError(f'Need at least 12 points for a monthly series.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment