Skip to content

Matrix Multiplication

Since a ciphertext encodes a vector, matrix–vector multiplication can be performed with basic homomorphic encryption operations by multiplying the ciphertext with each row of a matrix, and then by adding all the products. However, this naive method consumes a tremendous amount of time and memory, so the DESILO FHE library provides an optimized way to perform matrix multiplication. This document explains how to utilize the provided operations to perform matrix multiplication more efficiently.

Setting the engine slot count

To use optimized matrix multiplication, the input matrix must be square. For efficient computation in the encrypted state, the engine requires the input matrix to be a slot_count by slot_count matrix. The computational time and memory usage grows quadratically (or worse) with the slot count. Therefore it is important to configure the engine's slot_count according to the matrix size.

from desilofhe import Engine

engine = Engine(slot_count=64)  # Supports matrices of size 64x64

Matrix Multiplication Methods

There are two methods for performing the same matrix multiplication: one that uses a rotation key, and another that utilizes both a matrix multiplication key and a plain matrix. The multiplication with a rotation key is more memory-efficient but slower, while the multiplication with both a matrix multiplication key and a plain matrix is faster but consumes more memory.

Rotation Key

A matrix multiplication can be performed on various input types with just a rotation key. Below are examples for lists, NumPy arrays, and PyTorch tensors.

import torch
import numpy as np

from desilofhe import Engine

engine = Engine(slot_count=64)
secret_key = engine.create_secret_key()
rotation_key = engine.create_rotation_key(secret_key)

matrix = [[i * 64 + j for j in range(64)] for i in range(64)]
numpy_matrix = np.arange(64 * 64).reshape(64, 64)
tensor_matrix = torch.arange(64 * 64).reshape(64, 64)

message = [2] * 64
ciphertext = engine.encrypt(message, secret_key)

multiplied1 = engine.multiply_matrix(matrix, ciphertext, rotation_key)
multiplied2 = engine.multiply_matrix(numpy_matrix, ciphertext, rotation_key)
multiplied3 = engine.multiply_pytorch_tensor_matrix(
    tensor_matrix, ciphertext, rotation_key
)

Note that when using a PyTorch tensor, the API name is slightly different. For clarity, the library appends the suffix pytorch_tensor to all PyTorch tensor-related APIs.

Matrix Multiplication Key & Plain Matrix

Although basic matrix multiplication is convenient, it involves many rotation operations internally, making it slow. To improve efficiency, the DESILO FHE library provides two additional data structures: MatrixMultiplicationKey and PlainMatrix. Note that while this method improves computation speed, it consumes much more memory because additional data structures are created preemptively.

Matrix Multiplication Key

A matrix multiplication key is a precomputed set of fixed rotation keys optimized for a specific matrix size. Using it allows matrix multiplication with far fewer rotation operations than the basic method.

from desilofhe import Engine

engine = Engine(slot_count=64)
secret_key = engine.create_secret_key()
matrix_multiplication_key = engine.create_matrix_multiplication_key(secret_key)

Plain Matrix

A plain matrix is a set of plaintexts optimally prepared for matrix multiplication. It enables faster matrix multiplication than simply encoding a matrix in Plaintext.

import numpy as np

from desilofhe import Engine

engine = Engine(slot_count=64)

message = np.arange(64 * 64).reshape(64, 64)
plain_matrix = engine.encode_to_plain_matrix(message)

Matrix multiplication that uses both structures looks like the following:

import numpy as np

from desilofhe import Engine

engine = Engine(slot_count=64)
secret_key = engine.create_secret_key()
matrix_multiplication_key = engine.create_matrix_multiplication_key(secret_key)

matrix = np.arange(64 * 64).reshape(64, 64)
plain_matrix = engine.encode_to_plain_matrix(matrix)

message = [2] * 64
ciphertext = engine.encrypt(message, secret_key)

multiplied = engine.multiply_matrix(
    plain_matrix, ciphertext, matrix_multiplication_key
)