Source code for bioneuralnet.graph_generation.wgcna

import os
import subprocess
import pandas as pd
from typing import List, Dict, Any
from ..utils.logger import get_logger
import json
from io import StringIO


class WGCNA:
    """
    WGCNA Class for Graph Construction using Weighted Gene Co-expression Network Analysis (WGCNA).

    This class handles the preprocessing of omics data, execution of the WGCNA R script,
    and retrieval of the resulting adjacency matrix, all using in-memory data structures.
    """

    def __init__(
        self,
        phenotype_df: pd.DataFrame,
        omics_dfs: List[pd.DataFrame],
        data_types: List[str],
        soft_power: int = 6,
        min_module_size: int = 30,
        merge_cut_height: float = 0.25,
    ):
        """
        Initializes the WGCNA instance.

        Args:
            phenotype_df (pd.DataFrame): DataFrame containing phenotype data. The first column should be sample IDs.
            omics_dfs (List[pd.DataFrame]): List of DataFrames, each representing an omics dataset. Each DataFrame should have sample IDs as the first column.
            data_types (List[str]): List of data types corresponding to each omics dataset.
            soft_power (int, optional): Soft-thresholding power. Defaults to 6.
            min_module_size (int, optional): Minimum module size. Defaults to 30.
            merge_cut_height (float, optional): Merge cut height. Defaults to 0.25.
        """
        self.phenotype_df = phenotype_df
        self.omics_dfs = omics_dfs
        self.data_types = data_types
        self.soft_power = soft_power
        self.min_module_size = min_module_size
        self.merge_cut_height = merge_cut_height

        self.logger = get_logger(__name__)
        self.logger.info("Initialized WGCNA with the following parameters:")
        self.logger.info(f"Soft Power: {self.soft_power}")
        self.logger.info(f"Minimum Module Size: {self.min_module_size}")
        self.logger.info(f"Merge Cut Height: {self.merge_cut_height}")

        if len(self.omics_dfs) != len(self.data_types):
            self.logger.error("Number of omics dataframes does not match number of data types.")
            raise ValueError("Number of omics dataframes does not match number of data types.")

    def preprocess_data(self) -> Dict[str, Any]:
        """
        Preprocesses the omics data to ensure alignment and handle missing values.

        Returns:
            Dict[str, Any]: Dictionary containing serialized phenotype and omics data.
        """
        self.logger.info("Preprocessing omics data for NaN or infinite values.")
        phenotype_ids = self.phenotype_df.iloc[:, 0]
        self.logger.info(f"Number of samples in phenotype data: {len(phenotype_ids)}")
        valid_samples = pd.Series([True] * len(phenotype_ids), index=self.phenotype_df.index)

        serialized_data = {
            'phenotype': self.phenotype_df.to_csv(index=False)
        }

        for idx, omics_df in enumerate(self.omics_dfs):
            data_type = self.data_types[idx]
            self.logger.info(f"Processing omics DataFrame {idx+1}/{len(self.omics_dfs)}: Data Type = {data_type}")

            omics_ids = omics_df.iloc[:, 0]
            if not omics_ids.equals(phenotype_ids):
                self.logger.warning(f"Sample IDs in omics dataframe {idx+1} do not match phenotype data. Aligning data.")
                omics_df = omics_df.set_index(omics_ids).loc[phenotype_ids].reset_index()

            if omics_df.isnull().values.any():
                self.logger.warning(f"NaN values detected in omics dataframe {idx+1}. Marking samples with NaNs as invalid.")
                valid_samples &= ~omics_df.isnull().any(axis=1)

            if (omics_df == float('inf')).any().any() or (omics_df == -float('inf')).any().any():
                self.logger.warning(f"Infinite values detected in omics dataframe {idx+1}. Replacing with NaN and marking samples as invalid.")
                omics_df.replace([float('inf'), -float('inf')], pd.NA, inplace=True)
                valid_samples &= ~omics_df.isnull().any(axis=1)

            num_valid_before = valid_samples.sum()
            self.logger.info(f"Number of valid samples before filtering: {num_valid_before}")
            omics_df_clean = omics_df[valid_samples].reset_index(drop=True)
            num_valid_after = omics_df_clean.shape[0]
            self.logger.info(f"Number of valid samples after filtering: {num_valid_after}")

            if num_valid_after == 0:
                self.logger.error("No valid samples remaining after preprocessing. Aborting WGCNA run.")
                raise ValueError("No valid samples remaining after preprocessing.")

            serialized_data[f'omics_{idx+1}'] = omics_df_clean.to_csv(index=False)

        self.logger.info("Preprocessing completed successfully.")
        return serialized_data

    def run_wgcna(self, serialized_data: Dict[str, Any]) -> str:
        """
        Executes the WGCNA R script by passing serialized data via standard input.

        Args:
            serialized_data (Dict[str, Any]): Dictionary containing serialized phenotype and omics data.

        Returns:
            str: Serialized adjacency matrix CSV string from R script.
        """
        try:
            self.logger.info("Preparing data for WGCNA R script.")
            json_data = json.dumps(serialized_data)
            script_dir = os.path.dirname(os.path.abspath(__file__))
            r_script = os.path.join(script_dir, "WGCNA.R")

            if not os.path.isfile(r_script):
                self.logger.error(f"R script not found: {r_script}")
                raise FileNotFoundError(f"R script not found: {r_script}")

            command = [
                "Rscript",
                r_script,
                str(self.soft_power),
                str(self.min_module_size),
                str(self.merge_cut_height)
            ]

            self.logger.debug(f"Executing command: {' '.join(command)}")
            result = subprocess.run(
                command,
                input=json_data,
                text=True,
                capture_output=True,
                check=True
            )

            self.logger.info("WGCNA R script executed successfully.")
            self.logger.debug(f"WGCNA Output:\n{result.stdout}")

            if result.stderr:
                self.logger.warning(f"WGCNA Warnings/Errors:\n{result.stderr}")

            adjacency_json = result.stdout.strip()

            return adjacency_json

        except subprocess.CalledProcessError as e:
            self.logger.error(f"R script execution failed: {e.stderr}")
            raise
        except Exception as e:
            self.logger.error(f"Error during WGCNA execution: {e}")
            raise

