Created
February 18, 2026 09:07
-
-
Save betatim/d28503b15177a4530534117187b568de to your computer and use it in GitHub Desktop.
Serialise cuml estimators with onnx via `as_sklearn`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Test: Validate that cuml native estimators can be converted to ONNX | |
| via as_sklearn() -> skl2onnx -> onnxruntime. | |
| Unlike cuml.accel proxies (which skl2onnx recognizes directly), native cuml | |
| estimators must first be converted to sklearn via as_sklearn() before | |
| skl2onnx.convert_sklearn() will accept them. | |
| Run without cuml.accel: | |
| python test_onnx_as_sklearn.py | |
| """ | |
| import numpy as np | |
| from sklearn.datasets import make_classification, make_regression | |
| from sklearn.model_selection import train_test_split | |
| from skl2onnx import convert_sklearn | |
| from skl2onnx.common.data_types import FloatTensorType | |
| import onnxruntime as ort | |
| # ── Data ────────────────────────────────────────────────────────────────── | |
| X_cls, y_cls = make_classification( | |
| n_samples=500, n_features=10, n_informative=5, random_state=42 | |
| ) | |
| X_cls = X_cls.astype(np.float32) | |
| X_train_c, X_test_c, y_train_c, _ = train_test_split( | |
| X_cls, y_cls, test_size=0.2, random_state=42 | |
| ) | |
| X_reg, y_reg = make_regression( | |
| n_samples=500, n_features=10, n_informative=5, random_state=42 | |
| ) | |
| X_reg = X_reg.astype(np.float32) | |
| y_reg = y_reg.astype(np.float32) | |
| X_train_r, X_test_r, y_train_r, _ = train_test_split( | |
| X_reg, y_reg, test_size=0.2, random_state=42 | |
| ) | |
| initial_type_c = [("float_input", FloatTensorType([None, X_test_c.shape[1]]))] | |
| initial_type_r = [("float_input", FloatTensorType([None, X_test_r.shape[1]]))] | |
| def onnx_predict(onnx_model, X): | |
| sess = ort.InferenceSession(onnx_model.SerializeToString()) | |
| return sess.run(None, {sess.get_inputs()[0].name: X}) | |
| # ── Estimator definitions ───────────────────────────────────────────────── | |
| from cuml.ensemble import RandomForestClassifier, RandomForestRegressor | |
| from cuml.neighbors import KNeighborsClassifier, KNeighborsRegressor | |
| from cuml.svm import LinearSVC, LinearSVR, SVR, SVC | |
| from cuml.linear_model import ( | |
| LinearRegression, | |
| LogisticRegression, | |
| Ridge, | |
| ElasticNet, | |
| Lasso, | |
| ) | |
| from cuml.decomposition import PCA, TruncatedSVD | |
| from cuml.cluster import KMeans | |
| classifiers = [ | |
| ("RandomForestClassifier", RandomForestClassifier(n_estimators=20, max_depth=8, random_state=42)), | |
| ("KNeighborsClassifier", KNeighborsClassifier()), | |
| ("LinearSVC", LinearSVC()), | |
| # SVC/SVR omitted: running them in the same process as LinearSVC/LinearSVR | |
| # causes a C++ abort in cuml. They work individually (tested separately). | |
| #("SVC", SVC(kernel="linear")), | |
| ("LogisticRegression", LogisticRegression()), | |
| ] | |
| regressors = [ | |
| ("RandomForestRegressor", RandomForestRegressor(n_estimators=20, max_depth=8, random_state=42)), | |
| ("KNeighborsRegressor", KNeighborsRegressor()), | |
| ("LinearSVR", LinearSVR()), | |
| #("SVR", SVR(kernel="linear")), | |
| ("LinearRegression", LinearRegression()), | |
| ("Ridge", Ridge()), | |
| ("ElasticNet", ElasticNet()), | |
| ("Lasso", Lasso()), | |
| ] | |
| transformers = [ | |
| ("PCA", PCA(n_components=5)), | |
| ("TruncatedSVD", TruncatedSVD(n_components=5)), | |
| ("KMeans", KMeans(n_clusters=3, random_state=42, n_init=10)), | |
| ] | |
| results = [] | |
| # ── Classifiers ─────────────────────────────────────────────────────────── | |
| print("=" * 70) | |
| print("CLASSIFIERS: cuml.as_sklearn() -> skl2onnx -> onnxruntime") | |
| print("=" * 70) | |
| for name, est in classifiers: | |
| try: | |
| est.fit(X_train_c, y_train_c) | |
| cuml_preds = np.asarray(est.predict(X_test_c)) | |
| sklearn_est = est.as_sklearn() | |
| sklearn_preds = np.asarray(sklearn_est.predict(X_test_c)) | |
| options = {} | |
| if hasattr(sklearn_est, "predict_proba"): | |
| options["zipmap"] = False | |
| onnx_model = convert_sklearn( | |
| sklearn_est, initial_types=initial_type_c, options=options | |
| ) | |
| onnx_preds = np.asarray(onnx_predict(onnx_model, X_test_c)[0]) | |
| cuml_vs_sklearn = np.mean(cuml_preds == sklearn_preds) | |
| sklearn_vs_onnx = np.mean(sklearn_preds == onnx_preds) | |
| status = "PASS" if sklearn_vs_onnx >= 0.99 else "FAIL" | |
| print(f" {name}:") | |
| print(f" cuml vs sklearn match: {cuml_vs_sklearn:.4f}") | |
| print(f" sklearn vs ONNX match: {sklearn_vs_onnx:.4f} [{status}]") | |
| except Exception as e: | |
| status = "ERROR" | |
| print(f" {name}: ERROR - {e}") | |
| results.append((name, status)) | |
| # ── Regressors ──────────────────────────────────────────────────────────── | |
| print() | |
| print("=" * 70) | |
| print("REGRESSORS: cuml.as_sklearn() -> skl2onnx -> onnxruntime") | |
| print("=" * 70) | |
| for name, est in regressors: | |
| try: | |
| est.fit(X_train_r, y_train_r) | |
| cuml_preds = np.asarray(est.predict(X_test_r)).flatten() | |
| sklearn_est = est.as_sklearn() | |
| sklearn_preds = np.asarray(sklearn_est.predict(X_test_r)).flatten() | |
| onnx_model = convert_sklearn( | |
| sklearn_est, initial_types=initial_type_r | |
| ) | |
| onnx_preds = np.asarray(onnx_predict(onnx_model, X_test_r)[0]).flatten() | |
| cuml_vs_sklearn = np.max(np.abs(cuml_preds - sklearn_preds)) | |
| sklearn_vs_onnx = np.max(np.abs(sklearn_preds - onnx_preds)) | |
| status = "PASS" if sklearn_vs_onnx < 1e-2 else "FAIL" | |
| print(f" {name}:") | |
| print(f" cuml vs sklearn max diff: {cuml_vs_sklearn:.6e}") | |
| print(f" sklearn vs ONNX max diff: {sklearn_vs_onnx:.6e} [{status}]") | |
| except Exception as e: | |
| status = "ERROR" | |
| print(f" {name}: ERROR - {e}") | |
| results.append((name, status)) | |
| # ── Transformers ────────────────────────────────────────────────────────── | |
| print() | |
| print("=" * 70) | |
| print("TRANSFORMERS: cuml.as_sklearn() -> skl2onnx -> onnxruntime") | |
| print("=" * 70) | |
| for name, est in transformers: | |
| try: | |
| est.fit(X_train_c) | |
| sklearn_est = est.as_sklearn() | |
| onnx_model = convert_sklearn( | |
| sklearn_est, initial_types=initial_type_c | |
| ) | |
| onnx_results = onnx_predict(onnx_model, X_test_c) | |
| if name == "KMeans": | |
| cuml_out = np.asarray(est.predict(X_test_c)) | |
| sklearn_out = np.asarray(sklearn_est.predict(X_test_c)) | |
| onnx_out = np.asarray(onnx_results[0]) | |
| cuml_vs_sklearn = np.mean(cuml_out == sklearn_out) | |
| sklearn_vs_onnx = np.mean(sklearn_out == onnx_out) | |
| status = "PASS" if sklearn_vs_onnx >= 0.99 else "FAIL" | |
| print(f" {name}:") | |
| print(f" cuml vs sklearn match: {cuml_vs_sklearn:.4f}") | |
| print(f" sklearn vs ONNX match: {sklearn_vs_onnx:.4f} [{status}]") | |
| else: | |
| cuml_out = np.asarray(est.transform(X_test_c)) | |
| sklearn_out = np.asarray(sklearn_est.transform(X_test_c)) | |
| onnx_out = np.asarray(onnx_results[0]) | |
| cuml_vs_sklearn = np.max(np.abs(cuml_out - sklearn_out)) | |
| sklearn_vs_onnx = np.max(np.abs(sklearn_out - onnx_out)) | |
| status = "PASS" if sklearn_vs_onnx < 1e-4 else "FAIL" | |
| print(f" {name}:") | |
| print(f" cuml vs sklearn max diff: {cuml_vs_sklearn:.6e}") | |
| print(f" sklearn vs ONNX max diff: {sklearn_vs_onnx:.6e} [{status}]") | |
| except Exception as e: | |
| status = "ERROR" | |
| print(f" {name}: ERROR - {e}") | |
| results.append((name, status)) | |
| # ── Summary ─────────────────────────────────────────────────────────────── | |
| print() | |
| print("=" * 70) | |
| print("SUMMARY") | |
| print("=" * 70) | |
| for name, status in results: | |
| print(f" {name:30s} {status}") | |
| passed = sum(1 for _, s in results if s == "PASS") | |
| failed = sum(1 for _, s in results if s == "FAIL") | |
| errors = sum(1 for _, s in results if s == "ERROR") | |
| print(f"\n {passed} passed, {failed} failed, {errors} errors") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment