{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(working_with_InferenceData)=\n", "\n", "# Working with InferenceData\n", "\n", "Here we present a collection of common manipulations you can use while working with `InferenceData`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import numpy as np\n", "import xarray as xr\n", "\n", "xr.set_options(display_expand_data=False, display_expand_attrs=False);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`display_expand_data=False` makes the default view for {class}`xarray.DataArray` fold the data values to a single line. To explore the values, click on the {fas}`database` icon on the left of the view, right under the `xarray.DataArray` text. It has no effect on `Dataset` objects that already default to folded views.\n", "\n", "`display_expand_attrs=False` folds the attributes in both `DataArray` and `Dataset` objects to keep the views shorter. In this page we print DataArrays and Datasets several times and they always have the same attributes." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata = az.load_arviz_data(\"centered_eight\")\n", "idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get the dataset corresponding to a single group" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 165kB\n",
       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
       "Coordinates:\n",
       "  * chain    (chain) int64 32B 0 1 2 3\n",
       "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
       "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
       "Data variables:\n",
       "    mu       (chain, draw) float64 16kB ...\n",
       "    theta    (chain, draw, school) float64 128kB ...\n",
       "    tau      (chain, draw) float64 16kB ...\n",
       "Attributes: (6)
" ], "text/plain": [ " Size: 165kB\n", "Dimensions: (chain: 4, draw: 500, school: 8)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n", " * school (school) \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 181kB\n",
       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
       "Coordinates:\n",
       "  * chain    (chain) int64 32B 0 1 2 3\n",
       "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
       "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
       "Data variables:\n",
       "    mu       (chain, draw) float64 16kB ...\n",
       "    theta    (chain, draw, school) float64 128kB ...\n",
       "    tau      (chain, draw) float64 16kB 4.726 3.909 4.844 ... 2.741 2.932 4.461\n",
       "    log_tau  (chain, draw) float64 16kB 1.553 1.363 1.578 ... 1.008 1.076 1.495\n",
       "Attributes: (6)
" ], "text/plain": [ " Size: 181kB\n", "Dimensions: (chain: 4, draw: 500, school: 8)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n", " * school (school) \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 225kB\n",
       "Dimensions:  (sample: 2000, school: 8)\n",
       "Coordinates:\n",
       "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
       "  * sample   (sample) object 16kB MultiIndex\n",
       "  * chain    (sample) int64 16kB 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3\n",
       "  * draw     (sample) int64 16kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
       "Data variables:\n",
       "    mu       (sample) float64 16kB 7.872 3.385 9.1 7.304 ... 1.767 3.486 3.404\n",
       "    theta    (school, sample) float64 128kB 12.32 11.29 5.709 ... 8.452 1.295\n",
       "    tau      (sample) float64 16kB 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461\n",
       "    log_tau  (sample) float64 16kB 1.553 1.363 1.578 ... 1.008 1.076 1.495\n",
       "Attributes: (6)
" ], "text/plain": [ " Size: 225kB\n", "Dimensions: (sample: 2000, school: 8)\n", "Coordinates:\n", " * school (school) \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 12kB\n",
       "Dimensions:  (sample: 100, school: 8)\n",
       "Coordinates:\n",
       "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
       "  * sample   (sample) object 800B MultiIndex\n",
       "  * chain    (sample) int64 800B 1 2 0 1 2 1 3 3 2 2 1 ... 1 2 1 0 0 1 0 2 1 0 0\n",
       "  * draw     (sample) int64 800B 203 316 58 22 372 214 ... 460 136 37 476 368\n",
       "Data variables:\n",
       "    mu       (sample) float64 800B 6.36 4.445 6.403 ... -0.8143 5.246 1.743\n",
       "    theta    (school, sample) float64 6kB 10.71 0.5876 9.016 ... 5.24 -0.8556\n",
       "    tau      (sample) float64 800B 4.929 3.515 3.592 7.412 ... 1.755 3.332 9.721\n",
       "    log_tau  (sample) float64 800B 1.595 1.257 1.279 ... 0.5626 1.203 2.274\n",
       "Attributes: (6)
" ], "text/plain": [ " Size: 12kB\n", "Dimensions: (sample: 100, school: 8)\n", "Coordinates:\n", " * school (school) \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'school' (school: 8)> Size: 512B\n",
       "'Choate' 'Deerfield' 'Phillips Andover' ... "St. Paul's" 'Mt. Hermon'\n",
       "Coordinates:\n",
       "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
" ], "text/plain": [ " Size: 512B\n", "'Choate' 'Deerfield' 'Phillips Andover' ... \"St. Paul's\" 'Mt. Hermon'\n", "Coordinates:\n", " * school (school) \n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 93kB\n",
             "Dimensions:  (chain: 2, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 16B 0 2\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu       (chain, draw) float64 8kB 7.872 3.385 9.1 ... 2.871 4.096 1.776\n",
             "    theta    (chain, draw, school) float64 64kB 12.32 9.905 ... 2.363 -2.968\n",
             "    tau      (chain, draw) float64 8kB 4.726 3.909 4.844 ... 4.09 2.72 1.917\n",
             "    log_tau  (chain, draw) float64 8kB 1.553 1.363 1.578 ... 1.408 1.001 0.6508\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 69kB\n",
             "Dimensions:  (chain: 2, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 16B 0 2\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 64kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 69kB\n",
             "Dimensions:  (chain: 2, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 16B 0 2\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 64kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 125kB\n",
             "Dimensions:              (chain: 2, draw: 500)\n",
             "Coordinates:\n",
             "  * chain                (chain) int64 16B 0 2\n",
             "  * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499\n",
             "Data variables: (12/16)\n",
             "    max_energy_error     (chain, draw) float64 8kB ...\n",
             "    energy_error         (chain, draw) float64 8kB ...\n",
             "    lp                   (chain, draw) float64 8kB ...\n",
             "    index_in_trajectory  (chain, draw) int64 8kB ...\n",
             "    acceptance_rate      (chain, draw) float64 8kB ...\n",
             "    diverging            (chain, draw) bool 1kB ...\n",
             "    ...                   ...\n",
             "    smallest_eigval      (chain, draw) float64 8kB ...\n",
             "    step_size_bar        (chain, draw) float64 8kB ...\n",
             "    step_size            (chain, draw) float64 8kB ...\n",
             "    energy               (chain, draw) float64 8kB ...\n",
             "    tree_depth           (chain, draw) int64 8kB ...\n",
             "    perf_counter_diff    (chain, draw) float64 8kB ...\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 45kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau      (chain, draw) float64 4kB ...\n",
             "    theta    (chain, draw, school) float64 32kB ...\n",
             "    mu       (chain, draw) float64 4kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 37kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 32kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    scores   (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", " \n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata.sel(chain=[0, 2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Remove the first n draws (burn-in)\n", "\n", "Let's say we want to remove the first 100 samples, from all the chains and all `InferenceData` groups with draws." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 145kB\n",
             "Dimensions:  (chain: 4, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu       (chain, draw) float64 13kB 11.7 8.118 -5.88 ... 1.767 3.486 3.404\n",
             "    theta    (chain, draw, school) float64 102kB 14.23 9.72 ... 6.762 1.295\n",
             "    tau      (chain, draw) float64 13kB 4.289 2.765 2.457 ... 2.741 2.932 4.461\n",
             "    log_tau  (chain, draw) float64 13kB 1.456 1.017 0.8991 ... 1.008 1.076 1.495\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 106kB\n",
             "Dimensions:  (chain: 4, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 102kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 106kB\n",
             "Dimensions:  (chain: 4, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 102kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 197kB\n",
             "Dimensions:              (chain: 4, draw: 400)\n",
             "Coordinates:\n",
             "  * chain                (chain) int64 32B 0 1 2 3\n",
             "  * draw                 (draw) int64 3kB 100 101 102 103 ... 496 497 498 499\n",
             "Data variables: (12/16)\n",
             "    max_energy_error     (chain, draw) float64 13kB ...\n",
             "    energy_error         (chain, draw) float64 13kB ...\n",
             "    lp                   (chain, draw) float64 13kB ...\n",
             "    index_in_trajectory  (chain, draw) int64 13kB ...\n",
             "    acceptance_rate      (chain, draw) float64 13kB ...\n",
             "    diverging            (chain, draw) bool 2kB ...\n",
             "    ...                   ...\n",
             "    smallest_eigval      (chain, draw) float64 13kB ...\n",
             "    step_size_bar        (chain, draw) float64 13kB ...\n",
             "    step_size            (chain, draw) float64 13kB ...\n",
             "    energy               (chain, draw) float64 13kB ...\n",
             "    tree_depth           (chain, draw) int64 13kB ...\n",
             "    perf_counter_diff    (chain, draw) float64 13kB ...\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 36kB\n",
             "Dimensions:  (chain: 1, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau      (chain, draw) float64 3kB ...\n",
             "    theta    (chain, draw, school) float64 26kB ...\n",
             "    mu       (chain, draw) float64 3kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 29kB\n",
             "Dimensions:  (chain: 1, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 26kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    scores   (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata.sel(draw=slice(100, None))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you check the `burnin` object you will see that the groups `posterior`, `posterior_predictive`, `prior` and `sample_stats` have 400 draws compared to `idata` that has 500. The group `observed_data` has not been affected because it does not have the `draw` dimension. Alternatively, you can specify which group or groups you want to change." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 145kB\n",
             "Dimensions:  (chain: 4, draw: 400, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu       (chain, draw) float64 13kB 11.7 8.118 -5.88 ... 1.767 3.486 3.404\n",
             "    theta    (chain, draw, school) float64 102kB 14.23 9.72 ... 6.762 1.295\n",
             "    tau      (chain, draw) float64 13kB 4.289 2.765 2.457 ... 2.741 2.932 4.461\n",
             "    log_tau  (chain, draw) float64 13kB 1.456 1.017 0.8991 ... 1.008 1.076 1.495\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 246kB\n",
             "Dimensions:              (chain: 4, draw: 500)\n",
             "Coordinates:\n",
             "  * chain                (chain) int64 32B 0 1 2 3\n",
             "  * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499\n",
             "Data variables: (12/16)\n",
             "    max_energy_error     (chain, draw) float64 16kB ...\n",
             "    energy_error         (chain, draw) float64 16kB ...\n",
             "    lp                   (chain, draw) float64 16kB ...\n",
             "    index_in_trajectory  (chain, draw) int64 16kB ...\n",
             "    acceptance_rate      (chain, draw) float64 16kB ...\n",
             "    diverging            (chain, draw) bool 2kB ...\n",
             "    ...                   ...\n",
             "    smallest_eigval      (chain, draw) float64 16kB ...\n",
             "    step_size_bar        (chain, draw) float64 16kB ...\n",
             "    step_size            (chain, draw) float64 16kB ...\n",
             "    energy               (chain, draw) float64 16kB ...\n",
             "    tree_depth           (chain, draw) int64 16kB ...\n",
             "    perf_counter_diff    (chain, draw) float64 16kB ...\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 45kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau      (chain, draw) float64 4kB ...\n",
             "    theta    (chain, draw, school) float64 32kB ...\n",
             "    mu       (chain, draw) float64 4kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 37kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 32kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    scores   (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata.sel(draw=slice(100, None), groups=\"posterior\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute posterior mean values along `draw` and `chain` dimensions\n", "\n", "To compute the mean value of the posterior samples, do the following:\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 32B\n",
       "Dimensions:  ()\n",
       "Data variables:\n",
       "    mu       float64 8B 4.486\n",
       "    theta    float64 8B 4.912\n",
       "    tau      float64 8B 4.124\n",
       "    log_tau  float64 8B 1.173
" ], "text/plain": [ " Size: 32B\n", "Dimensions: ()\n", "Data variables:\n", " mu float64 8B 4.486\n", " theta float64 8B 4.912\n", " tau float64 8B 4.124\n", " log_tau float64 8B 1.173" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "post.mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This computes the mean along all dimensions. This is probably what you want for `mu` and `tau`, which have two dimensions (`chain` and `draw`), but maybe not what you expected for `theta`, which has one more dimension `school`. \n", "\n", "You can specify along which dimension you want to compute the mean (or other functions)." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 600B\n",
       "Dimensions:  (school: 8)\n",
       "Coordinates:\n",
       "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
       "Data variables:\n",
       "    mu       float64 8B 4.486\n",
       "    theta    (school) float64 64B 6.46 5.028 3.938 4.872 3.667 3.975 6.581 4.772\n",
       "    tau      float64 8B 4.124\n",
       "    log_tau  float64 8B 1.173
" ], "text/plain": [ " Size: 600B\n", "Dimensions: (school: 8)\n", "Coordinates:\n", " * school (school) \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset> Size: 1MB\n",
       "Dimensions:            (chain: 4, draw: 500, school: 8, school_bis: 8)\n",
       "Coordinates:\n",
       "  * chain              (chain) int64 32B 0 1 2 3\n",
       "  * draw               (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499\n",
       "  * school             (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
       "  * school_bis         (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'\n",
       "Data variables:\n",
       "    mu                 (chain, draw) float64 16kB 7.872 3.385 ... 3.486 3.404\n",
       "    theta              (chain, draw, school) float64 128kB 12.32 9.905 ... 1.295\n",
       "    tau                (chain, draw) float64 16kB 4.726 3.909 ... 2.932 4.461\n",
       "    log_tau            (chain, draw) float64 16kB 1.553 1.363 ... 1.076 1.495\n",
       "    mlogtau            (chain, draw) float64 16kB nan nan nan ... 1.496 1.511\n",
       "    theta_school_diff  (chain, draw, school, school_bis) float64 1MB 0.0 ... 0.0\n",
       "Attributes: (6)
" ], "text/plain": [ " Size: 1MB\n", "Dimensions: (chain: 4, draw: 500, school: 8, school_bis: 8)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499\n", " * school (school) \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500)> Size: 16kB\n",
       "2.415 2.156 -0.04943 1.228 3.384 9.662 ... -1.656 -0.4021 1.524 -3.372 -6.305\n",
       "Coordinates:\n",
       "  * chain       (chain) int64 32B 0 1 2 3\n",
       "  * draw        (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
       "    school      <U16 64B 'Choate'\n",
       "    school_bis  <U16 64B 'Deerfield'