[docs] def run(self) -> pd.DataFrame: """ Executes the entire Weighted Gene Co-expression Network Analysis (WGCNA) workflow. **Steps:** 1. **Preprocessing Data:** - Prepares and formats the input omics data for WGCNA analysis. - Serializes the data into a format suitable for the WGCNA pipeline. 2. **Running WGCNA:** - Constructs a weighted correlation network based on the serialized omics data. - Identifies co-expression modules among genes or features in the dataset. 3. **Postprocessing Results:** - Deserializes the adjacency matrix (output of WGCNA) into a Pandas DataFrame. - Logs successful completion and prepares the matrix for downstream tasks. **Returns**: pd.DataFrame - A DataFrame containing the adjacency matrix, where each entry represents the weighted correlation between features. **Raises**: - ValueError If the input data is improperly formatted or missing. - Exception For any unforeseen errors encountered during preprocessing, network construction, or postprocessing. **Notes**: - The WGCNA workflow is sensitive to input data quality and formatting. - Ensure that the input omics data is preprocessed, normalized, and properly indexed to align with expected formats. - This method is designed for large-scale multi-omics data and may require significant computational resources depending on the dataset size. Example: .. code-block:: python wgcna = WGCNA(omics_data) adjacency_matrix = wgcna.run() print(adjacency_matrix.head()) """ try: self.logger.info("Starting WGCNA Network Construction Workflow.") serialized_data = self.preprocess_data() adjacency_json = self.run_wgcna(serialized_data) adjacency_matrix = pd.read_json(StringIO(adjacency_json), orient='split') self.logger.info("Adjacency matrix deserialized successfully.") self.logger.info("WGCNA Network Construction completed successfully.") return adjacency_matrix except Exception as e: self.logger.error(f"Error in WGCNA Network Construction: {e}") raise