Source code for genome_entropy.encode3di.token_estimator

"""Token size estimation for optimal GPU memory usage in 3Di encoding."""

import random
from typing import Any, Dict, List, Optional

try:
    import torch
except ImportError:
    torch = None  # type: ignore[assignment]

from ..config import AA_ALPHABET
from ..logging_config import get_logger

logger = get_logger(__name__)


[docs] def generate_random_protein(length: int, seed: Optional[int] = None) -> str: """Generate a random protein sequence of specified length. Args: length: Length of the protein sequence seed: Random seed for reproducibility (optional) Returns: Random protein sequence using the 20 standard amino acids """ if seed is not None: random.seed(seed) aa_list = list(AA_ALPHABET) return "".join(random.choices(aa_list, k=length))
[docs] def generate_combined_proteins( target_length: int, base_length: int = 100, seed: Optional[int] = None ) -> List[str]: """Generate multiple shorter proteins that combine to target length. Args: target_length: Total target length across all proteins base_length: Approximate length of each individual protein seed: Random seed for reproducibility (optional) Returns: List of protein sequences that total approximately target_length """ if seed is not None: random.seed(seed) proteins = [] remaining = target_length while remaining > 0: # Vary protein length slightly for realism variation = int(base_length * 0.2) # 20% variation length = random.randint(base_length - variation, base_length + variation) length = min(length, remaining) # Don't exceed target proteins.append(generate_random_protein(length)) remaining -= length return proteins
[docs] def estimate_token_size( encoder: Any, start_length: int = 3000, end_length: int = 10000, step: int = 1000, num_trials: int = 3, base_protein_length: int = 100, ) -> Dict[str, Any]: """Estimate optimal token size for GPU encoding by testing increasing lengths. This function generates random protein sequences of increasing total length and attempts to encode them. It catches OutOfMemoryError to find the maximum length that can be encoded on the available GPU. Args: encoder: ProstT5ThreeDiEncoder instance to use for encoding start_length: Starting total length to test (default: 3000) end_length: Maximum total length to test (default: 10000) step: Increment between test lengths (default: 1000) num_trials: Number of trials per length for robustness (default: 3) base_protein_length: Approximate length of individual proteins (default: 100) Returns: Dictionary with estimation results: - 'max_length': Maximum length successfully encoded - 'recommended_token_size': Recommended token budget (90% of max) - 'trials_per_length': Dictionary of successful trials per length - 'device': Device used for testing Raises: ValueError: If encoder doesn't have required attributes or torch not available """ if torch is None: raise ValueError("PyTorch is required for token size estimation") if not hasattr(encoder, "encode") or not hasattr(encoder, "device"): raise ValueError("encoder must be a ProstT5ThreeDiEncoder instance") if not hasattr(encoder, "_load_model"): raise ValueError("encoder must have _load_model method") logger.info("Starting token size estimation on device: %s", encoder.device) logger.info("Testing range: %d to %d (step: %d)", start_length, end_length, step) # Load the model before starting estimation logger.info("Loading model...") encoder._load_model() max_successful_length = 0 trials_per_length: Dict[int, int] = {} for total_length in range(start_length, end_length + 1, step): logger.info("Testing total length: %d amino acids", total_length) successful_trials = 0 for trial in range(num_trials): try: # Generate proteins that combine to target length proteins = generate_combined_proteins( total_length, base_length=base_protein_length, seed=trial, # Different seed per trial ) logger.info( " Trial %d/%d: encoding %d proteins (total %d AA)", trial + 1, num_trials, len(proteins), sum(len(p) for p in proteins), ) # Attempt encoding with token budget # Use the encoder's token_budget_batches for realistic batching batches = list(encoder.token_budget_batches(proteins, total_length)) # Try to encode all batches for batch_idx, batch in enumerate(batches): batch_seqs = [item.seq for item in batch] _ = encoder._encode_batch(batch_seqs) logger.info( " Batch %d/%d encoded successfully", batch_idx + 1, len(batches), ) successful_trials += 1 logger.info(" Trial %d/%d: SUCCESS", trial + 1, num_trials) except torch.cuda.OutOfMemoryError as e: logger.warning( " Trial %d/%d: Out of memory at length %d: %s", trial + 1, num_trials, total_length, str(e), ) # Clear cache and break on OOM if torch.cuda.is_available(): torch.cuda.empty_cache() break except Exception as e: logger.error( " Trial %d/%d: Unexpected error at length %d: %s", trial + 1, num_trials, total_length, str(e), ) # Don't break on other errors, might be transient continue trials_per_length[total_length] = successful_trials # If no trials succeeded, we've hit the limit if successful_trials == 0: logger.info( "No successful trials at length %d, stopping estimation", total_length ) break # Update max successful length if at least one trial succeeded if successful_trials > 0: max_successful_length = total_length # Calculate recommended token size (90% of max for safety margin) recommended = ( int(max_successful_length * 0.9) if max_successful_length > 0 else start_length ) results: Dict[str, Any] = { "max_length": max_successful_length, "recommended_token_size": recommended, "trials_per_length": trials_per_length, "device": encoder.device, } logger.info("=" * 60) logger.info("Token Size Estimation Complete") logger.info("=" * 60) logger.info("Device: %s", results["device"]) logger.info("Max successful length: %d amino acids", results["max_length"]) logger.info( "Recommended token size: %d amino acids", results["recommended_token_size"] ) logger.info("Trials per length: %s", results["trials_per_length"]) logger.info("=" * 60) return results