concrete.ml.sklearn.tree_to_numpy.md
module concrete.ml.sklearn.tree_to_numpy
concrete.ml.sklearn.tree_to_numpyImplements the conversion of a tree model to a numpy function.
Global Variables
MAX_BITWIDTH_BACKWARD_COMPATIBLE
OPSET_VERSION_FOR_ONNX_EXPORT
function get_onnx_model
get_onnx_modelget_onnx_model(model: Callable, x: ndarray, framework: str) → ModelProtoCreate ONNX model with Hummingbird convert method.
Args:
model(Callable): The tree model to convert.x(numpy.ndarray): Dataset used to trace the tree inference and convert the model to ONNX.framework(str): The framework from which the ONNX model is generated.(options: 'xgboost', 'sklearn')
Returns:
onnx.ModelProto: The ONNX model.
function workaround_squeeze_node_xgboost
workaround_squeeze_node_xgboostWorkaround to fix torch issue that does not export the proper axis in the ONNX squeeze node.
FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2778 The squeeze ops does not have the proper dimensions. remove the following workaround when the issue is fixed Add the axis attribute to the Squeeze node
Args:
onnx_model(onnx.ModelProto): The ONNX model.
function add_transpose_after_last_node
add_transpose_after_last_nodeAdd transpose after last node.
Args:
onnx_model(onnx.ModelProto): The ONNX model.
function preprocess_tree_predictions
preprocess_tree_predictionsApply post-processing from the graph.
Args:
init_tensor(numpy.ndarray): Model parameters to be pre-processed.output_n_bits(int): The number of bits of the output.
Returns:
QuantizedArray: Quantizer for the tree predictions.
function tree_onnx_graph_preprocessing
tree_onnx_graph_preprocessingApply pre-processing onto the ONNX graph.
Args:
onnx_model(onnx.ModelProto): The ONNX model.framework(str): The framework from which the ONNX model is generated.(options: 'xgboost', 'sklearn')expected_number_of_outputs(int): The expected number of outputs in the ONNX model.
function tree_values_preprocessing
tree_values_preprocessingPre-process tree values.
Args:
onnx_model(onnx.ModelProto): The ONNX model.framework(str): The framework from which the ONNX model is generated.(options: 'xgboost', 'sklearn')output_n_bits(int): The number of bits of the output.
Returns:
QuantizedArray: Quantizer for the tree predictions.
function tree_to_numpy
tree_to_numpyConvert the tree inference to a numpy functions using Hummingbird.
Args:
model(Callable): The tree model to convert.x(numpy.ndarray): The input data.framework(str): The framework from which the ONNX model is generated.(options: 'xgboost', 'sklearn')output_n_bits(int): The number of bits of the output. Default to 8.
Returns:
Tuple[Callable, List[QuantizedArray], onnx.ModelProto]: A tuple with a function that takes a numpy array and returns a numpy array, QuantizedArray object to quantize and de-quantize the output of the tree, and the ONNX model.
Last updated
Was this helpful?