{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Ocean salinity prediction based on marin microbiome data\n\nWe repoduce an example of prediction of ocean salinity over ocean microbiome data\nthat has been introduced in `this article <https://www.biorxiv.org/content/10.1101/2020.09.01.277632v1.full>`_,\nwhere the R package `trac <https://github.com/jacobbien/trac>`_ (which uses c-lasso)\nhas been used. \n\nThe data come originally from `trac <https://github.com/jacobbien/trac>`_,\nthen it is preprocessed in python in this `notebook <https://github.com/Leo-Simpson/c-lasso/examples/Tara/preprocess>`_.\n\n\n\nBien, J., Yan, X., Simpson, L. and M\u00fcller, C. (2020).\nTree-Aggregated Predictive Modeling of Microbiome Data :\n\n\"Integrative marine data collection efforts such as Tara Oceans (Sunagawa et al., 2020)\nor the Simons CMAP (https://simonscmap.com)\nprovide the means to investigate ocean ecosystems on a global scale.\nUsing Tara\u2019s environmental and microbial survey of ocean surface water (Sunagawa, 2015),\nwe next illustrate how trac can be used to globally connect environmental covariates\nand the ocean microbiome. As an example, we learn a global predictive model of ocean salinity\nfrom n = 136 samples and p = 8916 miTAG OTUs (Logares et al., 2014).\ntrac identifies four taxonomic aggregations,\nthe kingdom bacteria and the phylum Bacteroidetes being negatively associated\nand the classes Alpha and Gammaproteobacteria being positively associated with marine salinity.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import sys, os\nfrom os.path import join, dirname, abspath\n\nclasso_dir = dirname(dirname(abspath(\"__file__\")))\nsys.path.append(classo_dir)\nfrom classo import classo_problem\nimport matplotlib.pyplot as plt\nimport numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load data\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data_dir = join(classo_dir, \"examples/Tara\")\n\ndata = np.load(join(data_dir, \"tara.npz\"))\n\nx = data[\"x\"]\nlabel = data[\"label\"]\ny = data[\"y\"]\ntr = data[\"tr\"]\n\nA = np.load(join(data_dir, \"A.npy\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Preprocess: taxonomy aggregation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "label_short = np.array([l.split(\"::\")[-1] for l in label])\n\npseudo_count = 1\nX = np.log(pseudo_count + x)\nnleaves = np.sum(A, axis=0)\nlogGeom = X.dot(A) / nleaves"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Cross validation and Path Computation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "problem = classo_problem(logGeom[tr], y[tr], label=label_short)\n\nproblem.formulation.w = 1 / nleaves\nproblem.formulation.intercept = True\nproblem.formulation.concomitant = False\n\nproblem.model_selection.StabSel = False\nproblem.model_selection.PATH = True\nproblem.model_selection.CV = True\nproblem.model_selection.CVparameters.seed = (\n    6  # one could change logscale, Nsubset, oneSE\n)\nprint(problem)\n\nproblem.solve()\nprint(problem.solution)\n\n\nselection = problem.solution.CV.selected_param[1:]  # exclude the intercept\nprint(label[selection])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Prediction plot\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "te = np.array([i for i in range(len(y)) if not i in tr])\nalpha = problem.solution.CV.refit\nyhat = logGeom[te].dot(alpha[1:]) + alpha[0]\n\nM1, M2 = max(y[te]), min(y[te])\nplt.plot(yhat, y[te], \"bo\", label=\"sample of the testing set\")\nplt.plot([M1, M2], [M1, M2], \"k-\", label=\"identity\")\nplt.xlabel(\"predictor yhat\"), plt.ylabel(\"real y\"), plt.legend()\nplt.tight_layout()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Stability selection\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "problem = classo_problem(logGeom[tr], y[tr], label=label_short)\n\nproblem.formulation.w = 1 / nleaves\nproblem.formulation.intercept = True\nproblem.formulation.concomitant = False\n\n\nproblem.model_selection.PATH = False\nproblem.model_selection.CV = False\n# can change q, B, nS, method, threshold etc in problem.model_selection.StabSelparameters\n\nproblem.solve()\n\nprint(problem, problem.solution)\n\nselection = problem.solution.StabSel.selected_param[1:]  # exclude the intercept\nprint(label[selection])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Prediction plot\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "te = np.array([i for i in range(len(y)) if not i in tr])\nalpha = problem.solution.StabSel.refit\nyhat = logGeom[te].dot(alpha[1:]) + alpha[0]\n\nM1, M2 = max(y[te]), min(y[te])\nplt.plot(yhat, y[te], \"bo\", label=\"sample of the testing set\")\nplt.plot([M1, M2], [M1, M2], \"k-\", label=\"identity\")\nplt.xlabel(\"predictor yhat\"), plt.ylabel(\"real y\"), plt.legend()\nplt.tight_layout()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.1"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}