215  Deep Learning Frameworks: TensorFlow and JAX

Modern deep learning rests on a small number of numerical computing frameworks that handle automatic differentiation, hardware acceleration, and distributed execution. This chapter examines two systems that originated at Google: TensorFlow, the production-oriented platform built around the Keras API, and JAX, a research-oriented library built around composable functional transformations. Both compile to the same accelerator backend, XLA, yet they embody sharply different philosophies. Understanding their design choices, and how they compare to PyTorch, clarifies the tradeoffs that govern framework selection for any serious project.

215.1 1. The TensorFlow and Keras Ecosystem

215.1.1 1.1 What TensorFlow Provides

TensorFlow began in 2015 as a successor to Google’s internal DistBelief system. Its central abstraction is the tensor, an \(n\)-dimensional array, flowing through a computation. The library bundles far more than a math kernel. It ships a serving stack (TensorFlow Serving), a mobile and embedded runtime (TensorFlow Lite, now LiteRT), a browser runtime (TensorFlow.js), data pipeline tooling (tf.data), experiment tracking (TensorBoard), and a model exchange format (SavedModel). This breadth is the historical reason TensorFlow dominated production deployment: a model trained in Python could be exported as a language-neutral artifact and served in C++, Java, or on a phone without reimplementation.

215.1.2 1.2 Keras as the Front End

Keras, created by François Chollet in 2015, started as a high-level API that could sit on top of multiple backends. Since TensorFlow 2.0 (2019) it became the official high-level interface, and as of Keras 3 (2023) it is once again multi-backend, able to run on TensorFlow, JAX, or PyTorch. Keras offers three model-building styles that trade convenience against flexibility.

The Sequential API stacks layers linearly. The Functional API treats layers as callables on symbolic tensors, expressing arbitrary directed acyclic graphs such as residual or multi-input networks. Model subclassing lets the user write a call method imperatively for full control. A typical Functional model looks like this.

import keras
from keras import layers

inputs = keras.Input(shape=(784,))
x = layers.Dense(256, activation="relu")(inputs)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs, outputs)

model.compile(optimizer="adam",
              loss="categorical_crossentropy",
              metrics=["accuracy"])
model.fit(x_train, y_train, epochs=10, batch_size=128)

The compile and fit pattern hides the training loop, gradient computation, and device placement. For practitioners who want standard supervised training, this is the fastest path from data to a trained model. The cost is that the loop is opaque; custom training logic requires either callbacks or dropping to a lower level.

215.2 2. Graph Mode and Eager Mode

215.2.1 2.1 The Original Define-and-Run Model

TensorFlow 1.x used a define-and-run execution model. The programmer first constructed a static computation graph of symbolic operations, then launched a Session to feed data and execute it. The graph was a complete, serializable description of the computation, which made aggressive whole-program optimization, serialization, and distribution straightforward. The drawback was severe: debugging required inspecting graph nodes rather than concrete values, control flow needed special operators such as tf.cond and tf.while_loop, and the mental model diverged from ordinary Python. This friction drove many researchers toward PyTorch.

215.2.2 2.2 Eager Execution by Default

TensorFlow 2.x made eager execution the default. Operations now run immediately and return concrete values, exactly as in NumPy. An expression like tf.matmul(a, b) produces a result on the spot, so standard Python debuggers, print statements, and control flow work naturally. This closed most of the usability gap with PyTorch while retaining the option to recover graph performance when needed.

215.2.3 2.3 tf.function and the Bridge Between Modes

The bridge is the tf.function decorator. Applied to a Python function, it traces the function once to build a graph, then executes the compiled graph on subsequent calls. This combines the ergonomics of eager development with the throughput of graph execution.

import tensorflow as tf

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        preds = model(x, training=True)
        loss = loss_fn(y, preds)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

Tracing introduces subtleties. The function is retraced whenever it receives inputs with a new shape or dtype signature, and Python side effects such as printing or list mutation run only during tracing, not during graph execution. The mechanism that records operations for differentiation is tf.GradientTape, which watches trainable variables and applies reverse-mode automatic differentiation when tape.gradient is called. The tape gives explicit, imperative control over the backward pass, which is essential for custom losses, gradient penalties, and higher-order derivatives obtained by nesting tapes.

215.3 3. JAX and Functional Transformations

215.3.1 3.1 The Functional Philosophy

JAX, released in 2018, takes a different stance. It is built on the premise that numerical programs are best expressed as pure functions, functions with no side effects whose output depends only on their inputs. On top of a NumPy-compatible API (jax.numpy), JAX provides a small set of composable function transformations. Because the transformations operate on pure functions, they compose freely and predictably. The price of purity is that JAX arrays are immutable and that randomness and state must be threaded explicitly rather than hidden in global objects.

215.3.2 3.2 grad: Automatic Differentiation

The transformation jax.grad takes a scalar-valued function and returns a new function that computes its gradient.

import jax
import jax.numpy as jnp

def loss(w, x, y):
    pred = jnp.dot(x, w)
    return jnp.mean((pred - y) ** 2)

grad_loss = jax.grad(loss)        # gradient with respect to w
g = grad_loss(w, x, y)

Differentiation is itself a function transform, so jax.grad(jax.grad(f)) yields a second derivative, and jax.jacobian, jax.hessian, jax.jvp, and jax.vjp expose forward-mode and reverse-mode primitives directly. For a function \(f: \mathbb{R}^n \to \mathbb{R}\), reverse mode computes the full gradient \(\nabla f\) in roughly the cost of one forward evaluation, which is why it underlies neural network training.

215.3.3 3.3 jit: Compilation with XLA

The jax.jit transform traces a function into an intermediate representation called a jaxpr, then hands it to XLA for compilation into a fused, optimized kernel.

@jax.jit
def update(params, x, y):
    grads = jax.grad(loss)(params, x, y)
    return params - 0.01 * grads

As with tf.function, JIT tracing abstracts concrete values into symbolic tracers, so Python control flow that depends on array values must be replaced by jax.lax.cond and jax.lax.scan. Recompilation is triggered by changes in input shape or dtype, which is why JAX strongly favors static, fixed-shape computation.

215.3.4 3.4 vmap: Automatic Vectorization

The jax.vmap transform adds a batch dimension to a function written for a single example, removing the need to write batched code by hand or to insert manual broadcasting.

batched_predict = jax.vmap(predict, in_axes=(None, 0))
predictions = batched_predict(params, batch_of_inputs)

Here in_axes=(None, 0) means parameters are shared across the batch while the second argument is mapped over its leading axis. Because vmap composes with grad and jit, one can write clean per-example logic and obtain a batched, differentiated, compiled function by stacking transforms, for example jit(vmap(grad(f))). This per-example clarity also makes constructs such as per-sample gradients, needed for differential privacy, far simpler than in graph-first frameworks.

215.3.5 3.5 pmap and the Modern Sharding Model

The jax.pmap transform replicates a computation across multiple devices and runs them in single-program multiple-data fashion, inserting collective operations such as jax.lax.psum for cross-device reduction. It maps naturally onto data parallelism across GPUs or TPU cores.

@jax.pmap
def parallel_step(params, batch):
    grads = jax.grad(loss)(params, batch)
    grads = jax.lax.pmean(grads, axis_name="batch")
    return grads

For large models, the contemporary recommendation has shifted toward jax.jit combined with explicit sharding via jax.sharding and the higher-level shard_map, which let the compiler reason about partitioning across a logical device mesh. This unifies data, tensor, and pipeline parallelism under one mental model and scales to the thousand-chip TPU pods used to train frontier models.

215.4 4. The XLA Compiler

XLA (Accelerated Linear Algebra) is the compilation backend shared by TensorFlow and JAX, and it explains much of why these frameworks perform as they do. XLA ingests a graph of high-level operations and applies optimizations that a naive op-by-op executor cannot.

The most consequential optimization is operator fusion. Consider an elementwise chain such as \(y = \text{relu}(\alpha x + b)\). Executed eagerly, each operation reads its inputs from device memory and writes its output back, so the data crosses the memory bus several times. Accelerators are typically memory-bandwidth bound on such kernels, meaning the bottleneck is data movement rather than arithmetic. XLA fuses the multiply, add, and activation into a single kernel that keeps intermediate values in registers, eliminating the round trips. Beyond fusion, XLA performs constant folding, buffer assignment to reuse memory, layout optimization, and algebraic simplification. The result is compiled, hardware-specific code for CPU, GPU, and TPU.

The tradeoff is compilation latency and a preference for static shapes. Each new input shape can trigger a recompilation, so workloads with highly variable shapes, such as variable-length sequences, often require bucketing or padding to a fixed set of shapes. This is the practical reason both JAX and tf.function reward static, regular computation and penalize dynamic Python that changes structure at runtime.

215.5 5. Comparison to PyTorch

215.5.1 5.1 Execution Model

PyTorch is eager by default and builds its autograd graph dynamically on each forward pass, a define-by-run approach that makes debugging and dynamic control flow effortless. Historically this was its decisive advantage over TensorFlow 1.x and drove its dominance in research. PyTorch 2.0 narrowed the performance gap with torch.compile, which traces Python via TorchDynamo and lowers to optimized kernels through the TorchInductor backend, conceptually parallel to what tf.function and jax.jit do, though PyTorch can also target XLA through the PyTorch/XLA project for TPU execution.

215.5.2 5.2 State and Programming Style

The frameworks differ most in how they treat state. PyTorch and Keras are object-oriented: a module holds its parameters as mutable attributes, and optimizers mutate those tensors in place. JAX is functional: parameters live in explicit data structures, usually nested dictionaries known as pytrees, and every update returns a new structure. Neural network libraries built on JAX, such as Flax and Equinox, formalize this pattern. The functional style makes transformations and reproducibility cleaner but shifts the burden of state management onto the programmer.

Randomness illustrates the divide. PyTorch and TensorFlow rely on a global random seed and implicit state. JAX requires an explicit, splittable random key passed into each stochastic operation, so a function’s output is fully determined by its arguments. This explicitness guarantees bit-for-bit reproducibility across devices and parallel replicas, at the cost of more verbose code.

215.5.3 5.3 Ecosystem and Selection Criteria

The table below summarizes the practical contrasts.

Dimension TensorFlow / Keras JAX PyTorch
Default execution Eager, graph via tf.function Trace and compile via jit Eager, compile via torch.compile
State model Object-oriented, mutable Functional, immutable pytrees Object-oriented, mutable
Autodiff GradientTape grad transform autograd
Compiler backend XLA XLA TorchInductor, optional XLA
Typical strength Production, mobile, serving Research, large-scale TPU Research, broad community

A reasonable heuristic follows. Choose TensorFlow and Keras when deployment breadth matters, when targeting mobile or browser, or when a high-level fit-and-predict workflow suffices. Choose JAX when the work involves custom numerical methods, high-order derivatives, per-example gradients, or training at TPU scale where the functional model and XLA sharding pay off. Choose PyTorch when ecosystem momentum, the largest pool of pretrained models and tutorials, and frictionless debugging are the priorities, which describes most current research and a growing share of production. Keras 3, by running on all three backends, blurs these boundaries and lets a single high-level codebase migrate between them.

215.6 6. Conclusion

TensorFlow and JAX express two philosophies that meet at XLA. TensorFlow wraps a comprehensive production ecosystem around an eager core that can recover graph performance through tf.function, with Keras providing an accessible front end. JAX reduces deep learning to pure functions acted on by composable transforms, where grad, jit, vmap, and pmap combine to yield differentiated, compiled, vectorized, and parallelized programs from concise code. PyTorch occupies a pragmatic middle ground that has captured the research mainstream. The deciding factors are rarely raw speed, since all three converge on similar compiled kernels, but rather state model, deployment targets, and the scale of parallelism a project demands.

215.7 References

  1. Abadi, M., et al. “TensorFlow: A System for Large-Scale Machine Learning.” OSDI 2016. https://www.usenix.org/system/files/conference/osdi16/osdi16-abadi.pdf
  2. TensorFlow documentation. “Introduction to graphs and tf.function.” https://www.tensorflow.org/guide/intro_to_graphs
  3. Keras documentation. “About Keras 3.” https://keras.io/about/
  4. Chollet, F. “Deep Learning with Python,” 2nd ed. Manning, 2021. https://www.manning.com/books/deep-learning-with-python-second-edition
  5. Bradbury, J., et al. “JAX: composable transformations of Python+NumPy programs.” 2018. https://github.com/google/jax
  6. JAX documentation. “The Autodiff Cookbook.” https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
  7. JAX documentation. “Distributed arrays and automatic parallelization.” https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
  8. XLA documentation. “XLA: Optimizing Compiler for Machine Learning.” https://openxla.org/xla
  9. Paszke, A., et al. “PyTorch: An Imperative Style, High-Performance Deep Learning Library.” NeurIPS 2019. https://papers.nips.cc/paper/9015-pytorch-an-imperative-style-high-performance-deep-learning-library.pdf
  10. PyTorch documentation. “torch.compile.” https://pytorch.org/docs/stable/torch.compiler.html
  11. Heek, J., et al. “Flax: A neural network library and ecosystem for JAX.” https://github.com/google/flax