" ], "text/plain": [ " Size: 16kB\n", "2.415 2.156 -0.04943 1.228 3.384 9.662 ... -1.656 -0.4021 1.524 -3.372 -6.305\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n", " school \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500,\n",
       "                                       pairwise_school_diff: 3)> Size: 48kB\n",
       "2.415 -6.741 -1.84 2.156 -3.474 3.784 ... -2.619 6.923 -6.305 1.667 -6.641\n",
       "Coordinates:\n",
       "  * chain       (chain) int64 32B 0 1 2 3\n",
       "  * draw        (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
       "    school      (pairwise_school_diff) <U16 192B 'Choate' ... 'Mt. Hermon'\n",
       "    school_bis  (pairwise_school_diff) <U16 192B 'Deerfield' ... 'Lawrenceville'\n",
       "Dimensions without coordinates: pairwise_school_diff
" ], "text/plain": [ " Size: 48kB\n", "2.415 -6.741 -1.84 2.156 -3.474 3.784 ... -2.619 6.923 -6.305 1.667 -6.641\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n", " school (pairwise_school_diff) \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500, school: 3,\n",
       "                                       school_bis: 3)> Size: 144kB\n",
       "2.415 0.0 -4.581 -4.326 -6.741 -11.32 ... 1.667 -6.077 -5.203 1.102 -6.641\n",
       "Coordinates:\n",
       "  * chain       (chain) int64 32B 0 1 2 3\n",
       "  * draw        (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
       "  * school      (school) <U16 192B 'Choate' 'Hotchkiss' 'Mt. Hermon'\n",
       "  * school_bis  (school_bis) <U16 192B 'Deerfield' 'Choate' 'Lawrenceville'
