The Leanpub 60 Day 100% Happiness Guarantee
Within 60 days of purchase you can get a 100% refund on any Leanpub purchase, in two clicks.
See full terms...
Kick off your book project in 2 hours! Live workshop on Zoom. You’ll leave with a real book project, progress on your first chapter, and a clear plan to keep going. Tuesday, June 16, 2026. Learn more…
Deep learning in JAX and Flax
Minimum price
$10.00
$29.00
About the Book
Book about JAX and Flax. Covering how to perform deep learning in JAX and Flax.
JAX (What it is and how to use it in Python)
What is XLA?
Installing JAX
Setting up TPUs on Google Colab
Data types in JAX
Ways to create JAX arrays
Generating random numbers with JAX
Pure functions
JAX NumPy operations
JAX arrays are immutable
Out-of-Bounds Indexing
Data placement on devices in JAX
How fast is JAX?
Using jit() to speed up functions
How JIT works
Taking derivatives with grad()
Auto-vectorization with vmap
Parallelization with pmap
Debugging NANs in JAX
Double (64bit) precision
What is a pytree?
Handling state in JAX
Loading datasets with JAX
Building neural networks with JAX
Final thoughts
Optimizers in JAX and Flax
Adaptive vs stochastic gradient descent (SGD) optimizers
AdaBelief
AdaGrad
Adam – Adaptive moment estimation
AdamW
RAdam – Rectified Adam optimizer
AdaFactor
Fromage
Lamb – Layerwise adaptive large batch optimization
Lars – Layer-wise Adaptive Rate Scaling
SM3 - Square-root of Minima of Sums of Maxima of Squared-gradients
Method
SGD– Stochastic Gradient Descent
Noisy SGD
Optimistic GD
Differentially Private SGD
RMSProp
Yogi
Final thoughts
JAX loss functions
What is a loss function?
Creating custom loss functions in JAX
Which loss functions are available in JAX?
Sigmoid binary cross entropy
Softmax cross entropy
Cosine distance
Cosine similarity
Huber loss
l2 loss
log cosh
Smooth labels
Computing loss with JAX Metrics
How to monitor JAX loss functions
Why JAX loss nan happens
Final thoughts
Activation functions in JAX and Flax
ReLU – Rectified linear unit
PReLU– Parametric Rectified Linear Unit
Sigmoid
Log sigmoid
Softmax
Log softmax
ELU – Exponential linear unit activation
CELU – Continuously-differentiable exponential linear unit
GELU– Gaussian error linear unit activation
GLU – Gated linear unit activation
Soft sign
Softplus
Swish–Sigmoid Linear Unit( SiLU)
Custom activation functions in JAX and Flax
Final thoughts
How to load datasets in JAX with TensorFlow
How to load text data in JAX
Clean the text data
Label encode the sentiment column
Text preprocessing with TensorFlow
How to load image data in JAX
How to load CSV data in JAX
Final thoughts
Image classification with JAX & Flax
Loading the dataset
Define Convolution Neural Network with Flax
Define loss
Compute metrics
Create training state
Define training step
Define evaluation step
Training function
Evaluate the model
Train and evaluate the model
Model performance
Final thoughts
Distributed training with JAX & Flax
Perform standard imports
Setup TPUs on Colab
Download the dataset
Load the dataset
Define the model with Flax
Create training state
Apply the model
Training function
Train the model
Model evaluation
Final thoughts
How to use TensorBoard in JAX & Flax
How to use TensorBoard
How to install TensorBoard
Using TensorBoard with Jupyter notebooks and Google Colab
How to launch TensorBoard
Tensorboard dashboards
How to use TensorBoard with Flax
How to log images with TensorBoard in Flax
How to log text with TensorBoard in Flax
Track model training in JAX using TensorBoard
How to profile JAX programs with TensorBoard
Programmatic profiling
Manual profiling with TensorBoard
How to profile JAX program on a remote machine
Share TensorBoard dashboards
Final thoughts
Handling state in JAX & Flax (BatchNorm and DropOut layers)
Perform standard imports
Download the dataset
Loading datasets in JAX
Data processing with PyTorch
Define Flax model with BatchNorm and DropOut
Create loss function
Compute metrics
Create custom Flax training state
Training step
Evaluation step
Train Flax model
Set up TensorBoard in Flax
Train model
Save Flax model
Load Flax model
Evaluate Flax model
Visualize Flax model performance
Final thoughts
LSTM in JAX & Flax
Dataset download
Data processing with NLTK
Text vectorization with Keras
Create tf.data dataset
Define LSTM model in Flax
Compute metrics in Flax
Create training state
Define training step
Evaluate the Flax model
Create training function
Train LSTM model in Flax
Visualize LSTM model performance in Flax
Save LSTM model
Final thoughts
Flax vs. TensorFlow
Random number generation in TensorFlow and Flax
Model definition in Flax and TensorFlow
Activations in Flax and TensorFlow
Optimizers in Flax and TensorFlow
Metrics in Flax and TensorFlow
Computing gradients in Flax and TensorFlow
Loading datasets in Flax and TensorFlow
Training model in Flax vs. TensorFlow
Distributed training in Flax and TensorFlow
Working with TPU accelerators
Model evaluation
Visualize model performance
Final thoughts
Train ResNet in Flax from scratch(Distributed ResNet training)
Install Flax models
Perform standard imports
Download dataset
Loading dataset in Flax
Data transformation in Flax
Instantiate Flax ResNet model
Compute metrics
Create Flax model training state
Apply model function
TensorBoard in Flax
Train Flax ResNet model
Evaluate model with TensorBoard
Visualize Flax model performance
Save Flax ResNet model
Load Flax RestNet model
Final thoughts
Transfer learning with JAX & Flax
Install JAX ResNet
Download dataset
Data loading in JAX
Data processing
ResNet model definition
Create head network
Combine ResNet backbone with head
Load pre-trained ResNet 50
Get model and variables
Zero gradients
Define Flax optimizer
Define Flax loss function
Compute Flax metrics
Create Flax training state
Training step
Evaluation step
Train ResNet model in Flax
Set up TensorBoard in Flax
Train model
Save Flax model
Load saved Flax model
Evaluate Flax ResNet model
Visualize model performance
Final thoughts
Elegy(High-level API for deep learning in JAX & Flax)
Data pre-processing
Model definition in Elegy
Elegy model summary
Distributed training in Elegy
Keras-like callbacks in Flax
Train Elegy models
Evaluate Elegy models
Visualize Elegy model with TensorBoard
Plot model performance with Matplotlib
Making predictions with Elegy models
Saving and loading Elegy models
Final thoughts
Appendix
About the Author
Derrick Mwiti is a Data Scientist with expertise in machine learning, data analytics, and visualization. He enjoys working with data to derive meaningful insights that help business executives in decision making. He is an alumnus of the prestigious Meltwater Entrepreneurial School of Technology(MEST). Derrick is an avid contributor to the data science community. He does so by contributing to popular data science publications such as KDnuggets, Heartbeat, and Datacamp just to mention a few. He holds a Bachelor of Science in Mathematics and Computer Science from Multimedia University.

Episode 125
An Interview with Derrick Mwiti
You can get the free Community Edition in PDF or EPUB just by sharing your name and email address with the author, or you can just click this link to read a shorter sample online...
Also by the Author
Within 60 days of purchase you can get a 100% refund on any Leanpub purchase, in two clicks.
See full terms...
We pay 80% royalties on purchases of $7.99 or more, and 80% royalties minus a 50 cent flat fee on purchases between $0.99 and $7.98. You earn $8 on a $10 sale, and $16 on a $20 sale. So, if we sell 5000 non-refunded copies of your book for $20, you'll earn $80,000.
(Yes, some authors have already earned much more than that on Leanpub.)
In fact, authors have earned over $15 million writing, publishing and selling on Leanpub.
Learn more about writing on Leanpub
If you buy a Leanpub book, you get free updates for as long as the author updates the book! Many authors use Leanpub to publish their books in-progress, while they are writing them. All readers get free updates, regardless of when they bought the book or how much they paid (including free).
Most Leanpub books are available in PDF (for computers) and EPUB (for phones, tablets and Kindle). The formats that a book includes are shown at the top right corner of this page.
Finally, Leanpub books don't have any DRM copy-protection nonsense, so you can easily read them on any supported device.
Learn more about Leanpub's ebook formats and where to read them
You can use Leanpub to easily write, publish and sell in-progress and completed ebooks and online courses!
Leanpub is a powerful platform for serious authors, combining a simple, elegant writing and publishing workflow with a store focused on selling in-progress ebooks.
Leanpub is a magical typewriter for authors: just write in plain text, and to publish your ebook, just click a button. (Or, if you are producing your ebook your own way, you can even upload your own PDF and/or EPUB files and then publish with one click!) It really is that easy.