Source code for crabml.core.likelihood

"""
Likelihood calculation for phylogenetic models.

This module implements the Felsenstein pruning algorithm for computing
phylogenetic likelihood on a tree.
"""

import numpy as np
from typing import Optional

from ..io.sequences import Alignment
from ..io.trees import Tree, TreeNode
from .matrix import matrix_exponential


[docs] class LikelihoodCalculator: """ Compute phylogenetic likelihood using Felsenstein's pruning algorithm. Attributes ---------- alignment : Alignment Multiple sequence alignment tree : Tree Phylogenetic tree n_states : int Number of states (61 for codons, 20 for amino acids, 4 for nucleotides) n_sites : int Number of sites in alignment """
[docs] def __init__(self, alignment: Alignment, tree: Tree): """ Initialize likelihood calculator. Parameters ---------- alignment : Alignment Multiple sequence alignment tree : Tree Phylogenetic tree """ self.alignment = alignment self.tree = tree # Verify alignment and tree are compatible if len(alignment.names) != tree.n_leaves: raise ValueError( f"Alignment has {len(alignment.names)} sequences but tree has " f"{tree.n_leaves} leaves" ) # Check that all leaf names match alignment_names_set = set(alignment.names) tree_names_set = set(tree.leaf_names) if alignment_names_set != tree_names_set: raise ValueError( "Alignment and tree have different species. " f"In alignment but not tree: {alignment_names_set - tree_names_set}. " f"In tree but not alignment: {tree_names_set - alignment_names_set}" ) # Determine number of states if alignment.seqtype == "codon": self.n_states = 61 elif alignment.seqtype == "aa": self.n_states = 20 elif alignment.seqtype == "dna": self.n_states = 4 else: raise ValueError(f"Unknown sequence type: {alignment.seqtype}") self.n_sites = alignment.n_sites # Create mapping from leaf name to alignment row self.leaf_to_seq_idx = {name: i for i, name in enumerate(alignment.names)}
[docs] def compute_log_likelihood( self, Q: np.ndarray, pi: np.ndarray, scale_branch_lengths: float = 1.0 ) -> float: """ Compute log-likelihood for given substitution model. Parameters ---------- Q : np.ndarray Rate matrix (n_states x n_states) pi : np.ndarray Equilibrium frequencies (n_states,) scale_branch_lengths : float, optional Global scaling factor for branch lengths (default 1.0) Returns ------- float Log-likelihood value """ if Q.shape != (self.n_states, self.n_states): raise ValueError( f"Q matrix has shape {Q.shape}, expected ({self.n_states}, {self.n_states})" ) if pi.shape != (self.n_states,): raise ValueError(f"pi has shape {pi.shape}, expected ({self.n_states},)") # Compute transition probability matrices for each branch # Store in a dictionary keyed by node id P_matrices = {} for node in self.tree.postorder(): if node.parent is not None: branch_length = node.branch_length * scale_branch_lengths P_matrices[node.id] = matrix_exponential(Q, branch_length) # Initialize conditional likelihood arrays # L[node_id][site, state] = P(data below node | state at node) L = {} # Post-order traversal (leaves to root) for node in self.tree.postorder(): L[node.id] = np.zeros((self.n_sites, self.n_states)) if node.is_leaf: # For leaves, set likelihood based on observed data seq_idx = self.leaf_to_seq_idx[node.name] for site in range(self.n_sites): obs_state = self.alignment.sequences[seq_idx, site] if obs_state >= 0: # Valid codon # Likelihood is 1 if state matches observation, 0 otherwise L[node.id][site, obs_state] = 1.0 else: # Missing data or ambiguous - all states equally likely L[node.id][site, :] = 1.0 else: # For internal nodes, compute from children for site in range(self.n_sites): for state in range(self.n_states): # Product over all children likelihood = 1.0 for child in node.children: # Sum over all possible child states P = P_matrices[child.id] child_sum = np.sum(P[state, :] * L[child.id][site, :]) likelihood *= child_sum L[node.id][site, state] = likelihood # At root, compute total likelihood for each site root_L = L[self.tree.root.id] site_likelihoods = np.sum(root_L * pi[np.newaxis, :], axis=1) # Return log-likelihood (sum over sites) # Add small constant to avoid log(0) log_likelihood = np.sum(np.log(site_likelihoods + 1e-100)) return log_likelihood
[docs] def compute_log_likelihood_site_classes( self, Q_matrices: list[np.ndarray], pi: np.ndarray, proportions: list[float], scale_branch_lengths: float = 1.0, use_scaling: bool = False ) -> float: """ Compute log-likelihood for site class model. For models with site classes (M1a, M2a, M3), the likelihood is a mixture over site classes: P(data) = sum_k p_k * P(data | Q_k) Parameters ---------- Q_matrices : list[np.ndarray] List of rate matrices, one per site class pi : np.ndarray Equilibrium frequencies (n_states,) proportions : list[float] Proportion of sites in each class scale_branch_lengths : float, optional Global scaling factor for branch lengths (default 1.0) use_scaling : bool, optional Use PAML-style scaling to prevent underflow (default False) Returns ------- float Log-likelihood value """ n_classes = len(Q_matrices) if len(proportions) != n_classes: raise ValueError( f"Number of proportions ({len(proportions)}) must match " f"number of Q matrices ({n_classes})" ) # Compute likelihood for each site class class_likelihoods = np.zeros((self.n_sites, n_classes)) for k, Q in enumerate(Q_matrices): if Q.shape != (self.n_states, self.n_states): raise ValueError( f"Q matrix {k} has shape {Q.shape}, " f"expected ({self.n_states}, {self.n_states})" ) # Compute transition probability matrices for each branch P_matrices = {} for node in self.tree.postorder(): if node.parent is not None: branch_length = node.branch_length * scale_branch_lengths P_matrices[node.id] = matrix_exponential(Q, branch_length) # Initialize conditional likelihood arrays and scale factors L = {} scale_factors = np.zeros((self.tree.n_nodes, self.n_sites)) if use_scaling else None # Post-order traversal (leaves to root) for node in self.tree.postorder(): L[node.id] = np.zeros((self.n_sites, self.n_states)) if node.is_leaf: # For leaves, set likelihood based on observed data seq_idx = self.leaf_to_seq_idx[node.name] for site in range(self.n_sites): obs_state = self.alignment.sequences[seq_idx, site] if obs_state >= 0: # Valid codon L[node.id][site, obs_state] = 1.0 else: # Missing data - all states equally likely L[node.id][site, :] = 1.0 else: # For internal nodes, compute from children for site in range(self.n_sites): for state in range(self.n_states): likelihood = 1.0 for child in node.children: P = P_matrices[child.id] child_sum = np.sum(P[state, :] * L[child.id][site, :]) likelihood *= child_sum L[node.id][site, state] = likelihood # Apply PAML-style scaling if enabled if use_scaling and not node.is_leaf: for site in range(self.n_sites): max_val = np.max(L[node.id][site, :]) if max_val < 1e-300: L[node.id][site, :] = 1.0 scale_factors[node.id, site] = -800.0 else: L[node.id][site, :] /= max_val scale_factors[node.id, site] = np.log(max_val) # At root, compute total likelihood for each site root_L = L[self.tree.root.id] site_likelihoods_k = np.sum(root_L * pi[np.newaxis, :], axis=1) # Add back scale factors if using scaling if use_scaling: log_site_likelihoods_k = np.log(site_likelihoods_k + 1e-100) # Sum scale factors across all scaled nodes for node in self.tree.postorder(): if not node.is_leaf: log_site_likelihoods_k += scale_factors[node.id, :] class_likelihoods[:, k] = np.exp(log_site_likelihoods_k) else: class_likelihoods[:, k] = site_likelihoods_k # Mix over site classes using log-sum-exp trick (like PAML) # Convert to log-space log_class_likelihoods = np.log(class_likelihoods + 1e-100) log_proportions = np.log(np.array(proportions)) # Add log proportions: log(p_k * L_k) = log(p_k) + log(L_k) log_weighted = log_class_likelihoods + log_proportions[np.newaxis, :] # Use log-sum-exp trick to compute log(sum(p_k * L_k)) # log(sum(exp(x_i))) = max(x_i) + log(sum(exp(x_i - max(x_i)))) max_log = np.max(log_weighted, axis=1, keepdims=True) log_site_likelihoods = max_log.squeeze() + np.log( np.sum(np.exp(log_weighted - max_log), axis=1) ) # Return total log-likelihood (sum over sites) log_likelihood = np.sum(log_site_likelihoods) return log_likelihood