Source code for genome_entropy.pipeline.runner

"""End-to-end pipeline orchestration for DNA to 3Di with entropy calculation."""

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Union

from ..config import (
    DEFAULT_GENETIC_CODE_TABLE,
    DEFAULT_MIN_AA_LENGTH,
    DEFAULT_PROSTT5_MODEL,
    DEFAULT_ENCODING_SIZE,
    MODERNPROST_MODELS,
)
from ..encode3di.prostt5 import ProstT5ThreeDiEncoder, ThreeDiRecord
from ..encode3di.modernprost import ModernProstThreeDiEncoder
from ..entropy.shannon import (
    EntropyReport,
    calculate_sequence_entropy,
    calculate_entropies_for_sequences,
)
from ..errors import PipelineError
from ..io.fasta import read_fasta
from ..io.genbank import read_genbank, extract_cds_features, match_orf_to_genbank_cds
from ..io.jsonio import write_json
from ..logging_config import get_logger
from ..orf.finder import find_orfs
from ..orf.types import OrfRecord
from ..translate.translator import ProteinRecord, translate_orfs

logger = get_logger(__name__)


[docs] @dataclass class PipelineResult: """Result of running the complete DNA to 3Di pipeline. Attributes: input_id: ID of the input DNA sequence input_dna_length: Length of the input DNA sequence orfs: List of ORFs found in the sequence proteins: List of translated proteins three_dis: List of 3Di encoded structures entropy: Entropy report for all representations """ input_id: str input_dna_length: int orfs: List[OrfRecord] proteins: List[ProteinRecord] three_dis: List[ThreeDiRecord] entropy: EntropyReport
[docs] def run_pipeline( input_fasta: Optional[Union[str, Path]] = None, table_id: int = DEFAULT_GENETIC_CODE_TABLE, min_aa_len: int = DEFAULT_MIN_AA_LENGTH, model_name: str = DEFAULT_PROSTT5_MODEL, compute_entropy: bool = True, output_json: Optional[Union[str, Path]] = None, device: Optional[str] = None, use_multi_gpu: bool = False, gpu_ids: Optional[List[int]] = None, genbank_file: Optional[Union[str, Path]] = None, encoding_size: Optional[int] = None, ) -> List[PipelineResult]: """Run the complete DNA to 3Di pipeline with entropy calculation. Pipeline steps: 1. Read FASTA file or GenBank file 2. Find ORFs in all 6 reading frames 3. Translate ORFs to proteins 4. Encode proteins to 3Di structural tokens 5. Calculate entropy at all levels 6. Optionally match ORFs to GenBank CDS annotations 7. Optionally write results to JSON Args: input_fasta: Path to input FASTA file. Optional if genbank_file is provided. table_id: NCBI genetic code table ID min_aa_len: Minimum protein length in amino acids model_name: ProstT5 model name compute_entropy: Whether to compute entropy values output_json: Optional path to save results as JSON device: Device for 3Di encoding ("cuda", "mps", "cpu", or None for auto) Ignored if use_multi_gpu is True. use_multi_gpu: If True, use multi-GPU parallel encoding when available gpu_ids: Optional list of GPU IDs for multi-GPU encoding. If None and use_multi_gpu=True, auto-discover available GPUs. genbank_file: Optional path to GenBank file. If provided alone, extracts DNA sequences from it. Can be combined with input_fasta to use FASTA sequences with GenBank CDS annotations. encoding_size: Maximum size (approx. amino acids) to encode per batch. If None, uses DEFAULT_ENCODING_SIZE from config. Returns: List of PipelineResult objects (one per input sequence) Raises: PipelineError: If any pipeline step fails ValueError: If neither input_fasta nor genbank_file is provided """ # Validate that at least one input source is provided if not input_fasta and not genbank_file: raise ValueError("Must provide either input_fasta or genbank_file") logger.info("=" * 60) logger.info("Starting DNA to 3Di pipeline") logger.info("=" * 60) if input_fasta: logger.info("Input FASTA: %s", input_fasta) if genbank_file: logger.info("GenBank file: %s", genbank_file) logger.info("Genetic code table: %d", table_id) logger.info("Minimum AA length: %d", min_aa_len) logger.info("Model: %s", model_name) logger.info("Compute entropy: %s", compute_entropy) if use_multi_gpu: logger.info("Multi-GPU encoding: enabled") if gpu_ids: logger.info("GPU IDs: %s", gpu_ids) else: logger.info("GPU IDs: auto-discover") else: logger.info("Device: %s", device if device else "auto") try: # Step 1: Read FASTA or GenBank file logger.info("Step 1: Reading input file...") # Parse GenBank CDS features if GenBank file is provided genbank_cds_list = [] if genbank_file: logger.info("Reading GenBank file...") sequences = read_genbank(genbank_file) genbank_cds_list = extract_cds_features(genbank_file) logger.info("Extracted %d CDS features from GenBank", len(genbank_cds_list)) # If input_fasta is also provided, use it for sequences instead if input_fasta: logger.info("Also reading FASTA file for DNA sequences...") sequences = read_fasta(input_fasta) else: # Only FASTA file provided sequences = read_fasta(input_fasta) logger.info("Read %d sequence(s)", len(sequences)) # Initialize encoder once before processing all sequences # This ensures the model is loaded only once and reused for all entries logger.info("Initializing 3Di encoder (model will be loaded on first use)...") # Select encoder based on model name if model_name in MODERNPROST_MODELS: encoder = ModernProstThreeDiEncoder(model_name=model_name, device=device) encoder_class = ModernProstThreeDiEncoder else: encoder = ProstT5ThreeDiEncoder(model_name=model_name, device=device) encoder_class = ProstT5ThreeDiEncoder actual_encoding_size = ( encoding_size if encoding_size is not None else DEFAULT_ENCODING_SIZE ) # Initialize multi-GPU encoder once if multi-GPU mode is enabled # This prevents re-creating encoders and reloading models for each sequence multi_gpu_encoder = None if use_multi_gpu: from ..encode3di.multi_gpu import MultiGPUEncoder logger.info( "Initializing multi-GPU encoder (will be reused for all sequences)..." ) multi_gpu_encoder = MultiGPUEncoder( model_name=model_name, encoder_class=encoder_class, gpu_ids=gpu_ids, ) # Load models on all GPUs once at initialization logger.info("Loading models on all GPUs...") for gpu_encoder in multi_gpu_encoder.encoders: gpu_encoder._load_model() results = [] for seq_idx, (seq_id, dna_sequence) in enumerate(sequences.items(), 1): logger.info("-" * 60) logger.info( "Processing sequence %d/%d: %s (length=%d)", seq_idx, len(sequences), seq_id, len(dna_sequence), ) # Step 2: Find ORFs logger.info("Step 2: Finding ORFs...") min_nt_len = min_aa_len * 3 orfs = find_orfs( {seq_id: dna_sequence}, table_id=table_id, min_nt_length=min_nt_len, ) if not orfs: logger.warning("No ORFs found for sequence '%s'", seq_id) # No ORFs found, create empty result empty_entropy = EntropyReport( dna_entropy_global=( calculate_sequence_entropy(dna_sequence) if compute_entropy else 0.0 ), orf_nt_entropy={}, protein_aa_entropy={}, three_di_entropy={}, alphabet_sizes={}, ) results.append( PipelineResult( input_id=seq_id, input_dna_length=len(dna_sequence), orfs=[], proteins=[], three_dis=[], entropy=empty_entropy, ) ) continue logger.info("Found %d ORF(s)", len(orfs)) # Match ORFs to GenBank CDS if GenBank file was provided if genbank_file and genbank_cds_list: logger.info("Matching ORFs to GenBank CDS annotations...") for orf in orfs: orf.in_genbank = match_orf_to_genbank_cds( orf.aa_sequence, genbank_cds_list ) matched_count = sum(1 for orf in orfs if orf.in_genbank) logger.info( "Matched %d/%d ORFs to GenBank CDS", matched_count, len(orfs) ) # Step 3: Translate ORFs logger.info("Step 3: Translating ORFs to proteins...") proteins = translate_orfs(orfs, table_id=table_id) logger.info("Translated %d protein(s)", len(proteins)) # Step 4: Encode to 3Di (reusing the same encoder instance) logger.info("Step 4: Encoding proteins to 3Di tokens...") three_dis = encoder.encode_proteins( proteins, encoding_size=actual_encoding_size, use_multi_gpu=use_multi_gpu, gpu_ids=gpu_ids, multi_gpu_encoder=multi_gpu_encoder, ) logger.info("Encoded %d 3Di sequence(s)", len(three_dis)) # Step 5: Calculate entropy if compute_entropy: logger.info("Step 5: Calculating entropy at all levels...") entropy_report = calculate_pipeline_entropy( dna_sequence, orfs, proteins, three_dis ) logger.info("Calculated entropy values") else: logger.info("Step 5: Skipping entropy calculation") entropy_report = EntropyReport( dna_entropy_global=0.0, orf_nt_entropy={}, protein_aa_entropy={}, three_di_entropy={}, alphabet_sizes={}, ) # Create result result = PipelineResult( input_id=seq_id, input_dna_length=len(dna_sequence), orfs=orfs, proteins=proteins, three_dis=three_dis, entropy=entropy_report, ) results.append(result) logger.info("Completed processing for sequence '%s'", seq_id) # Step 6: Write output if requested if output_json: logger.info("Step 6: Writing results to JSON...") write_json(results, output_json) logger.info("=" * 60) logger.info("Pipeline complete! Processed %d sequence(s)", len(results)) logger.info("=" * 60) return results except Exception as e: logger.error("Pipeline execution failed: %s", e, exc_info=True) raise PipelineError(f"Pipeline execution failed: {e}")
[docs] def calculate_pipeline_entropy( dna_sequence: str, orfs: List[OrfRecord], proteins: List[ProteinRecord], three_dis: List[ThreeDiRecord], ) -> EntropyReport: """Calculate entropy at all representation levels. Args: dna_sequence: Original DNA sequence orfs: List of ORF records proteins: List of protein records three_dis: List of 3Di records Returns: EntropyReport with entropy values """ logger.debug("Calculating entropy at all representation levels") # DNA entropy (global) dna_entropy = calculate_sequence_entropy(dna_sequence) logger.debug("DNA entropy: %.4f bits", dna_entropy) # ORF nucleotide entropy orf_nt_sequences = {orf.orf_id: orf.nt_sequence for orf in orfs} orf_nt_entropy = calculate_entropies_for_sequences(orf_nt_sequences) # Protein amino acid entropy protein_aa_sequences = {p.orf.orf_id: p.aa_sequence for p in proteins} protein_aa_entropy = calculate_entropies_for_sequences(protein_aa_sequences) # 3Di entropy three_di_sequences = {td.protein.orf.orf_id: td.three_di for td in three_dis} three_di_entropy = calculate_entropies_for_sequences(three_di_sequences) # Alphabet sizes alphabet_sizes = { "dna": 4, "protein": 20, "three_di": 20, } logger.debug("Entropy calculation complete") return EntropyReport( dna_entropy_global=dna_entropy, orf_nt_entropy=orf_nt_entropy, protein_aa_entropy=protein_aa_entropy, three_di_entropy=three_di_entropy, alphabet_sizes=alphabet_sizes, )