Optimize table lookups
import time
import numpy as np
import matplotlib.pyplot as plt
from concrete import fhe
def f(x):
return x // 2
bit_widths = list(range(2, 9))
complexities = []
timings = []
for bit_width in bit_widths:
inputset = fhe.inputset(lambda _: np.random.randint(0, 2 ** bit_width))
compiler = fhe.Compiler(f, {"x": "encrypted"})
circuit = compiler.compile(inputset)
circuit.keygen()
for sample in inputset[:3]: # warmup
circuit.encrypt_run_decrypt(*sample)
current_timings = []
for sample in inputset[3:13]:
start = time.time()
result = circuit.encrypt_run_decrypt(*sample)
end = time.time()
assert np.array_equal(result, f(*sample))
current_timings.append(end - start)
complexities.append(int(circuit.complexity))
timings.append(float(np.mean(current_timings)))
print(f"{bit_width} bits -> {complexities[-1]:>13_} complexity -> {timings[-1]:.06f}s")
figure, complexity_axis = plt.subplots()
color = "tab:red"
complexity_axis.set_xlabel("bit width")
complexity_axis.set_ylabel("complexity", color=color)
complexity_axis.plot(bit_widths, complexities, color=color)
complexity_axis.tick_params(axis="y", labelcolor=color)
timing_axis = complexity_axis.twinx()
color = 'tab:blue'
timing_axis.set_ylabel('execution time', color=color)
timing_axis.plot(bit_widths, timings, color=color)
timing_axis.tick_params(axis='y', labelcolor=color)
figure.tight_layout()
plt.show()Last updated
Was this helpful?
