2.5. Comparative study of BMM method in bivariate linear module using Coleman toy models.#
The best way to learn Taweret is to use it. You can run, modify and experiment with this notebook using GitHub Codespaces.
The models can be found in Coleman Thesis : https://go.exlibris.link/3fVZCfhl
This notebook shows how to use the Bayesian model mixing methods available in bivariate_linear mixing method of package Taweret for a toy problem.
Author : Dan Liyanage
Date : 08/14/2023
import sys
import os
# You will have to change the following imports depending on where you have
# the packages installed
# ! pip install Taweret # if using Colab, uncomment to install
# Setting Taweret path
cwd = os.getcwd()
# Get the first part of this path and append to the sys.path
tw_path = cwd.split("Taweret/")[0] + "Taweret"
sys.path.append(tw_path)
# For plotting
import matplotlib.pyplot as plt
! pip install seaborn # comment if installed
! pip install ptemcee # comment if installed
import seaborn as sns
sns.set_context('poster')
# To define priors. (uncoment if not using default priors)
# ! pip install bilby # uncomment if not already installed
import bilby
# For other operations
import numpy as np
Requirement already satisfied: seaborn in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (0.13.2)
Requirement already satisfied: numpy!=1.24.0,>=1.20 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from seaborn) (2.4.6)
Requirement already satisfied: pandas>=1.2 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from seaborn) (3.0.3)
Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from seaborn) (3.11.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (4.63.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.5.0)
Requirement already satisfied: packaging>=20.0 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (26.2)
Requirement already satisfied: pillow>=9 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (12.2.0)
Requirement already satisfied: pyparsing>=3 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (3.3.2)
Requirement already satisfied: python-dateutil>=2.7 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from python-dateutil>=2.7->matplotlib!=3.6.1,>=3.4->seaborn) (1.17.0)
Requirement already satisfied: ptemcee in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (1.0.0)
Requirement already satisfied: numpy in /home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages (from ptemcee) (2.4.6)
# Import models with a predict method
from Taweret.models import coleman_models as toy_models
m1 = toy_models.coleman_model_1()
m2 = toy_models.coleman_model_2()
truth = toy_models.coleman_truth()
#!pwd
g = np.linspace(0,9,10)
plot_g = np.linspace(0,9,100)
true_output = truth.evaluate(plot_g)
exp_data = truth.evaluate(g)
2.5.1. 1. The models and the experimental data.#
Truth
\(f(x) = 2-0.1(x-4)^2\), where \(x \in [-1, 9]\)
Model 1
\(f_1(x,\theta)= 0.5(x+\theta)-2\) , where \(\theta \in [1, 6]\)
Model 2
\(f_2(x,\theta)= -0.5(x-\theta) + 3.7\) , where \(\theta \in [-2, 3]\)
Experimental data
sampled from the Truth with a fixed standard deviation of 0.3
sns.set_context('notebook')
fig, axs = plt.subplots(1,2,figsize=(20,5))
prior_ranges = [(1,6), (-2,3)]
for i in range(0,2):
ax = axs.flatten()[i]
ax.plot(plot_g, true_output[0], label='truth', color='black')
ax.errorbar(g,exp_data[0],exp_data[1], fmt='o', label='experimental data', color='r')
ax.legend()
ax.set_ylim(-2,4)
for value in np.linspace(*prior_ranges[i],10):
if i==0:
predict_1 = m1.evaluate(plot_g, value, full_corr=False)
ax.plot(plot_g, predict_1[0])
ax.set_ylabel(r'$f_1(x)$')
if i==1:
predict_2 = m2.evaluate(plot_g, value, full_corr=False)
ax.plot(plot_g, predict_2[0])
ax.set_ylabel(r'$f_2(x)$')
ax.set_xlabel('x')
2.5.2. 2. Choose a Mixing method#
# Mixing method
from Taweret.mix.bivariate_linear import BivariateLinear as BL
models= {'model1':m1,'model2':m2}
mix_model_BMMC_mix = BL(models_dic=models, method='addstepasym', nargs_model_dic={'model1':1, 'model2':1},
same_parameters = False)
mix_model_BMMcor_mix = BL(models_dic=models, method='addstepasym', nargs_model_dic={'model1':1, 'model2':1},
same_parameters = False, BMMcor=True)
mix_model_mean_mix = BL(models_dic=models, method='addstepasym', nargs_model_dic={'model1':1, 'model2':1},
same_parameters = False, mean_mix=True)
mix_models = [mix_model_BMMC_mix, mix_model_BMMcor_mix, mix_model_mean_mix]
## uncoment to change the prior from the default
priors = bilby.core.prior.PriorDict()
priors['addstepasym_0'] = bilby.core.prior.Uniform(0, 9, name="addstepasym_0")
priors['addstepasym_1'] = bilby.core.prior.Uniform(0, 9, name="addstepasym_1")
priors['addstepasym_2'] = bilby.core.prior.Uniform(0, 1, name="addstepasym_2")
for mix_model in mix_models:
mix_model.set_prior(priors)
for mix__model in mix_models:
print(mix_model.prior)
{'addstepasym_0': Uniform(minimum=0, maximum=9, name='addstepasym_0', latex_label='addstepasym_0', unit=None, boundary=None), 'addstepasym_1': Uniform(minimum=0, maximum=9, name='addstepasym_1', latex_label='addstepasym_1', unit=None, boundary=None), 'addstepasym_2': Uniform(minimum=0, maximum=1, name='addstepasym_2', latex_label='addstepasym_2', unit=None, boundary=None), 'model1_0': Uniform(minimum=1, maximum=6, name='model1_0', latex_label='model1_0', unit=None, boundary=None), 'model2_0': Uniform(minimum=-2, maximum=3, name='model2_0', latex_label='model2_0', unit=None, boundary=None)}
{'addstepasym_0': Uniform(minimum=0, maximum=9, name='addstepasym_0', latex_label='addstepasym_0', unit=None, boundary=None), 'addstepasym_1': Uniform(minimum=0, maximum=9, name='addstepasym_1', latex_label='addstepasym_1', unit=None, boundary=None), 'addstepasym_2': Uniform(minimum=0, maximum=1, name='addstepasym_2', latex_label='addstepasym_2', unit=None, boundary=None), 'model1_0': Uniform(minimum=1, maximum=6, name='model1_0', latex_label='model1_0', unit=None, boundary=None), 'model2_0': Uniform(minimum=-2, maximum=3, name='model2_0', latex_label='model2_0', unit=None, boundary=None)}
{'addstepasym_0': Uniform(minimum=0, maximum=9, name='addstepasym_0', latex_label='addstepasym_0', unit=None, boundary=None), 'addstepasym_1': Uniform(minimum=0, maximum=9, name='addstepasym_1', latex_label='addstepasym_1', unit=None, boundary=None), 'addstepasym_2': Uniform(minimum=0, maximum=1, name='addstepasym_2', latex_label='addstepasym_2', unit=None, boundary=None), 'model1_0': Uniform(minimum=1, maximum=6, name='model1_0', latex_label='model1_0', unit=None, boundary=None), 'model2_0': Uniform(minimum=-2, maximum=3, name='model2_0', latex_label='model2_0', unit=None, boundary=None)}
2.5.3. 3. Train to find posterior#
g.shape
(10,)
#from Taweret.utils.utils import normed_mvn_loglike
kwargs_for_sampler = {'sampler':'ptemcee',
'ntemps':5,
'nwalkers':40,
'Tmax':100,
'burn_in_fixed_discard':500,
'nsamples':3000,
'threads':6,
'verbose':False}
#'safety':2,
#'autocorr_tol':5}
import os
import shutil
outdirs = ['outdir/mix_model_1', 'outdir/mix_model_2', 'outdir/mix_model_3']
labels = ['BMMC','BMMcor','BMMmean']
results = []
for i in range(0,3):
mix_model = mix_models[i]
label = labels[i]
outdir = outdirs[i]
if os.path.isdir(outdir):
print('removing '+outdir)
shutil.rmtree(outdir)
else:
print('file does not exist '+outdir)
result = mix_model.train(x_exp=g.reshape(-1,1), y_exp=exp_data[0].reshape(-1,1), y_err=exp_data[1].reshape(-1,1)
,kwargs_for_sampler=kwargs_for_sampler, label=label, outdir=outdir)
results.append(result)
/home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages/bilby/core/likelihood.py:127: FutureWarning: Setting non-trivial parameters for <class 'Taweret.sampler.likelihood_wrappers.likelihood_wrapper_for_bilby'>. This is deprecated behaviour and will be removed in Bilby version 3. See https://bilby-dev.github.io/bilby/parameters for more details.
warnings.warn(msg, FutureWarning)
21:51 bilby INFO : Running for label 'BMMC', output will be saved to 'outdir/mix_model_1'
file does not exist outdir/mix_model_1
/home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages/bilby/core/sampler/ptemcee.py:134: FutureWarning: The ptemcee sampler interface in bilby is deprecated and will be removed in Bilby version 3. Please use the `ptemcee-bilby`sampler plugin instead: https://github.com/bilby-dev/ptemcee-bilby.
warnings.warn(msg, FutureWarning)
21:51 bilby INFO : Analysis priors:
21:51 bilby INFO : addstepasym_0=Uniform(minimum=0, maximum=9, name='addstepasym_0', latex_label='addstepasym_0', unit=None, boundary=None)
21:51 bilby INFO : addstepasym_1=Uniform(minimum=0, maximum=9, name='addstepasym_1', latex_label='addstepasym_1', unit=None, boundary=None)
21:51 bilby INFO : addstepasym_2=Uniform(minimum=0, maximum=1, name='addstepasym_2', latex_label='addstepasym_2', unit=None, boundary=None)
21:51 bilby INFO : model1_0=Uniform(minimum=1, maximum=6, name='model1_0', latex_label='model1_0', unit=None, boundary=None)
21:51 bilby INFO : model2_0=Uniform(minimum=-2, maximum=3, name='model2_0', latex_label='model2_0', unit=None, boundary=None)
21:51 bilby INFO : Analysis likelihood class: <class 'Taweret.sampler.likelihood_wrappers.likelihood_wrapper_for_bilby'>
21:51 bilby INFO : Analysis likelihood noise evidence: nan
21:51 bilby INFO : Single likelihood evaluation took 1.763e-04 s
21:51 bilby INFO : Using sampler Ptemcee with kwargs {'ntemps': 5, 'nwalkers': 40, 'Tmax': 100, 'betas': None, 'a': 2.0, 'adaptation_lag': 10000, 'adaptation_time': 100, 'random': None, 'adapt': False, 'swap_ratios': False}
21:51 bilby INFO : Global meta data was removed from the result object for compatibility. Use the `BILBY_INCLUDE_GLOBAL_METADATA` environment variable to include it. This behaviour will be removed in a future release. For more details see: https://bilby-dev.github.io/bilby/faq.html#global-meta-data
21:51 bilby INFO : Using convergence inputs: ConvergenceInputs(autocorr_c=5, autocorr_tol=50, autocorr_tau=1, gradient_tau=0.1, gradient_mean_log_posterior=0.1, Q_tol=1.02, safety=1, burn_in_nact=50, burn_in_fixed_discard=500, mean_logl_frac=0.01, thin_by_nact=0.5, nsamples=3000, ignore_keys_for_tau=None, min_tau=1, niterations_per_check=5)
21:51 bilby INFO : Generating pos0 samples
21:51 bilby INFO : Starting to sample
21:52 bilby INFO : Finished sampling
21:52 bilby INFO : Writing checkpoint and diagnostics
21:52 bilby INFO : Finished writing checkpoint
21:52 bilby INFO : Sampling time: 0:01:22.584247
21:52 bilby WARNING : Result.save_to_file called with extension=True. This will default to json, and ignore the extension from the filename. This behaviour is deprecated and will be removed.
21:52 bilby WARNING : Result.save_to_file called with extension=True. This will default to json, and ignore the extension from the filename. This behaviour is deprecated and will be removed.
21:52 bilby INFO : Summary of results:
nsamples: 3040
ln_noise_evidence: nan
ln_evidence: -9.076 +/- 2.597
ln_bayes_factor: nan +/- 2.597
/home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages/bilby/core/likelihood.py:127: FutureWarning: Setting non-trivial parameters for <class 'Taweret.sampler.likelihood_wrappers.likelihood_wrapper_for_bilby'>. This is deprecated behaviour and will be removed in Bilby version 3. See https://bilby-dev.github.io/bilby/parameters for more details.
warnings.warn(msg, FutureWarning)
21:52 bilby INFO : Running for label 'BMMcor', output will be saved to 'outdir/mix_model_2'
file does not exist outdir/mix_model_2
21:52 bilby INFO : Analysis priors:
21:52 bilby INFO : addstepasym_0=Uniform(minimum=0, maximum=9, name='addstepasym_0', latex_label='addstepasym_0', unit=None, boundary=None)
21:52 bilby INFO : addstepasym_1=Uniform(minimum=0, maximum=9, name='addstepasym_1', latex_label='addstepasym_1', unit=None, boundary=None)
21:52 bilby INFO : addstepasym_2=Uniform(minimum=0, maximum=1, name='addstepasym_2', latex_label='addstepasym_2', unit=None, boundary=None)
21:52 bilby INFO : model1_0=Uniform(minimum=1, maximum=6, name='model1_0', latex_label='model1_0', unit=None, boundary=None)
21:52 bilby INFO : model2_0=Uniform(minimum=-2, maximum=3, name='model2_0', latex_label='model2_0', unit=None, boundary=None)
21:52 bilby INFO : Analysis likelihood class: <class 'Taweret.sampler.likelihood_wrappers.likelihood_wrapper_for_bilby'>
21:52 bilby INFO : Analysis likelihood noise evidence: nan
21:52 bilby INFO : Single likelihood evaluation took 2.527e-04 s
21:52 bilby INFO : Using sampler Ptemcee with kwargs {'ntemps': 5, 'nwalkers': 40, 'Tmax': 100, 'betas': None, 'a': 2.0, 'adaptation_lag': 10000, 'adaptation_time': 100, 'random': None, 'adapt': False, 'swap_ratios': False}
21:52 bilby INFO : Global meta data was removed from the result object for compatibility. Use the `BILBY_INCLUDE_GLOBAL_METADATA` environment variable to include it. This behaviour will be removed in a future release. For more details see: https://bilby-dev.github.io/bilby/faq.html#global-meta-data
21:52 bilby INFO : Using convergence inputs: ConvergenceInputs(autocorr_c=5, autocorr_tol=50, autocorr_tau=1, gradient_tau=0.1, gradient_mean_log_posterior=0.1, Q_tol=1.02, safety=1, burn_in_nact=50, burn_in_fixed_discard=500, mean_logl_frac=0.01, thin_by_nact=0.5, nsamples=3000, ignore_keys_for_tau=None, min_tau=1, niterations_per_check=5)
21:52 bilby INFO : Generating pos0 samples
21:52 bilby INFO : Starting to sample
21:54 bilby INFO : Finished sampling
21:54 bilby INFO : Writing checkpoint and diagnostics
21:54 bilby INFO : Finished writing checkpoint
21:54 bilby INFO : Sampling time: 0:01:52.086143
21:54 bilby WARNING : Result.save_to_file called with extension=True. This will default to json, and ignore the extension from the filename. This behaviour is deprecated and will be removed.
21:54 bilby WARNING : Result.save_to_file called with extension=True. This will default to json, and ignore the extension from the filename. This behaviour is deprecated and will be removed.
21:54 bilby INFO : Summary of results:
nsamples: 3960
ln_noise_evidence: nan
ln_evidence: 7.023 +/- 5.136
ln_bayes_factor: nan +/- 5.136
/home/runner/work/Taweret/Taweret/.tox/book/lib/python3.13/site-packages/bilby/core/likelihood.py:127: FutureWarning: Setting non-trivial parameters for <class 'Taweret.sampler.likelihood_wrappers.likelihood_wrapper_for_bilby'>. This is deprecated behaviour and will be removed in Bilby version 3. See https://bilby-dev.github.io/bilby/parameters for more details.
warnings.warn(msg, FutureWarning)
21:54 bilby INFO : Running for label 'BMMmean', output will be saved to 'outdir/mix_model_3'
file does not exist outdir/mix_model_3
21:54 bilby INFO : Analysis priors:
21:54 bilby INFO : addstepasym_0=Uniform(minimum=0, maximum=9, name='addstepasym_0', latex_label='addstepasym_0', unit=None, boundary=None)
21:54 bilby INFO : addstepasym_1=Uniform(minimum=0, maximum=9, name='addstepasym_1', latex_label='addstepasym_1', unit=None, boundary=None)
21:54 bilby INFO : addstepasym_2=Uniform(minimum=0, maximum=1, name='addstepasym_2', latex_label='addstepasym_2', unit=None, boundary=None)
21:54 bilby INFO : model1_0=Uniform(minimum=1, maximum=6, name='model1_0', latex_label='model1_0', unit=None, boundary=None)
21:54 bilby INFO : model2_0=Uniform(minimum=-2, maximum=3, name='model2_0', latex_label='model2_0', unit=None, boundary=None)
21:54 bilby INFO : Analysis likelihood class: <class 'Taweret.sampler.likelihood_wrappers.likelihood_wrapper_for_bilby'>
21:54 bilby INFO : Analysis likelihood noise evidence: nan
21:54 bilby INFO : Single likelihood evaluation took 1.718e-04 s
21:54 bilby INFO : Using sampler Ptemcee with kwargs {'ntemps': 5, 'nwalkers': 40, 'Tmax': 100, 'betas': None, 'a': 2.0, 'adaptation_lag': 10000, 'adaptation_time': 100, 'random': None, 'adapt': False, 'swap_ratios': False}
21:54 bilby INFO : Global meta data was removed from the result object for compatibility. Use the `BILBY_INCLUDE_GLOBAL_METADATA` environment variable to include it. This behaviour will be removed in a future release. For more details see: https://bilby-dev.github.io/bilby/faq.html#global-meta-data
21:54 bilby INFO : Using convergence inputs: ConvergenceInputs(autocorr_c=5, autocorr_tol=50, autocorr_tau=1, gradient_tau=0.1, gradient_mean_log_posterior=0.1, Q_tol=1.02, safety=1, burn_in_nact=50, burn_in_fixed_discard=500, mean_logl_frac=0.01, thin_by_nact=0.5, nsamples=3000, ignore_keys_for_tau=None, min_tau=1, niterations_per_check=5)
21:54 bilby INFO : Generating pos0 samples
21:54 bilby INFO : Starting to sample
21:56 bilby INFO : Finished sampling
21:56 bilby INFO : Writing checkpoint and diagnostics
21:56 bilby INFO : Finished writing checkpoint
21:56 bilby INFO : Run interrupted by signal 15: checkpoint and exit on 77
21:56 bilby INFO : Run interrupted by signal 15: checkpoint and exit on 77
21:56 bilby INFO : Sampling time: 0:01:31.564568
21:56 bilby WARNING : Result.save_to_file called with extension=True. This will default to json, and ignore the extension from the filename. This behaviour is deprecated and will be removed.
21:56 bilby WARNING : Result.save_to_file called with extension=True. This will default to json, and ignore the extension from the filename. This behaviour is deprecated and will be removed.
21:56 bilby INFO : Summary of results:
nsamples: 3360
ln_noise_evidence: nan
ln_evidence: -1.374 +/- 3.866
ln_bayes_factor: nan +/- 3.866
posteriors = [0,0,0]
for i in range(0,3):
result = results[i]
label = labels[i]
result = result.posterior.iloc[:,0:-2]
result['model'] = label
posteriors[i]=result
import pandas as pd
df = pd.concat(posteriors, ignore_index=True, sort=False)
df.head(-10)
| addstepasym_0 | addstepasym_1 | addstepasym_2 | model1_0 | model2_0 | model | |
|---|---|---|---|---|---|---|
| 0 | 4.328721 | 3.085605 | 0.858453 | 4.087503 | 1.054924 | BMMC |
| 1 | 4.445307 | 7.286557 | 0.984294 | 4.484985 | 2.324976 | BMMC |
| 2 | 2.821524 | 1.905797 | 0.949112 | 4.795993 | 1.266464 | BMMC |
| 3 | 3.044459 | 7.077850 | 0.848450 | 4.552938 | 1.368657 | BMMC |
| 4 | 3.226768 | 8.468429 | 0.883749 | 4.924694 | 1.034284 | BMMC |
| ... | ... | ... | ... | ... | ... | ... |
| 10345 | 2.345494 | 5.990116 | 0.926905 | 4.653338 | 0.624478 | BMMmean |
| 10346 | 3.264760 | 6.364910 | 0.918606 | 4.938764 | 0.520867 | BMMmean |
| 10347 | 3.631256 | 4.068621 | 0.948427 | 4.379929 | 1.108604 | BMMmean |
| 10348 | 4.216015 | 2.464983 | 0.944545 | 4.370439 | 1.239395 | BMMmean |
| 10349 | 4.138656 | 2.835907 | 0.942286 | 4.335914 | 1.221215 | BMMmean |
10350 rows × 6 columns
df_renamed=df.rename(columns={'addstepasym_0':r'$\beta_0$', 'addstepasym_1':r'$\beta_1$',
'addstepasym_2':r'$\alpha$', 'model1_0':r'$\theta_1$',
'model2_0':r'$\theta_2$', 'model':'method'})
#g.savefig('temp_save')
import seaborn as sns
sns.set_context('paper', font_scale=1.5)
gg = sns.PairGrid(df_renamed, hue="method", diag_sharey=False, hue_kws={'alpha':0.5}, corner=True,
palette={'BMMC':sns.color_palette()[2],'BMMcor':sns.color_palette()[3], 'BMMmean':sns.color_palette()[-1]})
gg.map_lower(sns.kdeplot, fill=True)
gg.map_diag(sns.kdeplot, linewidth=2, fill=True)
gg.add_legend(loc='upper center')
plt.tight_layout()
plt.savefig('comparative_posterior', dpi=100)
2.5.3.1. 4. Predictions#
sns.set_context('paper', font_scale=1.9)
fig, axs = plt.subplots(1,2,figsize=(20,10))
ax, ax2 = axs.flatten()
#fig2, ax2 = plt.subplots(figsize=(10,10))
colors = {'BMMC':sns.color_palette()[2],'BMMcor':sns.color_palette()[3], 'BMMmean':sns.color_palette()[-1]}
for i, mix_model in enumerate(mix_models):
_,mean_prior,CI_prior, _ = mix_model.prior_predict(plot_g, CI=[5,20,80,95])
_,mean,CI, _ = mix_model.predict(plot_g, CI=[5,20,80,95])
per5, per20, per80, per95 = CI
prior5, prior20, prior80, prior95 = CI_prior
# Map value prediction for the step mixing function parameter
model_params = [np.array(mix_model.map[3]), np.array(mix_model.map[4])]
map_prediction = mix_model.evaluate(mix_model.map[0:3], plot_g, model_params=model_params)
print(mix_model.map)
_,_,CI_weights,_=mix_model.predict_weights(plot_g, CI=[5,20, 80, 95])
perw_5, perw_20, perw_80, perw_95 = CI_weights
#ax.fill_between(plot_g,perw_5,perw_95,color=colors[labels[i]], alpha=0.2, label='90% C.I.')
ax.fill_between(plot_g,perw_20,perw_80, color=colors[labels[i]], alpha=0.3, label=labels[i])
if i==0:
ax2.fill_between(plot_g,prior20.flatten(),prior80.flatten(),color=sns.color_palette()[7], alpha=0.2, label='60% C.I. Prior')
ax2.errorbar(g,exp_data[0],yerr=exp_data[1], marker='x', label='experimental data', color='red', fmt='.')
ax2.plot(plot_g, mean_prior.flatten(), label='prior mean')
#ax2.plot(plot_g, mean.flatten(), label=labels[i])
#ax2.fill_between(plot_g,per5.flatten(),per95.flatten(),color=sns.color_palette()[4], alpha=0.2, label='90% C.I.')
ax2.fill_between(plot_g,per20.flatten(),per80.flatten(), color=colors[labels[i]], alpha=0.3, label=labels[i])
ax2.plot(plot_g, map_prediction.flatten(), color=colors[labels[i]], linestyle='dashed')
ax.legend()
ax.set_xlabel('x')
ax.set_ylabel('Model weight (w)')
ax2.set_ybound(-1,4)
ax2.legend(loc='upper center')
ax2.set_xlabel('x')
ax2.set_ylabel('Model output')
ax.set_title('(a)')
ax2.set_title('(b)')
plt.tight_layout()
fig.savefig('comparative_posterior_prditcions', dpi=100)
#fig2.savefig('comparative_posterior_predict', dpi=100)
[3.1784058 7.8851678 0.99982408 5.00752331 1.29121783]
/tmp/ipykernel_4826/1972745161.py:26: UserWarning: marker is redundantly defined by the 'marker' keyword argument and the fmt string "." (-> marker='.'). The keyword argument will take precedence.
ax2.errorbar(g,exp_data[0],yerr=exp_data[1], marker='x', label='experimental data', color='red', fmt='.')
[3.43904458 3.02770927 0.99134379 5.00311711 1.27175777]
[3.56428245 5.85040468 0.99686365 4.97645555 1.24519691]