{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# pH prediction using the Central Park soil dataset \n\n\nThe next microbiome example considers the [Central Park Soil dataset](./examples/CentralParkSoil) from [Ramirez et al.](https://royalsocietypublishing.org/doi/full/10.1098/rspb.2014.1988). The sample locations are shown in the Figure on the right.)\n\nThe task is to predict pH concentration in the soil from microbial abundance data.\n\nThis task is also done in `Tree-Aggregated Predictive Modeling of Microbiome Data <https://www.biorxiv.org/content/10.1101/2020.09.01.277632v1>`_.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Import the package\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import sys, os\nfrom os.path import join, dirname, abspath\n\nclasso_dir = dirname(dirname(abspath(\"__file__\")))\nsys.path.append(classo_dir)\n\nfrom classo import classo_problem\nimport matplotlib.pyplot as plt\nimport numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load data\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data_dir = join(classo_dir, \"examples/CentralParkSoil\")\ndata = np.load(join(data_dir, \"cps.npz\"))\n\nx = data[\"x\"]\nlabel = data[\"label\"]\ny = data[\"y\"]\n\nA = np.load(join(data_dir, \"A.npy\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Preprocess: taxonomy aggregation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "label_short = np.array([l.split(\"::\")[-1] for l in label])\n\npseudo_count = 1\nX = np.log(pseudo_count + x)\nnleaves = np.sum(A, axis=0)\nlogGeom = X.dot(A) / nleaves\n\nn, d = logGeom.shape\n\ntr = np.random.permutation(n)[: int(0.8 * n)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Cross validation and Path Computation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "problem = classo_problem(logGeom[tr], y[tr], label=label_short)\n\nproblem.formulation.w = 1 / nleaves\nproblem.formulation.intercept = True\nproblem.formulation.concomitant = False\n\nproblem.model_selection.StabSel = False\nproblem.model_selection.PATH = True\nproblem.model_selection.CV = True\nproblem.model_selection.CVparameters.seed = (\n    6  # one could change logscale, Nsubset, oneSE\n)\nprint(problem)\n\nproblem.solve()\nprint(problem.solution)\n\nselection = problem.solution.CV.selected_param[1:]  # exclude the intercept\nprint(label[selection])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Prediction plot\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "te = np.array([i for i in range(len(y)) if not i in tr])\nalpha = problem.solution.CV.refit\nyhat = logGeom[te].dot(alpha[1:]) + alpha[0]\n\nM1, M2 = max(y[te]), min(y[te])\nplt.plot(yhat, y[te], \"bo\", label=\"sample of the testing set\")\nplt.plot([M1, M2], [M1, M2], \"k-\", label=\"identity\")\nplt.xlabel(\"predictor yhat\"), plt.ylabel(\"real y\"), plt.legend()\nplt.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Stability selection\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "problem = classo_problem(logGeom[tr], y[tr], label=label_short)\n\nproblem.formulation.w = 1 / nleaves\nproblem.formulation.intercept = True\nproblem.formulation.concomitant = False\n\n\nproblem.model_selection.PATH = False\nproblem.model_selection.CV = False\n# can change q, B, nS, method, threshold etc in problem.model_selection.StabSelparameters\n\nproblem.solve()\n\nprint(problem, problem.solution)\n\nselection = problem.solution.StabSel.selected_param[1:]  # exclude the intercept\nprint(label[selection])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Prediction plot\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "te = np.array([i for i in range(len(y)) if not i in tr])\nalpha = problem.solution.StabSel.refit\nyhat = logGeom[te].dot(alpha[1:]) + alpha[0]\n\nM1, M2 = max(y[te]), min(y[te])\nplt.plot(yhat, y[te], \"bo\", label=\"sample of the testing set\")\nplt.plot([M1, M2], [M1, M2], \"k-\", label=\"identity\")\nplt.xlabel(\"predictor yhat\"), plt.ylabel(\"real y\"), plt.legend()\nplt.tight_layout()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.1"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}