concrete.ml.torch.hybrid_model.md
module concrete.ml.torch.hybrid_model
concrete.ml.torch.hybrid_modelImplement the conversion of a torch model to a hybrid fhe/torch inference.
Global Variables
MAX_BITWIDTH_BACKWARD_COMPATIBLE
function tuple_to_underscore_str
tuple_to_underscore_strtuple_to_underscore_str(tup: Tuple) → strConvert a tuple to a string representation.
Args:
tup(Tuple): a tuple to change into string representation
Returns:
str: a string representing the tuple
function underscore_str_to_tuple
underscore_str_to_tupleConvert a a string representation of a tuple to a tuple.
Args:
tup(str): a string representing the tuple
Returns:
Tuple: a tuple to change into string representation
function convert_conv1d_to_linear
convert_conv1d_to_linearConvert all Conv1D layers in a module or a Conv1D layer itself to nn.Linear.
Args:
layer_or_module(nn.Module or Conv1D): The module which will be recursively searched for Conv1D layers, or a Conv1D layer itself.
Returns:
nn.Module or nn.Linear: The updated module with Conv1D layers converted to Linear layers, or the Conv1D layer converted to a Linear layer.
class HybridFHEMode
HybridFHEModeSimple enum for different modes of execution of HybridModel.
class RemoteModule
RemoteModuleA wrapper class for the modules to be evaluated remotely with FHE.
method __init__
__init__method forward
forwardForward pass of the remote module.
To change the behavior of this forward function one must change the fhe_local_mode attribute. Choices are:
disable: forward using torch module
remote: forward with fhe client-server
simulate: forward with local fhe simulation
calibrate: forward for calibration
Args:
x(torch.Tensor): The input tensor.
Returns:
(torch.Tensor): The output tensor.
Raises:
ValueError: if local_fhe_mode is not supported
method init_fhe_client
init_fhe_clientSet the clients keys.
Args:
path_to_client(str): Path where the client.zip is located.path_to_keys(str): Path where keys are located.
Raises:
ValueError: if anything goes wrong with the server.
method remote_call
remote_callCall the remote server to get the private module inference.
Args:
x(torch.Tensor): The input tensor.
Returns:
torch.Tensor: The result of the FHE computation
class HybridFHEModel
HybridFHEModelConvert a model to a hybrid model.
This is done by converting targeted modules by RemoteModules. This will modify the model in place.
Args:
model(nn.Module): The model to modify (in-place modification)module_names(Union[str, List[str]]): The module name(s) to replace with FHE server.server_remote_address): The remote address of the FHE servermodel_name(str): Model name identifierverbose(int): If logs should be printed when interacting with FHE server
method __init__
__init__method compile_model
compile_modelCompiles the specific layers to FHE.
Args:
x(torch.Tensor): The input tensor for the model. This is used to run the model once for calibration.n_bits(int): The bit precision for quantization during FHE model compilation. Default is 8.rounding_threshold_bits(int): The number of bits to use for rounding threshold during FHE model compilation. Default is 8.p_error(float): Error allowed for each table look-up in the circuit.configuration(Configuration): A concrete Configuration object specifying the FHE encryption parameters. If not specified, a default configuration is used.
method init_client
init_clientInitialize client for all remote modules.
Args:
path_to_clients(Optional[Path]): Path to the client.zip files.path_to_keys(Optional[Path]): Path to the keys folder.
method publish_to_hub
publish_to_hubAllow the user to push the model and FHE required files to HF Hub.
method save_and_clear_private_info
save_and_clear_private_infoSave the PyTorch model to the provided path and also saves the corresponding FHE circuit.
Args:
path(Path): The directory where the model and the FHE circuit will be saved.via_mlir(bool): if fhe circuits should be serialized using via_mlir option useful for cross-platform (compile on one architecture and run on another)
method set_fhe_mode
set_fhe_modeSet Hybrid FHE mode for all remote modules.
Args:
hybrid_fhe_mode(Union[str, HybridFHEMode]): Hybrid FHE mode to set to all remote modules.
class LoggerStub
LoggerStubPlaceholder type for a typical logger like the one from loguru.
method info
infoPlacholder function for logger.info.
Args:
msg(str): the message to output
class HybridFHEModelServer
HybridFHEModelServerHybrid FHE Model Server.
This is a class object to server FHE models serialized using HybridFHEModel.
method __init__
__init__method add_key
add_keyAdd public key.
Arguments:
key(bytes): public keymodel_name(str): model namemodule_name(str): name of the module in the modelinput_shape(str): input shape of said module
Returns: Dict[str, str] - uid: uid a personal uid
method check_inputs
check_inputsCheck that the given configuration exist in the compiled models folder.
Args:
model_name(str): name of the modelmodule_name(Optional[str]): name of the module in the modelinput_shape(Optional[str]): input shape of the module
Raises:
ValueError: if the given configuration does not exist.
method compute
computeCompute the circuit over encrypted input.
Arguments:
model_input(bytes): input of the circuituid(str): uid of the public key to usemodel_name(str): model namemodule_name(str): name of the module in the modelinput_shape(str): input shape of said module
Returns:
bytes: the result of the circuit
method dump_key
dump_keyDump a public key to a stream.
Args:
key_bytes(bytes): stream to dump the public serialized key touid(Union[str, uuid.UUID]): uid of the public key to dump
method get_circuit
get_circuitGet circuit based on model name, module name and input shape.
Args:
model_name(str): name of the modelmodule_name(str): name of the module in the modelinput_shape(str): input shape of the module
Returns:
FHEModelServer: a fhe model server of the given module of the given model for the given shape
method get_client
get_clientGet client.
Args:
model_name(str): name of the modelmodule_name(str): name of the module in the modelinput_shape(str): input shape of the module
Returns:
Path: the path to the correct client
Raises:
ValueError: if client couldn't be found
method list_modules
list_modulesList all modules in a model.
Args:
model_name(str): name of the model
Returns: Dict[str, Dict[str, Dict]]
method list_shapes
list_shapesList all modules in a model.
Args:
model_name(str): name of the modelmodule_name(str): name of the module in the model
Returns: Dict[str, Dict]
method load_key
load_keyLoad a public key from the key path in the file system.
Args:
uid(Union[str, uuid.UUID]): uid of the public key to load
Returns:
bytes: the bytes of the public key
Last updated
Was this helpful?