Source code for crabml.io.trees

"""
Phylogenetic tree parsing and manipulation.
"""

import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional


@dataclass
class TreeNode:
    """
    Phylogenetic tree node.

    Attributes
    ----------
    id : int
        Node identifier
    name : Optional[str]
        Node name (for leaves)
    parent : Optional[TreeNode]
        Parent node
    children : list[TreeNode]
        Child nodes
    branch_length : float
        Branch length to parent
    label : Optional[str]
        Branch label (e.g., '#1' for model specification)
    """

    id: int
    name: Optional[str] = None
    parent: Optional["TreeNode"] = None
    children: list["TreeNode"] = field(default_factory=list)
    branch_length: float = 0.0
    label: Optional[str] = None

    @property
    def is_leaf(self) -> bool:
        """Check if node is a leaf."""
        return len(self.children) == 0


[docs] @dataclass class Tree: """ Phylogenetic tree. Attributes ---------- root : TreeNode Root node of the tree n_nodes : int Total number of nodes n_leaves : int Number of leaf nodes leaf_names : list[str] Names of leaf nodes """ root: TreeNode n_nodes: int n_leaves: int leaf_names: list[str]
[docs] @classmethod def from_newick(cls, newick_string: str) -> "Tree": """ Parse Newick format tree string. Parameters ---------- newick_string : str Newick format tree Returns ------- Tree Parsed tree """ # Clean the input: remove comments and whitespace # Remove // comments newick = re.sub(r'//.*', '', newick_string) # Remove /* */ style comments (PAML uses / * with space) newick = re.sub(r'/\s*\*.*?\*\s*/', '', newick) # Remove extra whitespace but preserve species names newick = newick.strip() # Find the tree string (ends with semicolon) if ';' not in newick: raise ValueError("Invalid Newick format: missing semicolon") # Extract just the tree part (all lines with content until semicolon) lines = newick.split('\n') tree_lines = [] for line in lines: # Skip lines that look like PAML headers (just numbers) if line.strip() and not re.match(r'^\s*\d+\s+\d+\s*$', line): tree_lines.append(line) if ';' in line: break if not tree_lines: raise ValueError("Invalid Newick format: no tree found") # Join all tree lines and remove newlines/tabs (but keep spaces before #) tree_line = ''.join(tree_lines) # Remove only newlines and tabs, preserve spaces tree_line = tree_line.replace('\n', '').replace('\t', '').replace('\r', '') # Parse the tree recursively node_id_counter = [0] # Use list for mutable counter def skip_whitespace(s: str, pos: int) -> int: """Skip whitespace characters.""" while pos < len(s) and s[pos] in ' \t\n\r': pos += 1 return pos def parse_node(s: str, start: int, parent: Optional[TreeNode] = None) -> tuple[TreeNode, int]: """Parse a node from position start in string s.""" node = TreeNode(id=node_id_counter[0]) node_id_counter[0] += 1 node.parent = parent pos = skip_whitespace(s, start) # Check if this is an internal node (starts with '(') if pos < len(s) and s[pos] == '(': pos = skip_whitespace(s, pos + 1) # skip '(' and whitespace # Parse children while True: child, pos = parse_node(s, pos, node) node.children.append(child) pos = skip_whitespace(s, pos) if pos < len(s) and s[pos] == ',': pos = skip_whitespace(s, pos + 1) # skip ',' and whitespace continue elif pos < len(s) and s[pos] == ')': pos = skip_whitespace(s, pos + 1) # skip ')' and whitespace break else: raise ValueError(f"Expected ',' or ')' at position {pos}") # Parse node name/label (for leaves or labeled internal nodes) name_start = pos while pos < len(s) and s[pos] not in ',:();# \t\n\r': pos += 1 if pos > name_start: node.name = s[name_start:pos] pos = skip_whitespace(s, pos) # Parse branch label (e.g., #1, #2) if pos < len(s) and s[pos] == '#': pos += 1 label_start = pos while pos < len(s) and s[pos] not in ',:(); \t\n\r': pos += 1 node.label = '#' + s[label_start:pos] pos = skip_whitespace(s, pos) # Parse branch length (e.g., :0.123 or : 0.123) if pos < len(s) and s[pos] == ':': pos += 1 pos = skip_whitespace(s, pos) # Skip whitespace after colon length_start = pos while pos < len(s) and s[pos] not in ',(); \t\n\r': pos += 1 try: node.branch_length = float(s[length_start:pos]) except ValueError: raise ValueError(f"Invalid branch length: {s[length_start:pos]}") return node, pos # Parse the root root, pos = parse_node(tree_line, 0, None) # Count nodes and leaves def count_nodes(node: TreeNode) -> tuple[int, int, list[str]]: """Count total nodes, leaves, and collect leaf names.""" if node.is_leaf: leaf_name = node.name if node.name else str(node.id) return 1, 1, [leaf_name] else: total_nodes = 1 total_leaves = 0 leaf_names = [] for child in node.children: n, l, names = count_nodes(child) total_nodes += n total_leaves += l leaf_names.extend(names) return total_nodes, total_leaves, leaf_names n_nodes, n_leaves, leaf_names = count_nodes(root) return cls( root=root, n_nodes=n_nodes, n_leaves=n_leaves, leaf_names=leaf_names )
[docs] def postorder(self) -> list[TreeNode]: """ Return nodes in post-order traversal (leaves to root). Returns ------- list[TreeNode] Nodes in post-order """ result = [] def traverse(node: TreeNode) -> None: """Recursively traverse in post-order.""" for child in node.children: traverse(child) result.append(node) traverse(self.root) return result
[docs] def get_branches(self) -> list[tuple[TreeNode, TreeNode]]: """ Get all branches as (parent, child) pairs. Returns ------- list[tuple[TreeNode, TreeNode]] List of (parent, child) tuples for each branch """ branches = [] def traverse(node: TreeNode) -> None: """Recursively collect branches.""" for child in node.children: branches.append((node, child)) traverse(child) traverse(self.root) return branches
[docs] def get_branch_labels(self) -> list[int]: """ Get integer branch labels for branch-site models. Converts string labels like '#0', '#1' to integers. Branches without labels are assigned 0 (background). Returns ------- list[int] Branch labels as integers (0=background, 1=foreground, etc.) """ branches = self.get_branches() labels = [] for parent, child in branches: if child.label is not None: # Parse label like '#1' -> 1 label_str = child.label.lstrip('#') try: labels.append(int(label_str)) except ValueError: raise ValueError(f"Invalid branch label: {child.label}") else: # Default to background (0) labels.append(0) return labels
[docs] def validate_branch_site_labels(self) -> None: """ Validate branch labels for branch-site models. Branch-site models (Model A, A1) require exactly 2 label types: - 0 (background) - 1 (foreground) Raises ------ ValueError If labels are not valid for branch-site models """ labels = self.get_branch_labels() unique_labels = sorted(set(labels)) if unique_labels != [0, 1]: raise ValueError( f"Branch-site models require exactly 2 label types (0 and 1). " f"Found: {unique_labels}. " f"Mark foreground branches with '#1' in the tree." ) n_foreground = sum(1 for label in labels if label == 1) if n_foreground == 0: raise ValueError("No foreground branches marked with '#1'") print(f"✓ Tree validation passed: {n_foreground} foreground branch(es) marked")
[docs] def to_newick(self) -> str: """ Convert tree to Newick format string. Returns ------- str Tree in Newick format """ def node_to_newick(node: TreeNode) -> str: """Recursively convert node to Newick string.""" # Build the subtree if node.is_leaf: # Leaf node: just the name result = node.name if node.name else "" else: # Internal node: (child1,child2,...) child_strings = [node_to_newick(child) for child in node.children] result = f"({','.join(child_strings)})" # Add internal node name if present if node.name: result += node.name # Add branch label if present if node.label: result += node.label # Add branch length (except for root) if node.parent is not None: result += f":{node.branch_length}" return result return node_to_newick(self.root) + ";"