콘텐츠로 이동

행렬 곱셈

기본적으로 Ciphertext는 긴 배열이기 때문에 동형암호 연산으로도 기본적으로 여러 암호문을 만들고 더하는 방식으로 행렬과 백터의 곱을 구할 수 있습니다. 하지만 이런 단순한 방식은 엄청난 시간과 메모리를 소비하기에 DESILO FHE 라이브러리는 좀 더 효율적인 방식으로 행렬 곱셈을 하는 연산을 지원합니다. 해당 문서에서는 제공되는 연산들을 활용해 어떻게 하면 더 효율적으로 행렬 곱셈을 할 수 있는지 설명합니다.

엔진 slot count 설정

효율화된 행렬 곱셈을 사용하기 위해서는 우선 입력되는 행렬이 정방행렬이어야만 합니다. 암호문 상태에서의 효율적인 연산을 위해 기본적으로 엔진에 입력되는 행렬은 Engine.slot_countEngine.slot_count열이어야 합니다. 연산 시간과 메모리 사용량은 slot_count가 커질수록 제곱 이상의 비율로 증가하므로, 행렬 크기에 따라 엔진의 slot_count를 적절히 설정하는 것이 중요합니다.

from desilofhe import Engine

engine = Engine(slot_count=64)  # 크기가 64x64인 행렬의 연산 지원

연산 방법

동일한 행렬 곱셈에 대해 회전 키만을 사용하는 방법과 행렬 곱셈 키와 행렬 평문을 활용하는 방법 두 가지가 있습니다. 회전 키를 사용하는 곱셈은 메모리 친화적이나 느리고, 행렬 곱셈 키와 행렬 평문을 사용하는 방법은 빠르지만 메모리를 더 많이 사용합니다.

회전 키

기본적으로는 회전키만 있으면 다양한 타입의 인풋에 대하여 행렬 곱셈이 가능합니다. 다음은 리스트, numpy array, PyTorch tensor에 대해 각각 행렬 곱셈을 수행하는 예시입니다. 회전키를 사용한 행렬 곱셈은 메모리 사용량이 적지만 더 느리고, 행렬 곱셈키와 행렬 평문을 사용한 곱셈은 빠르지만 메모리 사용량이 더 큽니다.

import torch
import numpy as np

from desilofhe import Engine

engine = Engine(slot_count=64)
secret_key = engine.create_secret_key()
rotation_key = engine.create_rotation_key(secret_key)

matrix = [[i * 64 + j for j in range(64)] for i in range(64)]
numpy_matrix = np.arange(64 * 64).reshape(64, 64)
tensor_matrix = torch.arange(64 * 64).reshape(64, 64)

message = [2] * 64
ciphertext = engine.encrypt(message, secret_key)

multiplied1 = engine.multiply_matrix(matrix, ciphertext, rotation_key)
multiplied2 = engine.multiply_matrix(numpy_matrix, ciphertext, rotation_key)
multiplied3 = engine.multiply_pytorch_tensor_matrix(
    tensor_matrix, ciphertext, rotation_key
)

주의할 점은 PyTorch tensor의 경우 API의 이름이 다르다는 점입니다. 라이브러리에서는 현재 제공되는 다양한 tensor 관련 API들은 쉬운 용도 식별을 위해 pytorch_tensor라는 접미어를 붙이고 있습니다.

행렬 곱셈 키 & 행렬 평문

일반 행렬 곱셈은 편리하지만, 내부적으로 많은 회전 연산을 수행하기 때문에 시간이 오래 걸립니다. 좀 더 효율적인 연산을 위해 DESILO FHE 라이브러리는 행렬 곱셈 키(matrix multiplication key)와 행렬 평문(plain matrix)라는 두 가지 자료구조를 추가적으로 제공합니다. 해당 방법은 연산 속도를 향상시키지만 추가적인 자료구조들을 미리 생성해야하기에 메모리를 더 소비한다는 점에 주의해야 합니다.

