[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398
Open
mbohlool wants to merge 1 commit into
Open
[LTX-2] Run Gemma-3 Text Encoder natively in JAX via TorchAX#398mbohlool wants to merge 1 commit into
mbohlool wants to merge 1 commit into
Conversation
a449a5c to
d681d61
Compare
Perseus14
reviewed
May 7, 2026
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully integrates TorchAX for the LTX-2 pipeline's text encoder, bringing significant performance improvements and memory optimizations on TPU. The transition from eager PyTorch to JAX-native execution is well-implemented, and the additional sharding constraints for both the text encoder and VAE are effective strategies for preventing OOM crashes.
🔍 General Feedback
- TorchAX Integration: The use of
TorchaxGemma3TextEncoderand the manual batch sharding logic is a great addition for efficiency. - Memory Management: The conditional sharding and slicing disabling in the VAE decoding loop correctly addresses HBM issues for larger batches.
- Distributed Performance: One critical observation is the explicit un-sharding of text encoder hidden states to a single device, which should be avoided to ensure optimal performance in multi-host environments.
- Code Cleanliness: Small refactors to use
getattrinstead of broadtry/exceptblocks will improve maintainability.
d681d61 to
8caeb1c
Compare
8caeb1c to
7b28885
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR transitions the LTX-2 pipeline's text encoding process to utilize TorchAX, bridging the Gemma-3 model natively into JAX and significantly optimizing memory usage to prevent TPU out-of-memory errors. Minor PyLint warnings across the pipeline and text encoder wrapper were also resolved during the refactor.
Key changes include:
TorchaxGemma3TextEncoderwrapping the HuggingFaceGemma3ForConditionalGenerationmodel. This allows full compiler optimization via JAX tracing.batch_size <= 2, it utilizes standard VAE replication and slicing. Forbatch_size > 2, it dynamically disables sequential slicing and skips replication, applying NamedSharding constraints on the batch dimension of the latents across the mesh axes. This prevents JAX from trying to concatenate massive arrays on the TPU, avoiding HBM out-of-memory crashes.Benchmarks
Performance comparison demonstrating latency and throughput improvements, based on robust averages of repeated runs (with the furthest outlier removed for each configuration).
Note
Crucial VAE Memory Optimization Impact:
For Batch Size 8 (w/ Upsampler), running without the conditional VAE batch sharding constraints (
enable_dynamic_vae_sharding=False) causes the generation to immediately fail with a TPU HBM Out-of-Memory (OOM) crash during VAE decoding.By conditionally enabling
enable_dynamic_vae_sharding=Truefor larger batches, the pipeline avoids the OOM completely and finishes in 85.61s (a +24.5% net generation speedup).