Using Torch
import brevitas.nn as qnn
import torch.nn as nn
import torch
N_FEAT = 12
n_bits = 3
class QATSimpleNet(nn.Module):
def __init__(self, n_hidden):
super().__init__()
self.quant_inp = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc1 = qnn.QuantLinear(N_FEAT, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
self.quant2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc2 = qnn.QuantLinear(n_hidden, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
self.quant3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc3 = qnn.QuantLinear(n_hidden, 2, True, weight_bit_width=n_bits, bias_quant=None)
def forward(self, x):
x = self.quant_inp(x)
x = self.quant2(torch.relu(self.fc1(x)))
x = self.quant3(torch.relu(self.fc2(x)))
x = self.fc3(x)
return x
Configuring quantization parameters
target accumulator bit-width
activation bit-width
weight bit-width
number of active neurons
Running encrypted inference
Simulated FHE Inference in the clear
Generic Quantization Aware Training import
Supported operators and activations
Operators
univariate operators
shape modifying operators
operators that take an encrypted input and unencrypted constants
operators that can take both encrypted+unencrypted and encrypted+encrypted inputs
Quantizers
Activations
Last updated
Was this helpful?