The confusion matrix plots the number of times each class was predicted incorrectly as a function of the number of times each class was labeled.
# Plot the confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(
confusion_mtx, xticklabels=class_names, yticklabels=class_names, annot=True, fmt="g"
)
plt.xlabel("Prediction")
plt.ylabel("Label")
plt.title("Validation Confusion Matrix")
plt.show()