Source code for corrGraph.corrGraph

import networkx as nx
from typing import Dict
import pandas as pd
from pyvis.network import Network
import plotly.graph_objects as go

[docs] class Config: """ Configuration class for setting default parameters for the correlation graph visualization. """ AXIS_CONFIG = dict(showline=False, zeroline=False, showgrid=False, showticklabels=False, title='') LAYOUT_CONFIG = dict( title='<br>Correlation Graph', titlefont_size=16, showlegend=False, hovermode='closest', margin=dict(b=20, l=5, r=5, t=40), annotations=[dict( text="Visualization of the correlation graph", showarrow=False, xref="paper", yref="paper", x=0.005, y=-0.002)], xaxis=AXIS_CONFIG, yaxis=AXIS_CONFIG, width=800, height=800 ) DEFAULT_NODE_SIZE = 0.5 DEFAULT_EDGE_WIDTH = 0.01 NODE_SIZE_SCALE = 20 EDGE_WIDTH_SCALE = 1 NEGATIVE_COLOR = 'orange' POSITIVE_COLOR = 'blue' FIGURE_SIZE = '800px'
[docs] class CorrGraph: """ Class for creating and visualizing a correlation graph from a correlation matrix. """ def __init__(self, corr_matrix: pd.DataFrame, threshold: float = 0.5, use_correlations_as_weights: bool = True) -> None: """ Initialize the CorrGraph object. :param corr_matrix: A pandas DataFrame representing the correlation matrix. :param threshold: A float value to determine the minimum correlation value to consider an edge. :param use_correlations_as_weights: A boolean to decide if correlations should be used as edge weights. """ self.corr_matrix: pd.DataFrame = corr_matrix self.threshold: float = threshold self.use_correlations_as_weights: bool = use_correlations_as_weights self.graph: nx.Graph = nx.Graph() self._create_graph() def _create_graph(self) -> None: """ Create a graph from the correlation matrix based on the threshold and weight settings. """ features = self.corr_matrix.columns num_features = len(features) for i in range(num_features): self.graph.add_node(features[i]) for j in range(i + 1, num_features): if abs(self.corr_matrix.iloc[i, j]) >= self.threshold: weight = self.corr_matrix.iloc[i, j] self.graph.add_edge(features[i], features[j], weight=weight)
[docs] def get_graph(self) -> nx.Graph: """ Get the created graph. :return: A networkx Graph object representing the correlation graph. """ return self.graph
[docs] def update_node_weights(self, weights: Dict[str, float]) -> None: """ Update the weights of the nodes in the graph. :param weights: A dictionary where keys are node names and values are the weights to be assigned. :raises ValueError: If a node is not found in the graph. :raises TypeError: If a weight is not numeric. """ for node, weight in weights.items(): if node not in self.graph: raise ValueError(f"Feature '{node}' not found in the graph.") if not isinstance(weight, (int, float)): raise TypeError(f"Weight for feature '{node}' must be numeric.") self.graph.nodes[node]['weight'] = weight
[docs] def visualize_graph_with_plotly(self, node_weight_is_size: bool = True, edge_weight_is_size: bool = True) -> None: """ Visualize the graph using Plotly. :param node_weight_is_size: A boolean to decide if node weights should determine node sizes. :param edge_weight_is_size: A boolean to decide if edge weights should determine edge widths. """ pos = nx.spring_layout(self.graph) fig = go.Figure(layout=Config.LAYOUT_CONFIG) for edge in self.graph.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x = [x0, x1, None] edge_y = [y0, y1, None] edge_width = self.graph.edges[edge].get('weight', Config.DEFAULT_EDGE_WIDTH) * Config.EDGE_WIDTH_SCALE if edge_weight_is_size else Config.DEFAULT_EDGE_WIDTH fig.add_trace( go.Scatter( x=edge_x, y=edge_y, line=dict(width=edge_width, color='#888'), hoverinfo='none', mode='lines', text=f'{edge_width if edge_weight_is_size else "none"}', ) ) for node in self.graph.nodes(): x, y = pos[node] node_text = f'{node}<br>Weight: {self.graph.nodes[node].get("weight", "N/A")}' node_size = abs(self.graph.nodes[node].get('weight', Config.DEFAULT_NODE_SIZE)) * Config.NODE_SIZE_SCALE if node_weight_is_size else Config.DEFAULT_NODE_SIZE node_color = Config.NEGATIVE_COLOR if self.graph.nodes[node].get('weight', 0) < 0 else Config.POSITIVE_COLOR fig.add_trace( go.Scatter( x=[x], y=[y], mode='markers', text=node_text, textposition="top center", hoverinfo='text', marker=dict( size=node_size, color=node_color, line_width=2) ) ) fig.add_trace( go.Scatter( x=[None], y=[None], mode='markers', marker=dict( size=Config.DEFAULT_NODE_SIZE, color=Config.NEGATIVE_COLOR, ), legendgroup='Negative Weight', showlegend=True, name='Negative Weight' ) ) fig.add_trace( go.Scatter( x=[None], y=[None], mode='markers', marker=dict( size=Config.DEFAULT_NODE_SIZE, color=Config.POSITIVE_COLOR, ), legendgroup='Positive Weight', showlegend=True, name='Positive Weight' ) ) fig.show()
[docs] def visualize_graph_with_pyvis(self, node_weight_is_size: bool = True, edge_weight_is_size: bool = True, use_as_notebook = True) -> None: """ Visualize the graph using PyVis. :param node_weight_is_size: A boolean to decide if node weights should determine node sizes. :param edge_weight_is_size: A boolean to decide if edge weights should determine edge widths. :param use_as_notebook: A boolean to decide if the visualization should be displayed in a Jupyter notebook. """ net = Network(height=Config.FIGURE_SIZE, width=Config.FIGURE_SIZE, notebook=use_as_notebook) for node in self.graph.nodes(): node_size = abs(self.graph.nodes[node].get('weight', Config.DEFAULT_NODE_SIZE)) * Config.NODE_SIZE_SCALE if node_weight_is_size else Config.DEFAULT_NODE_SIZE node_color = Config.NEGATIVE_COLOR if self.graph.nodes[node].get('weight', 0) < 0 else Config.POSITIVE_COLOR node_title = f'{node}\nWeight: {self.graph.nodes[node].get("weight", "N/A")}' net.add_node(node, label=node, size=node_size, color=node_color, title=node_title) for edge in self.graph.edges(): edge_weight = self.graph.edges[edge].get('weight', Config.DEFAULT_EDGE_WIDTH) edge_width = edge_weight * Config.EDGE_WIDTH_SCALE if edge_weight_is_size else Config.DEFAULT_EDGE_WIDTH edge_title = f'Weight: {edge_weight}' if edge_weight < 0: edge_color = 'red' else: edge_color = 'green' # Determine edge color based on weight if edge_weight_is_size: net.add_edge(edge[0], edge[1], value=edge_width, title=edge_title, color=edge_color) else: net.add_edge(edge[0], edge[1], title=edge_title, color=edge_color) net.show('correlation_graph.html')
[docs] def visualize_graph(self, visualization_type: str = 'plotly', node_weight_is_size: bool = True, edge_weight_is_size: bool = True, use_as_notebook: bool = True) -> None: """ Visualize the graph using either Plotly or PyVis based on the visualization_type argument. :param visualization_type: A string to decide which visualization library to use ('plotly' or 'pyvis'). :param node_weight_is_size: A boolean to decide if node weights should determine node sizes. :param edge_weight_is_size: A boolean to decide if edge weights should determine edge widths. :param use_as_notebook: A boolean to decide if the visualization should be displayed in a Jupyter notebook (only for PyVis). """ if visualization_type == 'plotly': self.visualize_graph_with_plotly(node_weight_is_size, edge_weight_is_size) elif visualization_type == 'pyvis': self.visualize_graph_with_pyvis(node_weight_is_size, edge_weight_is_size, use_as_notebook) else: raise ValueError("Invalid visualization_type. Choose either 'plotly' or 'pyvis'.")
[docs] def get_centrality(self, centrality_type: str = 'degree') -> Dict[str, float]: """ Calculate and return the centrality of each node in the graph based on the specified centrality type. :param centrality_type: A string to specify the type of centrality measure ('degree', 'betweenness', 'closeness', 'eigenvector'). :return: A dictionary where keys are node names and values are their centrality measures. :raises ValueError: If an invalid centrality type is provided. """ if centrality_type == 'degree': centrality = nx.degree_centrality(self.graph) elif centrality_type == 'betweenness': centrality = nx.betweenness_centrality(self.graph) elif centrality_type == 'closeness': centrality = nx.closeness_centrality(self.graph) elif centrality_type == 'eigenvector': centrality = nx.eigenvector_centrality(self.graph) else: raise ValueError("Invalid centrality_type. Choose from 'degree', 'betweenness', 'closeness', or 'eigenvector'.") return centrality