Implementing UPGMA in Python

{First, point out that we have two assertion functions included in starter code.}

def assert_square_matrix(mtx: DistanceMatrix) -> None:
    """
    Validate that a distance matrix is square.

    Args:
        mtx (DistanceMatrix): The matrix to validate.

    Raises:
        ValueError: If the matrix is not square.
    """
    num_rows = len(mtx)
    for r in range(num_rows):
        if len(mtx[r]) != num_rows:
            print(f"Row {r} of matrix has length {len(mtx[r])} and matrix has {num_rows} rows.")
            raise ValueError("Error! Matrix is not square.")


def assert_same_number_species(mtx: DistanceMatrix, species_names: list[str]) -> None:
    """
    Validate that the number of species matches the matrix dimension.

    Args:
        mtx (DistanceMatrix): Square distance matrix.
        species_names (list[str]): Species labels.

    Raises:
        ValueError: If their sizes do not match.
    """
    if len(mtx) != len(species_names):
        raise ValueError("Error: Number of rows of matrix don't match number of species.")

{next we implement upgma. Note that we are going to use a method for count_leaves}

def upgma(mtx: DistanceMatrix, species_names: list[str]) -> Tree:
    """
    Build a phylogenetic tree using the UPGMA algorithm.

    Given a distance matrix and species names, iteratively merges the closest
    clusters and updates the matrix using cluster-size–weighted averages.
    The resulting tree has `n` leaves (the species) and `n-1` internal nodes.

    Args:
        mtx (DistanceMatrix): Square symmetric matrix of pairwise distances.
        species_names (list[str]): Names of the species, in the same order as `mtx`.

    Returns:
        Tree: A list of `Node` objects representing the full UPGMA tree.
              Conventionally, the last node (index -1) is the root.
    """
    assert_square_matrix(mtx)
    assert_same_number_species(mtx, species_names)

    num_leaves = len(mtx)
    t = initialize_tree(species_names)
    clusters = initialize_clusters(t)  # references to current cluster representatives

    for p in range(num_leaves, 2 * num_leaves - 1):
        # Each iteration creates one new internal node at index p
        row, col, min_val = find_min_element(mtx)

        # Age of the new node is half the inter-cluster distance
        t[p].age = min_val / 2.0

        # Set children to the cluster representatives we’re merging
        t[p].child1 = clusters[row]
        t[p].child2 = clusters[col]

        cluster_size1 = t[p].child1.count_leaves()
        cluster_size2 = t[p].child2.count_leaves()

        # Update the distance matrix: add new row/col, then delete merged rows/cols
        mtx = add_row_col(row, col, cluster_size1, cluster_size2, mtx)
        mtx = delete_row_col(mtx, row, col)

        # Update the active clusters: append the new node and drop the merged children
        clusters.append(t[p])
        clusters = delete_clusters(clusters, row, col)

    return t

{initialize_tree}

def initialize_tree(species_names: list[str]) -> Tree:
    """
    Initialize a tree container for UPGMA with labeled leaves and internal nodes.

    Creates a list of 2n - 1 nodes:
      - The first n nodes (0..n-1) are leaves labeled by species_names.
      - The remaining n - 1 nodes (n..2n-2) are internal nodes labeled as ancestors.

    Args:
        species_names (list[str]): Species labels.

    Returns:
        Tree: The preallocated list of `Node` objects used by UPGMA.
    """
    num_leaves = len(species_names)

    t: Tree = []
    for i in range(2 * num_leaves - 1):
        v = Node(num=i)
        t.append(v)

    for i in range(len(t)):
        if i < num_leaves:
            t[i].label = species_names[i]
        else:
            t[i].label = f"Ancestor Species: {i}"

    return t

{initialize_clusters}

def initialize_clusters(t: Tree) -> list[Node]:
    """
    Extract the initial cluster representatives (the leaves) from the tree.

    Args:
        t (Tree): The full node list allocated for UPGMA.

    Returns:
        list[Node]: The first n nodes of `t`, corresponding to the leaves.
    """
    num_leaves = (len(t) + 1) // 2

    clusters: list[Node] = []
    
    for i in range(num_leaves):
        clusters.append(t[i])

    return clusters

{find_min_element}

def find_min_element(mtx: DistanceMatrix) -> tuple[int, int, float]:
    """
    Find the indices (i, j) of the smallest strictly upper-triangular entry.

    Args:
        mtx (DistanceMatrix): A square matrix with size >= 2.

    Returns:
        tuple[int, int, float]: (row_index, col_index, min_value) with col_index > row_index.

    Raises:
        ValueError: If the matrix is smaller than 2x2.
    """
    if len(mtx) <= 1 or len(mtx[0]) <= 1:
        raise ValueError("One row or one column!")

    row, col = 0, 1
    min_val = mtx[row][col]

    for i in range(len(mtx) - 1):
        for j in range(i + 1, len(mtx[i])):
            if mtx[i][j] < min_val:
                row, col, min_val = i, j, mtx[i][j]

    return row, col, min_val

{opportunity to discuss how we should delete col first, and then row. Not sure that we have used del much, need to check}

{figure needed}

def delete_clusters(clusters: list[Node], row: int, col: int) -> list[Node]:
    """
    Remove two cluster representatives at indices `row` and `col`.

    This is used after we merge those two clusters into a new one.

    Args:
        clusters (list[Node]): Active cluster representatives.
        row (int): Index of the first cluster (row < col).
        col (int): Index of the second cluster.

    Returns:
        list[Node]: Updated list of cluster representatives with the two removed.
    """
    del clusters[col]
    del clusters[row]
    return clusters

{delete row_col}

