Truncating
import matplotlib.pyplot as plt
import numpy as np
from concrete import fhe
original_bit_width = 5
lsbs_to_remove = 2
assert 0 < lsbs_to_remove < original_bit_width
original_values = list(range(2**original_bit_width))
truncated_values = [
fhe.truncate_bit_pattern(value, lsbs_to_remove)
for value in original_values
]
previous_truncated = truncated_values[0]
for original, truncated in zip(original_values, truncated_values):
if truncated != previous_truncated:
previous_truncated = truncated
print()
original_binary = np.binary_repr(original, width=(original_bit_width + 1))
truncated_binary = np.binary_repr(truncated, width=(original_bit_width + 1))
print(
f"{original:2} = 0b_{original_binary[:-lsbs_to_remove]}[{original_binary[-lsbs_to_remove:]}] "
f"=> "
f"0b_{truncated_binary[:-lsbs_to_remove]}[{truncated_binary[-lsbs_to_remove:]}] = {truncated}"
)
fig = plt.figure()
ax = fig.add_subplot()
plt.plot(original_values, original_values, label="original", color="black")
plt.plot(original_values, truncated_values, label="truncated", color="green")
plt.legend()
ax.set_aspect("equal", adjustable="box")
plt.show()

Auto Truncators

Last updated
Was this helpful?