{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to use `indigopy`\n",
    "\n",
    "Example code for how to use the `indigopy` package. The sample data used in this example notebook is derived from the [INDIGO](https://doi.org/10.1007/978-1-4939-8891-4_13), [INDIGO-MTB](https://doi.org/10.1128/mbio.02627-19), and [MAGENTA](https://doi.org/10.1371/journal.pcbi.1006677) publications.  \n",
    "\n",
    "## Set up environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import dependencies\n",
    "import pandas as pd\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.metrics import r2_score, classification_report\n",
    "from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Import package functions\n",
    "from indigopy.core import load_sample, featurize, classify"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: *E. coli*\n",
    "\n",
    "The following analysis and results were originally reported in the [INDIGO](https://doi.org/10.1007/978-1-4939-8891-4_13) paper.  \n",
    "- **Training dataset**: 105 two-way interactions between 15 antibiotics  \n",
    "- **Testing dataset**: 66 two-way interactions between the 15 antibiotics in the training set + 4 new antibiotics  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Defining INDIGO features: 100%|██████████| 105/105 [00:00<00:00, 424.50it/s]\n",
      "Defining INDIGO features: 100%|██████████| 66/66 [00:00<00:00, 441.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regression results:\n",
      "\tSpearman R = 0.6791\n",
      "\tSpearman p = 3.68e-10\n",
      "\tR2 = 0.4068\n",
      "Classification results:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           A       0.44      0.31      0.36        13\n",
      "           N       0.68      0.90      0.78        42\n",
      "           S       1.00      0.09      0.17        11\n",
      "\n",
      "    accuracy                           0.65        66\n",
      "   macro avg       0.71      0.43      0.44        66\n",
      "weighted avg       0.69      0.65      0.59        66\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Load sample data\n",
    "sample = load_sample('ecoli')\n",
    "\n",
    "# Define input arguments\n",
    "key             = sample['key']\n",
    "profiles        = sample['profiles']\n",
    "feature_names   = sample['feature_names']\n",
    "train_ixns      = sample['train']['interactions']\n",
    "train_scores    = sample['train']['scores']\n",
    "test_ixns       = sample['test']['interactions']\n",
    "test_scores     = sample['test']['scores']\n",
    "\n",
    "# Determine ML features\n",
    "train_data      = featurize(train_ixns, profiles, feature_names=feature_names, key=key, silent=True)\n",
    "test_data       = featurize(test_ixns, profiles, feature_names=feature_names, key=key, silent=True)\n",
    "X_train, X_test = train_data['feature_df'].to_numpy().transpose(), test_data['feature_df'].to_numpy().transpose()\n",
    "\n",
    "# Determine class labels\n",
    "thresh, classes = (-0.5, 2), ('S', 'N', 'A')\n",
    "train_labels    = classify(train_scores, thresholds=thresh, classes=classes)\n",
    "test_labels     = classify(test_scores, thresholds=thresh, classes=classes)\n",
    "\n",
    "# Train and apply a regression-based model\n",
    "reg_model = RandomForestRegressor()\n",
    "reg_model.fit(X_train, train_scores)\n",
    "reg_y = reg_model.predict(X_test)\n",
    "r1, p = spearmanr(test_scores, reg_y)\n",
    "r2 = r2_score(test_scores, reg_y)\n",
    "print('Regression results:')\n",
    "print('\\tSpearman R = {}'.format(round(r1, 4)))\n",
    "print('\\tSpearman p = {:.3g}'.format(p))\n",
    "print('\\tR2 = {}'.format(round(r2, 4)))\n",
    "\n",
    "# Train and apply a classification-based model\n",
    "class_model = RandomForestClassifier()\n",
    "class_model.fit(X_train, train_labels)\n",
    "class_y = class_model.predict(X_test)\n",
    "print('Classification results:')\n",
    "print(classification_report(test_labels, class_y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: *M. tuberculosis*\n",
    "\n",
    "The following analysis and results were originally reported in the [INDIGO-MTB](https://doi.org/10.1128/mbio.02627-19) paper.  \n",
    "- **Training dataset**: 196 two- to five-way interactions between 40 antibacterials  \n",
    "- **Testing dataset**: 36 two- to three-way interactions between the 13 antibacterials  \n",
    "- **Clinical dataset**: clinical outcomes for 57 two- to five-way interactions between 7 antibacterials  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Defining INDIGO features: 100%|██████████| 196/196 [00:00<00:00, 346.44it/s]\n",
      "Defining INDIGO features: 100%|██████████| 36/36 [00:00<00:00, 513.40it/s]\n",
      "Defining INDIGO features: 100%|██████████| 57/57 [00:00<00:00, 491.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regression results:\n",
      "\tSpearman R = 0.5264\n",
      "\tSpearman p = 0.000975\n",
      "\tR2 = 0.0905\n",
      "Classification results:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           A       0.67      0.38      0.48        16\n",
      "           N       0.00      0.00      0.00         1\n",
      "           S       0.73      0.84      0.78        19\n",
      "\n",
      "    accuracy                           0.61        36\n",
      "   macro avg       0.46      0.41      0.42        36\n",
      "weighted avg       0.68      0.61      0.63        36\n",
      "\n",
      "Clinical results:\n",
      "\tSpearman R = 0.3955\n",
      "\tSpearman p = 0.00232\n"
     ]
    }
   ],
   "source": [
    "# Load sample data\n",
    "sample = load_sample('mtb')\n",
    "\n",
    "# Define input arguments\n",
    "key             = sample['key']\n",
    "profiles        = sample['profiles']\n",
    "feature_names   = sample['feature_names']\n",
    "train_ixns      = sample['train']['interactions']\n",
    "train_scores    = sample['train']['scores']\n",
    "test_ixns       = sample['test']['interactions']\n",
    "test_scores     = sample['test']['scores']\n",
    "clinical_ixns   = sample['clinical']['interactions']\n",
    "clinical_scores = sample['clinical']['scores']\n",
    "\n",
    "# Determine ML features\n",
    "train_data      = featurize(train_ixns, profiles, feature_names=feature_names, key=key, silent=True)\n",
    "test_data       = featurize(test_ixns, profiles, feature_names=feature_names, key=key, silent=True)\n",
    "clinical_data   = featurize(clinical_ixns, profiles, feature_names=feature_names, key=key, silent=True)\n",
    "X_train, X_test = train_data['feature_df'].to_numpy().transpose(), test_data['feature_df'].to_numpy().transpose()\n",
    "X_clinical      = clinical_data['feature_df'].to_numpy().transpose()\n",
    "\n",
    "# Determine class labels\n",
    "thresh, classes = (0.9, 1.1), ('S', 'N', 'A')\n",
    "train_labels    = classify(train_scores, thresholds=thresh, classes=classes)\n",
    "test_labels     = classify(test_scores, thresholds=thresh, classes=classes)\n",
    "\n",
    "# Train and apply a regression-based model\n",
    "reg_model = RandomForestRegressor()\n",
    "reg_model.fit(X_train, train_scores)\n",
    "reg_y = reg_model.predict(X_test)\n",
    "r, p = spearmanr(test_scores, reg_y)\n",
    "r2 = r2_score(test_scores, reg_y)\n",
    "print('Regression results:')\n",
    "print('\\tSpearman R = {}'.format(round(r, 4)))\n",
    "print('\\tSpearman p = {:.3g}'.format(p))\n",
    "print('\\tR2 = {}'.format(round(r2, 4)))\n",
    "\n",
    "# Train and apply a classification-based model\n",
    "class_model = RandomForestClassifier()\n",
    "class_model.fit(X_train, train_labels)\n",
    "class_y = class_model.predict(X_test)\n",
    "print('Classification results:')\n",
    "print(classification_report(test_labels, class_y))\n",
    "\n",
    "# Apply model to clinical data\n",
    "clinical_y = reg_model.predict(X_clinical)\n",
    "r, p = spearmanr(clinical_scores, clinical_y)\n",
    "print('Clinical results:')\n",
    "print('\\tSpearman R = {}'.format(round(-r, 4)))\n",
    "print('\\tSpearman p = {:.3g}'.format(p))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: *S. aureus*\n",
    "\n",
    "The following analysis and results were originally reported in the [INDIGO](https://doi.org/10.1007/978-1-4939-8891-4_13) paper.  \n",
    "- **Training dataset**: 171 two-way interactions between 19 antibiotics measured in *E. coli*  \n",
    "- **Testing dataset**: 45 two-way interactions between the 10 antibiotics measured in *S. aureus*  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Defining INDIGO features: 100%|██████████| 171/171 [00:00<00:00, 524.04it/s]\n",
      "Defining INDIGO features: 100%|██████████| 45/45 [00:00<00:00, 412.21it/s]\n",
      "Mapping orthologous genes: 100%|██████████| 1/1 [00:02<00:00,  2.55s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regression results:\n",
      "\tSpearman R = 0.4781\n",
      "\tSpearman p = 0.000894\n",
      "\tR2 = -1.1485\n",
      "Classification results:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           A       0.00      0.00      0.00         2\n",
      "           N       0.49      1.00      0.66        22\n",
      "           S       0.00      0.00      0.00        21\n",
      "\n",
      "    accuracy                           0.49        45\n",
      "   macro avg       0.16      0.33      0.22        45\n",
      "weighted avg       0.24      0.49      0.32        45\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\carol\\AppData\\Local\\Programs\\PythonCodingPack\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "# Load sample data\n",
    "sample = load_sample('saureus')\n",
    "\n",
    "# Define input arguments\n",
    "key             = sample['key']\n",
    "profiles        = sample['profiles']\n",
    "feature_names   = sample['feature_names']\n",
    "train_ixns      = sample['train']['interactions']\n",
    "train_scores    = sample['train']['scores']\n",
    "test_ixns       = sample['test']['interactions']\n",
    "test_scores     = sample['test']['scores']\n",
    "strains         = sample['orthology']['strains']\n",
    "orthology_map   = sample['orthology']['map']\n",
    "\n",
    "# Determine ML features\n",
    "train_data      = featurize(train_ixns, profiles, feature_names=feature_names, key=key, silent=True)\n",
    "test_data       = featurize(test_ixns, profiles, feature_names=feature_names, key=key, silent=True, \n",
    "                            strains=strains, orthology_map=orthology_map)\n",
    "X_train, X_test = train_data['feature_df'].to_numpy().transpose(), test_data['feature_df'].to_numpy().transpose()\n",
    "\n",
    "# Determine class labels\n",
    "thresh, classes = (-0.5, 2), ('S', 'N', 'A')\n",
    "train_labels    = classify(train_scores, thresholds=thresh, classes=classes)\n",
    "test_labels     = classify(test_scores, thresholds=thresh, classes=classes)\n",
    "\n",
    "# Train and apply a regression-based model\n",
    "reg_model = RandomForestRegressor()\n",
    "reg_model.fit(X_train, train_scores)\n",
    "reg_y = reg_model.predict(X_test)\n",
    "r, p = spearmanr(test_scores, reg_y)\n",
    "r2 = r2_score(test_scores, reg_y)\n",
    "print('Regression results:')\n",
    "print('\\tSpearman R = {}'.format(round(r, 4)))\n",
    "print('\\tSpearman p = {:.3g}'.format(p))\n",
    "print('\\tR2 = {}'.format(round(r2, 4)))\n",
    "\n",
    "# Train and apply a classification-based model\n",
    "class_model = RandomForestClassifier()\n",
    "class_model.fit(X_train, train_labels)\n",
    "class_y = class_model.predict(X_test)\n",
    "print('Classification results:')\n",
    "print(classification_report(test_labels, class_y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: *A. baumannii*\n",
    "\n",
    "The following analysis and results were originally reported in the [MAGENTA](https://doi.org/10.1371/journal.pcbi.1006677) paper.  \n",
    "- **Training dataset**: 338 two- to three-way interactions between 24 antibiotics measured in *E. coli* cultured in various media conditions  \n",
    "- **Testing dataset**: 45 two-way interactions between the 8 antibiotics measured in *A. baumannii*  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Defining INDIGO features: 100%|██████████| 338/338 [00:00<00:00, 457.89it/s]\n",
      "Defining INDIGO features: 100%|██████████| 45/45 [00:00<00:00, 495.31it/s]\n",
      "Mapping orthologous genes: 100%|██████████| 1/1 [00:02<00:00,  2.90s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regression results:\n",
      "\tSpearman R = 0.5987\n",
      "\tSpearman p = 1.4e-05\n",
      "\tR2 = -0.4241\n",
      "Classification results:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           A       0.40      0.94      0.56        17\n",
      "           N       0.00      0.00      0.00        11\n",
      "           S       0.80      0.24      0.36        17\n",
      "\n",
      "    accuracy                           0.44        45\n",
      "   macro avg       0.40      0.39      0.31        45\n",
      "weighted avg       0.45      0.44      0.35        45\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\carol\\AppData\\Local\\Programs\\PythonCodingPack\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAEXCAYAAACDChKsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAsTAAALEwEAmpwYAAAklklEQVR4nO3deXxcVd3H8U+SpmyBtLaFUqEFQX7gg1qEqiWhdQHxAVtbFpdWyoNKBQTFjYIsD1R9oCiggBVBpVWsC8hWVgVKS0oBQRYR+MnWhbZIKU2gQCVN8vxx7pTpOElubmbmZjLf9+vVV+eeu/2SyczvnnPuPaeqo6MDERGRnqpOOwARESlPSiAiIpKIEoiIiCSiBCIiIokogYiISCJKICIikogSiAhgZrVmtsrMbou5/dlmdmmx4yo2M5tpZtOi14+Y2aCUQ5IyMiDtAET6iMnAY8C+ZraXuz+ZdkCl4O5nZb0enWIoUoaUQESCE4DfA88AJwNfibHPXma2CHgH8DBwgru/ZmafAr4LDAS2B+a6+5lm9hHgUnffGyB72czOBnaL/o0A7gf+DBwN7Aqc4u6/M7MdgJ8DOwDDgWXAZ9z9JTNbCswBPg6MBP7g7qdE5/kB8BywN7AF8FV3X2Bmc4DH3f1HZtYBDHP3l3v+65NKpCYsqXhm9h7gw8AfgbnAUWY2JMauuwOHA+8FqoAzzKwK+BZwtLvvFx33NDMbGuN4jcB/A3sBBwHvcfdxwInAOdE2nwOWuPtY4F3AG8BRWceoc/cDgP2Bk8xs16j8Q8AF7r4P8Evg7BjxiHRJCUQEjgdudvdX3P2vwPPEq4Fc6+5r3L0DuBI4KHo9gdAU9r/AhYTksk2M493h7i3u/iawCsj0xzxLqOXg7j8B7jWzbwKzCTWKuqxj3BBttxJ4KbMfsMzdH4le/y2rXCQxJRCpaGa2DTANaDSzpVEz0I7AV82stpvd27JeVwGt0fEeBj5A+KL+DtAare+I/s8YmHO8f+cst+aJdxYwE1gDXE5o5so+5ptZr7PP11m5SGJKIFLppgIvAyPcfRd334XQNFQHfKabfSea2WAzqwGmA7cC7wa2A85w9/nAeEKfQw3hS3+kmW0fNXVNShDvwcCP3f03hBrGQdGxRUpOCUQq3fHAhe6+qTbh7s3AxcDJZnaLmU3sZN8ngJuAvwPNwHmEO7luAp4ys78BE6Ptdnf3Jwgd4A8C9wGrE8Q7E/iRmT0EXAs0EfpiREquSsO5i4hIEqqBiIhIIkogIiKSiBKIiIgkogQiIiKJVMpQJlsAYwh3vbR1s62IiAQ1hOei/sp/PqdUMQlkDHBP2kGIiJSpAwi3jG+mUhLIaoB1616nvV23LYuIxFFdXcXgwdtAJ88sVUoCaQNob+9QAhER6bm8Tf/qRBcRkUSUQEREJBElEBERSUQJpB9obl7HeefNpKWlOe1QRKSCKIH0A/PnX8fTTzs33nht2qGISAVRAilzzc3raGpaSEdHB01Ni1QLEZGSUQIpc/PnX7fp1uT29nbVQkSkZJRAytySJYtpa9sIQFvbRpYsWZxyRCJSKVJ7kNDMtgPuBT7l7ktz1o0GrgDqgUXAce6+0cxGAlcB2wMOTHX39aWMu68ZO7aBRYvupq1tIzU1Axg7tiHtkESkQqRSAzGzDxHGVdmjk02uAk5y9z2AKuDYqHw2MNvd9yRMC3pmsWPt6yZMmEx1dRUA1dXVTJx4WMoRiUilSKsJ61jgq8Cq3BVmNgrYyt3vi4rmAEeaWS0wDrgmu7zokfZxgwYNprFxPFVVVTQ2jqO+flDaIYlIhUilCcvdvwxgZvlWj2DzgbtWAzsBQ4FX3X1jTnnFmzBhMitXvqDah4iUVF8cTLEqT1l7F+WxDRlSlyigvm7YsG254IIfph2GiFSYvphAVgLDs5Z3JDR1rQG2M7Mad2/LKo9t7dr1Go1XRCSm6uqqLi+8+9xtvO6+DNhgZpnbiaYBt7p7K2FSqM9ml6cQooiI0IcSiJndYmb7RYtTgYvM7ElgG+DiqPwEYLqZPUGYIeuM0kcqIiIAVR0dFdGkswvwvJqwRETiy2rC2hVY+h/rSx2QiIj0D0ogIiKSiBKIiIgkogQiIiKJKIGIiEgiSiAiIpKIEoiIiCSiBCIiIokogYiISCJKICIikogSiIiIJKIEIiIiiSiBiIhIIkogIiKSiBKIiIgkogQiIiKJKIGIiEgiSiAiIpKIEoiIiCSiBCIiIokMSOOkZjYFOAMYCFzk7j/NWjcamJO1+TBgnbvvbWbTgFnAv6J1N7v76SUJWkRENlPyBGJm7wR+AOwL/Bu418wWuPsTAO7+CDA62nZr4AHguGj3McA33f13JQ5bRERypFEDORC4y91fATCza4AjgJl5tj0NWOjuTdHyGGB3MzsV+DtwkruvK0HMBbF48SKamhYW/LgtLc0A1NcPKvixARobx9PQMK4oxxaR8pVGH8gIYHXW8mpgp9yNzGwQMB04J2fbswk1lBXApUWKsay0tLTQ0tKSdhgiUmHSqIFU5Slrz1M2Fbje3V/KFLj75MxrMzsfeK4nJx4ypK4nmxfcpEmHMmnSoQU/7mmnnQbAueeeW/Bji4h0Jo0EshI4IGt5R2BVnu0mAf+XWTCzeuCL7n5RVFQFtPbkxGvXrqe9vaNHwZaD1tY2ANaseS3lSESkP6mururywjuNJqw7gI+b2bCok/xw4LbsDcysitDJviSreD1wipl9KFo+EbiuBPGKiEgeJU8g7r4SOB1YADwCzHP3B8zsFjPbL9psGPCWu2/I2q8N+AzwMzN7kpBgTilp8CIiskkqz4G4+zxgXk7ZIVmvXwKG59nvHuADRQ9QRES6pSfRRUQkkVRqICLlqByf49EzPFJMqoGIpEzP8Ui5Ug1EJKaGhnFFuZqfNet7AMyYcWbBjy1STKqBiIhIIkogIiKSiBKIiIgkogQiIiKJKIGIiEgiSiAiIpKIEoiIiCSiBCIiIokogYiISCJVHR39b4KlPHYBno87odS8eb9mxYplRQ+qUJYvD7GOHDkq5Uh6ZuedRzFlyrSCHrPc3jsoz/evGO+d9D1ZE0rtCizNXa+hTPJYsWIZ/vQz1Gw5KO1QYmlvqwHgmRUvpxxJfG0bmoty3BUrlrH0macYXlc+f9pbRzM6b3jxmZQjiefF9RvTDkH6iPL5lJVYzZaD2HrUx9MOo996Y9mdRTv28LoBHPO+dxTt+JXuysdeSTsE6SPUByIiIokogYiISCJKICIikogSiIiIJNJtJ7qZ1QGzgD2BI4FzgW+5+/qkJzWzKcAZwEDgInf/ac76s4AvAeuioivc/admNhq4AqgHFgHHubtuCRERSUGcGsjFQDOwA7AB2A64POkJzeydwA+ARuD9wHQze0/OZmOAz7n76OhfJsFcBZzk7nsAVcCxSeMQEZHeiZNA9nH304FWd38DmAqM7sU5DwTucvdX3P114BrgiJxt9gNmmNljZnapmW1pZqOArdz9vmibOYQakYiIpCDOcyBtOcs1ED35lMwIYHXW8mrgg5mFqMnsYeDbhCcf5wBnAjfl2W+nXsTRqZaWZto2NBf1WYVK17ahmZYWPYYkUs7ifIIXmdksYCszOxg4EVjQi3NW5SnblJCivpVDMstmdgHwK+DmrvaLI3okv1s1Nbq3oBRqaqoZNmzbgh6ztraGDQU9ouRTW1tT8PdOyk+cBDIDOBVoIfRd3A58rxfnXAkckLW8I7Aqs2BmI4ED3f1XUVEV0BrtN7yz/eKIOxZWXd121Gz5lp5EL6I3lt1JXd12rFnzWkGP29qaW2GWYmhtbSv4eyd9T9ZYWHnFSSAz3f00epc0st0BnG1mw4DXgcOB6Vnr3wTON7MFhCasrwLXufsyM9tgZg3uvhiYBtxaoJhERKSH4rTVfKqQJ3T3lcDphGawR4B57v6Amd1iZvu5+xrgK8B8wAk1kAui3acCF5nZk8A2hDvEREQkBXFqIM+Z2Z+BJmDTsx/ufmHSk7r7PGBeTtkhWa//BPwpz36PktXhLiIi6YmTQDJDb+6aVVYRk4iIiEjnuk0g7n4MQPQcRq27l8ekBSIiUlRxhjLZHbiB8PxGtZm9DBzq7k8VOzgREem74nSiXwqc7+6D3b0e+D4wu7hhiYhIXxcngezg7nMzC+5+JTCseCGJiEg5iNOJPsDM3uHurwCY2VDUiS4iZWTx4kU0NS0s+HFbWpoBqK8fVPBjAzQ2jqehYVxRjl0IcRLIJcB9ZvaHaPmzwEXFC0lEpDy0tLQAxUsgfV2cu7AuN7OngU8SmryOd3eNMigiZaOhYVxRruRnzQoDdMyYcWbBj10Ouu0DiebvONLdZwC/AE4ys+Hd7CYiIv1cnE70uUDmlt1lwN2E0XFFRKSCxekDGeruFwO4+wbgx2Z2dHHDEkmmpaWZdes3cuVjr3S/sSTy4vqNDI46j6WyxamBDDCzEZkFM9uB/HN6iIhIBYlTA7kQeMTMbiPcvnsg8J2iRiWSUH39ILZ482WOed870g6l37rysVfYskLvOpLNdVsDiSZ2OpAwzeyDwMHRaLoiIlLB4k5Kvd7dLzKzw4HDzGyFu7cUM7C0ldOc6O0bwySu1QO2TDmS+No2NAND0w5DRHohzmCKP4/+/zHhocLbgV8CRxQ1shTtvPOotEPokeXLlwEwcudy+kIeWna/ZxHZXJwayL6ESZxOBea6+2lm9tfihpWuKVOmpR1Cj1T6w0wiko44d2FVu3s7cBBwV1S2TfFCEhGRchAngTxjZrcA7wLuNrPfAo8WNywREenr4iSQYwjzl49391bgHuBLRY1KRET6vDiDKb4OXJW1fFlRIxIRkbIQ9zbegjKzKcAZwEDgInf/ac76TwPnEJ54fx44xt3Xmdk0YBbwr2jTm9399NJFLiIiGSVPINHovj8g3N31b+BeM1vg7k9E67cDfgaMcfeVZjYTOBv4OjAG+Ka7/67UcYuIyObi9IEU2oHAXe7+StQ8dg2bP1NSC5zg7iuj5ceAkdHrMcA0M3vUzK4ys8Eli1pERDbTaQ3EzBbQxdS17v6xhOccAazOWl5NeM4kc9y1wPVRDFsRnj+5JGvb84AHgP8DLgWmxj3xkCF1CUPu22prawAYNmzblCNJX21tDUvLbDTe9W+1A1A3MI3ruZ57cf1G3l1bo7839Nnrqgnr0uj/yUA9YQ6QjcBRQHMvzplvJN/23AIzqyckkkfdfS6Au0/OWn8+8FxPTrx27Xra2/vfdO6trW0ArFnzWsqRpG/48J02/T7KxUvRSAJDh5fHk/m7EH7P+nvr/5+96uqqLi+8O00g7v4nADP7DrB/9DAhZnYzsKQXMa0EDsha3hFYlb2Bme1IGDLlLuAbUVk98EV3z8zHXgW09iIO6YfKbRQB0EgCUr5iTSgFbAm8ES1vC/RmrOw7gLPNbBjwOnA4MD2z0sxqgJuAP7r797P2Ww+cYmb3uvv9wInAdb2IQ0T6mHnzfs2KFcvSDiO2zDh0mYuAcrHzzqMKcrEVJ4HMA+43s2sJV/1HApcnPWF0Z9XpwALCbby/cPcHoqfdzwJ2BvYBasws07n+oLt/2cw+A/ws6hv5J1B+l5si0qkVK5bxz+ecmvqBaYcSS3tNaMJ6du3zKUcSX1vLWwU7VpwHCc8yswcJd091EG6jvbU3J43mE5mXU3ZI9PJBOrk7zN3vAT7Qm3OLSN9WUz+Q+nEjut9QEmlZtKr7jWKKe9vHi8A/gFPoXQe6iIj0E90mEDM7BriSkDzqgRvM7NhiByYiIn1bnBrIScBY4FV3f4nwBPnJxQxKRET6vjgJpM3dX80suPsKwvMgIiJSweIkkFfMbDTRU+lmNhUon8d8RUSkKOLcxnsycDWwm5mtAjYAny5mUCIi0vfFSSBPAe8H9gBqAAf656BSIiISW5wE8pC7fwB4MlNgZouB9xQtKhER6fO6Go33TsLw6Vub2atZqwYAfyt2YCIi0rd1VQOZTBjz6leEedEzNrL5cOwiIlKBOr0Ly91fdfelwCRgirtnRjj7DrBV8UMTEZG+LM5tvFcCQ6LXzYTbea8oVkAiIlIe4nSiv9vdDwdw9xbgG2b2aHHDEhGRvi5ODaTWzLbLLJhZHflnFRQRkQoSpwbya8J8IFcTmq8OIzRriYhIBeu2BuLu5wIzCCPx1gGnuPuFxQ5MRET6tk4TSKbZyszeATQB5wA/INRGejOlrYiI9ANdNWHdTZj972WigRQjVdFyTfHCEhGRvq7TBBINX4K7x521UEREKkhXQ5lM62pHd/914cMREZFy0VUT1pHR/8OBPYG7CMOYfBR4mHB3ViJmNgU4AxgIXOTuP81ZP5rwsGI9sAg4zt03mtlI4Cpge8KowFPdfX3SOEpt8eJFNDUtLPhxly8PgwTMmvW9gh8boLFxPA0N44pybBEpX10NZTLB3ScALwGj3f3T0QOFo+nFjIRm9k5CZ3wjYZj46WaWO7LvVcBJ7r4Hoc8lMwf7bGC2u+8JPAicmTSO/qS+vp76+vq0wxCRChPnOZCR7v5sZsHdl5vZTr0454HAXe7+CoCZXQMcAcyMlkcBW7n7fdH2c4BzzOwXwDjC2FyZ8oWEW4zLQkPDOF3Ji0i/ESeBrDazcwhf2FXAdOC5XpxzBJuP5rsa+GA363cChgKvuvvGnPLYhgzRPFjS99TWhhsahw3bNuVI0pf5XUhx1dbWFOTvLU4COZrQdPQo0A7cxubDu/dUvmFQ2mOs726/bq1du5729o7uNxQpodbWNgDWrHkt5UjSl/ldSHG1trbF+nurrq7q8sK72wTi7quByWY22N3X9SjK/FYCB2Qt7wisylk/PM/6NcB2Zlbj7m159hMRkRLq9hkPC/4BPG5mI8zsSTPbsxfnvAP4uJkNM7OtgcMJtRoAonlHNphZQ1Q0DbjV3VuBe4DPZpf3Ig4REemFOE1YlwAnA+e7+yozuwS4nNCh3WPuvtLMTgcWEG7j/YW7P2BmtwBnufuDwFTgCjPblnDL8MXR7icAc83sDGA58PkkMYhI39TS0szG5n/TskiNC8WysfnftAxoLsix4iSQIe7+FzMDwN1nm9n03pzU3ecB83LKDsl6/Sibd6xnypcBH+nNuUVEpDDiJJAOM9uSaDwsMxuOxsESkSKorx/EyxvXUT9uRNqh9Fsti1ZRXz+oIMeKM87Vz4Dbge3N7FzgPsJdWSIiUsHi3IX1SzN7GjgUqAWmu/ufix6ZiIj0ad0mEDO7090/ThiTSkREBIjXhDXIzLYpeiQiIlJW4nSivw4sM7PHgE0j37r7xKJFJT3S3LyOyy67hOOP/1rBOsdERLoTJ4H8suhRSK/Mn38dTz/t3HjjtRx11BfTDkdEKkSXCcTM9gZeA+5395WlCUl6orl5HU1NC+no6KCpaRETJx6mWoiIlESnfSBmdgyh43wG8KiZfaJkUUls8+dft2mAyPb2dm688dqUIxKRStFVJ/rXgL3d/UPABODU0oQkPbFkyWLa2sII921tG1myZHHKEYlIpejyLix3XxX9vwQYVpKIpEfGjm2gpia0RNbUDGDs2IZu9hARKYyuEkjuxBmJp7GV4pkwYTLV1WGqlOrqaiZOPCzliESkUsR5DiRDMzH1QYMGDaaxcTxVVVU0No5TB7qIlExXd2G9z8xezVreOlquAjrcfbvihiZxTZgwmZUrX1DtQ0RKqqsEslvJopBeGTRoMKeeelbaYYgURFvLW2UzH0j7hjAFb/WW5TNAeVvLWzCkMMfqNIFEc2+IiJTMzjuPSjuEHlm+PHxNjhxSRnEPKdzvOc6T6CIiJTFlyrS0Q+iRWbO+B8CMGWemHEk6etKJLiIisokSiIiIJKIEIiIiiZS8D8TMRgJXAdsDDkx19/U52+wIXAkMB9qBb7v7XWZWC6wFnsvafF93bytJ8CIiskkaneizgdnu/nszOxM4kzBgY7YfAje5+6VmZsBCM3sn8D5gibsfXNqQRWDx4kU0NS0s+HEzd/JkOmQLqbFxPA0N4wp+XBEocQKJahDjgElR0RxgIf+ZQK4FFkSvnwG2BOqAMcAwM7svWjfD3Qv/iRYpofr6+rRDEEmk1DWQocCr7p4ZV2s1sFPuRu6ePSb5t4GH3b3FzDqA64HvAaOBW81sb3d/Oc7Jhwyp60XoUukmTTqUSZMOTTsM6UNqa8MDhMOGbZtyJOkoWgIxsyOBi3KK/5ln0/YujnEy8BVgPIC7/zxr9cNmdj/QANwQJ6a1a9dvmjtDRKS3WltD9+uaNa+lHElxVFdXdXnhXbQE4u5XA1dnl2U6wc2sJur43hHIO2aBmZ0PHAqMc/cXorKjgHvd/dlosyqgtUg/goiIdKGkt/G6eytwD/DZqGgacGvudlHN46NAQyZ5RN4PfCvaxoB9ouOJiEiJpXEX1gnAXDM7A1gOfB7AzI4DRgD/G/17Fbg75AkADgFmAr8ys8cJw8tPc/f+WXcUEenjSp5AokEaP5Kn/LKsxcFdHOKIQsckIiI9pyfRRUQkESUQERFJRAlEREQSUQIREZFElEBERCQRJRAREUlECURERBJRAukHmpvXcd55M2lpaU47FBGpIEog/cD8+dfx9NPOjTde2/3GIiIFogRS5pqb19HUtJCOjg6amhapFiIiJaMEUubmz79u0xD17e3tqoWISMkogZS5JUsW09YW5udqa9vIkiWLU45IRCqFEkiZGzu2gZqaMCZmTc0Axo5tSDkiEakUSiBlbsKEyVRXVwFQXV3NxImHpRyRiFQKJZAyN2jQYBobx1NVVUVj4zjq6welHZKIVIg0JpSSApswYTIrV76g2oeIlJQSSD8waNBgTj31rLTDEJEKoyYsERFJRAlEREQSUQIREZFESt4HYmYjgauA7QEHprr7+jzb/AN4Nir6l7sfbGYDgV8C+wFvAlPc/amSBS8iIpukUQOZDcx29z2BB4Ez82wzBpjn7qOjfwdH5V8DXnf3vYCTgbmlCFhERP5TSROImdUC44BroqI5wJF5Nh0D7G1mD5rZXWb23qj8UOC3AO6+CBga1VZERKTESt2ENRR41d03RsurgZ3ybLcB+LW7/9zMDgGuN7O9gBHRPuTsvzzOyYcMqUscuIhIrtraGgCGDds25UjSUbQEYmZHAhflFP8zz6btuQXufnbW61vM7FxgL6Aqzv6dWbt2/aaRa0VEequ1tQ2ANWteSzmS4qiururywrtoCcTdrwauzi6LmrDWmlmNu7cBOwKrcvc1s5MIfSBro6IqoBVYCQwHnonK8+4vIiLFV9I+EHdvBe4BPhsVTQNuzbPpeOBLAGY2HqgBngJuifbBzBqBDe4eq/lKREQKK427sE4AppvZE8ABwBkAZnacmc2Mtvk6cJCZPQ78CPi8u7cDlwBbmNk/gIuBo0oevYiIACk8B+Luy4CP5Cm/LOv1SuCgPNtsAI4uZnwiIhKPnkQXEZFElEBERCQRJRAREUlECURERBJRAhERkUSUQEREJBElEBERSUQJREREElECERGRRJRAREQkESUQERFJpKqjoyLmx9gFeF7zgYhUpsWLF9HUtLDgx12+fBkAI0eOKvixARobx9PQMK4ox44jaz6QXYGluetLPpiiiEh/UV9fn3YIqVINRERE8uquBqI+EBERSUQJREREElECERGRRJRAREQkESUQERFJRAlEREQSUQIREZFEKuVBwhoI9zSLiEg8Wd+ZNfnWV0oC2RFg8OBt0o5DRKQc7Qg8m1tYKU+ibwGMAVYDbSnHIiJSLmoIyeOvwL9zV1ZKAhERkQJTJ7qIiCSiBCIiIokogYiISCJKICIikogSiIiIJKIEIiIiiSiBiIhIIkogRWZme5tZh5kdHmPb6Wb2+RLFNcLMbinFufo7M9sleo8Pyilfama7JDjelWY2qof76IGuHjKzI8zsITN71Mz+bmbfSTumcqMEUnzHANcAx8XYdn/CU/NF5+6r3P2QUpyrQrQCV5jZtgU41kcBDdxWRGb2TuAC4BPu/n5gLPA5M5uYbmTlpVLGwkqFmQ0AvgAcANxrZru5+7NmthT4DXAwsA0wDRgMTAQ+ZmargZXAJUAdsD1wgbtfbGb1wK+B3YHngJ2AycBy4MfAx4EO4DfuPsvMPgJ8F3gD2Av4OzAFGAHc7e67mNkU4BTCMC/PRzF/GDid8EW2GyEJtgCTorJD3P1fBf+lla9VwF8IX0rTs1eY2anAZwjDQtwOzABGEf3+o23OjjbfQHhvbjGzA4CHgPuB0YS/o68T3uN3AC8Dh7n7i8X7sfqtoUAtsDWw1t3Xm9nRQIOZ3evu+wNEZR8mvAefJPze3wX82d1PiLbp7P29jfAebQD+G7gMaCR8tjuA7wFfAu5x98ujYy0ATnX3+4v+GygA1UCK61Bgmbv/E7ge+ErWurXu/kHCH9V33f0O4EbgLHe/Hfgy8H13H0O4Iv1BtN9ZgLv7fwHnAO+Lyo8Ddo6WPwgcbmaHRuv2B04kJJCRhMSV7fuEK7F9gaeAPaPyDxFqUP8FHA+scff9gMeAzyX9pfRj3wIOzmnK+iSwL2Estn2AdwJTOzuAu59HSEaHuPvaqPhWdzdgO8J7s7+77wE809WxpHPu/ihwA/CcmT1gZrMICeByYLiZ7RZtejQwJ3q9P3A44TM2wczea2Zdvb8GfMHdDyR8PrchvH/HRNsD/IpwwUbUbLl9uSQPUAIptmOA30Wv/wD8j5kNjJZvi/5/nHBVk+tbwJZmdhohedRF5QcRai+4+4OEL3OAjwFz3L3N3d8Afku4UgV43N1fcPd24Mk855sPLDazHwI3ufsjWfutiI73MnBnVL6MUGOSLO7+KnAsmzdlHUhIxA8BfwP2IyTknrg/Ov4zhL+LL5vZBYRml7qudpTOufvxwC7Azwg1hvsItfm5wBfMbCSwQ9YX+r3u/lr0eXiO8Dnq6v19yd2XRq8PAn7r7h3uvoy3P0t3AyOivrJphNaFsqEmrCIxs+2BQ4D9zOzrhGafwYQrGAjVWghV2Xzt3X8E1hG+3H/P21f8beRP/LllVbz9/m7IKv+P87n7183sl4Qa01VRc8oLwFs5x9yY57ySxd3/bGaZpiwIV7U/dvcLAcxsEOH3OITN34daQj9KPm9G++5LuCC5kNCk2Ib6ShKJaud17v4H4ErgSjM7ltCkdALhAm8Dm3+h5/scdfb+DiV63yJ5P7fu3mFmc4HPE5rBclsH+jTVQIrnC8Cd7r6Tu+/i7qMINYmvdLHPRt7+0j+I0Jx1AzAewMxqCO3sU6Ll9wJ7E/6Y7wKONrMaM9uaUI1e0F2QZjbAzJ4GXnb3cwkfmH16/NNKtm8RvghGEN6Xo8ysLuoTux44AmgGBpvZMDPbgtDUlZH9d5BtPKHf5DLgCeATdDLRj3TrDeDczF1yZlYFvAd4OKohvEBotv1NN8fp7P3N9RdCJ32VmY0APkL43EJoIjsOWOHuq3rxM5WcEkjxHAPMzimbTeif2LKTfe4AvmtmRwBnA01m9jfCl9FSYFdCf8XuZvYYMBN4kXCl83PCH/2jwMPAje5+XXdBuvtGQr/KHWb2IDCOcIUrCWU1ZdUSapB/IjRDPQ48Asx19xbgh4R5Fu4AHsg6xE2ETvRdcw79B+D90Xt/F6H5MncbicHdFxD6EG8yMyf0/dUQPlMQav1PdPeF7u553988m14BvEa4iWUuoRn4zegYKwg3wczpzc+UBs0HUmbM7AvA8+6+OGqjXQjsFvVviEgvRTWJ3wBXu/u1BTrmoUCVu98U3Un5MKG/ZB1hwqaFwN7u/h+TNvVlqoGUn6eAn5jZw8C1wFeUPEQKI2rKWgW0E5qjCuUJ4FQze4SQLM5y91cIfaKPAqeVW/IA1UBERCQh1UBERCQRJRAREUlECURERBLRg4QieZjZxYRbmiE8H/A8bz8YNtbd38y7Y+HjmEMYEeBHpTifSE8ogYjk4e5fy7yOBr+cGg0dIyIRJRCRHoqGehlLuH//McKghkPd/cSs9UPd/cTonv+fAO8lPFh4J/Cd6AHO7GPWEUZfbiA8iX49YTTk7G2+SBjJYCBhHKbz3P1nZjacMILA0GjTm939zM7KC/RrEFEfiEhCo4APuPsXutnuIuChaKTjfQhf5t/Ms91MwggFexGGbm8gGsIGNiWYYwmj9O4DfBY4P1p9LPCcu3+AMOT7u6PE1Vm5SEGoBiKSzH25tYhOfAr4oJl9KVreqpPtDgS+6e5thIH3MuOf/Q9ANF/Fp4BDzezdhCSTGYn3NsLQJyMJw6Kc6u4tZpa3vIc/p0inVAMRSWZ91uvcEY4HZr2uAY5099HuPpow9PeJeY63kbcH18PMdjazIVnLOxHGWRoFNAFnZNa5+18JY2JdThie/AEz27+z8h7+nCKdUgIR6b01wL7RSKvbEEbJzbgd+Ea0bgvCpGH5EsgdhNGUq6PtriGrCYswbtIawiRjtxNqNkSjL58HnOnu1xNmLPwHsEdn5YX6oUWUQER677eEL/engVuAJVnrvkaYie7vhA73v/N230W2cwjzr2RGU74lZyC/PxNGW/ZoHLSR0Tl3J0xlPNrMHgceJNxy/LsuykUKQmNhiYhIIqqBiIhIIkogIiKSiBKIiIgkogQiIiKJKIGIiEgiSiAiIpKIEoiIiCSiBCIiIon8P+R4PEHarEqQAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load sample data\n",
    "sample = load_sample('abaumannii')\n",
    "\n",
    "# Define input arguments\n",
    "key             = sample['key']\n",
    "profiles        = sample['profiles']\n",
    "feature_names   = sample['feature_names']\n",
    "train_ixns      = sample['train']['interactions']\n",
    "train_scores    = sample['train']['scores']\n",
    "test_ixns       = sample['test']['interactions']\n",
    "test_scores     = sample['test']['scores']\n",
    "strains         = sample['orthology']['strains']\n",
    "orthology_map   = sample['orthology']['map']\n",
    "\n",
    "# Determine ML features\n",
    "train_data      = featurize(train_ixns, profiles, feature_names=feature_names, key=key, silent=True)\n",
    "test_data       = featurize(test_ixns, profiles, feature_names=feature_names, key=key, silent=True, \n",
    "                            strains=strains, orthology_map=orthology_map)\n",
    "X_train, X_test = train_data['feature_df'].to_numpy().transpose(), test_data['feature_df'].to_numpy().transpose()\n",
    "\n",
    "# Determine class labels\n",
    "thresh, classes = (-0.5, 0), ('S', 'N', 'A')\n",
    "train_labels    = classify(train_scores, thresholds=thresh, classes=classes)\n",
    "test_labels     = classify(test_scores, thresholds=thresh, classes=classes)\n",
    "\n",
    "# Train and apply a regression-based model\n",
    "reg_model = RandomForestRegressor()\n",
    "reg_model.fit(X_train, train_scores)\n",
    "reg_y = reg_model.predict(X_test)\n",
    "r, p = spearmanr(test_scores, reg_y)\n",
    "r2 = r2_score(test_scores, reg_y)\n",
    "print('Regression results:')\n",
    "print('\\tSpearman R = {}'.format(round(r, 4)))\n",
    "print('\\tSpearman p = {:.3g}'.format(p))\n",
    "print('\\tR2 = {}'.format(round(r2, 4)))\n",
    "\n",
    "# Train and apply a classification-based model\n",
    "class_model = RandomForestClassifier()\n",
    "class_model.fit(X_train, train_labels)\n",
    "class_y = class_model.predict(X_test)\n",
    "print('Classification results:')\n",
    "print(classification_report(test_labels, class_y))\n",
    "\n",
    "# Visualize results\n",
    "df = pd.DataFrame({'x': test_labels, 'y': reg_y})\n",
    "df.replace({'A': 'Antagonism', 'N': 'Neutral', 'S': 'Synergy'}, inplace=True)\n",
    "sns.set(rc={'figure.figsize':(3, 4)})\n",
    "ax = sns.boxplot(x='x', y='y', data=df, order=['Antagonism', 'Neutral', 'Synergy'], )\n",
    "ax.set(title='A. baumannii', xlabel='True class', ylabel='Predicted score')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extension to CARAMeL: using metabolic flux data as drug profiles\n",
    "\n",
    "The following example shows how to use metabolic flux data to define the drug profiles, which was introduced in our [CARAMeL algorithm](https://github.com/sriram-lab/CARAMeL). Here we re-run the *E. coli* example using already saved and normalized metabolic flux data (`ecoli_flux_normalized.csv`), which was simulated using the [COBRA Toolbox](https://opencobra.github.io/cobratoolbox/stable/) in [MATLAB](https://www.mathworks.com/products/matlab.html). Please refer to the [CARAMeL publication](https://doi.org/10.1093/pnasnexus/pgac132) for further details.  \n",
    "- **Training dataset**: 105 two-way interactions between 15 antibiotics  \n",
    "- **Testing dataset**: 66 two-way interactions between the 15 antibiotics in the training set + 4 new antibiotics  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Defining INDIGO features: 100%|██████████| 105/105 [00:00<00:00, 530.35it/s]\n",
      "Defining INDIGO features: 100%|██████████| 66/66 [00:00<00:00, 515.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Regression results:\n",
      "\tSpearman R = 0.7238\n",
      "\tSpearman p = 6.57e-12\n",
      "\tR2 = 0.4797\n",
      "Classification results:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           A       0.53      0.69      0.60        13\n",
      "           N       0.72      0.79      0.75        42\n",
      "           S       0.67      0.18      0.29        11\n",
      "\n",
      "    accuracy                           0.67        66\n",
      "   macro avg       0.64      0.55      0.55        66\n",
      "weighted avg       0.67      0.67      0.64        66\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAANwAAAEXCAYAAADGAipCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAsTAAALEwEAmpwYAAAps0lEQVR4nO3dd1yV5f/H8dfhABJOcmAuUlFyoGLuEQkqLiBHaSpqGmlqJmXiQHOVMzUtB2bZcKYoahqKpuHObxoucmQZhqCIgyXr+v3hz5MnwAPKuRl+no+Hj0f3ONd93afz5r7Pfa77/uiUUgohhCYs8rsDQjxNJHBCaEgCJ4SGJHBCaEgCJ4SGJHBCaOipDpyTkxOenp54e3sb/YuMjMzzbbm5uXHq1ClOnTrFqFGj8rx9U3x8fPjxxx8BCAgI4PTp05r3wVzWrl1LYGBgfncjRyzzuwP57euvv+bZZ5/VbHvOzs4sWrRIs+1l5dChQ/Tu3Ttf+5CXXn/99fzuQo499YF7HGlpacydO5d9+/ah1+txcXHhww8/RKfTMWvWLA4fPoxer6dBgwaMHz+eEiVKGF579OhRpk+fzvbt243aPHr0KB999BG2trYkJiayceNGDhw4wNKlS0lNTcXGxgZ/f39cXFy4dOkSEydOJCUlBaUUvXr1ol+/fixevJi4uDgmT54MkGkaYMGCBcTExDBmzBjmzJlDdHQ0S5cuRafTodfrGTt2LE2bNs20z8uXL2fz5s1YWlri4ODArFmzKFmyJJ9//jk//PADer2e6tWrM2nSJMqXL4+Pjw/16tXjyJEjxMbGMmDAAGJjYzl27BhJSUksXLgQJycnfHx8qFmzJqdPnyYuLg5vb2/DGcCyZcsIDQ3l3r17JCUl4e/vT4cOHVi8eDEnT54kJiYGJycnHBwcDPu5Zs0a1q1bh5WVFcWKFWPatGk4Ojpy4cIFpk2bxq1bt9DpdAwePJhXXnmFo0ePsmDBAqpWrcqFCxdISUlh8uTJtGjRwhwfHVBPsdq1a6tu3bopLy8vw7/hw4ebfN3XX3+t+vXrp5KSklR6erp699131ebNm9Wnn36qRo4cqVJSUlR6eroaN26cmjRpklJKqXbt2qnw8HB15MgR1bVr10xtHjlyRL3wwgsqMjJSKaXU5cuXVbdu3dTNmzeVUkqdP39etW7dWiUkJKjx48er5cuXK6WUiomJUaNHj1bp6elq0aJFaurUqYY2H57u37+/2rlzp1FflFLK3d1dnThxQimlVFhYmFq8eHGmvoWGhqqOHTuqW7duKaWU+vjjj9WSJUvUxo0bVe/evVVCQoJhe4MHDzZsb+TIkUoppU6ePKlq166t9uzZo5RS6qOPPlIBAQGG9Xx9fVVKSoq6ffu28vDwUHv37lWRkZHKx8dHJSUlKaWU2r59u+rWrZthOx4eHio1NdVoP9PS0lS9evVUdHS0UkqpzZs3q3Xr1qnU1FTl7u6uQkJClFJKXbt2TbVt21b9+uuv6siRI6pOnTrq7NmzSimlVq5cqfr162fqI/DYnvoj3OOcUh46dAhvb29sbGwAWLhwIQC9evXCz88PKysr4P73phEjRuS43eeee47KlSsDcPDgQWJiYhg0aJBhuU6n48qVK3To0AF/f3/Cw8Np2bIlAQEBWFg83tfxrl27MnLkSFxdXWndujW+vr6Z1jl8+DCdOnWidOnSAIwfPx6Ad999lx49emBrawvAgAEDWLZsGSkpKQB06NABgKpVqwLQtm1bAKpVq8axY8cM7ffu3RsrKyusrKzo1KkTBw4coF27dsyePZtt27bx119/8dtvv5GQkGB4TaNGjbC0NP746vV6OnXqRJ8+fXj55Zdp3bo1np6eXL58mXv37tGxY0cA7O3t6dixI2FhYTRv3pxKlSpRp04dAOrWrcvmzZsf673Miaf6osnj+u//6Bs3bhATE0NGRobR/IyMDFJTU3Pc7oMP7oPXtmzZkuDgYMO/DRs2UKtWLdq1a0dISAidO3fm3LlzeHp6cuXKFXQ6HeqhobE52bafnx9r166lfv36BAUF0bt370z7odfr0el0huk7d+4QGRlptK0HfU5LSzNMW1tbGy1/8Ifovx5+P5VSWFhYcObMGfr06UN8fDytW7fmzTffNHrNw+/Vw+bNm8eyZcuoVq0aK1asYOTIkZn258F2HvT1wR9OINN7mNckcI+hZcuWbN++nZSUFDIyMpgyZQo//PADbdu2Zd26daSmppKRkcHq1atp3br1Y22jRYsWHDx4kEuXLgGwf/9+vLy8uHfvHu+//z47duyga9eufPjhh5QoUYKoqCjs7Ow4c+YMSikSExM5cOBAlm3r9XrS0tJIS0vDzc2NxMREXn/9dT788EMuXbpkFBqAVq1asXv3buLj44H73w1XrVpFmzZtCAoKIjExEYBvv/2Wpk2bZgqaKVu3biUjI4Pbt2+zc+dO3Nzc+OWXX6hfvz5vvPEGzZo1Y8+ePaSnpz+ynZs3b+Lq6kqZMmUYNGgQo0eP5vfff6d69epYWVmxa9cuAKKjowkJCaFVq1a56mdeeOpPKQcOHJjpdOy9997D1dUVX19f+vTpg7u7u9HyPn36cPXqVXr06IFSimbNmuHj40NaWhqzZ8/mlVdeIS0tjQYNGjBp0qTH6letWrWYNm0a7733HkopLC0tWbp0Kba2tgwfPpyJEyeyfv169Ho97du3p1mzZtStW5ewsDA6duyIvb09Li4uWf61bt++PX5+fsyYMYMJEyYwZswYLC0t0el0fPzxx5kC4+rqysWLFw1XAx0dHZk+fTq2trZERUXx6quvkpGRgYODA/Pmzcv1viYnJ9OrVy8SEhLo27cvLVu2pFatWuzatYsuXbpgZWVFy5YtuX37tiH0WXn22Wd5++23GTRoEDY2Nuj1embMmIGVlRVLlixhxowZLF68mPT0dEaMGEGLFi04evRorvv7JHTKnMdPIUzw8fGhX79+dOrUKb+7ogk5pRRCQ3KEE0JDcoQTQkMSOCE0JIETQkMSOCE0VKh/h4uLSyAjQ675FCQHD4axbNlnpKamUrOmIxMmTKZ48X8Hb+/cuZ1161YbpuPj44mJiSY4eCfFixdn3rzZnDt3/8f7unXrM2aMP8WK2XDx4gXmzp1JcnISoGPYsBG0bHl/UMH+/Xv54ovlWFhYULJkScaNm0SVKlUN24iOvoav7yC++WYtZcrYAXDgwM9Mn/4hFStWNKy3ZMkXFC9e/InfAwsLHXZ2WbdTqK9SxsbGS+AKkLi4OHx8XmPp0pVUrVqNJUsWkZiYyJgx47JcPy0tjREjfOncuRuvvNKTwMAlREdfY+LEKSilmDZtElWrVuPNN4cxcGAfhgwZxksvvcwff1xk6NDB7Nixh4yMdLp2bc+qVWupUqUq69ev5vjxY8yd+ylwP+ArVy7n2rUotm8PpUyZMgAsW/YZtra2DBgwOM/fBwsLHWXLlsh6WZ5vTTy1fvnlCHXq1KVq1WoAdO/ei927d2Y7NvG771ZhZ2fHK6/0BKBRo8YMHDgECwsL9Ho9tWs7ce1aFAArV35H27auAFy9GknJkiWxsLAgPT0DpZRhBEpSUpJhpMyNG9cJC9tvCN/DTp8O53//O87gwf0ZPvxNTp78NW/fjGwU6lNKUbBER0dToYK9Ybp8+QokJCSQmJhgdFoJcOvWLdatW82XX35nmNes2b/3oF27FsWGDWsZO3YicH+As1KK117z5tq1KN599330ej22traMGTOet98eTKlSpcnIyGDp0pUAlCtXno8/nptlX0uVKo2HRxdcXdvx228nGT/+fVatWmPUf3OQwIk8o1TmUfkAFhb6TPO2bg2ibVtXKlWqnGlZRMQ5JkwYQ8+er9G6dVvDfJ1Ox4YNwfzzz1VGjPDl+edrUKaMHatWfcF3331P5cpV+P77dUycOJZVq9YY3eHwXw8HsWHDRtSv34BffjlK165eudnlXJNTSpFn7O0rEht7wzB948Z1SpYsxTPPPJNp3T17dtOli2em+aGhIfj5jWDYsHcM369SU1MJDQ0x3GZTqVJlmjRpxvnzv3P06GGcnRtSuXIVAHr0eJXLly9x+/btbPt59+5dvvnmy/+c6qpMt12ZgwRO5JlmzVpw5sxp/v77CgBbtmwyfO962J07d7h69W+cnRsazf/pp1AWLpzHggWf0bHjv4OZraysWLFiKaGh92+vuXHjOr/+ehwXl8Y4Ob3AyZO/cvNmLABhYft47rlKhosjWbG1tSUo6Hv2798LwPnzEZw9e4bmzc1/u45cpRR56vDhAyxb9jlpaalUrlyFgICp/PPPVWbNmsGqVWsAOHfuDFOmTGT9+i1Gr+3Tpzvx8XcpV66CYZ6zc0Pef9+fS5cuMn/+bBISErCw0NGv3yDc3e/fUb5p0waCgjZgaWlFqVKl8PMbS40aNY3abtOmidFVyoiIsyxYMJfExAT0ektGjXqPxo2b5Ml78KirlBI4IfLYowInF01EjtiVtsbSulh+d6NASEu5R9ztlMd6rQRO5IildTH+N+dN0ys+BV4c+wXweIGTiyZCaMisR7ht27YZHmQ6aNAg+vXrZ1h27tw5xo37d8jPzZs3KV26dKYHpGrh0KEDLF/+GSkpKdSsWYvx4ydl+qH20qWLLFgwh4SEeCws9HzwwQReeKEOCxfO5eTJE4b1btyIoWzZcnz99TouXDjP/PmziI+Pp3jxEvj6vs2LLzY1uc1u3dobXTjo29eHjh07P7I9UTiY7aJJdHQ0r7/+OkFBQVhbW9OnTx/mz5+Po6NjpnWTkpJ49dVXmTJlCk2a5PxKUV5cNMnJ+L/k5GR69/Zm3LhJtGzZhrCwfSxdupg1azYZtRUV9Q8jRvgye/Z8atVyolcvT954w5euXb2Ijb3ByJFv8dlngVhY6LPd5pUrfzJ27HusWxeUqa/ZtVe2bLkneg9yonz5knJK+f9eHPsF16/fzXZ5voylPHToEC1atKBMmTLY2tri4eFhKCbxX8uXL6dp06a5Clteycn4v2PHjlCpUhVatmwDQJs2rkybNitTW7Nnz6B3777UquXErVu3iImJplOnrgCULVuOmjVrcfTo4Udu89SpcPR6C955ZygDB/bhq69WkJ6e/sj2ROFhtsDFxMRQvnx5w3SFChWIjo7OtN6dO3fYsGEDI0eONFdXHulR4/8e+PvvvyhbtiwzZ05jyBAfRo8ekekZiYcPHyQmJppevfoAUKZMGZ57rhI7d94/Rb56NZLw8JPExt545DbT09Np2rQ5n3yymM8+W8GxY4fZtGn9I9sThYfZvsNldaaa1di2bdu20b59e8qWLZvrbWR32M4NW1srbGysKF++JIDhIagVKpQ2PN23WDE9R44c4ptvvqFhw4aEhobi7z+an376yTAyfcuWDbz99jAqVixjaDswcDmzZ88mKGg9Tk5OtGv3MmXKlCAlJSXbbQ4ZMsCof76+b/Ltt98yYsTQbNt70I7QzuO+52YLnL29PcePHzdMx8TEUKFChUzrhYaGMnTo0MfaRl58hytRwo6rV/9nOCe/di2KkiVLkZCQTkLC/XnPPFOKatUcqFSpBtev36Vhw+akpaXx228RPP98deLi4jh58iRTpswyOrePjb3LtGlzDGP03n9/FE2atCQxMTHbbW7atA5Hx9o4OtYC4PbtRDIydFy/fjfb9h71fSKvSKiNFbjvcK1ateLw4cPcvHmTpKQkdu3axUsvvWS0jlKKM2fO4OLiYq5umJST8X8tWrQiKiqKiIhzAP9/75SO556rBMCpU7/xwgv1Mg3SnTPnY8LC9hnWuXz5Ek2aNH/kNv/44xIrVy4jPT2de/eS2bRpg2EIU3bticLDrEO7tm3bxvLly0lNTaVXr174+vri6+vLqFGjcHZ2JjY2Fi8vLw4ePPhY7efV0K6cjP87efJXPv/8U5KTk7Cysubdd8fQsGEjANas+YYrV/5i3Djjx5r/8cdFZs2aQXJyEs88Y4uf31heeKFOttssVao0ycnJzJ8/m7NnT5OWlka7du15663h6HS6R7ZnbnKV8l9PcpVSxlKKHJHA/atA/iwghMisSI+lLFnKBptiWdcke9ok30vl7p3k/O7GU69IB86mmBV9x642veJTYM2cftxFApff5JRSCA1J4ITQkAROCA1J4ITQkAROCA1J4ITQkAROCA1J4ITQkAROCA1J4ITQkAROCA1J4ITQkAROCA1J4ITQkAROCA1J4ITQkAROCA2ZNXDbtm2jS5cudOjQgdWrM995/ccff+Dj44OXlxdDhgx5ZF1mIYoCswUuOjqaBQsWsGbNGoKDg1m/fj0XL140LFdK8fbbb+Pr68vWrVupU6cOgYGB5uqOEAVCvhXzOHPmDLa2toaHww4bNsyonJUQRVG+FfO4cuUK5cqVw9/fH09PTz788EPDs/yFKKryrZhHWloax44d47vvvsPZ2ZmFCxcya9YsZs3KXAYqO3lRzONpIvUB8k6hK+ZRvnx5HBwccHZ2BqBbt26MGjUqV9sw9eRl+YAZe5KiH/JeGitwT142VczDxcWFmzdvEhERAcDevXupV6+eubojRIFg1iOcn58fAwYMMBTzaNCggVExj88//5yAgACSkpKoWLEic+bMMVd3hCgQzPrkZU9PTzw9PY3mrVixwvDfDRs2ZOPGjebsghAFiow0EUJDEjghNCSBE0JDEjghNCSBE0JDEjghNCSBE0JDEjghNCSBE0JDEjghNCSBE0JDEjghNCSBE0JDEjghNCSBE0JDEjghNCSBE0JDEjghNCSBE0JDEjghNJSvxTw+++wz2rVrh7e3N97e3lmuI0RRYrandj0o5hEUFIS1tTV9+vShefPmODo6GtY5ffo08+fPx8XFxVzdEKJAybdiHnA/cCtWrMDT05Np06Zx7949c3VHiALBbEe4rIp5hIeHG6YTEhKoU6cO/v7+VK5cmXHjxrFkyRL8/PxyvA2pLZA78rjyvFPgaguYKuZRvHhxo4fCDh48mAkTJuQqcFJbIHektkDe0bS2QEpKisl17O3tuXHjhmH6v8U8/vnnH6OnLiulsLQ064Oghch32QYuJSWFDRs2EBISYjR/z549dO3a1WTDpop52NjYMHfuXP7++2+UUqxevZoOHTo8wa4IUfBle0iZPHkyFy5cID4+njt37uDm5saECRM4cuQIb775psmGc1LMY9q0abz99tukpqbSuHFj3njjjTzdOSEKmmwDd/z4cXbs2EFcXByjR4/myy+/pHLlymzfvp2qVavmqHFTxTw8PDzw8PB4zK4LUfhkG7jixYtjbW2Nvb09f/75Jz4+PgwfPlzLvglR5GT7He7hK4p2dnYSNiHyQI4CV6xYMU06I0RRl+0p5d9//82wYcMy/fcDy5YtM2/PhCiCsg3cxIkTDf8tFzaEyBvZBq579+5a9kOIp4LcDyeEhiRwQmhIAieEhkyOFk5KSuLHH3/k9u3bRncAyDAsIXLPZOD8/PyIiYmhdu3aRr/NCSFyz2Tg/vjjD3bs2CG3zgiRB0x+h6tYsaIW/RDiqWDysFW7dm0GDBhA27ZtsbGxMcyX73BC5J7JwCUkJODg4MCVK1e06I8QRZrJwM2cOVOLfgjxVDAZuBMnThAYGEhiYiJKKTIyMoiMjGTfvn0adE+IosXkRZOAgABcXFyIj4/H09OTEiVK0LFjRy36JkSRY/IIp9PpeOutt4iLi6NGjRp4eXnx+uuva9E3IYock0e44sWLA1CtWjUuXLhAsWLFSE9Pz1HjpmoLPLBv3z7c3Nxy2GUhCi+TR7gGDRowevRo3n33XYYOHcqff/6JXq832XBOagsA3Lhxg9mzZz/+HghRiJg8wk2YMIFBgwZRvXp1JkyYQEZGBvPmzTPZcE5qC8D974gjR458vN4LUciYDJxOp6Ny5cpcvHiRKlWq8Nprr5GRkWGy4axqC0RHRxut880331C3bl0aNmz4GF0XovDJ0e9wq1evpkSJf5+VrtPpOHz48CNfZ6q2wPnz59m1axerVq3i2rVruemzgRTzyB2pD5B3zFbMY/fu3YSFhWFnZ5erhu3t7Tl+/Lhh+r+1BX788UeuX79Oz549SU1NJSYmhr59+7JmzZocb0OKeeSOFPPIO2Yr5vH8889TqlSpXHfIVG2BUaNGERISQnBwMIGBgVSoUCFXYROiMDJ5hPPx8aF///40b97c6BYdUxc6clJbQIinjcnALV68mLJly3L3bu5PR0zVFnigSpUq7N27N9ftC1HY5OgRC1mFRAiReya/w9WqVYuIiAgt+iJEkWfyCBcTE0OvXr2oXLky1tbWhvnbtm0za8eEKIpMBm7kyJFGQRNCPD6TgZs3bx5btmzRoCtCFH0mv8PZ2Ng89kgQIYSxHF2ldHd3p2LFitja2hrmy3c4IXLPZOAeLlslhHgyJgPXrFkzbt26RVJSEkop0tPT5QleQjwmk4H79NNPCQwMBECv15Oamoqjo6OcUgrxGExeNAkODuann37Cw8ODXbt2MWvWrEx3bQshcsZk4J599lkqVKhAjRo1iIiIwNvbm7/++kuLvglR5JgMnKWlJVeuXKFGjRocP36ctLQ07ty5o0XfhChyTAZu6NChTJo0iZdffpndu3fz8ssv06JFCy36JkSRY/KiSbt27WjXrh0AW7Zs4a+//sLJycnsHROiKDJ5hEtISGDq1KkMHDiQe/fusWbNGhITE7XomxBFjsnAzZgxg1KlShEbG0uxYsWIj49n8uTJWvRNiCLHZODOnTuHn58flpaWPPPMM8ybN49z585p0TchihyTgbOwMF4lPT090zwhRM6YvGjStGlT5s6dS3JyMmFhYaxevZrmzZtr0TchihyTh6oxY8Zga2tLyZIlWbBgAU5OTowdOzZHjZsq5rF79248PT3p2rUr48aNIyUlJfd7IEQhYvIIZ2VlxYgRI+jfvz96vd7oCcyPYqqYR2JiItOmTWPz5s2UK1cOPz8/Nm/eTO/evZ9sj4QowEwe4S5dukTPnj1p1aoVzZo1o3///vzzzz8mGzZVzMPW1pa9e/dSrlw5EhMTiY2NfawHzgpRmJg8wk2YMIFXX32Vnj17opRi/fr1TJw4ka+++uqRr8uqmEd4eLjROlZWVuzfv5+xY8dSoUIF2rRpk6vOS22B3JHHlecds9UWSEpKok+fPoZpHx8fNmzYYLJhU8U8HnB1deXo0aPMnz+fKVOm8Mknn5hs+wGpLZA7Ulsg75ittkDVqlX59ddfDdPnz5+nSpUqJjtkb2/PjRs3DNP/LeZx69YtDhw4YJj29PTk999/N9muEIWZycBFR0fj4+NDjx49eO211+jRowcRERFZPsb8YaaKeSil+OCDDwzfB3fu3Enjxo3zYJeEKLhMnlLm9CeA/8pJMY/p06czdOhQdDodjo6OTJ069bG2JURh8cjAKaVo3LgxlpaWxMfHc+jQIWrXrs3zzz+fo8ZNFfNo37497du3z32vhSiksj2lvHjxIu7u7oSFhZGcnMyrr77KwoULGTx4MAcPHtSyj0IUGdkGbs6cOYwePZp27drxww8/oJRi+/btrF69msWLF2vZRyGKjGwDFxUVhZeXFwBHjx6lffv2WFhY8NxzzxEfH69ZB4UoSrIN3MN3BJw4cYKmTZsapu/du2feXglRRGV70aR06dJEREQQHx/P9evXDYH79ddfsbe316yDQhQl2QbuvffeY9CgQcTHxxvuGFi5ciXLli3j888/17KPQhQZ2QauUaNG/PzzzyQnJxsGFbu4uPD999/n+GcBIYSxR/4OZ21tbVSMUUaCCPFk5FkJQmhIAieEhiRwQmhIAieEhiRwQmhIAieEhiRwQmhIAieEhiRwQmhIAieEhiRwQmjIrIEzVVsgNDQUb29vvLy8GD58OLdv3zZnd4TId2YL3IPaAmvWrCE4OJj169dz8eJFw/L4+HimTJlCYGAgW7duxcnJSR7dIIo8swXOVG2B1NRUpkyZYriZ1cnJiaioKHN1R4gCwWyBy6q2QHR0tGHazs7O8Ii85ORkAgMD5ZF5osgz+SDYx5XT2gJ3795l+PDhvPDCC3Tv3j1X25BiHrkj9QHyjtmKeTwue3t7jh8/bpj+b22BB/OGDBlCixYtmDBhQq63IcU8ckeKeeQdsxXzeFymagukp6czbNgwOnfuzMSJE7M8+glR1Jj1CPeo2gLXrl3j7NmzpKenExISAkD9+vX56KOPzNUlIfKd2QIHj64t4OzsTEREhDk3L0SBIyNNhNCQBE4IDUnghNCQBE4IDUnghNCQBE4IDUnghNCQBE4IDUnghNCQBE4IDUnghNCQBE4IDUnghNCQBE4IDUnghNCQBE4IDUnghNCQBE4IDUnghNCQBE4IDeVrMY8H/P39CQoKMmdXhCgQ8q2Yx4N1hg0bZlRzQIiiLN+KecD9I6C7uzudO3c2VzeEKFDM9lzKrIp5hIeHG63z5ptvAvC///3PXN0QokDJ92IeT0KKeeSO1AfIO4WymMeTkmIeuSPFPPJOoSvmIcTTyGyBe7iYxyuvvEK3bt0MxTxOnTplrs0KUaDlWzGPh82aNcuc3RCiwJCRJkJoSAInhIYkcEJoSAInhIYkcEJoSAInhIYkcEJoSAInhIYkcEJoSAInhIYkcEJoSAInhIYkcEJoSAInhIYkcEJoSAInhIYkcEJoSAInhIYkcEJoSAInhIbytZjHuXPn6NmzJx4eHkycOJG0tDRzdkeIfJevxTw++OADJk2aREhICEopNmzYYK7uCFEgmO0xeQ8X8wAMxTxGjhwJwNWrV0lOTqZRo0YA9OjRg0WLFtG3b98cb8PCwvSj08vZFc9134uqnLxfj2Jdqmwe9aTwe9R7+ahl+VbM47/Ly5cvT3R0dK62YZeDMC0a/0qu2izKnrQWg/Ow2XnUk8Lvcd9Ls51SmirmoUWxDyEKGrM+6vzGjRuG6f8W8/jv8uvXr+d5sQ8hCpp8K+ZRuXJlihUrZqgNt2XLFin2IYo8ncrq3C6PbNu2jeXLl5OamkqvXr3w9fXF19eXUaNG4ezsTEREBAEBASQkJFC3bl1mzpyJtbW1ubojRL4za+CEEMZkpIkQGpLACaEhCZwQGpLACaEhCVwWnJycAIiMjMTJyYmDBw8aLXdzcyMyMpLIyEjq16+Pt7c33t7eeHh4MGrUKKPfFx+0BZCWlsbSpUvp3LkzXbp0wcPDg2XLlhkNAggPD2fAgAF4eHjQtWtXxo8fz82bN828x3kjPj6eqVOn0q1bN7y9vfHx8eHMmTOG5efPn8fJyYmQkBCj1/n4+NChQwfD++ju7s6gQYOM3keAnj17MmzYMKN5QUFBODk5sX37dqP5q1atwsnJicjIyGz7GxkZiZub2+Pu7mORwJlgZWXFpEmTiI+Pz3J5hQoVCA4OJjg4mB9//BEHBwdGjRqV5bpTp04lPDyc9evXs2PHDjZt2sThw4dZs2YNABcvXmT48OEMGzaMkJAQtm3bRvXq1RkwYAD37t0z2z7mhYyMDHx9fSldujRbtmwhODiYESNG4OvrS1xcHHA/HB4eHqxbty7T62fMmGF4H3fv3k2JEiX46quvDMt///13rKysiIiIICoqyui1FStWzBTi3bt3U6pUKTPs6ZORwJlQoUIFWrVqxezZpscR6nQ63nnnHS5cuEBERITRsmvXrrF161ZmzZpl+CCUKFGCyZMnU65cOQC++OILevfuTatWrQCwsLDgrbfewsbGhp07d+bxnuWto0ePEhMTw6hRo7C0vD9Et0WLFsycOZOMjAzS0tLYunUrfn5+nD17litXrmTbVmJiInFxcZQuXdowLygoiNatW+Pu7p7prpKmTZty+vRpEhMTgfsD44sXL07JkiUN6wQGBtK9e3e8vLyYM2dOlkMLtSCBy4Fx48Zx4MCBTKeWWbG2tsbBwYE//vjDaH54eDg1a9Y0+hAB1KxZEw8PDwBOnTpFgwYNMrX54ANVkJ09exZnZ2csLIw/Uq6urpQtW5Z9+/ZRqVIlqlevTvv27TMd5QICAvDy8qJNmzaGPzqDBg0CIDU1la1bt9K5c2c6d+7Mxo0bje6dtLS0pE2bNuzfvx+AnTt30rlzZ8Pyn3/+mdOnT7Nx40a2bNlCdHQ0W7duNdM78WgSuBwoUaIE06dPf+Sp5cN0Oh02NjZZzn/gxx9/xNvbG09PT3r27GlYntVNuKmpqU/Qe21YWFg88qgRFBREt27dAOjSpQubN28mJSXFsHzGjBls3bqVRYsWcfv2bVxdXQ2jjvbv30/58uVxdHTkxRdfxMLCgp9++smo/c6dOxtOK0NDQ2nfvr1h2eHDhwkPD6dHjx50796d06dPZ7o3UysSuBxq06ZNjk4tU1JSuHz5Mo6Ojkbz69Wrx6VLlwyB7dSpE8HBwSxdutTwHadBgwacPHkyU5snTpygfv36ebMjZlK/fn3Onj2bKXTz589nx44d/Pzzz3z55Ze4ubkREBDAnTt32LVrV6Z2GjdujI+PD/7+/oY/Pps2bSIqKgo3Nzfc3d2Jj4/PdIRs3rw54eHhnD9/Hjs7O6PTyfT0dAYOHGj4jvj9999nuviiFQlcLjw4tYyJiclyeUZGBosXL6Zhw4ZUq1bNaFnlypXx8vLC39+fO3fuAPc/CPv27TOchg0dOpRNmzYZTl2VUixZsoTk5GSjU6SCqEmTJpQtW5bPPvuM9PR0AMLCwggKCiIyMpIWLVrw888/s3fvXn766SeGDRvG+vXrs2zrjTfeICkpiXXr1nHjxg0OHjzI9u3b2bt3L3v37mXLli0cOXKEv//+2/AavV5PmzZtmDx5Ml26dDFqr0WLFgQHB5OQkEBaWhojRozIdJFFK2a7AbUoenBqOWTIEMO8mJgYvL29gfuBq1OnDp988kmWr58yZQpfffUVAwYMQClFSkoKjRo1YsWKFQA4ODiwcuVK5s2bx4wZM0hPT+fFF1/k22+/pVixYubfwSeg0+lYsmQJM2fOpFu3blhaWmJnZ0dgYCD+/v74+fkZrd+3b1+++OILLl26lKkta2trRo8ezccff8ytW7dwdXXF3t7esLxq1aq4ubmxfv16atSoYZjfuXNngoODM13qd3NzIyIigtdee4309HTatm1L9+7duXr1Kv/88w8uLi6GdV988UW++OKLvHpbMpHBy0JoSE4phdCQBE4IDUnghNCQBE4IDUnghNCQBK6QSU1NpU2bNkY/TRw9etQwiiMv7NmzhxkzZgCwb98+Pv30U+D+aJGhQ4fm2XaeRvI7XCGze/dunJycOHPmDJcuXaJmzZp5vg13d3fc3d2B++M7b9++nefbeFpJ4AqZtWvX0qVLFxwcHPj666+ZNm2a0fKbN28yfvx4rly5QpkyZShfvjy1atXinXfe4fjx48yZM4ekpCSsrKwYPXo0L730EkFBQWzcuJGkpCRKlChB9+7dCQkJYfjw4axbt4709HRKliyJg4MD169f56233iIqKgq9Xs8nn3xCzZo18fHxoV69ehw5coTY2FgGDBhAbGwsx44dIykpiYULFxrdG/jUUqLQuHDhgqpfv76Ki4tTv/32m2rQoIG6efOmOnLkiOratatSSik/Pz81Z84cpZRS0dHRqnXr1mrRokXq5s2bqmXLlurkyZNKKaXOnz+vmjVrpq5cuaI2bdqkmjZtqu7evauUUmrTpk3qrbfeUkoptWjRIjV16lTD/CZNmqg///xTKaXU9OnT1fjx45VSSvXv31+NHDlSKaXUyZMnVe3atdWePXuUUkp99NFHKiAgQIu3qMCT73CFyNq1a3n55ZcpU6YMDRo0oEqVKpnGI+7fv5/evXsD9+/l69SpE3D/9qBq1arRsGFDAGrVqkXjxo05duwYcP/O9BIlTD8vv0GDBjg4OABQp04do7vRO3ToANwfegXQtm1bAKpVqyanpf9PTikLicTERLZs2UKxYsUMYwXj4+NZvXo1zs7OhvUsLS2NRuw/GBidkZGRqU2lFGlpaVhZWWFra5ujfjy4uRTuj598eFv/fYivlZVVjtp8msgRrpDYtm0bdnZ2hIWFGUbNh4aGkpiYSGxsrGE9V1dXNm7cCEBcXByhoaHodDoaNmzI5cuXDRWMLly4wC+//EKzZs0euV29Xi+FMvOQHOEKibVr1/LGG2+g1+sN80qVKoWPjw9ff/21Yd748eMJCAjA09OTMmXKUKlSJWxsbHj22Wf59NNPmT59OsnJyeh0OmbOnEn16tU5ceJEtttt2bIl77zzDlZWVtSrV8+s+/g0kLsFipjVq1dTt25dXFxcSElJoW/fvrzzzju4urrmd9cEcoQrchwdHZk+fToZGRmkpqbSqVMnCVsBIkc4ITQkF02E0JAETggNSeCE0JAETggNSeCE0JAETggN/R9cbcqo9XwCSQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 216x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load sample data\n",
    "sample = load_sample('ecoli')\n",
    "\n",
    "# Define input arguments\n",
    "key             = sample['key']\n",
    "train_ixns      = sample['train']['interactions']\n",
    "train_scores    = sample['train']['scores']\n",
    "test_ixns       = sample['test']['interactions']\n",
    "test_scores     = sample['test']['scores']\n",
    "\n",
    "# Define drug profile information\n",
    "df              = pd.read_csv('ecoli_flux_normalized.csv')\n",
    "profiles        = df.loc[:, ~df.columns.isin(['rxn', 'rxnName', 'subSystem'])].to_dict('list')\n",
    "feature_names   = df['rxn'].tolist()\n",
    "\n",
    "# Determine ML features\n",
    "train_data      = featurize(train_ixns, profiles, feature_names=feature_names, key=key, silent=True)\n",
    "test_data       = featurize(test_ixns, profiles, feature_names=feature_names, key=key, silent=True)\n",
    "X_train, X_test = train_data['feature_df'].to_numpy().transpose(), test_data['feature_df'].to_numpy().transpose()\n",
    "\n",
    "# Determine class labels\n",
    "thresh, classes = (-0.5, 2), ('S', 'N', 'A')\n",
    "train_labels    = classify(train_scores, thresholds=thresh, classes=classes)\n",
    "test_labels     = classify(test_scores, thresholds=thresh, classes=classes)\n",
    "\n",
    "# Train and apply a regression-based model\n",
    "reg_model = RandomForestRegressor()\n",
    "reg_model.fit(X_train, train_scores)\n",
    "reg_y = reg_model.predict(X_test)\n",
    "r, p = spearmanr(test_scores, reg_y)\n",
    "r2 = r2_score(test_scores, reg_y)\n",
    "print('Regression results:')\n",
    "print('\\tSpearman R = {}'.format(round(r, 4)))\n",
    "print('\\tSpearman p = {:.3g}'.format(p))\n",
    "print('\\tR2 = {}'.format(round(r2, 4)))\n",
    "\n",
    "# Train and apply a classification-based model\n",
    "class_model = RandomForestClassifier()\n",
    "class_model.fit(X_train, train_labels)\n",
    "class_y = class_model.predict(X_test)\n",
    "print('Classification results:')\n",
    "print(classification_report(test_labels, class_y))\n",
    "\n",
    "# Visualize compare INDIGO vs. CARAMeL \n",
    "df = pd.DataFrame({'Algorithm': ['INDIGO', 'CARAMeL'], 'Spearman R': [r1, r]})\n",
    "sns.set(rc={'figure.figsize':(3, 4)})\n",
    "ax = sns.barplot(x='Algorithm', y='Spearman R', data=df, order=['INDIGO', 'CARAMeL'])\n",
    "ax.set(title='E. coli results comparison')\n",
    "ax.bar_label(ax.containers[0])\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "0ac2a46207b2ff734e5406bb8bd0909b0a981f84a860af7db5bce33c6bd25d0b"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
