인공지능/Machine Learning

[Machine Learning] 분류 예측 K-NN

건휘맨 2024. 4. 15. 17:45

KNeighborsClassifier() : 가장 가까운 n개 데이터(이웃데이터)로 분류

 

from sklearn.neighbors import KNeighborsClassifier

classifier = KNeighborsClassifier(n_neighbors= n)
# n_neighbors= 몇개의 데이터로 분석할건지 입력, 디폴트값은 5
>>> classifier = KNeighborsClassifier(n_neighbors=7)
>>> classifier.fit(X_train, y_train)

>>> y_pred = classifier.predict(X_test)
>>> y_pred
array([0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1,
       1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0,
       1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1,
       0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0,
       1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], dtype=int64)
       

from sklearn.metrics import confusion_matrix, accuracy_score

>>> cm = confusion_matrix(y_test, y_pred)
>>> cm
array([[49,  9],
       [ 3, 39]], dtype=int64)
       
>>> accuracy_score(y_test, y_pred)  # (50 + 39) / 100 (cm.sum())
0.88