JAX and Flax
Minimum price
Suggested price

JAX and Flax

Deep learning in JAX and Flax

About the Book

Book about JAX and Flax. Covering how to perform deep learning in JAX and Flax.

JAX (What it is and how to use it in Python)  

What is XLA?  

Installing JAX 

Setting up TPUs on Google Colab 

Data types in JAX 

Ways to create JAX arrays

Generating random numbers with JAX

Pure functions

JAX NumPy operations

JAX arrays are immutable

Out-of-Bounds Indexing

Data placement on devices in JAX

How fast is JAX?

Using jit() to speed up functions

How JIT works

Taking derivatives with grad()

Auto-vectorization with vmap

Parallelization with pmap

Debugging NANs in JAX

Double (64bit) precision

What is a pytree?

Handling state in JAX

Loading datasets with JAX

Building neural networks with JAX

Final thoughts

Optimizers in JAX and Flax

Adaptive vs stochastic gradient descent (SGD) optimizers



Adam – Adaptive moment estimation


RAdam – Rectified Adam optimizer



Lamb – Layerwise adaptive large batch optimization

Lars – Layer-wise Adaptive Rate Scaling

SM3 - Square-root of Minima of Sums of Maxima of Squared-gradients


SGD– Stochastic Gradient Descent

Noisy SGD

Optimistic GD

Differentially Private SGD



Final thoughts

JAX loss functions

What is a loss function?

Creating custom loss functions in JAX

Which loss functions are available in JAX?

Sigmoid binary cross entropy

Softmax cross entropy

Cosine distance

Cosine similarity

Huber loss

l2 loss

log cosh

Smooth labels

Computing loss with JAX Metrics

How to monitor JAX loss functions

Why JAX loss nan happens

Final thoughts

Activation functions in JAX and Flax

ReLU – Rectified linear unit

PReLU– Parametric Rectified Linear Unit


Log sigmoid


Log softmax

ELU – Exponential linear unit activation

CELU – Continuously-differentiable exponential linear unit

GELU– Gaussian error linear unit activation

GLU – Gated linear unit activation

Soft sign


Swish–Sigmoid Linear Unit( SiLU)

Custom activation functions in JAX and Flax

Final thoughts

How to load datasets in JAX with TensorFlow

How to load text data in JAX

Clean the text data

Label encode the sentiment column

Text preprocessing with TensorFlow

How to load image data in JAX

How to load CSV data in JAX

Final thoughts

Image classification with JAX & Flax

Loading the dataset

Define Convolution Neural Network with Flax

Define loss

Compute metrics

Create training state

Define training step

Define evaluation step

Training function

Evaluate the model

Train and evaluate the model

Model performance

Final thoughts

Distributed training with JAX & Flax

Perform standard imports

Setup TPUs on Colab

Download the dataset

Load the dataset

Define the model with Flax

Create training state

Apply the model

Training function

Train the model

Model evaluation

Final thoughts

How to use TensorBoard in JAX & Flax

How to use TensorBoard

How to install TensorBoard

Using TensorBoard with Jupyter notebooks and Google Colab

How to launch TensorBoard

Tensorboard dashboards

How to use TensorBoard with Flax

How to log images with TensorBoard in Flax

How to log text with TensorBoard in Flax

Track model training in JAX using TensorBoard

How to profile JAX programs with TensorBoard

Programmatic profiling

Manual profiling with TensorBoard

How to profile JAX program on a remote machine

Share TensorBoard dashboards

Final thoughts

Handling state in JAX & Flax (BatchNorm and DropOut layers)

Perform standard imports

Download the dataset

Loading datasets in JAX

Data processing with PyTorch

Define Flax model with BatchNorm and DropOut

Create loss function

Compute metrics

Create custom Flax training state

Training step

Evaluation step

Train Flax model

Set up TensorBoard in Flax

Train model

Save Flax model

Load Flax model

Evaluate Flax model

Visualize Flax model performance

Final thoughts

LSTM in JAX & Flax

Dataset download

Data processing with NLTK

Text vectorization with Keras

Create tf.data dataset

Define LSTM model in Flax

Compute metrics in Flax

Create training state

Define training step

Evaluate the Flax model

Create training function

Train LSTM model in Flax

Visualize LSTM model performance in Flax

Save LSTM model

Final thoughts

Flax vs. TensorFlow

Random number generation in TensorFlow and Flax

Model definition in Flax and TensorFlow

Activations in Flax and TensorFlow

Optimizers in Flax and TensorFlow

Metrics in Flax and TensorFlow

Computing gradients in Flax and TensorFlow

Loading datasets in Flax and TensorFlow

Training model in Flax vs. TensorFlow

Distributed training in Flax and TensorFlow

Working with TPU accelerators

Model evaluation

Visualize model performance

Final thoughts

Train ResNet in Flax from scratch(Distributed ResNet training)

Install Flax models

Perform standard imports

Download dataset

Loading dataset in Flax

Data transformation in Flax

Instantiate Flax ResNet model

Compute metrics

Create Flax model training state

Apply model function

TensorBoard in Flax

Train Flax ResNet model

Evaluate model with TensorBoard

Visualize Flax model performance

Save Flax ResNet model

Load Flax RestNet model

Final thoughts

Transfer learning with JAX & Flax

Install JAX ResNet

Download dataset

Data loading in JAX

Data processing

ResNet model definition

Create head network

Combine ResNet backbone with head

Load pre-trained ResNet 50

Get model and variables

Zero gradients

Define Flax optimizer

Define Flax loss function

Compute Flax metrics

Create Flax training state

Training step

Evaluation step

Train ResNet model in Flax

Set up TensorBoard in Flax

Train model

Save Flax model

Load saved Flax model

Evaluate Flax ResNet model

Visualize model performance

Final thoughts

Elegy(High-level API for deep learning in JAX & Flax)

Data pre-processing

Model definition in Elegy

Elegy model summary

Distributed training in Elegy

Keras-like callbacks in Flax

Train Elegy models

Evaluate Elegy models

Visualize Elegy model with TensorBoard

Plot model performance with Matplotlib

Making predictions with Elegy models

Saving and loading Elegy models

Final thoughts


About the Author

Derrick Mwiti
Derrick Mwiti

Derrick Mwiti is a Data Scientist with expertise in machine learning, data analytics, and visualization. He enjoys working with data to derive meaningful insights that help business executives in decision making. He is an alumnus of the prestigious Meltwater Entrepreneurial School of Technology(MEST). Derrick is an avid contributor to the data science community. He does so by contributing to popular data science publications such as KDnuggets, Heartbeat, and Datacamp just to mention a few. He holds a Bachelor of Science in Mathematics and Computer Science from Multimedia University.

Derrick Mwiti

Episode 125

The Leanpub 60-day 100% Happiness Guarantee

Within 60 days of purchase you can get a 100% refund on any Leanpub purchase, in two clicks.

See full terms

Do Well. Do Good.

Authors have earned$11,952,850writing, publishing and selling on Leanpub, earning 80% royalties while saving up to 25 million pounds of CO2 and up to 46,000 trees.

Learn more about writing on Leanpub

Free Updates. DRM Free.

If you buy a Leanpub book, you get free updates for as long as the author updates the book! Many authors use Leanpub to publish their books in-progress, while they are writing them. All readers get free updates, regardless of when they bought the book or how much they paid (including free).

Most Leanpub books are available in PDF (for computers) and EPUB (for phones, tablets and Kindle). The formats that a book includes are shown at the top right corner of this page.

Finally, Leanpub books don't have any DRM copy-protection nonsense, so you can easily read them on any supported device.

Learn more about Leanpub's ebook formats and where to read them

Write and Publish on Leanpub

You can use Leanpub to easily write, publish and sell in-progress and completed ebooks and online courses!

Leanpub is a powerful platform for serious authors, combining a simple, elegant writing and publishing workflow with a store focused on selling in-progress ebooks.

Leanpub is a magical typewriter for authors: just write in plain text, and to publish your ebook, just click a button. (Or, if you are producing your ebook your own way, you can even upload your own PDF and/or EPUB files and then publish with one click!) It really is that easy.

Learn more about writing on Leanpub