Rounding
import matplotlib.pyplot as plt
import numpy as np
from concrete import fhe
original_bit_width = 5
lsbs_to_remove = 2
assert 0 < lsbs_to_remove < original_bit_width
original_values = list(range(2**original_bit_width))
rounded_values = [
fhe.round_bit_pattern(value, lsbs_to_remove)
for value in original_values
]
previous_rounded = rounded_values[0]
for original, rounded in zip(original_values, rounded_values):
if rounded != previous_rounded:
previous_rounded = rounded
print()
original_binary = np.binary_repr(original, width=(original_bit_width + 1))
rounded_binary = np.binary_repr(rounded, width=(original_bit_width + 1))
print(
f"{original:2} = 0b_{original_binary[:-lsbs_to_remove]}[{original_binary[-lsbs_to_remove:]}] "
f"=> "
f"0b_{rounded_binary[:-lsbs_to_remove]}[{rounded_binary[-lsbs_to_remove:]}] = {rounded}"
)
fig = plt.figure()
ax = fig.add_subplot()
plt.plot(original_values, original_values, label="original", color="black")
plt.plot(original_values, rounded_values, label="rounded", color="green")
plt.legend()
ax.set_aspect("equal", adjustable="box")
plt.show()

Auto Rounders

Last updated
Was this helpful?