NumPyroのお勉強

目次 (2023/4/25 ver.0.02)

1番目のタブが選択された!

1.はじめに

もくじ

1.1 NumPyroとは

1.NumPyroとは
  • NumpyroはUberが開発したPPLで、統計モデルを簡単に記述できると同時に、GPUを使った計算の高速化を実現しています。
  • 同じくUberが開発しているPyroと文法が似通っているのも特徴です。
  • では「NumpyroとPyroの違いは何か?」というと、以下のような差異があります。
    • NumPyro → MCMCに特化したPPL
    • Pyro → 変分推論に特化したPPL
2.データセットの作成
  • NumPyは、多次元配列を扱う数値演算ライブラリです。
  • 機械学習だけでなく画像処理、音声処理などコンピュータサイエンスをするならNumPyを学んでおくことで、あなたの日々の研究や開発の基礎力は格段にアップするはずです。

1.2 参考URL

1.NumPyroとは

以上

2.線形回帰

もくじ

2.1 線形回帰

本節は、「 NumPyroによるベイズモデリング入門【線形回帰編】 」を引用している。

1.ライブラリの準備
  • 以下を「import」します。
  • import pandas as pd
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    import jax
    import jax.numpy as jnp
    import numpyro
    sns.set_style("darkgrid")
    
2.単回帰モデル
  • 今回は1番シンプルな、単回帰モデル「$y=ax + b + \epsilon$」をモデリングしていきます。
  • 確率モデル「$L(a,b,\epsilon)$」、事前分布「$p(a,b,\sigma)$」はそれぞれ以下のように設定します。
  • 事前分布
  • $\Large \displaystyle a \sim Normal(0, 100)$

    $\Large \displaystyle b \sim Normal(0, 100)$

    $\Large \displaystyle \sigma \sim LogNormal(0, 10)$

  • 確率モデル
  • $\Large \displaystyle y_i \sim Normal(ax_i+b, \sigma)$

  • これは、以下のようなデータの生成過程をモデル化することに対応しています。
    1. 回帰係数「$a$」、切片項「$a$」、観測の分散パラメータ「$\sigma$」を事前分布からサンプル
    2. それらと「$x_i$」 により確率モデルの形状を決定
    3. 確率モデルから観測「$y_i$」 をサンプル
    4. 2.3.を「$N$」回繰り返す
  • ここではあくまで、なんの学習(推論)も行っていない段階でこのようにモデル化できるだろうと考えているモデルです。
  • このような生成過程を記述する事前分布と確率モデルをNumPyroで書くと以下のようになります。
  • def model(x, y, N=100):
      ## 回帰係数 a、切片項 b の事前分布に平均0, 分散100の正規分布を置きます。
      a = numpyro.sample("a", numpyro.distributions.Normal(loc=jnp.array(0.), scale=jnp.array(100.)))
      b = numpyro.sample("b", numpyro.distributions.Normal(loc=jnp.array(0.), scale=jnp.array(100.)))
      ## 分散パラメータ sigma の事前分布に平均0, 分散10の対数正規分布を置きます。
      sigma = numpyro.sample("sigma", numpyro.distributions.LogNormal(0, 10))
      ## 観測についての確率モデルとして、正規分布を置きます。
      ## yは実際に観測されているものなので、データyと観測モデルを紐づけるために、obs=yと設定しておきます
      with numpyro.plate("data", N):
        numpyro.sample("obs",numpyro.distributions.Normal(a*x + b, sigma), obs=y)
    
  • (jnp.arrayはnp.arrayのようなものであると考えてください。)
  • このようにNumPyroでは、自分で考えた生成過程を直感的にプログラムに落とし込むことができます。
  • 確率モデルの記述には numpyro.plate を用いています。
  • plateは、「$y_i$」のように複数の確率変数をまとめて書くことができる便利な関数です。
3.データの生成
  • 以下のようなモデルからデータを生成します。
  • $\Large \displaystyle y_i \sim Normal(-5x+3, 1.0)$

  • 確率モデルは、
  • $\Large \displaystyle y_i \sim Normal(ax_i+b, \sigma)$

  • と設定していたので、「$a=−5,b=3,\sigma=1.0$」とした場合のデータを生成することになります。
  • 当たり前ですが、「$a=-5$」「$b=3$」「$\sigma=1.0$」は分析者にとって未知の値です。
  • ここでの問題設定ではこの「$a,b,\sigma$」の分布を推論します。
  • 推論結果として事後分布は、それぞれ「$a=-5$」「$b=3$」「$\sigma=1.0$」あたりにピークが来るような分布であると嬉しいです。
  • 今回はデータ数を 5, 10, 50 として、それぞれの挙動を確認してみましょう
4.MCMCによるサンプリング
5.結果
6.予測分布

以上

3.単回帰モデル

もくじ

3.1 モデルを用いた予測

この記事は「 Numpyroでベイズ統計モデリング~単回帰モデル~ 」を参照しています。

1.データの準備・確認
  • 以下のようなビールの売れ行き「sales」と温度「temperature」の1変量による単回帰モデル用のcsvファイル形式のデータを使用する。
  • sales,temperature
    41.68,13.7
    110.99,24
    65.32,21.5
    72.64,13.4
    76.54,28.9
    62.76,28.9
     :
    
  • 以下のコードで「3-2-1-beer-sales-2.csv」の内容が確認できる。
  • file_bear_sales_2 = pd.read_csv("3-2-1-beer-sales-2.csv")
    file_bear_sales_2.head()
    plt.figure(figsize=(10,5))
    sns.scatterplot(x="temperature",y="sales",data=file_bear_sales_2)
    plt.show()
    
2.モデル
  • 単回帰モデルなので非常に単純である。
  • $\Large \displaystyle sales_i=Normal(Intercept+beta*temperature_i,\sigma^2)$

  • 具体的には、以下のように「model」関数を作成する。
  • def model(
        sales,
        temperature
    ):
      Intercept = numpyro.sample("Intercept",dist.Normal(0,100))
      beta = numpyro.sample("beta",dist.Normal(0,100))
      sigma = numpyro.sample("sigma",dist.HalfNormal(100))
    
      numpyro.sample("sales",dist.Normal(Intercept + beta * temperature, sigma),obs = sales)
    
3.推論
  • mcmc.run()は第一引数に再現性を取るためのrandom.PRNGKeyを与え、 残りの引数は modelで指定した引数を与えます。
  • data_dict = {
        "temperature":file_bear_sales_2["temperature"].values,
        "sales":file_bear_sales_2["sales"].values
    }
    kernel = NUTS(model)
    sample_kwargs = dict(
        sampler=kernel, 
        num_warmup=2000, 
        num_samples=2000, 
        num_chains=4, 
        chain_method="parallel"
    )
    mcmc = MCMC(**sample_kwargs)
    mcmc.run(random.PRNGKey(0), **data_dict)
    
    mcmc.print_summary()
    

3.2 事後予測分布

この記事は「 Numpyroでベイズ統計モデリング~事後予測分布~ 」を参照しています。

1.データ、モデル推定
  • モデル定義
  • def model(
        N,
        sales,
        temperature
    ):
      Intercept = numpyro.sample("Intercept",dist.Normal(0,100))
      beta = numpyro.sample("beta",dist.Normal(0,100))
      sigma = numpyro.sample("sigma",dist.HalfNormal(100))
    
      with numpyro.plate("N",N):
        numpyro.sample("sales",dist.Normal(Intercept + beta * temperature, sigma),obs = sales)
    
  • MCMCによる事後分布サンプリング
  • data_dict = {
        "N":len(file_bear_sales_2),
        "temperature":file_bear_sales_2["temperature"].values,
        "sales":file_bear_sales_2["sales"].values
    }
    kernel = NUTS(model)
    sample_kwargs = dict(
        sampler=kernel, 
        num_warmup=2000, 
        num_samples=2000, 
        num_chains=4, 
        chain_method="parallel"
    )
    mcmc = MCMC(**sample_kwargs)
    mcmc.run(random.PRNGKey(0), **data_dict)
    
    mcmc.print_summary()
    
2.事後予測分布の生成
  • 事後分布のサンプリングからパラメータのMCMCサンプルを得るところから。
  • numpyro.infer.Predictive インスタンスの利用と、予測したい説明変数を生成し、作成したモデルで事後予測分布を取得する準備をします。
  • 今回は本紙に則り、気温が11度~30度までの区間の事後予測分布を得ることを目的とします。
  • mcmc_samples=mcmc.get_samples()
    predictive = numpyro.infer.Predictive(model, mcmc_samples)
    temperature_pred = jnp.arange(11,31)
    
  • あとはpredictiveインスタンスの引数に乱数と予測したいモデルの説明変数を与えるだけです。
  • 目的変数となる観測データ(observations)はNoneを指定し、モデルから出力することを明示します。
  • ppc_samples = predictive(random.PRNGKey(0),N = len(temperature_pred), temperature = temperature_pred, sales=None)
    
  • これで事後予測分布が取得できました。
  • arviz による可視化のため、 InferenceDataオブジェクトに変換します。
  • idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)
    
3.可視化による事後予測分布のチェック
  • あとは arvizのプロットに必要な InferenceDataオブジェクトを渡すだけです。
  • 軸ラベルなどが適当ですが、arvizのラベルガイドに基づけば柔軟に対応できそうです。
  • ※読み込み大変なので、ここではデフォルトでご容赦ください。
  • まずは95%ベイズ予測区間の可視化です。
  • az.plot_forest(idata_ppc.posterior_predictive["sales"],
                      var_names=["sales"],
                      hdi_prob=0.95,
                      combined=True,
                      colors='b');
    
    plt.show()
    
  • 軸ラベルが分かりづらいですが、本紙と同じく、気温11度~30度までの各予測分布を可視化できました。
  • 特定の事後予測分布を並列で可視化することも可能です。
  • ここでは、本紙と同じく気温11度と気温30度の事後予測分布を可視化します。
  • az.plot_forest(idata_ppc.posterior_predictive["sales"][:,:,[0,19]],
                      kind="ridgeplot",
                      var_names=["sales"],
                      hdi_prob=0.95,
                      ridgeplot_overlap=0.9,
                      ridgeplot_truncate = False,
                      ridgeplot_quantiles=[.5],
                      combined=True,
                      ridgeplot_alpha=0.5,
                      figsize=(8,5),
                      colors='b');
    
    plt.show()
    
  • 本紙のRにおける bayesplotとは微妙に見た目違いますが、同じものが出力できました。
  • arviz はベイズモデリングの結果解釈でかなり有用になりそうです。
4.おまけ
  • 本紙第3部第3章にはありませんが、回帰直線による予測区間の可視化も可能です。
  • sales_pred =  idata_ppc.posterior_predictive['sales']
    az.plot_hdi(temperature_pred, sales_pred,fill_kwargs={'alpha': 0.3})
    plt.plot(temperature_pred, sales_pred.mean(axis=0).mean(axis=0),color="orange")
    
    sns.scatterplot(x=file_bear_sales_2["temperature"], y=file_bear_sales_2["sales"],s=50,color="gray")
    
    plt.show()
    

以上

4.重回帰モデル

もくじ

4.1 TODO

1.ライブラリのインポート

以上

5.ガウス過程回帰モデル

もくじ

5.1 NumPyro:ガウス過程

本節は、「 NumPyro:ガウス過程 」を引用している。

1.ライブラリのインポート
  • 以下の通りです。
  • import os
    
    import jax
    import jax.numpy as jnp
    from jax import vmap
    from jax import random
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import seaborn as sns
    
    import numpyro
    from numpyro.diagnostics import hpdi
    import numpyro.distributions as dist
    import numpyro.distributions.constraints as constraints
    from numpyro.infer import MCMC, NUTS
    from numpyro.infer import Predictive
    from numpyro.infer.util import initialize_model
    from numpyro.infer import (
        init_to_feasible,
        init_to_median,
        init_to_sample,
        init_to_uniform,
        init_to_value,
    )
    
    import arviz as az
    
    az.style.use("arviz-darkgrid")
    
    assert numpyro.__version__.startswith("0.11.0")
    
    numpyro.enable_x64(True)
    numpyro.set_platform("cpu")
    numpyro.set_host_device_count(1)
    
2.RBFカーネルと事前分布からのサンプリング
  • 今回は代表的なカーネルであるRBFカーネルを使用します。
  • def RBF(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
        # https://docs.pyro.ai/en/stable/_modules/pyro/contrib/gp/kernels/isotropic.html#RBF
        X = jnp.asarray(X)
        Z = jnp.asarray(Z)
        scaled_X = X / length
        scaled_Z = Z / length
        X2 = (scaled_X**2).sum(axis=1, keepdims=True)
        Z2 = (scaled_Z**2).sum(axis=1, keepdims=True)
        XZ = jnp.matmul(scaled_X, scaled_Z.T)
        r2 = X2 - 2 * XZ + Z2.T
        r2 = jnp.clip(r2, a_min=0)
        k = var * jnp.exp(-0.5 * r2)
        
        if include_noise:
            k += (noise + jitter) * jnp.eye(X.shape[0])
        return k
    
    def model_prior(X):
        K = RBF(X, X, var=1, length=0.2, noise=0.)
        numpyro.sample("y", dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=K))
        
    x_sim = np.linspace(-1, 1, 100)
    x_sim = x_sim[:, None]
        
    rng_key = random.PRNGKey(0)
    prior_predictive = Predictive(model_prior, num_samples=20)
    prior_predictions = prior_predictive(rng_key, X=x_sim)["y"]
    
    plt.figure(figsize=(6, 4))
    for i in range(20):
        plt.plot(x_sim[:], prior_predictions[i,:])
    
3.ガウス過程
  • データの準備
    • NumPyroチュートリアル[https://num.pyro.ai/en/latest/examples/gp.html]のデータを使用します。
    # create artificial regression dataset
    def get_data(N=30, sigma_obs=0.15, N_test=400):
        np.random.seed(0)
        X = jnp.linspace(-1, 1, N)
        Y = X + 0.2 * jnp.power(X, 3.0) + 0.5 * jnp.power(0.5 + X, 2.0) * jnp.sin(4.0 * X)
        Y += sigma_obs * np.random.randn(N)
        Y -= jnp.mean(Y)
        Y /= jnp.std(Y)
    
        assert X.shape == (N,)
        assert Y.shape == (N,)
    
        X_test = jnp.linspace(-1.3, 1.3, N_test)
    
        return X, Y, X_test
    
    X_, Y, X_test_ = get_data(N=25)
    X = X_[:, jnp.newaxis]
    X_test = X_test_[:, jnp.newaxis]
    
  • the marginal likelihood GP
    • 数値的により推奨されているコレスキー分解を使用していない形式です。イメージは一番掴みやすいコードになっています。
    def model_marginal_likelihood_GP(X, Y):
        
      var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
      length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
      noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
      
      K = RBF(X, X, var, length, noise)
      numpyro.sample(
          "Y",
          dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=K),
          obs=Y,
      )
    
    rng_key, rng_key_predict = random.split(random.PRNGKey(0))
    
    kernel = NUTS(model_marginal_likelihood_GP, init_strategy=init_to_feasible)
    mcmc = MCMC(
        kernel,
        num_warmup=1000,
        num_samples=1000,
        num_chains=1,
        thinning=2,
    )
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    
                         mean       std    median      5.0%     95.0%     n_eff     r_hat
      kernel_length      0.70      0.23      0.66      0.37      1.00    322.50      1.01
       kernel_noise      0.06      0.02      0.06      0.03      0.09    423.82      1.00
         kernel_var      2.57      4.00      1.36      0.29      5.31    328.78      1.00
    
    Number of divergences: 0
    
    def predict(rng_key, X, Y, X_test, var, length, noise):
        # compute kernels between train and test data, etc.
        k_pp = RBF(X_test, X_test, var, length, noise, include_noise=True)
        k_pX = RBF(X_test, X, var, length, noise, include_noise=False)
        k_XX = RBF(X, X, var, length, noise, include_noise=True)
        K_xx_inv = jnp.linalg.inv(k_XX)
        K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
        sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
            rng_key, X_test.shape[:1]
        )
        mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
        # we return both the mean function and a sample from the posterior predictive for the
        # given set of hyperparameters
        return mean, mean + sigma_noise
    
    # do prediction
    vmap_args = (
        random.split(rng_key_predict, samples["kernel_var"].shape[0]),
        samples["kernel_var"],
        samples["kernel_length"],
        samples["kernel_noise"],
    )
    means, predictions = vmap(
        lambda rng_key, var, length, noise: predict(
            rng_key, X, Y, X_test, var, length, noise
        )
    )(*vmap_args)
    
    mean_prediction = np.mean(means, axis=0)
    percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)
    
    # make plots
    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
    
    # plot training data
    ax.plot(X.ravel(), Y, "kx")
    # plot 90% confidence level of predictions
    ax.fill_between(X_test.ravel(), percentiles[0, :], percentiles[1, :], color="lightblue")
    # plot mean prediction
    ax.plot(X_test.ravel(), mean_prediction, "blue", ls="solid", lw=2.0)
    ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
    

以上

6.自動微分

もくじ

6.1 自動微分とは

1.自動微分とは
  • 自動微分(Automatic Differentiationあるいは Algorithmic Differentiationともいわれ、ADと略される場合が多い)とは、コンピュータープログラムで表現された関数を効率的かつ正確に計算する技術です。
  • もともとは流体力学、原子核工学、気象科学などで使用されていた手法ですが、近年機械学習や金融への応用が注目されています。
  • そこでここでは、自動微分の基礎について紹介します。

以上