Jaxを使った常微分方程式の数値解析

Jax.numpy使いたいし力学系も解析したい 今のところ、Jax.numpyの速さを実感しつつもメカニズムは理解していない。 が、どうせJax.numpyを使う機会が増えるので、力学系解析もJax.numpyでできるようにする。 ライブラリは、Diffraxを用いる。 https://github.com/patrick-kidger/diffrax ODEだけではなく、SDEも簡単に扱える。今回はODEだけ。 CDE(Controlled Differeitial Equation)も扱えるらしいが、これが何かはまだ分かっていない。 とりあえず、Duffing方程式の時間発展までを行ってコードを確認する。 $$ \begin{aligned} \dot{x} &= v,\ \dot{v} &= x - x^3 - ev + \gamma\cos(\omega t). \end{aligned} $$ importとソルバーに入れる関数は以下。 import jax.numpy as jnp from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt, PIDController def Duffing(t, X, args): x, v = X dx = v dv = x - x**3 - args['e']*v + args['g']*jnp.cos(t) return [dx, dv] 関数の引数は、scipyのodeintと同じで、時間 tと変数 Xとパラメータ argsとする。 パラメータの与え方はいくつか考えられるが、今回は辞書型で与えている。 次に、実際に時間発展させるコードを見る。 ARGS = {'e': 0....

April 29, 2023 · 2 min · 300 words · rKamiura