The-X NN is an extension to torch.nn.Module
that adds several necessary FHE (fully homomorphic encryption) features to the class, allowing it to be used in constructing neural networks.
- Adds
model.encrypt()
andmodel.decrypt()
methods to the standardmodel.train()
andmodel.eval()
, allowing the user to choose between plain or encrypted data mode. When the data state is changed, the submodule and function in the model will also be changed into the specific state, which can be achieved by using theTenSEAL
interface. - Supports built-in agent models for non-polynomial functions that cannot be directly changed into TenSEAL
CKKS
format. These agent models use approximation to convert the non-polynomial functions into polynomial functions.
To use The-X NN, simply import the library and use it as you would any other PyTorch module. The added encrypt() and decrypt() methods can be used to switch between plain and encrypted data states, allowing for secure inference on encrypted data.
from thex import xnn, cxt_man
class Model(xnn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = xnn.MultiLayerPerceptron(10, 5, activation=xnn.ReLU())
self.fc2 = xnn.MultiLayerPerceptron(5, 2, activation=xnn.ReLU())
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = Model()
x = np.arange(2 * 3 * 4).reshape([2, 3, 4])
# change the data into FHE
x_enc = cxt_man.encrypt(x)
# inference on the encrypted data
y_enc = model(x)
# get the plain result
y_dec = cxt_man.decrypt(y)
import torch
import tenseal as ts
class Module(torch.nn.Module):
"""Wrapper class for torch.nn.Module to add FHE support."""
def __init__(self, encryption_context=None):
super(Module, self).__init__()
self.is_encrypted = False
self.encryption_context = encryption_context
def encrypt(self, encryption_context=None):
"""Switches the module to FHE mode."""
self.is_encrypted = True
if not encryption_context:
self.encryption_context = ts.context(ts.SCHEME_TYPE.CKKS, poly_modulus_degree=4096, coeff_mod_bit_sizes=[40, 40, 40, 40])
def decrypt(self):
"""Switches the module to plain mode."""
self.is_encrypted = False
self.encryption_context = None
def forward(self, *inputs):
"""Abstract method that should be overridden by all subclasses."""
raise NotImplementedError
xnn.Linear
: This is a fully-connected layer operator, which is used for the linear transformation of input data.xnn.MultiHeadAttention
: This operator computes self-attention for input data, which is an important part of the transformer model.xnn.LayerNorm
: This operator implements layer normalization, which is used in various parts of the transformer model.xnn.ReLU
: This is a rectified linear unit operator, which is used as the activation function in various parts of the model.xnn.Dropout
: This operator is used for regularization by randomly dropping out some of the input during training.xnn.TransformerEncoderLayer
: This operator defines a single layer of the transformer encoder, which can be used to build the full transformer encoder.xnn.TransformerDecoderLayer
: This operator defines a single layer of the transformer decoder, which can be used to build the full transformer decoder.xnn.Softmax
: This operator is used to compute the attention scores using the dot product of the queries and keys.xnn.Transpose
: This operator is used to transpose the output of a previous operator to match the expected shape. linkxnn.MaskedFill
: This operator is used to mask out the padded values in the input sequence.xnn.MatMul
: This operator is used for matrix multiplication. It's used to compute the final output of the attention mechanism.
- pretrained model -> origin model
- origin model -> converted model
- converted model -> tenseal supported model
- Operators
- bert-toy enc test
- bert-tiny convert
- bert-tiny enc test
- benchmark of each oprators