Hands-on Tutorials

JAX implementation of FEA

And efficient inverse problem solving with neural networks

John T. Foster, PhD, PE
Towards Data Science
6 min readJan 18, 2021

--

Image by (and of) the Author

If you haven’t heard by now JAX is getting a lot of attention online as a “NumPy on steroids”. At its core, it can be thought of as a drop-in replacement for NumPy where the array calculations can be accelerated on GPUs or TPUs when available. This alone makes it worth looking at, especially if you have a lot of NumPy code that you would like to potentially speed up with GPU acceleration. Currently, most of the NumPy API is implemented in one-to-one correspondence, as well of some of the most used functions in SciPy.

The accelerated NumPy is just the beginning of the utility of JAX. All of the JAX NumPy data structures can be used in combination with most pure Python code to create functions which can be automatically differentiated. This includes computing the gradient of scalar functions, as well as Jacobian matrices of vector functions. These operations can be composed to compute gradients-of-gradients, etc. More information of the automatic differentiation capabilities are documented here.

Additionally, there is a built-in just-in-time compiler for compiling functions to be executed on CPUs/GPUs/TPUs, and support for automatic vectorization, i.e. functions written for scalar arguments can be easily mapped across arrays. These can be used with the automatic differentiation functions previously mentioned.

Finally, there is a very thin neural network library associated with JAX called stax. Other, more fully-featured libraries like Haiku, Flax, or Trax are under development on top of JAX technologies.

In what follows, I’ll highlight most of these features of JAX by implementing a finite element analysis (FEA) model and then using the finite element residual as part of the objective function when training a neural network in an inverse solution to a potentially unknown constitutive model.

As a model problem, we’ll start with the one-dimensional pressure diffusivity equation which governs single phase fluid flow in a porous media with fluid density ρ and small compressibility c.

Assuming steady state, multiplying by a test function δp on the left and integrating-by-parts over the domain (0, L) we have

where

κ is the porous medium’s permeability and μ is the fluid viscosity. λ is known as the mobility and is assumed to be spatially varying.

Using a Galerkin approximation, i.e. p = Nᵢ pⱼ and δp = Nᵢ for I,J = 1, 2, … basis functions and splitting the domain into n intervals, we now have

where summation over the J basis functions are implied for those that have support on the Iᵗʰ node. The right-hand side above is our residual, i.e. R

below, we’ll integrate this residual vector using Gauss integration and solve for the unknown nodal pressures pⱼ. Without loss of generality we’ll only consider Dirchelet boundary conditions, i.e. q(x) = 0.

While this model problem is linear, we’ll implement the FEA model to use the residual form of the equations, and solve for the unknowns using a nonlinear Newton-Raphson solver where the Jacobian matrix at each iteration is computed via automatic-differentiation with JAX. All of the computations are written in a way that they could be accelerated on GPUs/TPUs and are just-in-time compiled.

Below are the imports we need, note that we explicitly enable 64-bit floating point numbers for JAX as 32-bit is the default.

Below we will solve the forward problem via FEA using the implementation above to verify things are working correctly as well as generate some reference data that we’ll use in the inverse problem in the sequel. Here, the mobility function is

Here we will write our inverse problem solver. We will inherit from the FEAProblem class above so we can reuse some of the functions already defined.

Our objective function here will be the l₂-norm of the finite element residual when the “known data” is supplied as training data. Because the problem we are solving is a steady-state problem, we’ll need to provide the endpoints of constitutive model to the objective function, otherwise there are infinite valid solutions to learning the constitutive model that only differ by a constant. If we extended this technique to time-dependent problems, I believe the need to provide the boundary constraints can be avoided.

We’ll use a few functions from the jax.experimental.stax module, just to make the neural network construction easier. Our minimizer here will use the second-order "BFGS" method from jax.scipy.optimize.minimize.

Here we assume our data is supplied at the nodes of the FE model, but this restriction could be easily generalized by evaluating the residuals at any given spatial location via the FE shape functions.

Below we’ll test out our inverse problem solver using the data generated earlier from the forward finite element solution. First we define our neural network architecture. This is a fairly simple function, so we don’t need a large and/or complex neural network. Here we have an input layer with only 4 nodes and a tanh activation function feeding to a single node output. More complicated architectures also work, yielding the same result at more computational cost.

We also need to define the layer Dense64 which is the same as stax.Dense, but initialized to use 64-bit floats to be consistant with our data structures in the FEA residual calculation.

Now we instantiate the model and solve the inverse problem, i.e. train the network. We do have to supply the endpoints of the constitutive model. Given the problem is parabolic, there are infinite solutions to the inverse problem (they all have the same shape, but differ by a constant scale factor). We could remove this restriction by considering a time-dependent problem and supplying the time depended training data which we’ll leave for future work.

Plotting the neural network function over the range of the domain and comparing with the reference, we can see that the inverse solver has “learned” the mobility function well.

Just to verify, we’ll use our neural network as the mobility function in our forward finite element solver to demonstrate the resulting pressures are also accurate.

A major advantage of this approach over say, physics-informed neural networks, is that we have only “learned” the constitutive model, i.e. the mobility function, not the solution of the partial differential equation with the supplied boundary conditions. Instead, we rely on our finite element implementation to compute the solution. Which means we can now use our “learned” constitutive model to solve problems with different boundary conditions accurately.

--

--