Source code for bioneuralnet.network_embedding.gnn_embedding

import os
from typing import Dict, Optional
import torch
import torch.nn as nn
import pandas as pd
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from datetime import datetime
from .gnn_models import GCN, GAT, SAGE, GIN
from ..utils.logger import get_logger


class GNNEmbedding:
    """
    GNNEmbedding Class for Generating Graph Neural Network (GNN) Based Embeddings.
    """

    def __init__(
        self,
        adjacency_matrix: pd.DataFrame,
        omics_data: pd.DataFrame,
        clinical_data: pd.DataFrame,
        model_type: str = 'GCN',
        gnn_input_dim: Optional[int] = None,
        gnn_hidden_dim: int = 64,
        gnn_layer_num: int = 2,
        dropout: bool = True,
        output_dir: Optional[str] = None,
    ):
        if adjacency_matrix.empty:
            raise ValueError("Adjacency matrix cannot be empty.")
        if omics_data.empty:
            raise ValueError("Omics data cannot be empty.")
        if clinical_data.empty:
            raise ValueError("Clinical data cannot be empty.")

        self.adjacency_matrix = adjacency_matrix
        self.omics_data = omics_data
        self.clinical_data = clinical_data
        self.model_type = model_type
        self.gnn_input_dim = gnn_input_dim
        self.gnn_hidden_dim = gnn_hidden_dim
        self.gnn_layer_num = gnn_layer_num
        self.dropout = dropout
        self.output_dir = output_dir #if output_dir else self._create_output_dir()

        self.logger = get_logger(__name__)
        self.logger.info("Initialized GNNEmbedding with direct data inputs.")

    def _create_output_dir(self) -> str:
        base_dir = "gnn_embedding_output"
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        output_dir = f"{base_dir}_{timestamp}"
        os.makedirs(output_dir, exist_ok=True)
        return output_dir

[docs] def run(self) -> Dict[str, torch.Tensor]: """ Generate GNN-based embeddings from the provided adjacency matrix and node features. **Steps:** 1. **Node Feature Preparation**: - Computes correlations between omics nodes and clinical variables. 2. **Building PyG Data Object**: - Converts the adjacency matrix and node features into a PyTorch Geometric `Data` object. 3. **Model Inference**: - Runs the specified GNN model (e.g., GCN, GAT, SAGE, or GIN) to compute node embeddings. 4. **Saving Embeddings**: - Stores the resulting embeddings to a file for future analysis or downstream tasks. **Returns**: Dict[str, torch.Tensor] - A dictionary where keys are graph names (e.g., 'graph') and values are PyTorch tensors of shape `(num_nodes, embedding_dim)` containing the node embeddings. **Raises**: - **ValueError**: If node features cannot be computed or if required nodes are missing. - **Exception**: For any unforeseen errors encountered during node feature preparation, model inference, or embedding generation. **Notes**: - Ensure that the adjacency matrix aligns with the nodes present in the omics data. - Clinical variables should be properly correlated with omics features. - Adjust parameters like `model_type`, `gnn_hidden_dim`, or `gnn_layer_num` as needed to customize the embedding process. **Example**: .. code-block:: python gnn_embedding = GNNEmbedding(adjacency_matrix, omics_data, model_type='GCN') embeddings = gnn_embedding.run() print(embeddings['graph'].shape) """ self.logger.info("Running GNN Embedding process.") node_features = self._prepare_node_features() if self.gnn_input_dim is None: self.gnn_input_dim = node_features.shape[1] self.logger.info(f"GNN input dimension set to {self.gnn_input_dim}") data = self._adjacency_to_data(node_features) model = self._initialize_gnn_model( model_type=self.model_type, input_dim=self.gnn_input_dim, hidden_dim=self.gnn_hidden_dim, layer_num=self.gnn_layer_num, dropout=self.dropout ) embeddings = self._generate_embeddings(model, data) return {'graph': embeddings}
def _prepare_node_features(self) -> pd.DataFrame: self.logger.info("Preparing node features from omics_data and clinical_data.") nodes_in_network = self.adjacency_matrix.index.tolist() missing_nodes = set(nodes_in_network) - set(self.omics_data.columns) if missing_nodes: raise ValueError(f"Nodes missing in omics_data: {missing_nodes}") common_samples = self.omics_data.index.intersection(self.clinical_data.index) if len(common_samples) == 0: raise ValueError("No common samples between omics_data and clinical_data.") node_data = self.omics_data.loc[common_samples, nodes_in_network] clinical_data_aligned = self.clinical_data.loc[common_samples] correlations = [] for node in node_data.columns: corr_values = [] for clinical_var in clinical_data_aligned.columns: corr = node_data[node].corr(clinical_data_aligned[clinical_var]) corr_values.append(corr) correlations.append(corr_values) node_features = pd.DataFrame(correlations, index=node_data.columns, columns=clinical_data_aligned.columns) return node_features def _adjacency_to_data(self, node_features: pd.DataFrame) -> Data: self.logger.info("Converting adjacency matrix to PyTorch Geometric Data object.") G = nx.from_pandas_adjacency(self.adjacency_matrix) node_mapping = {node_name: idx for idx, node_name in enumerate(self.adjacency_matrix.index)} G = nx.relabel_nodes(G, node_mapping) data = from_networkx(G) node_order = [self.adjacency_matrix.index[idx] for idx in range(len(self.adjacency_matrix))] data.x = torch.tensor(node_features.loc[node_order].values, dtype=torch.float) return data def _initialize_gnn_model(self, model_type: str, input_dim: int, hidden_dim: int, layer_num: int, dropout: bool) -> nn.Module: if model_type == 'GCN': model = GCN(input_dim, hidden_dim, layer_num, dropout) elif model_type == 'GAT': model = GAT(input_dim, hidden_dim, layer_num, dropout) elif model_type == 'SAGE': model = SAGE(input_dim, hidden_dim, layer_num, dropout) elif model_type == 'GIN': model = GIN(input_dim, hidden_dim, layer_num, dropout) else: raise ValueError(f"Unsupported GNN model type: {model_type}") return model def _generate_embeddings(self, model: nn.Module, data: Data) -> torch.Tensor: self.logger.info("Generating embeddings using the GNN model.") model.eval() with torch.no_grad(): embeddings = model(data) return embeddings