"""ModernProst encoder for amino acid to 3Di structural token conversion.
This module implements an encoder for gbouras13/modernprost models,
adapted from the phold implementation.
Note: ModernProst models require transformers >= 4.47.0 for ModernBert support.
Multi-GPU support uses HuggingFace accelerate library.
"""
import time
from typing import Any, Iterator, List, Optional, Sequence
try:
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from accelerate import Accelerator
# Check transformers version for ModernBert support
import transformers
transformers_version = tuple(map(int, transformers.__version__.split(".")[:2]))
if transformers_version < (4, 47):
import warnings
warnings.warn(
f"ModernProst models require transformers >= 4.47.0 (current: {transformers.__version__}). "
"Please upgrade: pip install --upgrade 'transformers>=4.47.0'",
UserWarning,
)
except ImportError as e:
torch = None # type: ignore[assignment]
AutoModel = None # type: ignore[assignment,misc]
AutoTokenizer = None # type: ignore[assignment,misc]
F = None # type: ignore[assignment,misc]
Accelerator = None # type: ignore[assignment,misc]
import numpy as np
from pathlib import Path
from ..config import (
AUTO_DEVICE,
CPU_DEVICE,
CUDA_DEVICE,
DEFAULT_ENCODING_SIZE,
MPS_DEVICE,
MODERNPROST_PROFILES_MODEL,
)
from ..errors import DeviceError, ModelError
from ..logging_config import get_logger
from ..translate.translator import ProteinRecord
from .encoding import encode
from .types import IndexedSeq, ThreeDiRecord
logger = get_logger(__name__)
def _is_model_cached(model_name: str) -> bool:
"""Check if a model is already cached locally.
Args:
model_name: HuggingFace model identifier
Returns:
True if model is cached, False otherwise
"""
try:
from transformers.utils import TRANSFORMERS_CACHE
from huggingface_hub import try_to_load_from_cache
# Try to find the model in the cache
cache_path = try_to_load_from_cache(
repo_id=model_name,
filename="config.json",
)
# If we get a path (not None or _CACHED_NO_EXIST), model is cached
if cache_path is not None and isinstance(cache_path, (str, Path)):
logger.debug(f"Model {model_name} found in cache at {cache_path}")
return True
logger.debug(f"Model {model_name} not found in cache")
return False
except Exception as e:
# If we can't check the cache, assume it's not cached
logger.debug(f"Could not check cache for {model_name}: {e}")
return False
[docs]
class ModernProstThreeDiEncoder:
"""Encoder for converting amino acid sequences to 3Di structural tokens.
Uses ModernProst models (gbouras13/modernprost-base or modernprost-profiles)
from HuggingFace to predict 3Di tokens directly from protein sequences.
Based on implementation from phold:
https://github.com/gbouras13/phold/blob/main/src/phold/features/predict_3Di.py
"""
[docs]
def __init__(
self,
model_name: str,
device: Optional[str] = None,
use_accelerate: bool = False,
):
"""Initialize the ModernProst encoder.
Args:
model_name: HuggingFace model identifier (gbouras13/modernprost-base or modernprost-profiles)
device: Device to use ("cuda", "mps", "cpu", or None for auto-detect)
use_accelerate: If True, use HuggingFace accelerate for multi-GPU support
Raises:
ModelError: If PyTorch or Transformers are not installed
DeviceError: If specified device is not available
"""
if torch is None or AutoModel is None:
raise ModelError(
"PyTorch and Transformers are required for 3Di encoding. "
"Install with: pip install torch transformers"
)
self.model_name = model_name
self.use_accelerate = use_accelerate
self.accelerator: Any = None
if use_accelerate:
if Accelerator is None:
raise ModelError(
"accelerate library is required for multi-GPU support. "
"Install with: pip install accelerate"
)
# Initialize accelerator for multi-GPU support
self.accelerator = Accelerator()
self.device = str(self.accelerator.device)
logger.info(f"Using accelerate with device: {self.device}")
else:
self.device = self._select_device(device)
self.model: Any = None
self.tokenizer: Any = None
def _select_device(self, device_hint: Optional[str]) -> str:
"""Select the best available device for inference.
Args:
device_hint: User-specified device or None for auto-detection
Returns:
Device string ("cuda", "mps", or "cpu")
Raises:
DeviceError: If specified device is not available
"""
# If user specified a device other than "auto", use it
if device_hint and device_hint != AUTO_DEVICE:
if device_hint == CUDA_DEVICE and not torch.cuda.is_available():
raise DeviceError("CUDA requested but not available")
if device_hint == MPS_DEVICE and not (
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
):
raise DeviceError("MPS requested but not available")
return device_hint
# Auto-detect best device
if torch.cuda.is_available():
return CUDA_DEVICE
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return MPS_DEVICE
return CPU_DEVICE
def _load_model(self) -> None:
"""Load the ModernProst model and tokenizer.
Raises:
ModelError: If model loading fails
"""
if self.model is not None:
return # Already loaded
try:
logger.info("Loading ModernProst model: %s", self.model_name)
# Check if model is already cached
is_cached = _is_model_cached(self.model_name)
if is_cached:
logger.info("Model found in cache, using local files only")
else:
logger.info("Model not in cache, will download from HuggingFace")
if not self.use_accelerate:
# Disable torch.compile/dynamo globally for multi-GPU compatibility
# ModernBert uses compiled_mlp which conflicts with multi-threading
import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch._dynamo.reset()
# Disable compilation by setting environment
import os
os.environ["PYTORCH_JIT"] = "0"
os.environ["TORCH_COMPILE_DISABLE"] = "1"
logger.info("Disabled torch.compile/dynamo for multi-GPU compatibility")
# Load tokenizer with trust_remote_code for custom models
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
local_files_only=is_cached,
force_download=not is_cached,
)
# Load model config first
import transformers
config = transformers.AutoConfig.from_pretrained(
self.model_name,
trust_remote_code=True,
local_files_only=is_cached,
force_download=not is_cached,
)
# Disable torch.compile if supported (ModernBert has this option)
if hasattr(config, "reference_compile"):
config.reference_compile = False
logger.debug("Set reference_compile=False in model config")
# Load model with modified config
if self.use_accelerate:
# Use accelerate for multi-GPU - don't move to device yet
self.model = AutoModel.from_pretrained(
self.model_name,
config=config,
trust_remote_code=True,
local_files_only=is_cached,
force_download=not is_cached,
)
# Prepare model with accelerate for distributed inference
self.model = self.accelerator.prepare(self.model)
logger.info("Prepared model with accelerate for multi-GPU")
else:
# Standard single-GPU/CPU loading
self.model = AutoModel.from_pretrained(
self.model_name,
config=config,
trust_remote_code=True,
local_files_only=is_cached,
force_download=not is_cached,
).to(self.device)
# Additional safety: Remove torch.compile from model if it was applied
self._disable_torch_compile_in_model()
# ModernProst models use half precision only on CUDA
# CPU and MPS may not support half precision properly
if self.device.startswith("cuda") or self.device == CUDA_DEVICE:
self.model = self.model.half()
self.model = self.model.eval()
logger.info("Loaded model %s on device %s", self.model_name, self.device)
logger.debug("Model config:\n%s", self.model.config)
except Exception as e:
logger.error("Failed to load ModernProst model %s: %s", self.model_name, e)
raise ModelError(
f"Failed to load ModernProst model {self.model_name}: {e}"
) from e
def _disable_torch_compile_in_model(self) -> None:
"""Disable torch.compile in the model for multi-GPU compatibility.
ModernBert models use compiled_mlp which causes issues with multi-threading.
This method walks through the model and replaces any compiled components
with their original uncompiled versions.
"""
try:
# Walk through all modules in the model
for name, module in self.model.named_modules():
# Check if this is a compiled module
if hasattr(module, "_orig_mod"):
# This is a compiled module, replace it with the original
logger.debug(f"Removing torch.compile from module: {name}")
# Get the parent module and attribute name
parent_name = ".".join(name.split(".")[:-1])
attr_name = name.split(".")[-1]
if parent_name:
parent = dict(self.model.named_modules())[parent_name]
setattr(parent, attr_name, module._orig_mod)
else:
# This is a top-level module
self.model = module._orig_mod
logger.info(
"Removed torch.compile optimizations for multi-GPU compatibility"
)
except Exception as e:
# If removal fails, just log a warning - the config approach might have worked
logger.debug(f"Could not remove torch.compile (this may be okay): {e}")
[docs]
def token_budget_batches(
self,
aa_sequences: Sequence[str],
token_budget: int,
) -> Iterator[List[IndexedSeq]]:
"""
Yield batches of sequences (with original indices) under an approximate token budget.
Optimized strategy to address the problem of isolated long sequences:
1) Keep original indices.
2) Sort by length to minimize padding within each batch.
3) For each batch:
- Start with long sequences from the end (largest first)
- Add long sequences until adding another would exceed budget
- Fill remaining budget with short sequences from the beginning
4) This approach avoids ending up with long proteins that can't be combined,
resulting in better token budget utilization and fewer iterations.
Parameters
----------
aa_sequences : Sequence[str] Unordered amino acid sequences.
token_budget : int Maximum approximate "tokens" per batch
Yields
------
List[IndexedSeq] A batch of (original_index, sequence) records.
"""
if token_budget <= 0:
raise ValueError("token_budget must be > 0")
indexed: List[IndexedSeq] = [
IndexedSeq(i, s) for i, s in enumerate(aa_sequences)
]
indexed.sort(key=lambda x: len(x.seq)) # length-sorted for tight padding
start_idx = 0 # Points to shortest remaining sequence
end_idx = len(indexed) - 1 # Points to longest remaining sequence
while start_idx <= end_idx:
batch: List[IndexedSeq] = []
batch_max_len = 0
# Phase 1: Add long sequences from the end
while end_idx >= start_idx:
item = indexed[end_idx]
L = len(item.seq)
# If a single sequence exceeds the budget, yield it alone
if L > token_budget:
if batch:
yield batch
batch = []
batch_max_len = 0
yield [item]
end_idx -= 1
continue
new_max_len = max(batch_max_len, L)
new_size = len(batch) + 1
est_tokens = new_size * new_max_len
if est_tokens <= token_budget:
# Add this long sequence to the batch
batch.append(item)
batch_max_len = new_max_len
end_idx -= 1
else:
# Can't add more long sequences, move to phase 2
break
# Phase 2: Fill remaining budget with short sequences from the start
while start_idx <= end_idx:
item = indexed[start_idx]
L = len(item.seq)
new_max_len = max(batch_max_len, L)
new_size = len(batch) + 1
est_tokens = new_size * new_max_len
if est_tokens <= token_budget:
# Add this short sequence to fill the gap
batch.append(item)
batch_max_len = new_max_len
start_idx += 1
else:
# Can't fit more sequences, yield this batch
break
if batch:
yield batch
def _encode_batch(self, aa_sequences: List[str]) -> List[str]:
"""
Encode a batch of sequences using ModernProst.
Args:
aa_sequences: List of amino acid sequences (upper-case).
Returns:
List of 3Di token sequences (one per input sequence)
Raises:
EncodingError: If encoding fails
"""
# Replace non-standard amino acids
processed_seqs = []
seq_lens = []
for seq in aa_sequences:
seq = seq.replace("U", "X").replace("Z", "X").replace("O", "X")
processed_seqs.append(seq)
seq_lens.append(len(seq))
# Tokenize sequences (no special tokens, no prefix for modernprost)
token_encoding = self.tokenizer(
processed_seqs,
padding="longest",
truncation=False,
return_tensors="pt",
add_special_tokens=False,
).to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(
token_encoding.input_ids,
attention_mask=token_encoding.attention_mask,
)
# Extract logits and compute predictions
logits = outputs.logits # [B, L, C]
# Get predictions (argmax over classes)
preds = torch.argmax(logits, dim=-1) # [B, L]
preds_cpu = preds.cpu().numpy()
# Map predictions to 3Di alphabet
ss_mapping = {
0: "A",
1: "C",
2: "D",
3: "E",
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
}
# Convert predictions to 3Di strings
structure_sequences = []
for batch_idx, s_len in enumerate(seq_lens):
pred = preds_cpu[batch_idx, :s_len]
three_di = "".join([ss_mapping[int(p)] for p in pred])
structure_sequences.append(three_di)
return structure_sequences
[docs]
def encode(
self,
aa_sequences: List[str],
encoding_size: int = DEFAULT_ENCODING_SIZE,
use_multi_gpu: bool = False,
gpu_ids: Optional[List[int]] = None,
multi_gpu_encoder: Optional[Any] = None,
) -> List[str]:
"""Encode amino acid sequences to 3Di tokens.
Args:
aa_sequences: List of amino acid sequences (upper-case).
encoding_size: Maximum size (approx. amino acids) to encode per batch
use_multi_gpu: If True, use accelerate for multi-GPU parallel encoding
gpu_ids: Optional list of GPU IDs (currently unused with accelerate)
multi_gpu_encoder: Optional pre-initialized encoder (for backward compatibility)
Returns:
List of 3Di token sequences (one per input sequence)
Raises:
EncodingError: If encoding fails
"""
if use_multi_gpu:
# Use accelerate for multi-GPU encoding
if not self.use_accelerate:
# Need to re-initialize with accelerate support
logger.info(
"Re-initializing encoder with accelerate for multi-GPU support"
)
self.__init__(
model_name=self.model_name,
device=None,
use_accelerate=True,
)
self._load_model()
# Preprocess sequences (ModernProst-specific, no ProstT5 prefix)
processed_seqs = []
for seq in aa_sequences:
# Replace rare/ambiguous amino acids
seq = (
seq.replace("U", "X")
.replace("Z", "X")
.replace("O", "X")
.replace("B", "X")
)
processed_seqs.append(seq)
# Process in batches using accelerate
import math
total_sequences = len(processed_seqs)
total_batches = math.ceil(sum(map(len, processed_seqs)) / encoding_size)
# Create batches
batches = list(self.token_budget_batches(processed_seqs, encoding_size))
# Process all batches with accelerate
three_di_sequences: List[str] = [None] * total_sequences # type: ignore[list-item]
from .encoding import format_seconds, get_memory_info
t0 = time.perf_counter()
avg_batch_sec: float | None = None
for idx, batch in enumerate(batches, start=1):
batch_seqs = [x.seq for x in batch]
batch_idxs = [x.idx for x in batch]
# Calculate ETA
remaining = total_batches - (idx - 1)
eta_str = (
"--"
if avg_batch_sec is None
else format_seconds(avg_batch_sec * remaining)
)
# Get memory info
allocated, reserved = get_memory_info()
logger.info(
"3Di encoding batch %d of %d batches. "
"Estimated %s remaining. Cuda memory allocated: %.1f GB reserved: %.1f GB",
idx,
total_batches,
eta_str,
allocated,
reserved,
)
batch_start = time.perf_counter()
batch_results = self._encode_batch(batch_seqs)
# Store results in original order
for bi, br in zip(batch_idxs, batch_results):
three_di_sequences[bi] = br
# Update timing
batch_elapsed = time.perf_counter() - batch_start
if idx == 1:
avg_batch_sec = batch_elapsed
else:
elapsed_total = time.perf_counter() - t0
avg_batch_sec = elapsed_total / idx
return three_di_sequences
else:
# Use single-GPU encoding
self._load_model()
# ModernProst does not use ProstT5 preprocessing
# Just replace non-standard amino acids
processed_seqs = []
for seq in aa_sequences:
# Replace rare/ambiguous amino acids
seq = (
seq.replace("U", "X")
.replace("Z", "X")
.replace("O", "X")
.replace("B", "X")
)
processed_seqs.append(seq)
# Calculate batch info
import math
total_sequences = len(processed_seqs)
total_batches = math.ceil(sum(map(len, processed_seqs)) / encoding_size)
# Create batches iterator
batches = self.token_budget_batches(processed_seqs, encoding_size)
# Process all batches
from .encoding import process_batches
return process_batches(
batches,
self._encode_batch,
total_sequences,
total_batches,
)
[docs]
def encode_proteins(
self,
proteins: List[ProteinRecord],
encoding_size: int = DEFAULT_ENCODING_SIZE,
use_multi_gpu: bool = False,
gpu_ids: Optional[List[int]] = None,
multi_gpu_encoder: Optional[Any] = None,
) -> List[ThreeDiRecord]:
"""Encode protein records to 3Di records.
Args:
proteins: List of ProteinRecord objects
encoding_size: Maximum size (approx. amino acids) to encode per batch
use_multi_gpu: If True, use multi-GPU parallel encoding when available
gpu_ids: Optional list of GPU IDs to use for multi-GPU encoding
multi_gpu_encoder: Optional pre-initialized MultiGPUEncoder instance.
Returns:
List of ThreeDiRecord objects
"""
# Extract sequences
aa_sequences = [p.aa_sequence for p in proteins]
# Encode
three_di_sequences = self.encode(
aa_sequences,
encoding_size,
use_multi_gpu=use_multi_gpu,
gpu_ids=gpu_ids,
multi_gpu_encoder=multi_gpu_encoder,
)
# Create records
records = []
method = (
"modernprost_profiles"
if self.model_name == MODERNPROST_PROFILES_MODEL
else "modernprost_base"
)
for protein, three_di in zip(proteins, three_di_sequences):
record = ThreeDiRecord(
protein=protein,
three_di=three_di,
method=method,
model_name=self.model_name,
inference_device=self.device,
)
records.append(record)
return records