행렬 곱셈 키

행렬 곱셈 키는 특정 행렬 크기에 맞춰 미리 계산된 고정 회전 키들의 집합입니다. 행렬 곱셈 키를 사용하면 일반 행렬 곱셈보다 훨씬 적은 수의 회전 연산으로 행렬 곱셈을 수행할 수 있습니다.

from desilofhe import Engine

engine = Engine(slot_count=64)
secret_key = engine.create_secret_key()
matrix_multiplication_key = engine.create_matrix_multiplication_key(secret_key)

행렬 평문

행렬 평문은 효율적인 행렬 곱셈을 위해 미리 최적화된 평문의 집합입니다. 단순히 행렬을 평문으로 인코딩하는 방식보다 더 빠른 행렬 곱셈 연산을 가능하게 합니다.

import numpy as np

from desilofhe import Engine

engine = Engine(slot_count=64)

message = np.arange(64 * 64).reshape(64, 64)
plain_matrix = engine.encode_to_plain_matrix(message)

두 자료구조를 모두 활용한 행렬 곱셈은 다음과 같습니다.

import numpy as np

from desilofhe import Engine

engine = Engine(slot_count=64)
secret_key = engine.create_secret_key()
matrix_multiplication_key = engine.create_matrix_multiplication_key(secret_key)

matrix = np.arange(64 * 64).reshape(64, 64)
plain_matrix = engine.encode_to_plain_matrix(matrix)

message = [2] * 64
ciphertext = engine.encrypt(message, secret_key)

multiplied = engine.multiply_matrix(
    plain_matrix, ciphertext, matrix_multiplication_key
)

대각선 인덱스 목록을 지정한 행렬 평문

본 라이브러리에서 행렬은 대각선 단위로 인코딩됩니다. 대각선은 행의 순서에 따라 정의됩니다. 예를 들어, 4x4 행렬에서 각 위치의 대각선 인덱스는 다음과 같습니다:

[[0, 3, 2, 1],
 [1, 0, 3, 2],
 [2, 1, 0, 3],
 [3, 2, 1, 0]]

대각선이 모두 0으로 구성되어 있으면 해당 대각선과의 곱셈 결과도 0입니다. 따라서 0으로 채워진 대각선은 인코딩할 필요가 없습니다. 인코딩이 불필요한 대각선이 있다면, diagonal_indices 파라미터를 사용하여 필요한 대각선만 인코딩하여 메모리 사용량과 연산 시간을 줄일 수 있습니다.

대각선 2와 3에만 0이 아닌 값이 있는 4x4 희소 행렬을 예로 들어보겠습니다.

원본 행렬은 다음과 같습니다:

[[0, 1, 2, 0],
 [0, 0, 3, 4],
 [5, 0, 0, 6],
 [7, 8, 0, 0]]

원본 행렬에서 각 대각선을 추출하면: - 대각선 0: [0, 0, 0, 0] - 대각선 1: [0, 0, 0, 0] - 대각선 2: [5, 8, 2, 4] - 대각선 3: [7, 1, 3, 6]

대각선 0과 1이 모두 0이므로, 대각선 2와 3만 인코딩하면 됩니다:

import numpy as np

from desilofhe import Engine

example_matrix = np.array(
    [[0, 1, 2, 0], [0, 0, 3, 4], [5, 0, 0, 6], [7, 8, 0, 0]]
)

engine = Engine(slot_count=4)
secret_key = engine.create_secret_key()
matrix_multiplication_key = engine.create_matrix_multiplication_key(secret_key)

diagonal_indices = [2, 3]
plain_matrix = engine.encode_to_plain_matrix(
    example_matrix, engine.max_level - 1, diagonal_indices
)

message = [1, 2, 3, 4]
ciphertext = engine.encrypt(message, secret_key)

multiplied = engine.multiply_matrix(
    plain_matrix, ciphertext, matrix_multiplication_key
)