1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import orthojax as ojax ### easy orthogonal polynomials
from typing import List, Callable, Tuple
import equinox as eqx
import matplotlib.pyplot as plt
import jax
from jax import random as jr, numpy as jnp
from jaxtyping import Float, Array
from jax.scipy.stats import norm, multivariate_normal
import optax
from layers import NormalizingFlow
from functools import partial
import diffrax
from abc import ABC, abstractmethod
import numpy as np
key = jr.PRNGKey(42)
Problem Setup
We’re going to take a look at solving the 1D viscous Burgers’ equation
\[\frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} = \nu \frac{\partial^2 u}{\partial x^2}\]Where $u(x,t)$ is the velocity field and $\nu$ is the viscosity parameter.
We view this as a forward model with
\[y = \mathcal{G}(\nu) + \epsilon, \quad \epsilon \sim \mathcal{N}(0,0.1 * \mathbb{I}), \quad \nu \sim \mathcal{N}(1,0.5)\]And $\mathcal{G}$ as the solution operator which maps the viscosity to the solution $u(x,T;\nu)$ at time $T=0.3$ for a fixed initial condtion.
As such we have our prior as $\nu \sim \mathcal{N}(1,0.5)$, our likelihood as $p(y|\nu)$ and our posterior as $p(\nu|y)$. We aim to learn the posterior of a quantity of interest which is a pushforward of the posterior. In other words, a conditional distribution
\[p(\text{qoi}|y) = \int p(\text{qoi}|x,y) p(x|y) dx\]We take the $\text{qoi}$ to be the energy of the solution at a fixed time $u(x,t=0.3)$, which can be computed by taking the squared $\ell_2$ norm of the vector of coefficients associated with projecting it onto a sine basis. Our solver is coefficient-based anyways so this not an inconvienence.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from burgers_solver import Viscous_Burgers_Spectral_Galerkin
n_prior_samples = 1000
np.random.seed(0)
n_modes = 100
### generate fixed initial condition
U0 = np.zeros([1, n_modes])
U0[:,1] = 1
### final time step, prior over nu, time step size
t = 0.3
nu = np.random.uniform(0.1, 0.5, size=(n_prior_samples))
dt = 0.001
### solve for all nu samples
def solve(nu):
U_soln = Viscous_Burgers_Spectral_Galerkin(nu, dt, int(t/dt), U0) ### 1, n_modes
return U_soln
U_solns = []
for nu in nu:
U_solns.append(solve(nu).squeeze())
U_soln = np.asarray(U_solns) ### n_prior_samples, n_modes
epsilon = np.random.randn(n_prior_samples, n_modes) * 0.1
y = U_soln + epsilon
qoi = energy = np.sum(U_soln**2, axis=-1)[:,None] ### (n_prior_samples,1)
qoi_dim = z_dim = 1
y_dim = n_modes
Base Distribution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class ConditionalDiagonalGaussianBaseDensity(eqx.Module):
mu: eqx.Module
Sig: eqx.Module
def __init__(self, *, y_dim, z_dim, mu_func Sigma_func, key):
keys = jr.split(key)
self.mu = mu_funct(in_size=y_dim, out_size=z_dim, key=keys[0])
self.Sig = Sigma_func(in_size=y_dim, out_size=z_dim, key=keys[1])
def logpdf(self, z, y):
return multivariate_normal.logpdf(z, mean=self.mu(y), cov=jnp.diag(jax.nn.softplus(self.Sig(y))))
def rvs(self, y, key, shape):
return jr.multivariate_normal(key, mean=self.mu(y), cov=jnp.diag(jax.nn.softplus(self.Sig(y))), shape=shape)
### ... or not
class FixedStandardNormal(eqx.Module):
dim: int
def __init__(self, *, dim,):
self.dim = dim
def logpdf(self, z, y):
return multivariate_normal.logpdf(z, mean=jnp.zeros((self.dim,)), cov=jnp.eye(self.dim))
def rvs(self, y, key, shape):
return jr.multivariate_normal(key, mean=jnp.zeros((self.dim,)), cov=jnp.eye(self.dim), shape=shape)
key,_ = jr.split(key)
m_function = partial(eqx.nn.MLP, depth=1, width_size=16, activation=jax.nn.tanh)
S_function = partial(eqx.nn.MLP, depth=1, width_size=16, activation=jax.nn.tanh)
base_dist = ConditionalDiagonalGaussianBaseDensity(y_dim=y_dim,
z_dim=z_dim,
mu_function=m_function,
Sig_function=S_function,
key=key)
Now Define Vector Field
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
### Base layer
class ConcatSquash(eqx.Module):
lin1: eqx.nn.Linear
lin2: eqx.nn.Linear
lin3: eqx.nn.Linear
def __init__(self, *, in_size, out_size, key):
keys = jr.split(key, 3)
self.lin1 = eqx.nn.Linear(in_size, out_size, key=keys[0])
self.lin2 = eqx.nn.Linear(1, out_size, key=keys[1])
self.lin3 = eqx.nn.Linear(1, out_size, use_bias=False, key=keys[2])
def __call__(self, t, x):
return self.lin1(x) * jax.nn.sigmoid(self.lin2(t)) + self.lin3(t)
### Compositional
class ConditionalVectorField(eqx.Module):
layers: List[eqx.Module]
def __init__(self, *, base_layer, data_size, cond_size, width_size, depth, key):
layers = []
keys = jr.split(key, depth+1)
if depth == 0: layers.append(base_layer(in_size=data_size+cond_size, out_size=data_size, key=keys[0]))
else:
layers.append(base_layer(in_size=data_size+cond_size, out_size=width_size, key=keys[0]))
[layers.append(base_layer(in_size=width_size+cond_size, out_size=width_size, key=k)) for k in keys[1:-1]]
layers.append(base_layer(in_size=width_size+cond_size, out_size=data_size, key=keys[-1]))
self.layers = layers
def __call__(self, t, x, args):
y = args[0]
t = jnp.asarray(t)[None] ### make t shape (1,) vs scalar
for layer in self.layers[:-1]:
xy = jnp.concatenate((x,y))
x = layer(t, xy)
x = jax.nn.tanh(x)
xy = jnp.concatenate((x,y))
x = self.layers[-1](t, xy)
return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def approx_logp_wrapper(t, x, args):
x, _ = x
*args, eps, func = args
fn = lambda x: func(t, x, args=args)
f, vjp_fn = jax.vjp(fn, x)
(eps_dfdx,) = vjp_fn(eps) ### e^T @ \nabla_x F
logp = jnp.sum(eps_dfdy * eps)
return f, logp
def exact_logp_wrapper(t, x, args):
x, _ = x
*args, func = args
fn = lambda x: func(t, x, args=args)
f, vjp_fn = jax.vjp(fn, x)
(size,) = x.shape # this implementation only works for 1D input
eye = jnp.eye(size)
(dfdx,) = jax.vmap(vjp_fn)(eye)
logp = jnp.trace(dfdx)
return f, logp
Model which does the solve!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
adjoint = diffrax.RecursiveCheckpointAdjoint()
class ConditionalCNF(eqx.Module):
Func: eqx.Module
t0: float
t1: float
dt0: float
ODETerm: diffrax.ODETerm
solver: diffrax.AbstractSolver
base_dist: Callable
def __init__(self, *, Func, t0, t1, dt0, exact_logp, base_dist, key):
key,_ = jr.split(key)
self.Func = Func(key=key)
self.t0, self.t1, self.dt0 = t0, t1, dt0
if exact_logp: self.ODETerm = diffrax.ODETerm(exact_logp_wrapper)
else: self.ODETerm = diffrax.ODETerm(approx_logp_wrapper)
self.solver = diffrax.Tsit5()
self.base_dist = base_dist
def train(self, x, y, key):
eps = jr.normal(key, x.shape)
delta_log_likelihood = 0.0
x = (x, delta_log_likelihood)
sol = diffrax.diffeqsolve(
self.ODETerm,
self.solver,
self.t1,
self.t0,
-self.dt0,
x,
args=(y, eps, self.Func),
adjoint=adjoint,
)
(x,), (delta_log_likelihood,) = sol.ys
logp_x = delta_log_likelihood + self.base_dist.logpdf(x, y)
return x, logp_x
def sample(self, key, y):
x = self.base_dist.rvs(y, key=key, shape=())
term = diffrax.ODETerm(self.Func)
solver = diffrax.Tsit5()
sol = diffrax.diffeqsolve(term, solver, self.t0, self.t1, self.dt0, x, args=(y,))
(x,) = sol.ys
return x
Instantiate Model
1
2
3
4
5
6
7
8
key,_ = jr.split(key)
### could swap this to polynomial or kernel or whatever we want
base_layer = ConcatSquash
Func = partial(ConditionalVectorField, base_layer=base_layer, data_size=qoi_dim, cond_size=y_dim, width_size=32, depth=3)
model = ConditionalCNF(Func=Func, t0=0.0, t1=1.0, dt0=0.05, base_dist=base_dist, exact_logp=True, key=key)
Training/Optimizer Configuration
1
2
3
4
epochs = 10000
lr_schedule = optax.schedules.cosine_onecycle_schedule(epochs, peak_value=1e-3)
opt = optax.adamw(lr_schedule, weight_decay=0.1)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@eqx.filter_jit
def train_step(model, opt_state, batch, batch_key):
x, y = batch
def nll(model):
### T is a function of x i.e. x_cond_y and y
_, logp_x = jax.vmap(lambda x,y: model.train(x, y, batch_key))(x, y)
return -jnp.mean(logp_x)
loss, grads = eqx.filter_value_and_grad(nll)(model)
updates, opt_state = opt.update(grads, opt_state, eqx.filter(model, eqx.is_inexact_array))
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
def sample_conditional(model, yi, key):
keys = jr.split(key, 1000)
x_cond_yi_samples = jax.vmap(lambda key: model.sample(key,yi))(keys) ### n_samples, q_dim
return x_cond_yi_samples
@eqx.filter_jit
def eval(model, y, key):
x_cond_y_samples = jax.vmap(lambda yi: sample_conditional(model,
yi, key),
out_axes=1)(y) ### n_samples, n_y, q_dim
return x_cond_y_samples
Loop
1
2
3
4
5
6
for epoch in range(100):
key, epoch_key = jr.split(key)
model, opt_state, loss = train_step(model, opt_state, (qoi, y), epoch_key)
samples = eval(model, y, key) ### (n_samples, n_y, q_dim)
if epoch % 1 == 0:
print(f'{epoch=}, nll:{loss.item():.5f}')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
epoch=0, nll:1.69384
epoch=1, nll:1.69292
epoch=2, nll:1.69200
epoch=3, nll:1.69109
epoch=4, nll:1.69017
epoch=5, nll:1.68926
epoch=6, nll:1.68835
epoch=7, nll:1.68743
epoch=8, nll:1.68652
epoch=9, nll:1.68561
epoch=10, nll:1.68470
epoch=11, nll:1.68379
epoch=12, nll:1.68288
epoch=13, nll:1.68197
epoch=14, nll:1.68106
epoch=15, nll:1.68015
epoch=16, nll:1.67925
epoch=17, nll:1.67834
epoch=18, nll:1.67743
epoch=19, nll:1.67653
epoch=20, nll:1.67562
epoch=21, nll:1.67472
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[48], line 4
2 key, epoch_key = jr.split(key)
3 model, opt_state, loss = train_step(model, opt_state, (qoi, y), epoch_key)
----> 4 samples = eval(model, y, key) ### (n_samples, n_y, q_dim)
5 if epoch % 1 == 0:
6 print(f'{epoch=}, nll:{loss.item():.5f}')
[... skipping hidden 3 frame]
File ~/miniconda3/envs/mk/lib/python3.12/site-packages/jax/_src/pjit.py:292, in _cpp_pjit.<locals>.cache_miss(*args, **kwargs)
287 if config.no_tracing.value:
288 raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
289 "`jit`, but 'no_tracing' is set")
291 (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, box_data,
--> 292 executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
294 maybe_fastpath_data = _get_fastpath_data(
295 executable, out_tree, args_flat, out_flat, attrs_tracked, box_data,
296 jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, pgle_profiler)
298 return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
File ~/miniconda3/envs/mk/lib/python3.12/site-packages/jax/_src/pjit.py:153, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
151 args_flat = map(core.full_lower, args_flat)
152 core.check_eval_args(args_flat)
--> 153 out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
154 else:
155 out_flat = pjit_p.bind(*args_flat, **p.params)
File ~/miniconda3/envs/mk/lib/python3.12/site-packages/jax/_src/pjit.py:1877, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args)
1869 fingerprint = fingerprint.hex()
1870 distributed_debug_log(("Running pjit'd function", name),
1871 ("in_shardings", in_shardings),
1872 ("out_shardings", out_shardings),
(...) 1875 ("abstract args", map(core.abstractify, args)),
1876 ("fingerprint", fingerprint))
-> 1877 return compiled.unsafe_call(*args), compiled, pgle_profiler
File ~/miniconda3/envs/mk/lib/python3.12/site-packages/jax/_src/profiler.py:354, in annotate_function.<locals>.wrapper(*args, **kwargs)
351 @wraps(func)
352 def wrapper(*args, **kwargs):
353 with TraceAnnotation(name, **decorator_kwargs):
--> 354 return func(*args, **kwargs)
File ~/miniconda3/envs/mk/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1297, in ExecuteReplicated.__call__(self, *args)
1294 if (self.ordered_effects or self.has_unordered_effects
1295 or self.has_host_callbacks):
1296 input_bufs = self._add_tokens_to_inputs(input_bufs)
-> 1297 results = self.xla_executable.execute_sharded(
1298 input_bufs, with_tokens=True
1299 )
1301 result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
1302 len(self.ordered_effects))
1303 sharded_runtime_token = results.consume_token()
KeyboardInterrupt:
1
1
plt.hist(samples[:,0,0])
1
2
3
4
5
(array([ 6., 18., 72., 174., 220., 234., 165., 81., 27., 3.]),
array([-2.87272191, -2.33275294, -1.79278409, -1.25281525, -0.71284628,
-0.17287731, 0.36709142, 0.90706038, 1.44702935, 1.98699832,
2.52696729]),
<BarContainer object of 10 artists>)

1