본문 바로가기

Machine Learning_모델설계_Python

판별분석_4 이차판별분석(QDA)

안녕하십니까 배도리입니다. 오늘은 이차판별분석(Quadratic Discriminant Analysis, QDA)에 대해서 설명하겠습니다. 이차판별분석은 선형판별분석(Linear Discriminant Analysis, LDA)와 마찬가지로 분류 문제를 해결하기 위한 기법 중 하나입니다. 그러나 QDA는 LDA와 큰 차이점이 있습니다.

 

QDA는 각 클래스의 데이터가 서로 다른 공분산 구조를 가질 수 있다고 가정합니다. 이는 각 클래스에 속하는 데이터가 다른 형태의 분포를 가질 수 있다는 의미입니다. 예를 들어, 한 클래스의 데이터는 원형 분포를 보이는 반면, 다른 클래스의 데이터는 타원형 분포를 보일 수 있습니다(비선형 분류 가능) . QDA는 이러한 다양한 공분산 구조를 수용할 수 있기 때문에, 데이터의 복잡성을 더 잘 반영할 수 있습니다. 그러나 QDA는 서로 다른 공분산 데이터 분류를 위해 샘플이 많이 필요하다는 단점이 있습니다. 결정경계는 𝑥에 대한 선형방정식이 아닌 이차식의 형태가 됩니다.

 

 

 

Σ𝑘 = Σ 가 성립하지 않기 때문에, 결정경계는 𝑥에 대한 선형방정식이 아닌 이차식의 형태가 됩니다.

 

 

 

이전 자료의 데이터셋 기반으로 살펴보겠습니다.

from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

qda = QuadraticDiscriminantAnalysis(store_covariance=True).fit(X, y)
#store_covariance=True 옵션을 사용하면, 각 클래스별 공분산 행렬을 저장할 수 있습니다.
#qda.covariance_를 사용하여 저장된 공분산 행렬을 검색합니다.
#qda 객체를 사용하여 데이터 X와 레이블 y를 사용하여 모델을 학습합니다. 이렇게 하면 QDA 모델이 학습되어 분류를 수행할 수 있게 됩니다.

 

 

각 클래스의 사전 확률을 봅시다. 사전 확률은 학습 데이터에서 각 클래스가 발생할 확률로, 클래스 불균형이 있는 경우 분류기의 성능에 영향을 줄 수 있습니다.

qda.priors_
array([0.33333333, 0.33333333, 0.33333333])

 

 

각 클래스에 대한 평균 벡터를 보겠습니다. 이전 자료에서 말했듯이 평균 벡터는 다변량 정규 분포의 중심 위치를 나타내며, 각 클래스에 대한 데이터의 중심 경향을 설명합니다. qda.means_ 속성은 각 클래스의 특성별 평균 값을 포함하는 2차원 배열 형태로 반환됩니다. 이를 통해 클래스 간 데이터의 차이를 이해할 수 있으며, QDA 모델에서 이러한 차이를 활용하여 분류를 수행합니

qda.means_
.array(
[[-8.01254084e-04, 1.19457204e-01],
[ 1.16303727e+00, 1.03930605e+00],
[-8.64060404e-01, 1.02295794e+00]])

 

 

각 클래스에 대한 공분산 행렬을 보겠습니다. 각 클래스에 대한 공분산 행렬은 해당 클래스의 데이터 분포와 관련된 정보를 제공합니다. 공분산 행렬은 변수 간의 선형 관계와 분산을 설명합니다.

qda.covariance_
[array(
[[ 0.73846319, -0.01762041],
[-0.01762041, 0.72961278]]),
array(
[[0.66534246, 0.21132313],
[0.21132313, 0.78806006]]),
array(
[[0.9351386 , 0.22880955],
[0.22880955, 0.79142383]])]
 

 

 

 

시각화를 통해서 살펴보겠습니다.

import seaborn as sns
import matplotlib as mpl

x1min, x1max = -5, 5
x2min, x2max = -4, 5
XX1, XX2 = np.meshgrid(np.arange(x1min, x1max, (x1max-x1min)/1000),
                       np.arange(x2min, x2max, (x2max-x2min)/1000))
YY = np.reshape(qda.predict(np.array([XX1.ravel(), XX2.ravel()]).T), XX1.shape)
cmap = mpl.colors.ListedColormap(sns.color_palette(["r", "g", "b"]).as_hex())
plt.contourf(XX1, XX2, YY, cmap=cmap, alpha=0.5)
plt.scatter(X1[:, 0], X1[:, 1], alpha=0.8, s=50, marker="o", color='r', label="class 1")
plt.scatter(X2[:, 0], X2[:, 1], alpha=0.8, s=50, marker="s", color='g', label="class 2")
plt.scatter(X3[:, 0], X3[:, 1], alpha=0.8, s=50, marker="x", color='b', label="class 3")
plt.xlim(x1min, x1max)
plt.ylim(x2min, x2max)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.title("QDA Result")
plt.legend()
plt.show()