핸즈온머신러닝&딥러닝

MNIST 활용, 분류 2

threegopark 2021. 5. 19. 14:47
728x90

이전 포스팅에 이어,,,

 

 

정밀도 / 재현율 트레이드오프

 

  실제(양성) 실제(음성)
예측(양성) TruePositive(TP) FalsePositive(FP)
예측(음성) FalseNegative(FN) TrueNegative(TN)

정밀도 (Precision) : TP / (TP+FP)

민감도 (재현율, Recall) : TP / (TP+FN)

특이도 (Specificity) : TN / (FP + TN)

F1 Score : 2 * (Precision*Recall) / (Precision + Recall)

 

  • SGDClassifier 분류기 : 결정 함수를 사용하여 각 샘플의 점수 계산
  • 이 점수를 활용, 1. 점수 > 임곗값 --> 샘플 (양성)   2. 점수 < 임곗값 --> 샘플 (음성) 에 할당

예를 들어 임계값이 가운데 화살표인 경우(정밀도 80%, 재현율 67%), 임계값 오른쪽에는 4개의 실제 양성(5)과 1개의 거짓 양성(6)이 존재한다. 오른쪽 5개의 숫자 중 정확히 5라고 예측한 것이 4개이므로 정밀도는 80% (4/5) 이다. 전체 중 숫자 5는 6개 존재하며 해당 임계값에서는 분류기가 4개만 인식하였으므로 재현율은 67% (4/6) 이다.

 

이번에는 임계값을 오른쪽 화살표로 이동시켜보자. 마찬가지로 계산해보면 정밀도는 100%, 재현율은 이전보다 낮은 50%가 된다. 

 

이렇게 재현율이 커지면 정밀도가 작아지고, 정밀도가 커지면 재현율이 작아지는 현상을 정밀도/재현율 트레이드오프라고 한다. 사용자는 임계값을 선택하여 이러한 트레이드오프를 적절히 조율해야 한다.

 

 

위에서 분류기를 통해 예측에 사용한 점수를 계산하고 임계값과 비교하여 양성, 음성 판단을 한다고했다.

사이킷런에서는 decision_function() 함수를 사용하여 예측에 사용한 점수를 알 수 있다.

y_scores = sgd_clf.decision_function([some_digit])
y_scores

임계값이 0일 경우 해당 점수를 활용하여 some_digit라는 것이 True인지 (5인지) False인지 (5가 아닌지) 확인해보자.

threshold = 0   #임계값을 0으로 설정
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
#임계값보다 크면 양성 (True)

이번엔 임계값을 8000으로 올린 후 예측점수를 활용해보자.

threshold = 8000   #임계값을 8000으로 설정
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
#임계값보다 크면 양성 (True)

이번엔 False가 나왔다. 실제 5인데 5가 아니라고 예측한 것이다. 위의 그림에서 확인했던 것처럼 임계값을 올리면 재현율(실제 참을 참이라고 예측할 확률)이 떨어지는 것을 알 수 있다.

 

그렇다면 적절한 임계값은 어떻게 정할 수 있을까? 이를 위해서는

1. cross_val_predict()함수를 사용해 훈련 세트에 있는 모든 샘플의 점수를 구한다.

2. 하지만 예측 결과가 아니라 결정 점수를 반환받도록 지정해야 한다.

3. 이 점수로 precision_recall_curve() 함수를 사용하여 가능한 모든 임계값에 대해 정밀도와 재현율을 계산한다.

4. 시각화하여 확인한다.

 

y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method = "decision_function")

from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label = "Precision")
    plt.plot(thresholds, recalls[:-1], "g-", label = "Recall")
    plt.legend()
    plt.grid()
    
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.show()

 

또한 재현율에 대한 정밀도 곡선을 그려서 확인하는 방법도 있다. (PR 곡선)

def plot_precision_vs_recall(precisions, recall):
    plt.plot(recalls[:-1], precisions[:-1], "b--", label = "Precision")
    plt.legend()
    plt.grid()
    
plot_precision_vs_recall(precisions, recalls)
plt.show()

재현율 80% 근처에서 정밀도가 급격하게 줄어들기 시작한다. 이하강점 직전을 정밀도/재현율 트레이드오프로 선택하는 것이 바람직하다.

 

정밀도 90% 달성을 목표라고 하고 그에 해당하는 임계값을 찾아보자.

threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]

(+ agrmax() 함수는 최댓값의 첫 번째 인덱스를 반환한다.)

 

 

 

ROC 곡선(수신기 조작 특성, Receiver operating characteristic)

 

  • 거짓 양성 비율(FPR)에 대한 진짜 양성 비율(TPR)의 곡선
  • FPR : 양성으로 잘못 분류된 음성 샘플의 비율 (1 - 음성으로 정확하게 분류한 음성 샘플의 비율,TNR)
  • TNR : 특이도
  • 즉, 민감도에 대한 1 - 특이도 그래프
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)

def plot_roc_curve(fpr, tpr, label=None):
    plt.plot(fpr,tpr, linewidth=2, label = label)
    plt.plot([0,1],[0,1],'k--')
    plt.grid()

plot_roc_curve(fpr, tpr)
plt.show()
  • roc_curve() 함수를 사용하여 fpr, tpr, thresholds를 구한 후 roc곡선을 그린다.

왼쪽 모서리가 가장 좋은 분류기이다.

+ 그래프 설명 : 곡선 아래의 면적(AUC)를 통해 분류기를 판단할 수 있다.

AUC = 1 : 완벽한 분류기

AUC = 0.5 : 완전한 랜덤 분류기

 

from sklearn.metrics import roc_auc_score

roc_auc_score(y_train_5, y_scores)

사이킷런의 roc_auc_score() 함수를 사용하여 아래 면적을 구할 수 있다.

 

 

그렇다면 랜덤포레스트 분류기를 활용하여 auc곡선을 그려보고 위의 sgd분류기와 비교해보자.

 

from sklearn.ensemble import RandomForestClassifier

forest_clf = RandomForestClassifier(random_state = 42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv = 3, method = "predict_proba")
#랜덤포레스트에는 decision_function() 대신 predict_proba() 사용, ex)어떤 이미지가 5일 확률 70%

y_scores_forest = y_probas_forest[:, 1]
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)

plt.plot(fpr, tpr, "b:", label="SGD")
plot_roc_curve(fpr_forest, tpr_forest, "random_f")
plt.legend(loc="lower right")
plt.show()

--> 랜덤포레스트 분류기가 sgd분류기보다 더 좋은 성능을 보여주는 것을 알 수 있다.

'핸즈온머신러닝&딥러닝' 카테고리의 다른 글

모델 훈련 (수정 필요)  (0) 2021.05.30
MNIST 활용, 분류 3  (0) 2021.05.27
MNIST 활용, 분류  (0) 2021.05.18
캘리포니아 주택 가격 예측2  (0) 2021.05.16
캘리포니아 주택 가격 예측  (0) 2021.05.15