def delete_row_col(mtx: DistanceMatrix, row: int, col: int) -> DistanceMatrix:
    """
    Delete two rows and two columns (row/col) from the matrix.

    This is used after appending the new merged cluster at the end.

    Args:
        mtx (DistanceMatrix): The distance matrix.
        row (int): The first row/column index to delete (row < col).
        col (int): The second row/column index to delete.

    Returns:
        DistanceMatrix: The matrix with the specified rows/columns removed.
    """
    # Remove rows
    del mtx[col]
    del mtx[row]

    # Remove columns
    for r in range(len(mtx)):
        del mtx[r][col]
        del mtx[r][row]

    return mtx

{add_row_col}

def add_row_col(row: int, col: int, cluster_size1: int, cluster_size2: int, mtx: DistanceMatrix) -> DistanceMatrix:
    """
    Add a new cluster (row/column) to the matrix via size-weighted averaging.

    Computes distances from the new merged cluster to each existing cluster
    using a weighted average by cluster sizes, then appends this as a new
    row/column at the end of the matrix.

    Args:
        row (int): Index of the first merged cluster (row < col).
        col (int): Index of the second merged cluster.
        cluster_size1 (int): Number of leaves in the first cluster.
        cluster_size2 (int): Number of leaves in the second cluster.
        mtx (DistanceMatrix): The current distance matrix.

    Returns:
        DistanceMatrix: The matrix with the new cluster appended as the last row/column.
    """
    num_rows = len(mtx)
    new_row = [0.0] * (num_rows + 1)

    # Fill distances to the new cluster using weighted average
    for r in range(len(new_row) - 1):
        if r != row and r != col:
            new_row[r] = (
                cluster_size1 * mtx[r][row] + cluster_size2 * mtx[r][col]
            ) / (cluster_size1 + cluster_size2)

    # Append the new cluster as the last row
    mtx.append(new_row)

{Now, we turn to implementing count_leaves, which will involve recursion}

Recursion

{Create recursion.py}

{Show how to do summing_integers() in python)

{exercise: do factorial(), fibonacci() in python}

{show both. Show timeout of recursive_Fibonacci(). Big tree bad}

{show count_leaves pseudocode}

{Updating our class definition to include this.}

@dataclass
class Node:
    """
    Represents a node in a phylogenetic tree.

    Attributes:
        num (int): Numeric identifier for the node (e.g., index in a tree list).
        age (float): Age (or height) of the node, typically half the distance between clusters.
        label (str): Label of the node, usually the species name for leaves.
        child1 (Self | None): The first child node, or None if this node is a leaf.
        child2 (Self | None): The second child node, or None if this node is a leaf.
    """

    num: int = 0
    age: float = 0.0
    label: str = ""
    child1: Self | None = None
    child2: Self | None = None

    def count_leaves(self) -> int:
        """
        Recursively count the number of leaf nodes in the subtree rooted at this node.

        Returns:
            int: The number of leaves under this node.
        """
        if self.is_leaf():
            return 1
        leaves = 0
        if self.child1 is not None:
            leaves += self.child1.count_leaves()
        if self.child2 is not None:
            leaves += self.child2.count_leaves()
        return leaves

{also need is_leaf}

@dataclass
class Node:
    """
    Represents a node in a phylogenetic tree.

    Attributes:
        num (int): Numeric identifier for the node (e.g., index in a tree list).
        age (float): Age (or height) of the node, typically half the distance between clusters.
        label (str): Label of the node, usually the species name for leaves.
        child1 (Self | None): The first child node, or None if this node is a leaf.
        child2 (Self | None): The second child node, or None if this node is a leaf.
    """

    num: int = 0
    age: float = 0.0
    label: str = ""
    child1: Self | None = None
    child2: Self | None = None

    def count_leaves(self) -> int:
        """
        Recursively count the number of leaf nodes in the subtree rooted at this node.

        Returns:
            int: The number of leaves under this node.
        """
        if self.is_leaf():
            return 1
        leaves = 0
        if self.child1 is not None:
            leaves += self.child1.count_leaves()
        if self.child2 is not None:
            leaves += self.child2.count_leaves()
        return leaves

    def is_leaf(self) -> bool:
        """
        Return whether this node is a leaf (i.e., has no children).

        Returns:
            bool: True if both child1 and child2 are None, False otherwise.
        """
        return self.child1 is None and self.child2 is None

{We are ready to run our code, but it’s not clear how we should visualize the tree.}

Newick format

{Explain Newick format}

{Run the code and write to file. Now we just need code to have something that can visualize from Newick format.}

def hemoglobin_upgma() -> None:
    print("Read in Hemoglobin alpha subunit 1 matrix.")

    species_names, mtx = read_matrix_from_file("Data/HBA1/hemoglobin.mtx")

    print("Starting UPGMA.")

    hemoglobin_tree = upgma(mtx, species_names)

    print("UPGMA tree built. Writing to file.")

    write_newick_to_file(hemoglobin_tree, "Output/HBA1", "hba1.tre")

    print("Tree written to file.")
def sars2_upgma() -> None:
    print("Read in SARS-CoV-2 matrix.")

    genome_labels, mtx = read_matrix_from_file("Data/UK-SARS-CoV-2/SARS_spike.mtx")

    print("Matrix read!")

    print("Generating UPGMA tree.")

    # generate UPGMA tree
    sars_tree = upgma(mtx, genome_labels)

    print("UPGMA tree built! Writing to file.")

    write_newick_to_file(sars_tree, "Output/UK-SARS-CoV-2", "sars-cov-2.tre")

    print("Tree written to file.")

Visualizing our trees

Scroll to Top
Programming for Lovers banner no background
programming for lovers logo cropped

Join our community!

programming for lovers logo cropped
Programming for Lovers banner no background

Join our community!