utility.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # -*- coding: utf-8 -*-
  2. import csv
  3. import matplotlib.pyplot as plt
  4. import seaborn as sns
  5. from sklearn import metrics, svm
  6. from sklearn.ensemble import RandomForestClassifier
  7. from sklearn.metrics import precision_recall_fscore_support as score
  8. from sklearn.neighbors import KNeighborsClassifier
  9. from xgboost.sklearn import XGBClassifier
  10. def show_classifier_metrics(y_train, pred_train, y_test, pred_test, print_classification_report=True, print_confusion_matrix=True):
  11. if print_confusion_matrix:
  12. plt.figure(figsize=(6, 4))
  13. sns.heatmap(metrics.confusion_matrix(y_train, pred_train), annot=True, fmt='d', cmap="viridis")
  14. plt.title('Confusion matrix | Training data')
  15. plt.show()
  16. plt.figure(figsize=(6, 4))
  17. sns.heatmap(metrics.confusion_matrix(y_test, pred_test), annot=True, fmt='d', cmap="viridis")
  18. plt.title('Confusion matrix | Test data')
  19. plt.show()
  20. if print_classification_report:
  21. print('Classification report | Test data')
  22. print(metrics.classification_report(y_test, pred_test))
  23. print('Accuracy | Test data: %f%%' % (metrics.accuracy_score(y_test, pred_test) * 100))
  24. print('Accuracy | Training data: %f%%' % (metrics.accuracy_score(y_train, pred_train) * 100))
  25. def get_classifier_metrics(y_test, pred_test):
  26. precision, recall, f1_score, support = score(
  27. y_test, pred_test, average='macro')
  28. acc = metrics.accuracy_score(y_test, pred_test) * 100
  29. return acc, precision, recall, f1_score
  30. def zero_rule_baseline(y):
  31. baseline = max(y.value_counts() * 100) / len(y)
  32. return baseline
  33. def create_classifier(classifier, dataset):
  34. if classifier == 'rf':
  35. clf = RandomForestClassifier(n_estimators=300, oob_score=True,
  36. min_samples_split=5, max_depth=10, random_state=10)
  37. elif classifier == 'svm':
  38. clf = svm.LinearSVC(max_iter=1000, dual=False)
  39. elif classifier == 'knn':
  40. clf = KNeighborsClassifier(n_neighbors=10)
  41. elif classifier == 'xgb':
  42. param = {}
  43. if dataset == 'cahousing' or dataset == 'cmc':
  44. param['objective'] = 'multi:softmax'
  45. param['num_class'] = 3
  46. param['learning_rate'] = 0.1
  47. param['verbosity'] = 1
  48. param['colsample_bylevel'] = 0.9
  49. param['colsample_bytree'] = 0.9
  50. param['subsample'] = 0.9
  51. param['reg_lambda'] = 1.5
  52. param['max_depth'] = 5
  53. param['n_estimators'] = 100
  54. param['seed'] = 10
  55. clf = XGBClassifier(**param)
  56. else:
  57. print('Invalid classifier!')
  58. return clf
  59. def write_results(s, ml_res, anon_method, output_path, num=''):
  60. if len(ml_res.acc) <= 1:
  61. return
  62. # Write results for OLA
  63. # Every Suppression is its own file
  64. if anon_method in ['ola']:
  65. output_path = output_path[:-4] + '_s_' + str(s) + '.csv'
  66. with open(
  67. output_path, 'w', newline=''
  68. ) as csvfile: # in Python 3 the writer writes an extra blank row #https://stackoverflow.com/questions/16271236/python-3-3-csv-writer-writes-extra-blank-rows
  69. writer = csv.writer(csvfile, delimiter=',')
  70. writer.writerows(zip(*ml_res))