Created
June 10, 2020 18:41
-
-
Save bsnacks000/e66a3fe0e76b1bea69a25fc2c6d8cca0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ A high-level API that wraps the sklearn Pipeline interface. If designed well we should be able to use this base class to | |
| encapsulate elements of jupyter notebook script. A user should subclass and override the specified hook methods of these classes | |
| to build a workflow for deployment and benchmarking. | |
| """ | |
| import abc | |
| #from sklearn.externals.joblib import parallel_backend warning says to use parallel backend from joblibe directly... | |
| from joblib import parallel_backend | |
| from sklearn.model_selection import GridSearchCV, RandomizedSearchCV,\ | |
| ParameterGrid, ParameterSampler | |
| import pandas as pd | |
| class PipelineWrapper(abc.ABC): | |
| """ A base class that models a sklearn workflow using the Pipeline and model_selection APIs. A user should | |
| subclass this class and override the attributes _pipeline, _paramgrid, and _search_class. | |
| """ | |
| _pipeline = None # instance of _pipeline | |
| _param_grid = None # instance of _pipeline | |
| _search_class = None | |
| def __init__(self, df, parallel_backend='loky', n_jobs=-1, use_dfs=False, **search_settings): | |
| self._df = df | |
| self.parallel_backend = parallel_backend | |
| self.n_jobs = n_jobs | |
| self.search_settings = search_settings # store kwargs for later | |
| self.use_dfs = use_dfs | |
| #XXX validation here...all values | |
| @property | |
| def df(self): | |
| return self._df | |
| @property | |
| def pipeline(self): | |
| return self.__class__._pipeline | |
| @property | |
| def param_grid(self): | |
| return self.__class__._param_grid | |
| def prepare_df(self, df, *args, **kwargs): | |
| """ spec for a preprocessing step. For data transformations that should take place outside of | |
| Transformer calls. This is to both seperate concerns and save CPUs. There are no restrictions except that | |
| this must accept and return a valid dataframe. Other data or objects can be passed in to clean the dataframe for | |
| application specific needs. | |
| """ | |
| return df | |
| def extract_X_y(self, df, X_cols, y_cols): | |
| """ This is called after prepare_df and before the data is run. Its responsibility is to | |
| extract the X and y cols into the correct formats for | |
| """ | |
| return df.extract.Xmat_y(X_cols, y_cols, self.use_dfs) # <--- this should return dfs if working with ColumnTransformers | |
| def _prepare(self, X_cols, y_cols, *df_args, **df_kwargs): | |
| """ Copies the dataframe, calls prepare_df hook and extracts | |
| """ | |
| df = self._df.copy() | |
| df = self.prepare_df(df, *df_args, **df_kwargs) | |
| if not isinstance(df, pd.DataFrame): | |
| raise TypeError('The result of clean must be a pandas dataframe') | |
| return self.extract_X_y(df, X_cols, y_cols) # XXX <--- possibly call sklearn check X_y here? safety first... | |
| def run(self, X_cols, y_cols, df_args=(), df_kwargs={}, method='fit', **fit_params): | |
| """ The main method. Should call fit_transform and return the given model. | |
| """ | |
| # XXX <-- check that X_cols and y_cols are not empty ... should be list of at least 1 | |
| X, y = self._prepare(X_cols, y_cols, **df_kwargs) | |
| search = self._search_class(self.pipeline,param_grid=self.param_grid, **self.search_settings) | |
| with parallel_backend(self.parallel_backend, n_jobs=self.n_jobs): | |
| return getattr(search, method)(X, y, **fit_params) | |
| # implementations with Grid and Randomized | |
| class GridSearchCVPipelineWrapper(PipelineWrapper): | |
| _search_class = GridSearchCV | |
| class RandomizedSearchPipelineWrapper(PipelineWrapper): | |
| _search_class = RandomizedSearchCV |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment