Skip to content

[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398

Open
mbohlool wants to merge 1 commit into
mainfrom
text_encoder_tpu3
Open

[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398
mbohlool wants to merge 1 commit into
mainfrom
text_encoder_tpu3

Conversation

@mbohlool
Copy link
Copy Markdown
Collaborator

@mbohlool mbohlool commented May 4, 2026

Description

This PR transitions the LTX-2 pipeline's text encoding process to utilize TorchAX, bridging the Gemma-3 model natively into JAX and significantly optimizing memory usage to prevent TPU out-of-memory errors. Minor PyLint warnings across the pipeline and text encoder wrapper were also resolved during the refactor.

Key changes include:

  • TorchAX Integration: Replaced the eager PyTorch-based text encoder execution with the JAX-native TorchaxGemma3TextEncoder wrapping the HuggingFace Gemma3ForConditionalGeneration model. This allows full compiler optimization via JAX tracing.
  • VAE Memory Optimization: Updated the VAE decoding loop to conditionally apply sharding constraints depending on batch size. For batch_size <= 2, it utilizes standard VAE replication and slicing. For batch_size > 2, it dynamically disables sequential slicing and skips replication, applying NamedSharding constraints on the batch dimension of the latents across the mesh axes. This prevents JAX from trying to concatenate massive arrays on the TPU, avoiding HBM out-of-memory crashes.
  • Lint & Test Quality Cleanup: Addressed PyLint warnings across the pipeline, mock-patched the smoke tests to bypass loading the full 4B parameter text encoder when unnecessary, and ensured clean end-to-end execution.

Benchmarks

Performance comparison demonstrating latency and throughput improvements, based on robust averages of repeated runs (with the furthest outlier removed for each configuration).

Configuration Text Encoding (CPU) Text Encoding (TorchAX) Text Encoding Impr. Total Time (TE on CPU) Total Time (TE on TorchAX) Generation Impr.
Batch Size 1 (Latency Optimized) 3.55s 2.20s +38.0% 12.77s 11.43s +10.5%
Batch Size 1 (w/ Upsampler) 3.35s 2.60s +22.4% 16.04s 15.47s +3.6%
Batch Size 8 (Throughput Optimized) 23.50s 4.40s +81.3% 80.81s 58.78s +27.3%
Batch Size 8 (w/ Upsampler) 23.60s 4.65s +80.3% 113.38s 85.61s +24.5%

Note

Crucial VAE Memory Optimization Impact:
For Batch Size 8 (w/ Upsampler), running without the conditional VAE batch sharding constraints (enable_dynamic_vae_sharding=False) causes the generation to immediately fail with a TPU HBM Out-of-Memory (OOM) crash during VAE decoding.

By conditionally enabling enable_dynamic_vae_sharding=True for larger batches, the pipeline avoids the OOM completely and finishes in 85.61s (a +24.5% net generation speedup).

@mbohlool mbohlool requested a review from entrpn as a code owner May 4, 2026 20:08
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

@mbohlool mbohlool force-pushed the text_encoder_tpu3 branch 5 times, most recently from a449a5c to d681d61 Compare May 6, 2026 07:44
Comment thread src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py Outdated
Comment thread src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py Outdated
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully integrates TorchAX for the LTX-2 pipeline's text encoder, bringing significant performance improvements and memory optimizations on TPU. The transition from eager PyTorch to JAX-native execution is well-implemented, and the additional sharding constraints for both the text encoder and VAE are effective strategies for preventing OOM crashes.

🔍 General Feedback

  • TorchAX Integration: The use of TorchaxGemma3TextEncoder and the manual batch sharding logic is a great addition for efficiency.
  • Memory Management: The conditional sharding and slicing disabling in the VAE decoding loop correctly addresses HBM issues for larger batches.
  • Distributed Performance: One critical observation is the explicit un-sharding of text encoder hidden states to a single device, which should be avoided to ensure optimal performance in multi-host environments.
  • Code Cleanliness: Small refactors to use getattr instead of broad try/except blocks will improve maintainability.

Comment thread src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py Outdated
Comment thread src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py Outdated
Comment thread src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py Outdated
Comment thread src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py Outdated
Comment thread src/maxdiffusion/models/ltx2/text_encoders/torchax_text_encoder.py Outdated
@mbohlool mbohlool force-pushed the text_encoder_tpu3 branch from d681d61 to 8caeb1c Compare May 27, 2026 21:58
@mbohlool mbohlool force-pushed the text_encoder_tpu3 branch from 8caeb1c to 7b28885 Compare May 27, 2026 22:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants