Last active
January 11, 2025 23:48
-
-
Save ottobricks/d060101e06de7463c63f341ab1977cc7 to your computer and use it in GitHub Desktop.
Patch `pyspark.sql.DataFrame.hint` method in PySpark 3.3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, List, Union | |
| from pyspark.sql import DataFrame | |
| from pyspark.sql.column import Column, _to_java_column | |
| def hint(self, name: str, *parameters: Union[str, int, List[Union[int, str]]]) -> DataFrame: | |
| """ | |
| Patches method from :class:`DataFrame`. | |
| Reference: https://github.com/apache/spark/pull/37616 | |
| """ | |
| def _convert_hint_parameter(parameter: Union[str, int, Column]) -> Any: | |
| if isinstance(parameter, (Column, str)): | |
| return _to_java_column(parameter).expr() | |
| return parameter | |
| if len(parameters) == 1 and isinstance(parameters[0], list): | |
| parameters = parameters[0] # type: ignore[assignment] | |
| if not isinstance(name, str): | |
| raise TypeError("name should be provided as str, got {0}".format(type(name))) | |
| allowed_types = (str, list, float, int) | |
| for p in parameters: | |
| if not isinstance(p, allowed_types): | |
| raise TypeError("all parameters should be in {0}, got {1} of type {2}".format(allowed_types, p, type(p))) | |
| jdf = self._jdf.hint(name, self._jseq(parameters, converter=_convert_hint_parameter)) | |
| return DataFrame(jdf, self.sparkSession) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Reference: apache/spark#37616
After adding the function definition, patch it e.g. in
__main__: