콘텐츠로 이동

고급 예제

Quickstart보다 복잡한 FHE 활용 예제들입니다. 데이터의 회전이나 추출 같은 몇 가지 핵심 개념을 소개합니다. 구글 Colab에서 코드를 직접 수행해볼 수 있습니다.

A. 모든 요소의 합

data = [1, 2, 3, 4, 5, 6, 7, 8] 와 같은 벡터를 암호화했을때, 여덟 개의 모든 요소를 더하는 예제입니다. 본 예제에서는 데이터를 회전하고 더하여 암호화된 벡터의 8번째 요소에 요소들의 합을 저장합니다.

from desilofhe import Engine

engine = Engine()

secret_key = engine.create_secret_key()
public_key = engine.create_public_key(secret_key)
relinearization_key = engine.create_relinearization_key(secret_key)
rotation_key = engine.create_rotation_key(secret_key)

data = [1, 2, 3, 4, 5, 6, 7, 8]
encrypted = engine.encrypt(data, public_key)

rotated_data = encrypted
added = rotated_data
for i in range(7):
    rotated_data = engine.rotate(rotated_data, rotation_key, delta=1)
    added = engine.add(added, rotated_data)

decrypted = engine.decrypt(added, secret_key)

# 8번째 요소에 모든 요소의 합이 저장되어 있습니다.
print(decrypted[7])  # ~36

최적화된 방법

이전 예제를 좀 더 최적화한 예제입니다. 이전 예제에서는 오른쪽으로 한 칸씩 7번의 회전과 7번의 덧셈을 수행하여 전체 합을 얻었습니다. 하지만 오른쪽으로 각각 1칸, 2칸, 4칸 회전하는 3번의 회전과 3번의 덧셈만으로도 동일한 연산을 보다 효율적인 방식으로 수행할 수 있습니다.

from desilofhe import Engine

engine = Engine()

secret_key = engine.create_secret_key()
public_key = engine.create_public_key(secret_key)
relinearization_key = engine.create_relinearization_key(secret_key)
rotation_key = engine.create_rotation_key(secret_key)

data = [1, 2, 3, 4, 5, 6, 7, 8]
encrypted = engine.encrypt(data, public_key)

added = encrypted

# 1칸 회전하고 더하기
rotated_data = engine.rotate(added, rotation_key, 1)
added = engine.add(added, rotated_data)

# 2칸 회전하고 더하기
rotated_data = engine.rotate(added, rotation_key, 2)
added = engine.add(added, rotated_data)

# 4칸 회전하고 더하기
rotated_data = engine.rotate(added, rotation_key, 4)
added = engine.add(added, rotated_data)

decrypted = engine.decrypt(added, secret_key)

# 8번째 요소에 모든 요소의 합이 저장되어 있습니다.
print(decrypted[7])  # ~36

B. 데이터 추출 및 재구성

숫자 벡터를 암호화한 암호문은 CKKS 스킴이 SIMD를 지원하기 때문에 벡터 간의 덧셈이나 두 암호화된 벡터 간의 계수별 곱셈을 쉽게 수행할 수 있습니다.

이전 예제는 SIMD를 사용하지 않는 사례로, 암호화된 벡터의 모든 요소를 더하여 출력 암호문의 첫 번째 계수에 넣는 과정을 설명했습니다. 이때 주변의 다른 계수들은 원래 벡터 요소들의 부분합으로 채워지며, 이는 우리가 원하는 값이 아닙니다.

이번에는 암호문에서 필요한 데이터 일부를 추출하고, 이들을 새롭고 깔끔한 벡터로 재조합하는 방법을 살펴보겠습니다. 이 벡터는 필요한 값만 포함하고 나머지 모든 위치는 0으로 채워집니다.

해당 예제에서는 4개의 암호화된 벡터로부터 정보 일부를 추출하고 이를 재구성하여 [1, 2, 3, 4, 5, 6, 7, 8] 벡터를 결과로 얻습니다. 이 과정은 암호화된 벡터와 mask 된 데이터 간의 곱셈을 통해 올바른 요소를 추출하고, 추출된 요소를 회전시켜 올바른 위치로 옮긴 후, 마지막으로 덧셈을 통해 모든 정보를 하나의 암호문으로 병합하는 방식으로 진행됩니다.

from desilofhe import Engine

engine = Engine()

secret_key = engine.create_secret_key()
public_key = engine.create_public_key(secret_key)
relinearization_key = engine.create_relinearization_key(secret_key)
rotation_key = engine.create_rotation_key(secret_key)

# 추출할 값들은 index로 표시되어 있습니다.
# index         0         1
data1 = [12, 7, 1, 15, 9, 2, 11, 10]
# index  2  3
data2 = [3, 4, 20, 11, 17, 6, 9, 16]
# index         5     4
data3 = [9, 18, 6, 9, 5, 11, 13, 8]
# index                  6          7
data4 = [20, 19, 18, 17, 7, 14, 15, 8]

encrypted1 = engine.encrypt(data1, public_key)
encrypted2 = engine.encrypt(data2, public_key)
encrypted3 = engine.encrypt(data3, public_key)
encrypted4 = engine.encrypt(data4, public_key)

# data1에서 세 번째 요소를 추출
mask = [0, 0, 1, 0, 0, 0, 0, 0]
multiplied = engine.multiply(encrypted1, mask)
# 첫 번째 요소에 위치하도록 회전
rotated1 = engine.rotate(multiplied, rotation_key, -2)

# data1에서 여섯 번째 요소를 추출
mask = [0, 0, 0, 0, 0, 1, 0, 0]
multiplied = engine.multiply(encrypted1, mask)
# 두 번째 요소에 위치하도록 회전
rotated2 = engine.rotate(multiplied, rotation_key, -4)

# data2에서 첫 번째와 두 번째 요소를 추출
mask = [1, 1, 0, 0, 0, 0, 0, 0]
multiplied = engine.multiply(encrypted2, mask)
# 세 번째와 네 번째 요소에 위치하도록 회전
rotated34 = engine.rotate(multiplied, rotation_key, 2)

# data3에서 세 번째 요소를 추출
mask = [0, 0, 1, 0, 0, 0, 0, 0]
multiplied = engine.multiply(encrypted3, mask)
# 여섯 번째 요소에 위치하도록 회전
rotated6 = engine.rotate(multiplied, rotation_key, 3)

# data3에서 다섯 번째 요소를 추출
# (이미 올바른 위치에 있으므로 회전이 필요하지 않음)
mask = [0, 0, 0, 0, 1, 0, 0, 0]
rotated5 = engine.multiply(encrypted3, mask)

# data4에서 다섯 번째 요소를 추출
mask = [0, 0, 0, 0, 1, 0, 0, 0]
multiplied = engine.multiply(encrypted4, mask)
# 일곱 번째 요소에 위치하도록 회전
rotated7 = engine.rotate(multiplied, rotation_key, 2)

# data4에서 여덟 번째 요소를 추출
# (이미 올바른 위치에 있으므로 회전이 필요하지 않음)
mask = [0, 0, 0, 0, 0, 0, 0, 1]
rotated8 = engine.multiply(encrypted4, mask)

# 모든 요소를 더하여 하나의 암호문으로 병합
added = engine.add(rotated1, rotated2)
added = engine.add(added, rotated34)
added = engine.add(added, rotated5)
added = engine.add(added, rotated6)
added = engine.add(added, rotated7)
added = engine.add(added, rotated8)

# 복호화 후 결과 출력
decrypted = engine.decrypt(added, secret_key)
print(decrypted[:8])  # [~1 ~2 ~3 ~4 ~5 ~6 ~7 ~8]

C. 다항식 연산

이 예제에서는 SIMD 방식으로 다음과 같은 다항식을 평가하고자 합니다: x^3 - x^2 + sqrt(2)*x + 1. x의 입력값들은 [1, 2, 3, 4, 5, 6, 7, 8]을 암호화한 암호문에 저장됩니다. DESILO FHE 라이브러리를 사용하면 어떤 다항식이든 평가하는 것이 매우 간단합니다. 함수 evaluate_polynomial을 이용하여 바로 다항식을 평가할 수 있습니다.

from desilofhe import Engine
import math

engine = Engine()

secret_key = engine.create_secret_key()
public_key = engine.create_public_key(secret_key)
relinearization_key = engine.create_relinearization_key(secret_key)

data = [1, 2, 3, 4, 5, 6, 7, 8]
encrypted = engine.encrypt(data, public_key)

# 다항식 연산 p(x) = x^3 - x^2 + sqrt(2)*x + 1
coefficients = [1, math.sqrt(2), -1, 1]
polynomial = engine.evaluate_polynomial(encrypted, coefficients, relinearization_key)

# 복호화 후 결과 출력
decrypted = engine.decrypt(polynomial, secret_key)
print(decrypted[:8])
# [~2.4142 ~7.8284 ~23.2426 ~54.6569 ~108.0711 ~189.4853 ~304.8995 ~460.3137]

혹은 덧셈, 뺄셈, 그리고 곱셈을 활용하여 직접 다항식을 평가할 수도 있습니다.

from desilofhe import Engine
import math

engine = Engine()

secret_key = engine.create_secret_key()
public_key = engine.create_public_key(secret_key)
relinearization_key = engine.create_relinearization_key(secret_key)

data = [1, 2, 3, 4, 5, 6, 7, 8]
encrypted = engine.encrypt(data, public_key)

# 다항식 연산 p(x) = x^3 - x^2 + sqrt(2)*x + 1
coeff0 = 1
sqrt2 = math.sqrt(2)
# x^2를 연산
x2 = engine.square(encrypted, relinearization_key)
# x^3을 연산
x3 = engine.multiply(encrypted, x2, relinearization_key)
# sqrt(2)*x을 연산
x1 = engine.multiply(encrypted, sqrt2)
# 다항식 전체를 연산
polynomial = engine.subtract(x3, x2)
polynomial = engine.add(polynomial, x1)
polynomial = engine.add(polynomial, coeff0)

# 복호화 후 결과 출력
decrypted = engine.decrypt(polynomial, secret_key)
print(decrypted[:8])
# [~2.4142 ~7.8284 ~23.2426 ~54.6569 ~108.0711 ~189.4853 ~304.8995 ~460.3137]

D. 부호 함수 (Sign Function) 근사

이 예제에서는 구간 [-1,1]에서 부호 함수를 다항식으로 근사한 후, 이를 동형암호를 이용해 계산하고자 합니다. 이 함수는 입력값에 따라 다음과 같은 출력을 생성합니다: 음수 입력일 때 -1, 양수 입력일 때 1, 입력이 0일 때 0을 반환합니다. 이 예제는 해당 논문에서 제안된 근사 다항식을 사용합니다. 아래 예시 코드에서 확인할 수 있듯이 이 근사 방법은 총 8번 곱셈 레벨을 필요로 합니다. 이 다항식 근사 방법은 실제 부호 함수와의 최대 오차가 0.008 미만이며, 정확도를 더 높이고자 할 경우 더 높은 차수의 다항식을 사용해 근사 정밀도를 개선할 수 있습니다.

이 근사 다항식은 다항식 Q(X)와 P(X)의 합성함수인 Q(P(X))로 표현됩니다. 여기서 Q(X)의 계수는 p72, P(X)의 계수는 p71로 각각 주어집니다.

def sign(x, relinearization_key):
    # 다항식 근사 p(x) = p_{7,2}(p_{7,1}(x))
    p71 = [
        3.60471572275560 * 10**-36,
        7.30445164958251,
        -5.05471704202722 * 10**-35,
        -3.46825871108659 * 10,
        1.16564665409095 * 10**-34,
        5.98596518298826 * 10,
        -6.54298492839531 * 10**-35,
        -3.18755225906466 * 10,
    ]
    p72 = [
        -9.46491402344260 * 10**-49,
        2.40085652217597,
        6.41744632725342 * 10**-48,
        -2.63125454261783,
        -7.25338564676814 * 10**-48,
        1.54912674773593,
        2.06916466421812 * 10**-48,
        -3.31172956504304 * 10**-1,
    ]

    # y = p_{7,1}(x)를 연산
    y = engine.evaluate_polynomial(x, p71, relinearization_key)

    # p_{7,2}(p_{7,1}(x))를 연산
    return engine.evaluate_polynomial(y, p72, relinearization_key)

E. 인공 뉴런

이 예제에서는 SIMD 방식으로 인공 뉴런을 계산하고자 합니다. 여기서 뉴런은 4개의 실수를 입력으로 받아, 가중치 벡터(크기 4)와의 내적을 계산하고, 여기에 편향(bias)을 더합니다. 마지막 연산에서는 내적 결과에 대해 ReLU 함수를 적용합니다. ReLU 함수는 입력 실수 x에 대해, x가 양수이면 x를 출력하고 그렇지 않으면 0을 출력하는 함수입니다. 이 함수는 위에서 설명한 바와 같이 다항식 근사를 통해 계산됩니다.

내적과 편향

우선 내적 계산과 편향의 덧셈부터 시작해 보겠습니다.

def inner_product_bias(encrypted_data, weight, bias):
    inner_product = bias
    for encrypted_data, weight in zip(encrypted, weights):
        product = engine.multiply(encrypted_data, weight)
        inner_product = engine.add(inner_product, product)
    return inner_product

ReLU

[-1,1] 범위의 실수에 대해 ReLU 함수의 다항식 근사를 계산하는 함수를 작성해보겠습니다. ReLU 함수는 ReLU(x) = 0.5 * (x + x*sign(x))와 같이 부호 함수를 활용하여 표현할 수 있습니다. 그리고 부호 함수를 다항식 근사하여 ReLU 함수 전체를 다항식 형태로 계산합니다. 이를 위해 예제 D에서 사용한 부호 함수의 근사 기법을 사용합니다. 앞서 설명했듯이 해당 논문에서 제시된 근사 다항식을 사용합니다.

def relu(x, relinearization_key):
    # 다항식 근사 p(x) = 0.5 * (x + x * sign(x))

    # sign(x)을 연산
    sign_evaluation = sign(x, relinearization_key)

    # x * sign(x)을 연산
    multiplied = engine.multiply(sign_evaluation, x, relinearization_key)

    # x + x * sign(x)을 연산
    added = engine.add(multiplied, x)

    # 최종적으로 0.5 * (x + x * sign(x))을 연산
    result = engine.multiply(added, 0.5)
    return result

인공 뉴런

이제 앞서 작성한 두 함수를 합쳐 최종 코드를 작성할 수 있습니다. 연산을 더 빠르게 수행하기 위해, 더 많은 곱셈 레벨을 지원하는 엔진을 병렬 CPU 모드로 사용해 보겠습니다.

from desilofhe import Engine
import math

# CPU 병렬 모드로 높은 곱셈 레벨을 가진 엔진을 사용합니다.
engine = Engine(max_level=17, mode="parallel")

secret_key = engine.create_secret_key()
public_key = engine.create_public_key(secret_key)
relinearization_key = engine.create_relinearization_key(secret_key)
rotation_key = engine.create_rotation_key(secret_key)

# data는 SIMD 방식으로 뉴런을 한 번에 8번 연산할 수 있을 만큼 충분한 데이터를 저장하고 있습니다.
# [0.1, 0.9, 1.5, 0.8]이 첫 번째 입력, [0.2, 1.0, 0.3, 1.0]이 두 번째 입력 같은 방식으로 이어집니다.
data = [
    [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
    [0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 0.7, 1.6],
    [1.5, 0.3, 0.0, 0.7, 1.1, 1.3, 0.2, 0.8],
    [0.8, 1.0, 1.6, 1.2, 0.3, 0.7, 0.1, 1.1],
]

# 암호화
encrypted = [engine.encrypt(d, public_key) for d in data]

# 입력과 가중치의 내적에 편향을 더함
weights = [-0.4, -1.2, 0.6, 1.0]
bias = 0.34
inner_product = inner_product_bias(encrypted, weights, bias)
# [~0.92 ~0.24 ~0.5 ~0.36 ~-0.46 ~-0.1 ~-0.56 ~-0.32]

# ReLU 연산
neuron_output = relu(inner_product, relinearization_key)

# 복호화 후 결과 출력
decrypted = engine.decrypt(neuron_output, secret_key)
print(decrypted[:8])  # [~0.92 ~0.24 ~0.5 ~0.36 ~0.0 ~0.0 ~0.0 ~0.0]

F. Argmax

이 예제에서는 테이블 형태의 값들에 대해 argmax 함수를 동형암호를 이용해 계산하고자 합니다. 예를 들어, 입력이 [0.1, 0.6, 0.2, 0.3]인 경우, 출력은 [0.0, 1.0, 0.0, 0.0]이 되어야 합니다. 우선, 같은 암호문 안에 존재하는 두 암호화된 값 중 최댓값을 계산하는 함수를 정의하는 것부터 시작합니다.

이 예제에서 사용하는 방법은 해당 논문에서 제안된 방법을 기반으로 합니다.

두 값 사이 최댓값 함수

여기에서는 두 개의 암호화된 값 리스트(즉, 두 개의 암호문) 간에 SIMD 방식으로 최댓값을 계산하는 함수를 구현하고자 합니다. 예제 D에서는 부호 함수의 동형 연산을 구현한 바 있으며, 이번에는 이를 활용하여 최댓값 함수를 구성합니다. 실제로, 두 값 ab 사이의 최댓값은 다음 수식으로 계산할 수 있습니다: 0.5 * (a + b + (a - b) * sign(a - b)). 이 최댓값 함수는 10 곱셈 레벨을 소모하므로, 이 연산을 수행하기 위해서는 입력 암호문의 곱셈 레벨이 최소 10 이상이어야 합니다.

def max(a, b, relinearization_key):
    # max(a, b) = 0.5 * (a + b + (a - b) * sign(a - b))

    # a - b을 연산
    subtracted = engine.subtract(a, b)

    # sign(a - b)을 연산
    subtracted_sign = sign(subtracted, relinearization_key)

    # (a - b) * sign(a - b)을 연산
    multiplied = engine.multiply(
        subtracted, subtracted_sign, relinearization_key
    )

    # a + b + (a - b) * sign(a - b)을 연산
    added = engine.add(multiplied, engine.add(a, b))

    # 0.5 * (a + b + (a - b) * sign(a - b))을 연산
    result = engine.multiply(added, 0.5)

    return result

테이블 내 최댓값 함수

이 함수의 목표는 테이블 내 최댓값을 계산한 후, 그 값을 테이블의 모든 위치에 채워 출력하는 것입니다. 예를 들어 입력이 [0.1, 0.2, 0.3, 0.4]라면, 출력이 [0.4, 0.4, 0.4, 0.4]가 되어야 합니다.

이 연산은 네 개의 0.4 값을 한 번에 계산하는 방식으로 이루어집니다. 이를 위해 먼저 테이블의 크기를 두 배로 확장하고, 원래 데이터를 두 번 복제합니다. 예를 들어 [0.1, 0.2, 0.3, 0.4][0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]로 변환됩니다. 그 후 분할 정복(divide and conquer) 방식으로 진행하며, log2(크기)만큼 반복하여 암호문을 회전시킨 뒤, 회전된 암호문과 현재 암호문 간의 최댓값 함수를 계산합니다.

최댓값 연산은 10개의 곱셈 레벨을 소모하기 때문에, 해당 연산을 수행하기 전 암호문의 곱셈 레벨이 10보다 작다면 부트스트래핑을 수행해야 합니다. 부트스트래핑은 암호문의 곱셈 레벨을 다시 10으로 복원해 줍니다. 암호문 ct의 현재 곱셈 레벨은 ct.level을 통해 확인할 수 있으며, 필요할 경우 최댓값 함수 내부에서 부트스트래핑을 호출할 수도 있습니다.

해당 논문에서는 왼쪽 회전을 사용하지만, 본 예제는 효율성을 고려하여 반복문에서는 오른쪽 회전을 사용하고, 마지막에 한 번만 왼쪽 회전을 수행합니다. 또한, 더 이상 필요하지 않은 값들은 mask를 곱하여 0으로 초기화합니다.

def quick_max(
    a, n, rotation_key, relinearization_key, conjugation_key, bootstrap_key
):  # n은 2의 제곱수이고, 2n < N
    log2n = int(math.log2(n))

    right_rotated = engine.rotate(a, rotation_key, n)
    added = engine.add(a, right_rotated)

    mask = [1] * n

    for i in range(log2n):
        left_rotated = engine.rotate(added, rotation_key, -(2**i))
        temp = max(added, left_rotated, relinearization_key)

        # 부트스트랩
        temp = engine.bootstrap(
            temp, relinearization_key, conjugation_key, bootstrap_key
        )
        added = temp

    # n개의 원소를 출력하기 위해 mask를 곱함
    multiplied = engine.multiply(added, mask)
    return multiplied

테이블의 Argmax 함수

Argmax 계산의 핵심 아이디어는 원래 데이터에서 quick_max 결과를 뺀 후, 그 차이에 부호 함수를 적용하여 최댓값의 위치를 판별하는 것입니다. 예를 들어 [0.1, 0.2, 0.3, 0.4]는 먼저 [0.4, 0.4, 0.4, 0.4]로 변환되고, 원래 데이터와의 차이는 [-0.3, -0.2, -0.1, 0.0]가 됩니다. 이 값에 부호 함수를 적용하면 [-1.0, -1.0, -1.0, 0.0]가 되고, 여기에 1을 더하면 최댓값의 위치를 나타내는 [0.0, 0.0, 0.0, 1.0]을 얻을 수 있습니다.

def argmax(
    a, n, rotation_key, relinearization_key, conjugation_key, bootstrap_key
):  # n은 2의 제곱수이고, 2n < N
    # quick_max을 연산
    a_max = quick_max(
        a, n, rotation_key, relinearization_key, conjugation_key, bootstrap_key
    )

    # a와 a_max의 차이를 연산
    diff_values = engine.subtract(a, a_max)

    # diff_values의 부호 함수를 연산
    # 이 경우에 0과 -1만 존재
    sign_values = sign(diff_values, relinearization_key)

    # 1을 더하여 -1과 0으로 이루어진 데이터를 0과 1로 이루어진 데이터로 변환
    # n번째 이후의 불필요한 값을을 0으로 유지하기 위해 mask를 사용
    mask = [1] * n
    output = engine.add(sign_values, mask)

    return output

코드 실행하기

이제 암호화된 데이터를 사용하여 argmax 함수를 직접 실행해볼 수 있습니다.

from desilofhe import Engine
import math

engine = Engine(use_bootstrap=True)
secret_key = engine.create_secret_key()
public_key = engine.create_public_key(secret_key)
rotation_key = engine.create_rotation_key(secret_key)
relinearization_key = engine.create_relinearization_key(secret_key)
conjugation_key = engine.create_conjugation_key(secret_key)
bootstrap_key = engine.create_bootstrap_key(secret_key, stage_count=3)

# 데이터
data = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
n = len(data)

# 암호화
encrypted = engine.encrypt(data, public_key)

# argmax를 연산
encrypted_argmax = argmax(
    encrypted,
    n,
    rotation_key,
    relinearization_key,
    conjugation_key,
    bootstrap_key,
)

# 복호화 후 결과를 출력
decrypted = engine.decrypt(encrypted_argmax, secret_key)
print((decrypted[:16]))  # [~0 ~0 ~0 ~0 ~0 ~0 ~0 ~1 ~0 ...]