Using Torch
import brevitas.nn as qnn
import torch.nn as nn
import torch
N_FEAT = 12
n_bits = 3
class QATSimpleNet(nn.Module):
def __init__(self, n_hidden):
super().__init__()
self.quant_inp = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc1 = qnn.QuantLinear(N_FEAT, n_hidden, True, weight_bit_width=n_bits, bias_quant=None)
self.quant2 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc2 = qnn.QuantLinear(n_hidden, n_hidden, True, weight_bit_width=3, bias_quant=None)
self.quant3 = qnn.QuantIdentity(bit_width=n_bits, return_quant_tensor=True)
self.fc3 = qnn.QuantLinear(n_hidden, 2, True, weight_bit_width=n_hidden, bias_quant=None)
def forward(self, x):
x = self.quant_inp(x)
x = self.quant2(torch.relu(self.fc1(x)))
x = self.quant3(torch.relu(self.fc2(x)))
x = self.fc3(x)
return xGeneric Quantization Aware Training import
Supported operators and activations
Operators
univariate operators
shape modifying operators
operators that take an encrypted input and unencrypted constants
operators that can take both encrypted+unencrypted and encrypted+encrypted inputs
Quantizers
Activations
Last updated
Was this helpful?