concrete.ml.pytest.utils.md

arrow-up-right

module concrete.ml.pytest.utils

Common functions or lists for test files, which can't be put in fixtures.

Global Variables

  • sklearn_models_and_datasets


arrow-up-right

function get_random_extract_of_sklearn_models_and_datasets

get_random_extract_of_sklearn_models_and_datasets()

Return a random sublist of sklearn_models_and_datasets.

The sublist contains exactly one model of each kind.

Returns: the sublist


arrow-up-right

function instantiate_model_generic

Instantiate any Concrete ML model type.

Args:

  • model_class (class): The type of the model to instantiate.

  • n_bits (int): The number of quantization to use when initializing the model. For QNNs, default parameters are used based on whether n_bits is greater or smaller than 8.

  • parameters (dict): Hyper-parameters for the model instantiation. For QNNs, these parameters will override the matching default ones.

Returns:

  • model_name (str): The type of the model as a string.

  • model (object): The model instance.


arrow-up-right

function get_torchvision_dataset

Get train or testing data-set.

Args:

  • param (Dict): Set of hyper-parameters to use based on the selected torchvision data-set.

  • It must contain: data-set transformations (torchvision.transforms.Compose), and the data-set_size (Optional[int]).

  • train_set (bool): Use train data-set if True, else testing data-set

Returns: A torchvision data-sets.


arrow-up-right

function data_calibration_processing

Reduce size of the given data-set.

Args:

  • data: The input container to consider

  • n_sample (int): Number of samples to keep if the given data-set

  • targets: If dataset is a torch.utils.data.Dataset, it typically contains both the data and the corresponding targets. In this case, targets must be set to None. If data is instance of torch.Tensor or 'numpy.ndarray, targets` is expected.

Returns:

  • Tuple[numpy.ndarray, numpy.ndarray]: The input data and the target (respectively x and y).

Raises:

  • TypeError: If the 'data-set' does not match any expected type.


arrow-up-right

function load_torch_model

Load an object saved with torch.save() from a file or dict.

Args:

  • model_class (torch.nn.Module): A PyTorch or Brevitas network.

  • state_dict_or_path (Optional[Union[str, Path, Dict[str, Any]]]): Path or state_dict

  • params (Dict): Model's parameters

  • device (str): Device type.

Returns:

  • torch.nn.Module: A PyTorch or Brevitas network.


arrow-up-right

function values_are_equal

Indicate if two values are equal.

This method takes into account objects of type None, numpy.ndarray, numpy.floating, numpy.integer, numpy.random.RandomState or any instance that provides a __eq__ method.

Args:

  • value_2 (Any): The first value to consider.

  • value_1 (Any): The second value to consider.

Returns:

  • bool: If the two values are equal.


arrow-up-right

function check_serialization

Check that the given object can properly be serialized.

This function serializes all objects using the dump, dumps, load and loads functions from Concrete ML. If the given object provides a dump and dumps method, they are also serialized using these.

Args:

  • object_to_serialize (Any): The object to serialize.

  • expected_type (Type): The object's expected type.

  • equal_method (Optional[Callable]): The function to use to compare the two loaded objects. Default to values_are_equal.

  • check_str (bool): If the JSON strings should also be checked. Default to True.

Last updated

Was this helpful?