콘텐츠로 이동

행렬 곱셈

기본적으로 Ciphertext는 긴 배열이기 때문에 동형암호 연산으로도 기본적으로 여러 암호문을 만들고 더하는 방식으로 행렬과 백터의 곱을 구할 수 있습니다. 하지만 이런 단순한 방식은 엄청난 시간과 메모리를 소비하기에 DESILO FHE 라이브러리는 좀 더 효율적인 방식으로 행렬 곱셈을 하는 연산을 지원합니다. 해당 문서에서는 제공되는 연산들을 활용해 어떻게 하면 더 효율적으로 행렬 곱셈을 할 수 있는지 설명합니다.

엔진 slot count 설정

효율화된 행렬 곱셈을 사용하기 위해서는 우선 입력되는 행렬이 정방행렬이어야만 합니다. 암호문 상태에서의 효율적인 연산을 위해 기본적으로 엔진에 입력되는 행렬은 Engine.slot_countEngine.slot_count열이어야 합니다. 연산 시간과 메모리 사용량은 slot_count가 커질수록 제곱 이상의 비율로 증가하므로, 행렬 크기에 따라 엔진의 slot_count를 적절히 설정하는 것이 중요합니다.

from desilofhe import Engine

engine = Engine(slot_count=64)  # 크기가 64x64인 행렬의 연산 지원

연산 방법

동일한 행렬 곱셈에 대해 회전 키만을 사용하는 방법과 행렬 곱셈 키와 행렬 평문을 활용하는 방법 두 가지가 있습니다. 회전 키를 사용하는 곱셈은 메모리 친화적이나 느리고, 행렬 곱셈 키와 행렬 평문을 사용하는 방법은 빠르지만 메모리를 더 많이 사용합니다.

회전 키

기본적으로는 회전키만 있으면 다양한 타입의 인풋에 대하여 행렬 곱셈이 가능합니다. 다음은 리스트, numpy array, PyTorch tensor에 대해 각각 행렬 곱셈을 수행하는 예시입니다. 회전키를 사용한 행렬 곱셈은 메모리 사용량이 적지만 더 느리고, 행렬 곱셈키와 행렬 평문을 사용한 곱셈은 빠르지만 메모리 사용량이 더 큽니다.

import torch
import numpy as np

from desilofhe import Engine

engine = Engine(slot_count=64)
secret_key = engine.create_secret_key()
rotation_key = engine.create_rotation_key(secret_key)

matrix = [[i * 64 + j for j in range(64)] for i in range(64)]
numpy_matrix = np.arange(64 * 64).reshape(64, 64)
tensor_matrix = torch.arange(64 * 64).reshape(64, 64)

message = [2] * 64
ciphertext = engine.encrypt(message, secret_key)

multiplied1 = engine.multiply_matrix(matrix, ciphertext, rotation_key)
multiplied2 = engine.multiply_matrix(numpy_matrix, ciphertext, rotation_key)
multiplied3 = engine.multiply_pytorch_tensor_matrix(
    tensor_matrix, ciphertext, rotation_key
)

주의할 점은 PyTorch tensor의 경우 API의 이름이 다르다는 점입니다. 라이브러리에서는 현재 제공되는 다양한 tensor 관련 API들은 쉬운 용도 식별을 위해 pytorch_tensor라는 접미어를 붙이고 있습니다.

행렬 곱셈 키 & 행렬 평문

일반 행렬 곱셈은 편리하지만, 내부적으로 많은 회전 연산을 수행하기 때문에 시간이 오래 걸립니다. 좀 더 효율적인 연산을 위해 DESILO FHE 라이브러리는 행렬 곱셈 키(matrix multiplication key)와 행렬 평문(plain matrix)라는 두 가지 자료구조를 추가적으로 제공합니다. 해당 방법은 연산 속도를 향상시키지만 추가적인 자료구조들을 미리 생성해야하기에 메모리를 더 소비한다는 점에 주의해야 합니다.

행렬 곱셈 키

행렬 곱셈 키는 특정 행렬 크기에 맞춰 미리 계산된 고정 회전 키들의 집합입니다. 행렬 곱셈 키를 사용하면 일반 행렬 곱셈보다 훨씬 적은 수의 회전 연산으로 행렬 곱셈을 수행할 수 있습니다.

from desilofhe import Engine

engine = Engine(slot_count=64)
secret_key = engine.create_secret_key()
matrix_multiplication_key = engine.create_matrix_multiplication_key(secret_key)

행렬 평문

행렬 평문은 효율적인 행렬 곱셈을 위해 미리 최적화된 평문의 집합입니다. 단순히 행렬을 평문으로 인코딩하는 방식보다 더 빠른 행렬 곱셈 연산을 가능하게 합니다.

import numpy as np

from desilofhe import Engine

engine = Engine(slot_count=64)

message = np.arange(64 * 64).reshape(64, 64)
plain_matrix = engine.encode_to_plain_matrix(message)

두 자료구조를 모두 활용한 행렬 곱셈은 다음과 같습니다.

import numpy as np

from desilofhe import Engine

engine = Engine(slot_count=64)
secret_key = engine.create_secret_key()
matrix_multiplication_key = engine.create_matrix_multiplication_key(secret_key)

matrix = np.arange(64 * 64).reshape(64, 64)
plain_matrix = engine.encode_to_plain_matrix(matrix)

message = [2] * 64
ciphertext = engine.encrypt(message, secret_key)

multiplied = engine.multiply_matrix(
    plain_matrix, ciphertext, matrix_multiplication_key
)