Model Evaluation
Model evaluation is the process of assessing how well your machine learning model performs. It's crucial for understanding your model's strengths and weaknesses.
Evaluation Metrics
For Classification
- Accuracy: Proportion of correct predictions.
- Precision: Proportion of true positive predictions among all positive predictions.
- Recall: Proportion of true positive predictions among all actual positives.
- F1 Score: Harmonic mean of precision and recall.
- ROC AUC: Area under the Receiver Operating Characteristic curve.
For Regression
- Mean Squared Error (MSE): Average squared difference between predicted and actual values.
- Root Mean Squared Error (RMSE): Square root of MSE.
- Mean Absolute Error (MAE): Average absolute difference between predicted and actual values.
- R-squared: Proportion of variance in the dependent variable predictable from the independent variable(s).
Implementing Evaluation Metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
print(f"Precision: {precision_score(y_test, y_pred)}")
print(f"Recall: {recall_score(y_test, y_pred)}")
print(f"F1 Score: {f1_score(y_test, y_pred)}")
print(f"ROC AUC: {roc_auc_score(y_test, y_pred_proba)}")
# Confusion Matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()
# Learning Curve
from sklearn.model_selection import learning_curve
train_sizes, train_scores, test_scores = learning_curve(
model, X, y, cv=5, n_jobs=-1, train_sizes=np.linspace(0.1, 1.0, 10))
plt.plot(train_sizes, np.mean(train_scores, axis=1), label='Training score')
plt.plot(train_sizes, np.mean(test_scores, axis=1), label='Cross-validation score')
plt.xlabel('Training examples')
plt.ylabel('Score')
plt.legend()
plt.show()