" ], "text/plain": [ " Size: 144kB\n", "2.415 0.0 -4.581 -4.326 -6.741 -11.32 ... 1.667 -6.077 -5.203 1.102 -6.641\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n", " * school (school) \n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 1MB\n",
             "Dimensions:            (chain: 4, draw: 500, school: 8, school_bis: 8)\n",
             "Coordinates:\n",
             "  * chain              (chain) int64 32B 0 1 2 3\n",
             "  * draw               (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499\n",
             "  * school             (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "  * school_bis         (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu                 (chain, draw) float64 16kB 7.872 3.385 ... 3.486 3.404\n",
             "    theta              (chain, draw, school) float64 128kB 12.32 9.905 ... 1.295\n",
             "    tau                (chain, draw) float64 16kB 4.726 3.909 ... 2.932 4.461\n",
             "    log_tau            (chain, draw) float64 16kB 1.553 1.363 ... 1.076 1.495\n",
             "    mlogtau            (chain, draw) float64 16kB nan nan nan ... 1.496 1.511\n",
             "    theta_school_diff  (chain, draw, school, school_bis) float64 1MB 0.0 ... 0.0\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 36kB\n",
             "Dimensions:     (chain: 4, draw: 500, new_school: 2)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 32B 0 1 2 3\n",
             "  * draw        (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * new_school  (new_school) <U13 104B 'Essex College' 'Moordale'\n",
             "Data variables:\n",
             "    obs         (chain, draw, new_school) float64 32kB 2.041 -2.556 ... -0.2822\n",
             "Attributes: (2)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 246kB\n",
             "Dimensions:              (chain: 4, draw: 500)\n",
             "Coordinates:\n",
             "  * chain                (chain) int64 32B 0 1 2 3\n",
             "  * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499\n",
             "Data variables: (12/16)\n",
             "    max_energy_error     (chain, draw) float64 16kB ...\n",
             "    energy_error         (chain, draw) float64 16kB ...\n",
             "    lp                   (chain, draw) float64 16kB ...\n",
             "    index_in_trajectory  (chain, draw) int64 16kB ...\n",
             "    acceptance_rate      (chain, draw) float64 16kB ...\n",
             "    diverging            (chain, draw) bool 2kB ...\n",
             "    ...                   ...\n",
             "    smallest_eigval      (chain, draw) float64 16kB ...\n",
             "    step_size_bar        (chain, draw) float64 16kB ...\n",
             "    step_size            (chain, draw) float64 16kB ...\n",
             "    energy               (chain, draw) float64 16kB ...\n",
             "    tree_depth           (chain, draw) int64 16kB ...\n",
             "    perf_counter_diff    (chain, draw) float64 16kB ...\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 45kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau      (chain, draw) float64 4kB ...\n",
             "    theta    (chain, draw, school) float64 32kB ...\n",
             "    mu       (chain, draw) float64 4kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 37kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 32kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    scores   (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", " \n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> predictions\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rng = np.random.default_rng(3)\n", "idata.add_groups(\n", " {\"predictions\": {\"obs\": rng.normal(size=(4, 500, 2))}},\n", " dims={\"obs\": [\"new_school\"]},\n", " coords={\"new_school\": [\"Essex College\", \"Moordale\"]},\n", ")\n", "idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Add Transformations to Multiple Groups" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also add transformations to Multiple InferenceData Groups using {meth}`arviz.InferenceData.map`. It takes a function as an input and applies the function groupwise to the selected InferenceData groups and overwrites the group with the result of the function." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 309kB\n",
             "Dimensions:            (draw: 500, school: 8, school_bis: 8)\n",
             "Coordinates:\n",
             "  * draw               (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499\n",
             "  * school             (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "  * school_bis         (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu                 (draw) float64 4kB 5.974 5.096 7.177 ... 4.739 3.146\n",
             "    theta              (draw, school) float64 32kB 9.519 5.554 ... 5.595 3.773\n",
             "    tau                (draw) float64 4kB 4.068 3.156 3.603 ... 3.225 2.979\n",
             "    log_tau            (draw) float64 4kB 1.322 1.118 1.234 ... 1.035 0.9508\n",
             "    mlogtau            (draw) float64 4kB nan nan nan nan ... 1.002 1.01 1.021\n",
             "    theta_school_diff  (draw, school, school_bis) float64 256kB 0.0 ... 0.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 36kB\n",
             "Dimensions:     (chain: 4, draw: 500, new_school: 2)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 32B 0 1 2 3\n",
             "  * draw        (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * new_school  (new_school) <U13 104B 'Essex College' 'Moordale'\n",
             "Data variables:\n",
             "    obs         (chain, draw, new_school) float64 32kB 2.041 -2.556 ... -0.2822\n",
             "Attributes: (2)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 246kB\n",
             "Dimensions:              (chain: 4, draw: 500)\n",
             "Coordinates:\n",
             "  * chain                (chain) int64 32B 0 1 2 3\n",
             "  * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499\n",
             "Data variables: (12/16)\n",
             "    max_energy_error     (chain, draw) float64 16kB ...\n",
             "    energy_error         (chain, draw) float64 16kB ...\n",
             "    lp                   (chain, draw) float64 16kB ...\n",
             "    index_in_trajectory  (chain, draw) int64 16kB ...\n",
             "    acceptance_rate      (chain, draw) float64 16kB ...\n",
             "    diverging            (chain, draw) bool 2kB ...\n",
             "    ...                   ...\n",
             "    smallest_eigval      (chain, draw) float64 16kB ...\n",
             "    step_size_bar        (chain, draw) float64 16kB ...\n",
             "    step_size            (chain, draw) float64 16kB ...\n",
             "    energy               (chain, draw) float64 16kB ...\n",
             "    tree_depth           (chain, draw) int64 16kB ...\n",
             "    perf_counter_diff    (chain, draw) float64 16kB ...\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 45kB\n",
             "Dimensions:  (draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau      (draw) float64 4kB 1.941 3.388 4.208 5.687 ... 0.8353 0.06893 2.145\n",
             "    theta    (draw, school) float64 32kB 4.866 4.59 -0.7404 ... -2.031 6.045\n",
             "    mu       (draw) float64 4kB 3.903 3.915 -1.751 2.595 ... -2.294 0.7908 2.869

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 37kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 32kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    scores   (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> predictions\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "selected_groups = [\"posterior\", \"prior\"]\n", "\n", "def calc_mean(dataset, *args, **kwargs):\n", " result = dataset.mean(dim=\"chain\", *args, **kwargs)\n", " return result\n", "\n", "means = idata.map(calc_mean, groups=selected_groups, inplace=False)\n", "means" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also pass a lambda function in `map`" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 1MB\n",
             "Dimensions:            (chain: 4, draw: 500, school: 8, school_bis: 8)\n",
             "Coordinates:\n",
             "  * chain              (chain) int64 32B 0 1 2 3\n",
             "  * draw               (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499\n",
             "  * school             (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "  * school_bis         (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu                 (chain, draw) float64 16kB 10.87 6.385 ... 6.486 6.404\n",
             "    theta              (chain, draw, school) float64 128kB 15.32 12.91 ... 4.295\n",
             "    tau                (chain, draw) float64 16kB 7.726 6.909 ... 5.932 7.461\n",
             "    log_tau            (chain, draw) float64 16kB 4.553 4.363 ... 4.076 4.495\n",
             "    mlogtau            (chain, draw) float64 16kB nan nan nan ... 4.496 4.511\n",
             "    theta_school_diff  (chain, draw, school, school_bis) float64 1MB 3.0 ... 3.0\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 36kB\n",
             "Dimensions:     (chain: 4, draw: 500, new_school: 2)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 32B 0 1 2 3\n",
             "  * draw        (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * new_school  (new_school) <U13 104B 'Essex College' 'Moordale'\n",
             "Data variables:\n",
             "    obs         (chain, draw, new_school) float64 32kB 2.041 -2.556 ... -0.2822\n",
             "Attributes: (2)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 246kB\n",
             "Dimensions:              (chain: 4, draw: 500)\n",
             "Coordinates:\n",
             "  * chain                (chain) int64 32B 0 1 2 3\n",
             "  * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499\n",
             "Data variables: (12/16)\n",
             "    max_energy_error     (chain, draw) float64 16kB ...\n",
             "    energy_error         (chain, draw) float64 16kB ...\n",
             "    lp                   (chain, draw) float64 16kB ...\n",
             "    index_in_trajectory  (chain, draw) int64 16kB ...\n",
             "    acceptance_rate      (chain, draw) float64 16kB ...\n",
             "    diverging            (chain, draw) bool 2kB ...\n",
             "    ...                   ...\n",
             "    smallest_eigval      (chain, draw) float64 16kB ...\n",
             "    step_size_bar        (chain, draw) float64 16kB ...\n",
             "    step_size            (chain, draw) float64 16kB ...\n",
             "    energy               (chain, draw) float64 16kB ...\n",
             "    tree_depth           (chain, draw) int64 16kB ...\n",
             "    perf_counter_diff    (chain, draw) float64 16kB ...\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 45kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau      (chain, draw) float64 4kB 1.941 3.388 4.208 ... 0.06893 2.145\n",
             "    theta    (chain, draw, school) float64 32kB 4.866 4.59 ... -2.031 6.045\n",
             "    mu       (chain, draw) float64 4kB 3.903 3.915 -1.751 ... 0.7908 2.869\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 37kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 32kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    scores   (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> predictions\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata_shifted_obs = idata.map(lambda x: x + 3, groups=\"posterior\")\n", "idata_shifted_obs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also add extra coordinates using `map`" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 1MB\n",
             "Dimensions:            (chain: 4, draw: 500, school: 8, school_bis: 8)\n",
             "Coordinates:\n",
             "  * chain              (chain) int64 32B 0 1 2 3\n",
             "  * draw               (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499\n",
             "  * school             (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "  * school_bis         (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    mu                 (chain, draw) float64 16kB 7.872 3.385 ... 3.486 3.404\n",
             "    theta              (chain, draw, school) float64 128kB 12.32 9.905 ... 1.295\n",
             "    tau                (chain, draw) float64 16kB 4.726 3.909 ... 2.932 4.461\n",
             "    log_tau            (chain, draw) float64 16kB 1.553 1.363 ... 1.076 1.495\n",
             "    mlogtau            (chain, draw) float64 16kB nan nan nan ... 1.496 1.511\n",
             "    theta_school_diff  (chain, draw, school, school_bis) float64 1MB 0.0 ... 0.0\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8, Upper: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "    upper    (Upper) <U16 512B 'CHOATE' 'DEERFIELD' ... 'MT. HERMON'\n",
             "Dimensions without coordinates: Upper\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 36kB\n",
             "Dimensions:     (chain: 4, draw: 500, new_school: 2)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 32B 0 1 2 3\n",
             "  * draw        (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * new_school  (new_school) <U13 104B 'Essex College' 'Moordale'\n",
             "Data variables:\n",
             "    obs         (chain, draw, new_school) float64 32kB 2.041 -2.556 ... -0.2822\n",
             "Attributes: (2)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 133kB\n",
             "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 32B 0 1 2 3\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 128kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 246kB\n",
             "Dimensions:              (chain: 4, draw: 500)\n",
             "Coordinates:\n",
             "  * chain                (chain) int64 32B 0 1 2 3\n",
             "  * draw                 (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499\n",
             "Data variables: (12/16)\n",
             "    max_energy_error     (chain, draw) float64 16kB ...\n",
             "    energy_error         (chain, draw) float64 16kB ...\n",
             "    lp                   (chain, draw) float64 16kB ...\n",
             "    index_in_trajectory  (chain, draw) int64 16kB ...\n",
             "    acceptance_rate      (chain, draw) float64 16kB ...\n",
             "    diverging            (chain, draw) bool 2kB ...\n",
             "    ...                   ...\n",
             "    smallest_eigval      (chain, draw) float64 16kB ...\n",
             "    step_size_bar        (chain, draw) float64 16kB ...\n",
             "    step_size            (chain, draw) float64 16kB ...\n",
             "    energy               (chain, draw) float64 16kB ...\n",
             "    tree_depth           (chain, draw) int64 16kB ...\n",
             "    perf_counter_diff    (chain, draw) float64 16kB ...\n",
             "Attributes: (6)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 45kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    tau      (chain, draw) float64 4kB 1.941 3.388 4.208 ... 0.06893 2.145\n",
             "    theta    (chain, draw, school) float64 32kB 4.866 4.59 ... -2.031 6.045\n",
             "    mu       (chain, draw) float64 4kB 3.903 3.915 -1.751 ... 0.7908 2.869\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 37kB\n",
             "Dimensions:  (chain: 1, draw: 500, school: 8, Upper: 8)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 8B 0\n",
             "  * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "    upper    (Upper) <U16 512B 'CHOATE' 'DEERFIELD' ... 'MT. HERMON'\n",
             "Dimensions without coordinates: Upper\n",
             "Data variables:\n",
             "    obs      (chain, draw, school) float64 32kB ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 1kB\n",
             "Dimensions:  (school: 8, Upper: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "    upper    (Upper) <U16 512B 'CHOATE' 'DEERFIELD' ... 'MT. HERMON'\n",
             "Dimensions without coordinates: Upper\n",
             "Data variables:\n",
             "    obs      (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset> Size: 576B\n",
             "Dimensions:  (school: 8)\n",
             "Coordinates:\n",
             "  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
             "Data variables:\n",
             "    scores   (school) float64 64B ...\n",
             "Attributes: (4)

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> predictions\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_upper = np.array([\n", " x.upper() for x in idata.observed_data.school.values\n", "]).T \n", "idata_with_upper = idata.map(\n", " lambda ds, **kwargs: ds.assign_coords(**kwargs),\n", " groups=\"observed_vars\",\n", " upper=(\"Upper\", _upper),\n", ")\n", "idata_with_upper" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }