以前書いた #Julia言語 版HMC(Hamiltonian Monte Carlo)のサンプルコード
ポテンシャル函数φ(x)から、確率分布p(x)=exp(-φ(x))/Zのi.i.d.サンプルを生成する方法の1つ。
60行程度しかない。
#Julia言語 leapfrog法でHamiltonの正準方程式を解くので、leapfrog法のために必要な情報をLFProblem型の変数に格納し、それをHMC函数に渡すと分布p(x)=exp(-φ(x))/Zのサンプルを返してくれる。
そういうシンプルなコードになっています。
nbviewer.org/github/genkuro…
#Julia言語 この手の問題では、ポテンシャル函数φ(x)がパラメータに依存している場合が多いので、ポテンシャル函数はφ(x, param)の形式の函数で与える仕様になっています。
だから、HMC函数および関連の函数にはポテンシャル函数を決めるためのparamを渡す必要があります。
nbviewer.org/github/genkuro…
#Julia言語 Juliaでは(constも含めて)グローバル変数を引数を経由せずに函数の中で使うことは損になります。計算速度が落ちたり、柔軟な試行錯誤の可能性が潰されてしまう。
それを防ぐ最も単純な方法は「グローバル変数を使うときには毎回函数に引数として渡す」です。トートロジカルな方法が良い!
#Julia言語 しかし、複数の函数に毎回引数として渡される変数達が多いと面倒になります。
だから、複数の函数に毎回引数として渡される情報達を1つ(もしくは高々2〜3個)の変数にまとめるとよいです。
そのために自前でstructを定義して利用することが多いです。
#Julia言語 コードは
* 問題を記述するための情報を格納するとオブジェクトのコンストラクタ
* 問題を記述する情報を渡すと問題を解いてくれる函数
の形式で整理すると分かりやすくなることが多いです。
#Julia言語 日本に限らず、
* 問題を記述する情報を沢山のグローバル変数にべた書き
* 長大なmain函数で問題を解く
のようなスタイルのコードを書くスタイルの悪しき教育を受けた人達は多く、上で解説したスタイルなどで整理されたコードの書き方が普及するとよいと思います。
#Julia言語 HMCの短いサンプルコード
nbviewer.org/github/genkuro…
で参考になると思われる点は他にもあります。
まず、JuliaにおけるStaticArrays.jlを用いた計算の効率化。Juliaには多彩な配列の型があり、使いこなすと楽に見易く高速なコードを書けます。続く
#Julia言語 次に、自動微分!
Hamiltonの正準方程式ではポテンシャル函数φ(x, param)のxに関する導函数(gradient)が必要になります。
手計算でgradientを計算して、それを函数として実装して、問題を記述する情報が格納された変数に格納するのは面倒です。続く#
nbviewer.org/github/genkuro…
#Julia言語 φ(x, param)を与えるだけで、そのgradientを丸め誤差を除いて正確にかつ高速に自動的に計算してくれるならば、人間側はgradientを手計算する作業から解放されます。
nbviewer.org/github/genkuro… のサンプルコードでは実際にそれをForwardDiff.jlによる自動微分で実現しています!
#Julia言語 シンプルなサンプルコードだけでどれだけのことができるのか?
例1. φ(x) = (x₁² + x₁x₂ + x₂²)/2の場合の分布p(x)=exp(-φ(x))/Z (これは2次元の正規分布になる)のサンプルの生成。
問題を記述するデータを格納した変数lfを
lf = My.LFProblem(2, φ)
で作っています。
#Julia言語 例2. φ(x) = a(x₁² - 1)² の場合の分布p(x)=exp(-φ(x))/Zのサンプルの生成。a=3,4,5,6,7,8の場合を計算しています。
aが大きいほど、2つの山が強く分離されるようになり、偏りのないサンプル生成が難しくなります。
#Julia言語 例3. 正規分布モデルでのベイズ統計
HMC法はベイズ統計での事後分布のサンプルの構成でも役立つ。
面倒なので事前分布はフラット事前分布にしてしまっています(手抜き)。その場合のポテンシャル函数φ(x)はモデルのパラメータxの対数尤度函数の-1倍になります。
nbviewer.org/github/genkuro…
#Julia言語 サンプルコード
nbviewer.org/github/genkuro…
はJuliaで効率的に計算するための基礎になる「型安定性」と「アロケーションの削減」もきっちり実現しています。
こういう点も参考になると思います。
#Julia言語 ループの内側でrandn(rng, n)で長さnの配列の乱数を生成すると、毎回その分だけにメモリ割当が発生し、計算効率が悪化してしまいます。
サンプルコードでは randn!(rng, vtmp) の形式で事前割当された配列vtmpに乱数を書き込むことによってそれを防いでいます。
nbviewer.org/github/genkuro…
#Julia言語 注意・警告: StaticArrays.jlの使用は長さ30程度の配列を扱う場合程度までは有効ですが、サイズがそれを超えると大幅に遅くなる場合があるので要注意です。
StaticArrays.jlが使えないほど大きな問題では、in-place計算でメモリ割当を削減する必要があります。
Share this Scrolly Tale with your friends.
A Scrolly Tale is a new way to read Twitter threads with a more visually immersive experience.
Discover more beautiful Scrolly Tales like this.