Skip to content

explainer

Defines an explainer for mG models.

This package defines the means to generate the sub-graph of all nodes that influenced the final label of some query node.

The package contains the following classes:

  • MGExplainer

MGExplainer

MGExplainer(model: MGModel)

Bases: Interpreter

Generates an explanation for a mG model output.

Generates the sub-graph of nodes that are responsible for the label of a given node.

Attributes:

  • model

    The model to explain.

  • query_node (int | None) –

    The node of the input graph for which the sub-graph of relevant nodes will be generated.

  • context

    The context in which the current expression is being evaluated.

  • compiler

    The compiler for the explainer.

Parameters:

  • model (MGModel) –

    The model to explain.

Source code in libmg/explainer/explainer.py
def __init__(self, model: MGModel):
    """Initializes the instance with the model to explain.

    Args:
        model: The model to explain.
    """
    super().__init__()
    self.model = model
    if model.config is None:
        raise ValueError("Explained model must have a valid config!")
    self.query_node: int | None = None
    self.context = Context()
    self.compiler = MGCompiler(psi_functions={'node': MGExplainer.localize_node, 'or': MGExplainer.or_fun},
                               sigma_functions={'or': MGExplainer.or_agg},
                               phi_functions={'p3': MGExplainer.proj1},
                               config=model.config)

explain

explain(
    query_node: int, inputs: tuple[tf.Tensor, ...], filename: str | None = None, open_browser: bool = True, engine: Literal["pyvis", "cosmo"] = "pyvis"
) -> Graph

Explain the label of a query node by generating the sub-graph of nodes that affected its value.

Using the pyvis engine, the explanation is saved in a html file in the working directory. Using the cosmograph engine, the explanation is saved as a directory, containing an index.html file, in the working directory.

Parameters:

  • query_node (int) –

    The node for which to generate the explanation.

  • inputs (tuple[Tensor, ...]) –

    The inputs for the model to explain. This is the graph to which the query node belongs.

  • filename (str | None, default: None ) –

    The name of the .html file to save in the working directory. The string graph_ will be prepended to it.

  • open_browser (bool, default: True ) –

    If true, opens the default web browser and loads up the generated .html page.

  • engine (Literal['pyvis', 'cosmo'], default: 'pyvis' ) –

    The visualization engine to use. Options are pyvis for PyVis or cosmo for Cosmograph.

Returns:

  • Graph

    The generated sub-graph.

Source code in libmg/explainer/explainer.py
def explain(self, query_node: int, inputs: tuple[tf.Tensor, ...], filename: str | None = None, open_browser: bool = True,
            engine: Literal["pyvis", "cosmo"] = 'pyvis') -> Graph:
    """Explain the label of a query node by generating the sub-graph of nodes that affected its value.

    Using the pyvis engine, the explanation is saved in a html file in the working directory. Using the cosmograph engine, the explanation is saved as a
    directory, containing an index.html file, in the working directory.

    Args:
        query_node: The node for which to generate the explanation.
        inputs: The inputs for the model to explain. This is the graph to which the query node belongs.
        filename: The name of the .html file to save in the working directory. The string ``graph_`` will be prepended to it.
        open_browser: If true, opens the default web browser and loads up the generated .html page.
        engine: The visualization engine to use. Options are ``pyvis`` for PyVis or ``cosmo`` for Cosmograph.

    Returns:
        The generated sub-graph.
    """
    if self.model.expr is None:
        raise ValueError("Explained model must have a valid expr!")
    # Build the model
    self.query_node = query_node
    self.context.clear()
    actual_outputs = self.model.call(inputs)
    try:
        right_branch = self.visit(deepcopy(self.model.expr))
    except VisitError:
        right_branch = mg_parser.parse(MGExplainer.all_nodes_expr)
    left_branch = mg_parser.parse('node[' + str(self.query_node) + ']')
    explainer_expr_tree = mg_parser.parse('left ; right')
    explainer_expr_tree.children = [left_branch, right_branch]
    explainer_model = self.compiler.compile(explainer_expr_tree)

    # Run the model
    hierarchy = tf.squeeze(explainer_model.call(inputs))
    explanation = tf.math.less(hierarchy, MGExplainer.INF)
    graph = make_graph(explanation, hierarchy, inputs, actual_outputs)
    print_graph(graph, id_generator=self._get_original_ids_func(explanation), hierarchical=True,
                show_labels=True, filename=filename, open_browser=open_browser, engine=engine)
    return graph