Source code for bioneuralnet.graph_generation.smccnet

import os
import subprocess
import pandas as pd
from typing import List, Dict, Any
from ..utils.logger import get_logger
import json
from io import StringIO

class SmCCNet:
    """
    SmCCNet Class for Graph Generation using Sparse Multiple Canonical Correlation Networks (SmCCNet).

    This class handles the preprocessing of omics data, execution of the SmCCNet R script,
    and retrieval of the resulting adjacency matrix, all using in-memory data structures.
    """

    def __init__(
        self,
        phenotype_df: pd.DataFrame,
        omics_dfs: List[pd.DataFrame],
        data_types: List[str],
        kfold: int = 5,
        summarization: str = "PCA",
        seed: int = 732,
    ):
        """
        Initializes the SmCCNet instance.

        Args:
            phenotype_df (pd.DataFrame): DataFrame containing phenotype data. The first column should be sample IDs.
            omics_dfs (List[pd.DataFrame]): List of DataFrames, each representing an omics dataset. Each DataFrame should have sample IDs as the first column.
            data_types (List[str]): List of omics data types (e.g., ["protein", "metabolite"]).
            kfold (int, optional): Number of folds for cross-validation. Defaults to 5.
            summarization (str, optional): Summarization method. Defaults to "PCA".
            seed (int, optional): Random seed for reproducibility. Defaults to 732.
        """
        self.phenotype_df = phenotype_df
        self.omics_dfs = omics_dfs
        self.data_types = data_types
        self.kfold = kfold
        self.summarization = summarization
        self.seed = seed

        self.logger = get_logger(__name__)
        self.logger.info("Initialized SmCCNet with the following parameters:")
        self.logger.info(f"K-Fold: {self.kfold}")
        self.logger.info(f"Summarization: {self.summarization}")
        self.logger.info(f"Seed: {self.seed}")

        if len(self.omics_dfs) != len(self.data_types):
            self.logger.error("Number of omics dataframes does not match number of data types.")
            raise ValueError("Number of omics dataframes does not match number of data types.")

    def preprocess_data(self) -> Dict[str, Any]:
        """
        Preprocesses the omics data to ensure alignment and handle missing values.

        Returns:
            Dict[str, Any]: Dictionary containing serialized phenotype and omics data.
        """
        self.logger.info("Preprocessing omics data for NaN or infinite values.")
        phenotype_ids = self.phenotype_df.iloc[:, 0]
        self.logger.info(f"Number of samples in phenotype data: {len(phenotype_ids)}")

        valid_samples = pd.Series([True] * len(phenotype_ids), index=self.phenotype_df.index)

        serialized_data = {
            'phenotype': self.phenotype_df.to_csv(index=False)
        }

        for idx, omics_df in enumerate(self.omics_dfs):
            data_type = self.data_types[idx]
            self.logger.info(f"Processing omics DataFrame {idx+1}/{len(self.omics_dfs)}: Data Type = {data_type}")

            omics_ids = omics_df.iloc[:, 0]
            if not omics_ids.equals(phenotype_ids):
                self.logger.warning(f"Sample IDs in omics dataframe {idx+1} do not match phenotype data. Aligning data.")
                omics_df = omics_df.set_index(omics_ids).loc[phenotype_ids].reset_index()

            if omics_df.isnull().values.any():
                self.logger.warning(f"NaN values detected in omics dataframe {idx+1}. Marking samples with NaNs as invalid.")
                valid_samples &= ~omics_df.isnull().any(axis=1)

            if (omics_df == float('inf')).any().any() or (omics_df == -float('inf')).any().any():
                self.logger.warning(f"Infinite values detected in omics dataframe {idx+1}. Replacing with NaN and marking samples as invalid.")
                omics_df.replace([float('inf'), -float('inf')], pd.NA, inplace=True)
                valid_samples &= ~omics_df.isnull().any(axis=1)

            serialized_data[f'omics_{idx+1}'] = omics_df.to_csv(index=False)

        self.logger.info("Preprocessing completed successfully.")
        return serialized_data

    def run_smccnet(self, serialized_data: Dict[str, Any]) -> str:
        """
        Executes the SmCCNet R script by passing serialized data via standard input.

        Args:
            serialized_data (Dict[str, Any]): Dictionary containing serialized phenotype and omics data.

        Returns:
            str: Serialized adjacency matrix JSON string from R script.
        """
        try:
            self.logger.info("Preparing data for SmCCNet R script.")
            json_data = json.dumps(serialized_data)

            script_dir = os.path.dirname(os.path.abspath(__file__))
            r_script = os.path.join(script_dir, "SmCCNet.R")

            if not os.path.isfile(r_script):
                self.logger.error(f"R script not found: {r_script}")
                raise FileNotFoundError(f"R script not found: {r_script}")

            command = [
                "Rscript",
                r_script,
                ','.join(self.data_types),
                str(self.kfold),
                self.summarization,
                str(self.seed)
            ]

            self.logger.debug(f"Executing command: {' '.join(command)}")

            result = subprocess.run(
                command,
                input=json_data,
                text=True,
                capture_output=True,
                check=True
            )

            self.logger.info("SmCCNet R script executed successfully.")
            self.logger.debug(f"SmCCNet Output:\n{result.stdout}")

            if result.stderr:
                self.logger.warning(f"SmCCNet Warnings/Errors:\n{result.stderr}")

            adjacency_json = result.stdout.strip()

            return adjacency_json

        except subprocess.CalledProcessError as e:
            self.logger.error(f"R script execution failed: {e.stderr}")
            raise
        except Exception as e:
            self.logger.error(f"Error during SmCCNet execution: {e}")
            raise

[docs] def run(self) -> pd.DataFrame: """ Executes the entire Sparse Multiple Canonical Correlation Network (SmCCNet) workflow. **Steps:** 1. **Preprocessing Data**: - Formats and serializes the input omics and phenotype data for SmCCNet analysis. 2. **Graph Generation**: - Constructs a global network by generating an adjacency matrix through SmCCNet. 3. **Postprocessing Results**: - Deserializes the adjacency matrix (output of SmCCNet) into a Pandas DataFrame. **Returns**: pd.DataFrame - A DataFrame containing the adjacency matrix, where each entry represents the strength of the correlation between features. **Raises**: - **ValueError**: If the input data is improperly formatted or missing. - **Exception**: For any unforeseen errors encountered during preprocessing, graph generation, or postprocessing. **Notes**: - SmCCNet is designed for multi-omics data and requires a well-preprocessed and normalized dataset. - Ensure that omics and phenotype data are properly aligned to avoid errors in graph construction. **Example**: .. code-block:: python smccnet = SmCCNet(omics_data, phenotype_data) adjacency_matrix = smccnet.run() print(adjacency_matrix.head()) """ try: self.logger.info("Starting SmCCNet Graph Generation Workflow.") serialized_data = self.preprocess_data() adjacency_json = self.run_smccnet(serialized_data) adjacency_matrix = pd.read_json(StringIO(adjacency_json), orient='split') self.logger.info("Adjacency matrix deserialized successfully.") self.logger.info("SmCCNet Graph Generation completed successfully.") return adjacency_matrix except Exception as e: self.logger.error(f"Error in SmCCNet Graph Generation: {e}") raise