.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_convert_pipeline_vectorizer.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_convert_pipeline_vectorizer.py: Train, convert and predict with ONNX Runtime ============================================ This example demonstrates an end to end scenario starting with the training of a scikit-learn pipeline which takes as inputs not a regular vector but a dictionary ``{ int: float }`` as its first step is a `DictVectorizer `_. .. contents:: :local: Train a pipeline ++++++++++++++++ The first step consists in retrieving the boston datset. .. GENERATED FROM PYTHON SOURCE LINES 22-34 .. code-block:: default import pandas from sklearn.datasets import load_boston boston = load_boston() X, y = boston.data, boston.target from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y) X_train_dict = pandas.DataFrame(X_train[:, 1:]).T.to_dict().values() X_test_dict = pandas.DataFrame(X_test[:, 1:]).T.to_dict().values() .. rst-class:: sphx-glr-script-out .. code-block:: none /home/runner/.local/lib/python3.8/site-packages/sklearn/utils/deprecation.py:87: FutureWarning: Function load_boston is deprecated; `load_boston` is deprecated in 1.0 and will be removed in 1.2. The Boston housing prices dataset has an ethical problem. You can refer to the documentation of this function for further details. The scikit-learn maintainers therefore strongly discourage the use of this dataset unless the purpose of the code is to study and educate about ethical issues in data science and machine learning. In this special case, you can fetch the dataset from the original source:: import pandas as pd import numpy as np data_url = "http://lib.stat.cmu.edu/datasets/boston" raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None) data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]]) target = raw_df.values[1::2, 2] Alternative datasets include the California housing dataset (i.e. :func:`~sklearn.datasets.fetch_california_housing`) and the Ames housing dataset. You can load the datasets as follows:: from sklearn.datasets import fetch_california_housing housing = fetch_california_housing() for the California housing dataset and:: from sklearn.datasets import fetch_openml housing = fetch_openml(name="house_prices", as_frame=True) for the Ames housing dataset. warnings.warn(msg, category=FutureWarning) .. GENERATED FROM PYTHON SOURCE LINES 35-36 We create a pipeline. .. GENERATED FROM PYTHON SOURCE LINES 36-45 .. code-block:: default from sklearn.ensemble import GradientBoostingRegressor from sklearn.feature_extraction import DictVectorizer from sklearn.pipeline import make_pipeline pipe = make_pipeline(DictVectorizer(sparse=False), GradientBoostingRegressor()) pipe.fit(X_train_dict, y_train) .. raw:: html
Pipeline(steps=[('dictvectorizer', DictVectorizer(sparse=False)),
                    ('gradientboostingregressor', GradientBoostingRegressor())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 46-48 We compute the prediction on the test set and we show the confusion matrix. .. GENERATED FROM PYTHON SOURCE LINES 48-53 .. code-block:: default from sklearn.metrics import r2_score pred = pipe.predict(X_test_dict) print(r2_score(y_test, pred)) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.8456228739441438 .. GENERATED FROM PYTHON SOURCE LINES 54-60 Conversion to ONNX format +++++++++++++++++++++++++ We use module `sklearn-onnx `_ to convert the model into ONNX format. .. GENERATED FROM PYTHON SOURCE LINES 60-70 .. code-block:: default from skl2onnx import convert_sklearn from skl2onnx.common.data_types import DictionaryType, FloatTensorType, Int64TensorType, SequenceType # initial_type = [('float_input', DictionaryType(Int64TensorType([1]), FloatTensorType([])))] initial_type = [("float_input", DictionaryType(Int64TensorType([1]), FloatTensorType([])))] onx = convert_sklearn(pipe, initial_types=initial_type) with open("pipeline_vectorize.onnx", "wb") as f: f.write(onx.SerializeToString()) .. GENERATED FROM PYTHON SOURCE LINES 71-73 We load the model with ONNX Runtime and look at its input and output. .. GENERATED FROM PYTHON SOURCE LINES 73-84 .. code-block:: default import onnxruntime as rt from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument sess = rt.InferenceSession("pipeline_vectorize.onnx", providers=rt.get_available_providers()) import numpy inp, out = sess.get_inputs()[0], sess.get_outputs()[0] print("input name='{}' and shape={} and type={}".format(inp.name, inp.shape, inp.type)) print("output name='{}' and shape={} and type={}".format(out.name, out.shape, out.type)) .. rst-class:: sphx-glr-script-out .. code-block:: none input name='float_input' and shape=[] and type=map(int64,tensor(float)) output name='variable' and shape=[None, 1] and type=tensor(float) .. GENERATED FROM PYTHON SOURCE LINES 85-87 We compute the predictions. We could do that in one call: .. GENERATED FROM PYTHON SOURCE LINES 87-93 .. code-block:: default try: pred_onx = sess.run([out.name], {inp.name: X_test_dict})[0] except (RuntimeError, InvalidArgument) as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Unexpected input data type. Actual: ((seq(map(int64,tensor(float))))) , expected: ((map(int64,tensor(float)))) .. GENERATED FROM PYTHON SOURCE LINES 94-96 But it fails because, in case of a DictVectorizer, ONNX Runtime expects one observation at a time. .. GENERATED FROM PYTHON SOURCE LINES 96-98 .. code-block:: default pred_onx = [sess.run([out.name], {inp.name: row})[0][0, 0] for row in X_test_dict] .. GENERATED FROM PYTHON SOURCE LINES 99-100 We compare them to the model's ones. .. GENERATED FROM PYTHON SOURCE LINES 100-102 .. code-block:: default print(r2_score(pred, pred_onx)) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.9999999999999319 .. GENERATED FROM PYTHON SOURCE LINES 103-105 Very similar. *ONNX Runtime* uses floats instead of doubles, that explains the small discrepencies. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.915 seconds) .. _sphx_glr_download_auto_examples_plot_convert_pipeline_vectorizer.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_convert_pipeline_vectorizer.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_convert_pipeline_vectorizer.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_