Source code for genome_entropy.encode3di.multi_gpu

"""Multi-GPU asynchronous encoding for protein to 3Di conversion."""

import asyncio
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Iterator, List, Optional, Tuple

try:
    import torch
except ImportError:
    torch = None  # type: ignore[assignment]

from ..errors import EncodingError
from ..logging_config import get_logger
from .gpu_utils import (
    discover_available_gpus,
    select_device_for_gpu,
    validate_gpu_availability,
)
from .types import IndexedSeq

logger = get_logger(__name__)


[docs] class MultiGPUEncoder: """Manages multi-GPU encoding of amino acid sequences to 3Di tokens. This class distributes encoding batches across multiple GPUs using asyncio for parallel processing. It handles GPU allocation, load balancing, and error recovery. """
[docs] def __init__( self, model_name: str, encoder_class: type, gpu_ids: Optional[List[int]] = None, ): """Initialize multi-GPU encoder. Args: model_name: HuggingFace model identifier encoder_class: Encoder class to instantiate (e.g., ProstT5ThreeDiEncoder) gpu_ids: List of GPU IDs to use. If None, auto-discover available GPUs. If empty list or None after discovery, falls back to single GPU. """ self.model_name = model_name self.encoder_class = encoder_class # Discover and validate GPUs if gpu_ids is None: gpu_ids = discover_available_gpus() # Validate the GPU IDs self.gpu_ids = validate_gpu_availability(gpu_ids) if gpu_ids else [] # Create encoders for each GPU self.encoders: List[Any] = [] self.executors: List[Any] = [] if self.gpu_ids: logger.info( "Initializing multi-GPU encoding with %d GPU(s): %s", len(self.gpu_ids), self.gpu_ids, ) for gpu_id in self.gpu_ids: device = select_device_for_gpu(gpu_id) encoder = encoder_class(model_name=model_name, device=device) self.encoders.append(encoder) self.executors.append(ThreadPoolExecutor(max_workers=1)) logger.info("Created encoder for %s", device) else: # Fallback to single GPU/CPU encoder logger.info("No GPUs available, falling back to single device encoder") encoder = encoder_class(model_name=model_name, device=None) self.encoders.append(encoder)
@property def num_gpus(self) -> int: """Number of GPUs being used.""" return len(self.encoders)
[docs] def is_multi_gpu(self) -> bool: """Check if using multiple GPUs.""" return len(self.encoders) > 1
[docs] async def encode_batch_async( self, encoder_idx: int, batch: List[IndexedSeq], ) -> Tuple[List[int], List[str]]: """Encode a single batch on a specific GPU asynchronously. Args: encoder_idx: Index of encoder/GPU to use batch: List of IndexedSeq objects to encode Returns: Tuple of (original_indices, encoded_3di_sequences) """ encoder = self.encoders[encoder_idx] executor = self.executors[encoder_idx] gpu_id = self.gpu_ids[encoder_idx] if self.gpu_ids else None batch_seqs = [x.seq for x in batch] batch_idxs = [x.idx for x in batch] logger.info( "GPU %s: Encoding batch with %d sequences (total len: %d)", gpu_id if gpu_id is not None else "default", len(batch_seqs), sum(len(s) for s in batch_seqs), ) # Run encoding in thread pool to avoid blocking asyncio loop = asyncio.get_running_loop() # Use encoder's _encode_batch method directly results = await loop.run_in_executor( executor, encoder._encode_batch, batch_seqs ) return batch_idxs, results
[docs] async def encode_all_batches_async( self, batches: List[List[IndexedSeq]], total_sequences: int, ) -> List[str]: """Encode all batches across multiple GPUs asynchronously. Args: batches: List of batches to encode total_sequences: Total number of sequences Returns: List of encoded 3Di sequences in original input order Raises: EncodingError: If encoding fails """ three_di_sequences: List[str] = [None] * total_sequences # type: ignore[list-item] t0 = time.perf_counter() total_batches = len(batches) logger.info( "Starting multi-GPU encoding of %d sequences in %d batches across %d GPU(s)", total_sequences, total_batches, len(self.encoders), ) # Create a shared queue for all batches batch_queue: asyncio.Queue[Tuple[int, List[IndexedSeq]]] = asyncio.Queue() # Enqueue all batches with their indices for batch_idx, batch in enumerate(batches): await batch_queue.put((batch_idx, batch)) # Track completed batches and errors completed = 0 completed_lock = asyncio.Lock() first_error: Optional[Exception] = None async def gpu_worker(gpu_idx: int) -> None: """Worker coroutine that processes batches for a specific GPU.""" nonlocal completed, first_error while True: try: # Get next batch from queue (non-blocking check) batch_idx, batch = await asyncio.wait_for( batch_queue.get(), timeout=0.1 ) except asyncio.TimeoutError: # Queue is empty, exit worker break try: # Encode the batch on this GPU batch_idxs, batch_results = await self.encode_batch_async( gpu_idx, batch ) # Store results in original order for bi, br in zip(batch_idxs, batch_results): three_di_sequences[bi] = br # Update progress async with completed_lock: completed += 1 elapsed = time.perf_counter() - t0 avg_batch_time = elapsed / completed eta_remaining = avg_batch_time * (total_batches - completed) logger.info( "Completed batch %d/%d (%.1f%%) - Elapsed: %.1fs, ETA: %.1fs", completed, total_batches, 100.0 * completed / total_batches, elapsed, eta_remaining, ) # Mark task as done batch_queue.task_done() except Exception as e: # Store first error and stop processing if first_error is None: first_error = e logger.error( "GPU %d failed encoding batch %d: %s", gpu_idx, batch_idx, e, exc_info=True, ) batch_queue.task_done() break # Create one worker per GPU workers = [ asyncio.create_task(gpu_worker(gpu_idx)) for gpu_idx in range(len(self.encoders)) ] # Wait for all workers to complete try: await asyncio.gather(*workers, return_exceptions=True) except Exception as e: logger.error( "Multi-GPU encoding failed during worker execution: %s", e, exc_info=True, ) if first_error is None: first_error = e # If any error occurred, raise it if first_error is not None: raise EncodingError( f"Multi-GPU encoding failed: {first_error}" ) from first_error # Check all sequences encoded missing = [i for i, v in enumerate(three_di_sequences) if v is None] if missing: raise RuntimeError( f"Missing encodings for {len(missing)} sequences " f"(e.g., indices {missing[:10]})" ) elapsed_total = time.perf_counter() - t0 logger.info( "Multi-GPU encoding complete! Encoded %d sequences in %.1fs (%.2f seqs/sec)", total_sequences, elapsed_total, total_sequences / elapsed_total if elapsed_total > 0 else 0, ) return three_di_sequences
[docs] def encode_multi_gpu( self, aa_sequences: List[str], token_budget_batches_fn: Callable[[List[str], int], Iterator[Any]], encoding_size: int, skip_model_loading: bool = False, ) -> List[str]: """Encode sequences using multiple GPUs. This is a synchronous wrapper around the async encoding method. Args: aa_sequences: List of preprocessed amino acid sequences token_budget_batches_fn: Function to create batches under token budget encoding_size: Maximum size (approx. amino acids) per batch skip_model_loading: If True, skip model loading (assumes models already loaded). This is useful when the encoder is being reused across multiple calls. Returns: List of 3Di token sequences (one per input sequence) """ # Create batches batches_iter = token_budget_batches_fn(aa_sequences, encoding_size) batches = list(batches_iter) total_sequences = len(aa_sequences) # Load models for all encoders (unless already loaded) if not skip_model_loading: logger.info("Loading models on all GPUs...") for encoder in self.encoders: encoder._load_model() # Run async encoding if self.is_multi_gpu(): logger.info("Using multi-GPU parallel encoding") result = asyncio.run( self.encode_all_batches_async(batches, total_sequences) ) else: logger.info("Using single-GPU sequential encoding") # Fall back to sequential processing for single GPU result = self._encode_single_gpu_sequential(batches, total_sequences) return result
def _encode_single_gpu_sequential( self, batches: List[List[IndexedSeq]], total_sequences: int, ) -> List[str]: """Fallback to sequential single-GPU encoding. Args: batches: List of batches to encode total_sequences: Total number of sequences Returns: List of encoded 3Di sequences in original input order """ three_di_sequences: List[str] = [None] * total_sequences # type: ignore[list-item] encoder = self.encoders[0] logger.info( "Encoding %d sequences in %d batches (single GPU)", total_sequences, len(batches), ) t0 = time.perf_counter() for batch_idx, batch in enumerate(batches, 1): batch_seqs = [x.seq for x in batch] batch_idxs = [x.idx for x in batch] results = encoder._encode_batch(batch_seqs) for bi, br in zip(batch_idxs, results): three_di_sequences[bi] = br elapsed = time.perf_counter() - t0 avg_time = elapsed / batch_idx eta = avg_time * (len(batches) - batch_idx) logger.info( "Batch %d/%d complete - Elapsed: %.1fs, ETA: %.1fs", batch_idx, len(batches), elapsed, eta, ) # Check all sequences encoded missing = [i for i, v in enumerate(three_di_sequences) if v is None] if missing: raise RuntimeError(f"Missing encodings for {len(missing)} sequences") return three_di_sequences