218  Graph-Based Retrieval-Augmented Generation

218.1 Introduction and Motivation

Retrieval-augmented generation addresses a fundamental limitation of large language models: their knowledge is frozen at training time and encoded implicitly within billions of parameters. When a user poses a question requiring information beyond the training corpus or demanding precise factual recall, standard language models either confabulate plausible-sounding but incorrect responses or acknowledge uncertainty without providing useful guidance. Retrieval-augmented generation circumvents this limitation by augmenting the generation process with explicit retrieval from an external knowledge store, allowing the model to condition its output on retrieved evidence.

The canonical retrieval-augmented generation architecture employs dense vector retrieval over a corpus of text chunks. Documents are segmented into passages, each passage is embedded into a high-dimensional vector space using a learned encoder, and at query time, the system retrieves passages whose embeddings exhibit high similarity to the query embedding. These retrieved passages are concatenated with the query and provided as context to the language model, which then generates a response grounded in the retrieved evidence.

This architecture, while effective for many applications, exhibits systematic weaknesses when confronted with queries requiring multi-hop reasoning, synthesis across disparate sources, or understanding of complex relational structures. Consider a query such as “What are the main research contributions of scientists who collaborated with both Geoffrey Hinton and Yann LeCun?” Answering this query requires identifying collaborators of Hinton, identifying collaborators of LeCun, computing the intersection, and then aggregating research contributions across the resulting set. A flat retrieval system operating over text chunks struggles with such queries because the relevant information is distributed across many documents, no single passage contains the complete answer, and the retrieval model has no mechanism for expressing or enforcing the relational constraints implicit in the query.

Graph-based retrieval-augmented generation addresses these limitations by organizing knowledge into a graph structure that explicitly represents entities and their relationships. Rather than treating a corpus as an undifferentiated collection of text chunks, graph-based approaches extract structured knowledge, represent it as nodes and edges in a graph, and leverage graph algorithms and graph neural networks to perform retrieval that respects relational structure. The graph serves as both an index and a reasoning substrate, enabling the system to traverse multi-hop paths, aggregate information across related entities, and provide the language model with contextualized, structured evidence.

This chapter develops the theory, algorithms, and implementation of graph-based retrieval-augmented generation systems. We begin with the mathematical foundations of graphs and graph neural networks, proceed through knowledge graph construction and entity-relation extraction, develop retrieval algorithms that operate over graph structures, examine mechanisms for integrating graph-derived context into language model generation, and conclude with production considerations including scalability, evaluation, and deployment architecture.

218.2 Mathematical Foundations

218.2.1 Graph Theory Preliminaries

A graph \(G = (V, E)\) consists of a vertex set \(V\) and an edge set \(E \subseteq V \times V\). In the context of knowledge representation, vertices typically correspond to entities such as people, organizations, concepts, or documents, while edges represent relationships between entities. We denote the number of vertices as \(n = |V|\) and the number of edges as \(m = |E|\).

For knowledge graphs, we typically work with directed labeled graphs where edges carry semantic types. Formally, a knowledge graph is a tuple \(G = (V, E, R, \phi)\) where \(R\) is a set of relation types and \(\phi: E \to R\) assigns a relation type to each edge. An edge \((u, v) \in E\) with \(\phi((u,v)) = r\) is often written as a triple \((u, r, v)\), representing the assertion that entity \(u\) stands in relation \(r\) to entity \(v\).

The adjacency matrix \(A \in \{0,1\}^{n \times n}\) encodes graph structure, with \(A_{ij} = 1\) if and only if \((i, j) \in E\). For weighted graphs, \(A_{ij}\) takes values in \(\mathbb{R}_{\geq 0}\). The degree matrix \(D\) is diagonal with \(D_{ii} = \sum_j A_{ij}\). The graph Laplacian \(L = D - A\) plays a central role in spectral graph theory, with eigenvalues encoding structural properties of the graph.

The normalized Laplacian \(\mathcal{L} = I - D^{-1/2} A D^{-1/2}\) has eigenvalues in \([0, 2]\) and provides a basis for defining spectral graph convolutions. The eigendecomposition \(\mathcal{L} = U \Lambda U^T\) expresses the Laplacian in terms of orthonormal eigenvectors \(U\) and diagonal eigenvalue matrix \(\Lambda\).

218.2.2 Graph Neural Networks

Graph neural networks learn representations of graph-structured data by iteratively aggregating information from local neighborhoods. The fundamental operation is message passing: each vertex updates its representation by aggregating messages from its neighbors and combining the result with its current state.

Let \(h_v^{(l)} \in \mathbb{R}^d\) denote the hidden representation of vertex \(v\) at layer \(l\). The general message passing framework is expressed as:

\[m_v^{(l+1)} = \text{AGGREGATE}\left(\left\{ h_u^{(l)} : u \in \mathcal{N}(v) \right\}\right)\]

\[h_v^{(l+1)} = \text{UPDATE}\left(h_v^{(l)}, m_v^{(l+1)}\right)\]

where \(\mathcal{N}(v)\) denotes the neighborhood of vertex \(v\), AGGREGATE is a permutation-invariant function, and UPDATE combines the aggregated message with the current representation.

The Graph Convolutional Network (GCN) instantiates this framework with mean aggregation and linear transformation. The layer-wise propagation rule is:

\[H^{(l+1)} = \sigma\left(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)}\right)\]

where \(\tilde{A} = A + I\) adds self-loops, \(\tilde{D}\) is the corresponding degree matrix, \(W^{(l)} \in \mathbb{R}^{d_l \times d_{l+1}}\) is a learnable weight matrix, and \(\sigma\) is a nonlinearity such as ReLU.

This propagation rule can be derived from a first-order approximation to spectral graph convolutions. The spectral convolution of signal \(x\) with filter \(g_\theta\) is defined as:

\[g_\theta \star x = U g_\theta(\Lambda) U^T x\]

where \(g_\theta(\Lambda)\) is the filter evaluated at the eigenvalues. Taking \(g_\theta(\Lambda) = \sum_{k=0}^{K} \theta_k T_k(\tilde{\Lambda})\) as a Chebyshev polynomial expansion with \(\tilde{\Lambda} = \frac{2}{\lambda_{\max}} \Lambda - I\), and truncating at \(K=1\) with \(\lambda_{\max} = 2\), yields the GCN propagation rule.

The Graph Attention Network (GAT) replaces uniform neighborhood aggregation with learned attention weights:

\[\alpha_{ij} = \frac{\exp\left(\text{LeakyReLU}\left(a^T [W h_i \| W h_j]\right)\right)}{\sum_{k \in \mathcal{N}(i)} \exp\left(\text{LeakyReLU}\left(a^T [W h_i \| W h_k]\right)\right)}\]

\[h_i^{(l+1)} = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W h_j^{(l)}\right)\]

where \(\|\) denotes concatenation, \(W\) is a shared linear transformation, and \(a\) is a learned attention vector. Multi-head attention extends this by concatenating or averaging outputs from multiple attention heads with independent parameters.

GraphSAGE introduces sampling-based aggregation for scalability to large graphs:

\[h_v^{(l+1)} = \sigma\left(W^{(l)} \cdot \text{CONCAT}\left(h_v^{(l)}, \text{AGG}\left(\{h_u^{(l)} : u \in \mathcal{S}(v)\}\right)\right)\right)\]

where \(\mathcal{S}(v)\) is a fixed-size sample from \(\mathcal{N}(v)\). Common aggregators include mean, LSTM over a random permutation, and max pooling.

218.2.3 Relational Graph Neural Networks

Knowledge graphs contain multiple relation types, requiring architectures that can handle heterogeneous edge labels. The Relational Graph Convolutional Network (R-GCN) extends GCN with relation-specific transformations:

\[h_i^{(l+1)} = \sigma\left(W_0^{(l)} h_i^{(l)} + \sum_{r \in R} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{c_{i,r}} W_r^{(l)} h_j^{(l)}\right)\]

where \(\mathcal{N}_r(i)\) is the set of neighbors connected to \(i\) by relation \(r\), \(W_r^{(l)}\) is a relation-specific weight matrix, and \(c_{i,r}\) is a normalization constant typically set to \(|\mathcal{N}_r(i)|\).

The number of parameters grows linearly with the number of relations, which can be prohibitive for knowledge graphs with thousands of relation types. Two regularization strategies address this challenge. Basis decomposition represents each \(W_r\) as a linear combination of basis matrices: \(W_r^{(l)} = \sum_{b=1}^{B} a_{rb}^{(l)} V_b^{(l)}\) where \(B \ll |R|\). Block-diagonal decomposition constrains each \(W_r\) to be block-diagonal, reducing parameters while maintaining expressiveness.

218.2.4 Knowledge Graph Embeddings

Knowledge graph embedding methods learn low-dimensional representations of entities and relations that capture the semantic structure of the graph. These embeddings serve as initializations for graph neural networks and enable efficient approximate retrieval.

TransE models relations as translations in embedding space. Given a triple \((h, r, t)\), TransE enforces \(\mathbf{h} + \mathbf{r} \approx \mathbf{t}\) where \(\mathbf{h}, \mathbf{r}, \mathbf{t} \in \mathbb{R}^d\) are embeddings. The scoring function is:

\[f(h, r, t) = -\|\mathbf{h} + \mathbf{r} - \mathbf{t}\|_{p}\]

where \(p \in \{1, 2\}\). Training minimizes a margin-based ranking loss:

\[\mathcal{L} = \sum_{(h,r,t) \in S} \sum_{(h',r,t') \in S'} \max(0, \gamma + f(h',r,t') - f(h,r,t))\]

where \(S\) is the set of positive triples, \(S'\) is a set of corrupted triples with either head or tail replaced, and \(\gamma\) is the margin.

TransE struggles with one-to-many, many-to-one, and many-to-many relations because it enforces identical tail embeddings for all entities sharing a head-relation pair. RotatE addresses this by modeling relations as rotations in complex space:

\[f(h, r, t) = -\|\mathbf{h} \circ \mathbf{r} - \mathbf{t}\|\]

where \(\mathbf{h}, \mathbf{t} \in \mathbb{C}^d\), \(\mathbf{r} \in \mathbb{C}^d\) with \(|\mathbf{r}_i| = 1\), and \(\circ\) denotes element-wise product. This formulation can model symmetric, antisymmetric, inverse, and compositional relation patterns.

DistMult uses a bilinear scoring function: \(f(h, r, t) = \mathbf{h}^T \text{diag}(\mathbf{r}) \mathbf{t}\). ComplEx extends DistMult to complex space, enabling asymmetric relation modeling: \(f(h, r, t) = \text{Re}(\mathbf{h}^T \text{diag}(\mathbf{r}) \bar{\mathbf{t}})\) where \(\bar{\mathbf{t}}\) is the complex conjugate.

218.3 Knowledge Graph Construction

218.3.1 Entity and Relation Extraction

Constructing a knowledge graph from unstructured text requires extracting entities and the relations between them. The pipeline typically proceeds in stages: named entity recognition identifies entity mentions, entity linking grounds mentions to canonical entities in a knowledge base, and relation extraction identifies semantic relationships between entity pairs.

Named entity recognition is a sequence labeling task. Given a token sequence \(x_1, \ldots, x_n\), the model predicts a label sequence \(y_1, \ldots, y_n\) where labels follow a tagging scheme such as BIO (Beginning, Inside, Outside). Modern NER systems use transformer encoders with a linear classification layer:

\[P(y_i | x) = \text{softmax}(W h_i + b)\]

where \(h_i\) is the contextualized representation of token \(x_i\) from the transformer.

Relation extraction determines whether a semantic relation holds between a pair of identified entities. Given a sentence containing entities \(e_1\) and \(e_2\), the task is to classify the relation into one of a predefined set \(R \cup \{\text{NONE}\}\). Transformer-based relation extraction encodes the sentence with special markers indicating entity spans:

\[\text{[CLS]} \ldots \text{[E1]} e_1 \text{[/E1]} \ldots \text{[E2]} e_2 \text{[/E2]} \ldots \text{[SEP]}\]

The concatenation of \([\text{E1}]\) and \([\text{E2}]\) representations is passed through a classification layer.

Joint entity and relation extraction avoids error propagation by modeling both tasks simultaneously. The span-based approach enumerates all token spans up to a maximum length, classifies each span as an entity type or non-entity, and classifies each ordered pair of entity spans as a relation type or non-relation.

218.3.2 Large Language Models for Knowledge Extraction

Large language models can perform knowledge extraction through carefully designed prompts. The extraction process formulates entity and relation identification as text generation conditioned on schema specifications.

Code
from typing import Any
from dataclasses import dataclass, field
import json
import hashlib


@dataclass
class Entity:
    """Represents an extracted entity."""
    name: str
    entity_type: str
    description: str = ""
    source_text: str = ""
    confidence: float = 1.0
    
    def __hash__(self):
        return hash((self.name.lower(), self.entity_type.lower()))
    
    def __eq__(self, other):
        if not isinstance(other, Entity):
            return False
        return (self.name.lower() == other.name.lower() and 
                self.entity_type.lower() == other.entity_type.lower())


@dataclass
class Relation:
    """Represents an extracted relation between entities."""
    head: Entity
    relation_type: str
    tail: Entity
    confidence: float = 1.0
    source_text: str = ""
    
    def __hash__(self):
        return hash((hash(self.head), self.relation_type.lower(), hash(self.tail)))
    
    def __eq__(self, other):
        if not isinstance(other, Relation):
            return False
        return (self.head == other.head and 
                self.relation_type.lower() == other.relation_type.lower() and
                self.tail == other.tail)


@dataclass
class ExtractionSchema:
    """Defines the schema for knowledge extraction."""
    entity_types: list[str]
    relation_types: list[tuple[str, str, str]]  # (head_type, relation, tail_type)
    
    def to_prompt_description(self) -> str:
        entity_desc = "Entity Types:\n" + "\n".join(
            f"  - {et}" for et in self.entity_types
        )
        relation_desc = "Relation Types:\n" + "\n".join(
            f"  - {head} --[{rel}]--> {tail}" 
            for head, rel, tail in self.relation_types
        )
        return f"{entity_desc}\n\n{relation_desc}"


class KnowledgeExtractor:
    """Extracts entities and relations from text using an LLM."""
    
    def __init__(self, llm_client: Any, schema: ExtractionSchema):
        self.llm = llm_client
        self.schema = schema
        self.extraction_prompt = self._build_extraction_prompt()
    
    def _build_extraction_prompt(self) -> str:
        return f"""You are a knowledge extraction system. Extract entities and 
relations from the provided text according to the following schema.

{self.schema.to_prompt_description()}

For each piece of text, output a JSON object with the following structure:
{{
    "entities": [
        {{"name": "...", "type": "...", "description": "..."}}
    ],
    "relations": [
        {{"head": "...", "relation": "...", "tail": "..."}}
    ]
}}

Only extract entities and relations that match the schema. Be precise and 
conservative - only extract information that is explicitly stated or strongly 
implied by the text. Do not infer or speculate."""

    def extract(self, text: str) -> tuple[list[Entity], list[Relation]]:
        """Extract entities and relations from text."""
        response = self.llm.generate(
            system_prompt=self.extraction_prompt,
            user_prompt=f"Extract knowledge from the following text:\n\n{text}",
            response_format="json"
        )
        
        parsed = json.loads(response)
        
        entities = {}
        for e in parsed.get("entities", []):
            entity = Entity(
                name=e["name"],
                entity_type=e["type"],
                description=e.get("description", ""),
                source_text=text[:500]
            )
            entities[entity.name.lower()] = entity
        
        relations = []
        for r in parsed.get("relations", []):
            head_name = r["head"].lower()
            tail_name = r["tail"].lower()
            
            if head_name in entities and tail_name in entities:
                relation = Relation(
                    head=entities[head_name],
                    relation_type=r["relation"],
                    tail=entities[tail_name],
                    source_text=text[:500]
                )
                relations.append(relation)
        
        return list(entities.values()), relations
    
    def extract_batch(
        self, 
        texts: list[str], 
        deduplicate: bool = True
    ) -> tuple[list[Entity], list[Relation]]:
        """Extract from multiple texts with optional deduplication."""
        all_entities = []
        all_relations = []
        
        for text in texts:
            entities, relations = self.extract(text)
            all_entities.extend(entities)
            all_relations.extend(relations)
        
        if deduplicate:
            all_entities = list(set(all_entities))
            all_relations = list(set(all_relations))
        
        return all_entities, all_relations

218.3.3 Coreference Resolution and Entity Linking

Raw extractions often contain multiple surface forms referring to the same underlying entity. Coreference resolution groups mentions that refer to the same entity within a document, while entity linking grounds mentions to canonical entities in an external knowledge base.

Code
from dataclasses import dataclass, field
import numpy as np
from collections import defaultdict


@dataclass
class EntityMention:
    """A mention of an entity in text."""
    text: str
    start_char: int
    end_char: int
    document_id: str
    entity_type: str | None = None
    embedding: np.ndarray | None = None


@dataclass 
class CanonicalEntity:
    """A canonical entity in the knowledge base."""
    entity_id: str
    name: str
    aliases: list[str] = field(default_factory=list)
    entity_type: str = ""
    description: str = ""
    embedding: np.ndarray | None = None


class EntityLinker:
    """Links entity mentions to canonical entities."""
    
    def __init__(
        self, 
        embedding_model: Any,
        similarity_threshold: float = 0.85
    ):
        self.embedding_model = embedding_model
        self.similarity_threshold = similarity_threshold
        self.entity_index: dict[str, CanonicalEntity] = {}
        self.embedding_matrix: np.ndarray | None = None
        self.entity_ids: list[str] = []
    
    def build_index(self, entities: list[CanonicalEntity]):
        """Build search index from canonical entities."""
        self.entity_index = {e.entity_id: e for e in entities}
        self.entity_ids = list(self.entity_index.keys())
        
        # Compute embeddings for all entities
        texts = []
        for eid in self.entity_ids:
            entity = self.entity_index[eid]
            text = f"{entity.name}. {entity.description}"
            texts.append(text)
        
        embeddings = self.embedding_model.encode(texts)
        self.embedding_matrix = np.array(embeddings)
        
        # Normalize for cosine similarity
        norms = np.linalg.norm(self.embedding_matrix, axis=1, keepdims=True)
        self.embedding_matrix = self.embedding_matrix / (norms + 1e-10)
    
    def link(self, mention: EntityMention) -> tuple[str | None, float]:
        """Link a mention to a canonical entity."""
        if self.embedding_matrix is None:
            raise ValueError("Index not built. Call build_index first.")
        
        # Compute mention embedding
        mention_emb = self.embedding_model.encode([mention.text])[0]
        mention_emb = mention_emb / (np.linalg.norm(mention_emb) + 1e-10)
        
        # Compute similarities
        similarities = self.embedding_matrix @ mention_emb
        best_idx = np.argmax(similarities)
        best_score = similarities[best_idx]
        
        if best_score >= self.similarity_threshold:
            return self.entity_ids[best_idx], float(best_score)
        return None, float(best_score)
    
    def link_batch(
        self, 
        mentions: list[EntityMention]
    ) -> list[tuple[str | None, float]]:
        """Link multiple mentions efficiently."""
        if not mentions:
            return []
        
        texts = [m.text for m in mentions]
        mention_embs = self.embedding_model.encode(texts)
        mention_embs = np.array(mention_embs)
        norms = np.linalg.norm(mention_embs, axis=1, keepdims=True)
        mention_embs = mention_embs / (norms + 1e-10)
        
        similarities = mention_embs @ self.embedding_matrix.T
        best_indices = np.argmax(similarities, axis=1)
        best_scores = similarities[np.arange(len(mentions)), best_indices]
        
        results = []
        for idx, score in zip(best_indices, best_scores):
            if score >= self.similarity_threshold:
                results.append((self.entity_ids[idx], float(score)))
            else:
                results.append((None, float(score)))
        
        return results


class CoreferenceResolver:
    """Resolves coreferences within documents."""
    
    def __init__(self, embedding_model: Any, threshold: float = 0.8):
        self.embedding_model = embedding_model
        self.threshold = threshold
    
    def resolve(
        self, 
        mentions: list[EntityMention]
    ) -> list[list[EntityMention]]:
        """Cluster mentions that refer to the same entity."""
        if not mentions:
            return []
        
        # Compute embeddings
        texts = [m.text for m in mentions]
        embeddings = np.array(self.embedding_model.encode(texts))
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
        embeddings = embeddings / (norms + 1e-10)
        
        # Compute pairwise similarities
        similarity_matrix = embeddings @ embeddings.T
        
        # Agglomerative clustering with threshold
        n = len(mentions)
        cluster_assignments = list(range(n))
        
        for i in range(n):
            for j in range(i + 1, n):
                if similarity_matrix[i, j] >= self.threshold:
                    # Merge clusters
                    old_cluster = cluster_assignments[j]
                    new_cluster = cluster_assignments[i]
                    for k in range(n):
                        if cluster_assignments[k] == old_cluster:
                            cluster_assignments[k] = new_cluster
        
        # Group by cluster
        clusters = defaultdict(list)
        for idx, cluster_id in enumerate(cluster_assignments):
            clusters[cluster_id].append(mentions[idx])
        
        return list(clusters.values())

218.3.4 Graph Construction Pipeline

The complete pipeline integrates extraction, coreference resolution, and entity linking to construct a knowledge graph from a document corpus.

Code
from dataclasses import dataclass, field
from typing import Any
import hashlib
from collections import defaultdict


@dataclass
class KnowledgeGraph:
    """A knowledge graph with entities and relations."""
    entities: dict[str, Entity] = field(default_factory=dict)
    relations: list[Relation] = field(default_factory=list)
    entity_to_documents: dict[str, set[str]] = field(
        default_factory=lambda: defaultdict(set)
    )
    
    def add_entity(self, entity: Entity, document_id: str | None = None):
        """Add an entity to the graph."""
        entity_id = self._entity_id(entity)
        if entity_id not in self.entities:
            self.entities[entity_id] = entity
        if document_id:
            self.entity_to_documents[entity_id].add(document_id)
        return entity_id
    
    def add_relation(self, relation: Relation):
        """Add a relation to the graph."""
        if relation not in self.relations:
            self.relations.append(relation)
    
    def _entity_id(self, entity: Entity) -> str:
        """Generate stable ID for entity."""
        key = f"{entity.name.lower()}:{entity.entity_type.lower()}"
        return hashlib.sha256(key.encode()).hexdigest()[:16]
    
    def get_neighbors(
        self, 
        entity_id: str, 
        relation_type: str | None = None
    ) -> list[tuple[str, str, str]]:
        """Get neighboring entities via relations."""
        neighbors = []
        entity = self.entities.get(entity_id)
        if not entity:
            return neighbors
        
        for rel in self.relations:
            if rel.head == entity:
                if relation_type is None or rel.relation_type == relation_type:
                    tail_id = self._entity_id(rel.tail)
                    neighbors.append((tail_id, rel.relation_type, "outgoing"))
            elif rel.tail == entity:
                if relation_type is None or rel.relation_type == relation_type:
                    head_id = self._entity_id(rel.head)
                    neighbors.append((head_id, rel.relation_type, "incoming"))
        
        return neighbors
    
    def to_triples(self) -> list[tuple[str, str, str]]:
        """Convert to list of (head_id, relation, tail_id) triples."""
        triples = []
        for rel in self.relations:
            head_id = self._entity_id(rel.head)
            tail_id = self._entity_id(rel.tail)
            triples.append((head_id, rel.relation_type, tail_id))
        return triples
    
    def subgraph(
        self, 
        entity_ids: set[str], 
        max_hops: int = 1
    ) -> "KnowledgeGraph":
        """Extract subgraph around specified entities."""
        # Expand to include neighbors up to max_hops
        current_ids = entity_ids.copy()
        for _ in range(max_hops):
            new_ids = set()
            for eid in current_ids:
                for neighbor_id, _, _ in self.get_neighbors(eid):
                    new_ids.add(neighbor_id)
            current_ids = current_ids.union(new_ids)
        
        # Build subgraph
        subgraph = KnowledgeGraph()
        for eid in current_ids:
            if eid in self.entities:
                subgraph.entities[eid] = self.entities[eid]
                subgraph.entity_to_documents[eid] = self.entity_to_documents[eid]
        
        for rel in self.relations:
            head_id = self._entity_id(rel.head)
            tail_id = self._entity_id(rel.tail)
            if head_id in current_ids and tail_id in current_ids:
                subgraph.relations.append(rel)
        
        return subgraph


class GraphConstructionPipeline:
    """End-to-end pipeline for knowledge graph construction."""
    
    def __init__(
        self,
        extractor: KnowledgeExtractor,
        linker: EntityLinker | None = None,
        coreference_resolver: CoreferenceResolver | None = None
    ):
        self.extractor = extractor
        self.linker = linker
        self.coreference_resolver = coreference_resolver
    
    def process_documents(
        self, 
        documents: list[dict[str, str]]
    ) -> KnowledgeGraph:
        """
        Process documents to build knowledge graph.
        
        Args:
            documents: List of {"id": ..., "text": ...} dictionaries
        
        Returns:
            Constructed knowledge graph
        """
        graph = KnowledgeGraph()
        all_mentions = []
        mention_to_doc = {}
        
        # Extract from each document
        for doc in documents:
            doc_id = doc["id"]
            text = doc["text"]
            
            entities, relations = self.extractor.extract(text)
            
            for entity in entities:
                entity_id = graph.add_entity(entity, doc_id)
                mention = EntityMention(
                    text=entity.name,
                    start_char=0,
                    end_char=len(entity.name),
                    document_id=doc_id,
                    entity_type=entity.entity_type
                )
                all_mentions.append(mention)
                mention_to_doc[id(mention)] = (entity, doc_id)
            
            for relation in relations:
                graph.add_relation(relation)
        
        # Apply coreference resolution if available
        if self.coreference_resolver:
            clusters = self.coreference_resolver.resolve(all_mentions)
            # Merge entities within clusters
            for cluster in clusters:
                if len(cluster) > 1:
                    self._merge_cluster(graph, cluster, mention_to_doc)
        
        return graph
    
    def _merge_cluster(
        self, 
        graph: KnowledgeGraph, 
        cluster: list[EntityMention],
        mention_to_doc: dict
    ):
        """Merge entities in a coreference cluster."""
        # Use most frequent mention as canonical
        mention_counts = defaultdict(int)
        for mention in cluster:
            mention_counts[mention.text.lower()] += 1
        
        canonical_text = max(mention_counts, key=mention_counts.get)
        
        # Find the canonical entity
        canonical_entity = None
        for mention in cluster:
            if mention.text.lower() == canonical_text:
                entity, _ = mention_to_doc.get(id(mention), (None, None))
                if entity:
                    canonical_entity = entity
                    break
        
        if not canonical_entity:
            return
        
        # Update relations to point to canonical entity
        canonical_id = graph._entity_id(canonical_entity)
        for mention in cluster:
            entity, _ = mention_to_doc.get(id(mention), (None, None))
            if entity and entity != canonical_entity:
                old_id = graph._entity_id(entity)
                # Remove old entity
                if old_id in graph.entities:
                    del graph.entities[old_id]
                # Update document mappings
                if old_id in graph.entity_to_documents:
                    graph.entity_to_documents[canonical_id].update(
                        graph.entity_to_documents[old_id]
                    )
                    del graph.entity_to_documents[old_id]

218.4 Graph-Based Retrieval

218.4.1 Embedding-Based Entity Retrieval

The first stage of graph-based retrieval identifies entities relevant to the query. Dense embedding methods project both query and entities into a shared vector space, enabling efficient similarity search.

Code
import numpy as np
from typing import Any
from dataclasses import dataclass
import heapq


@dataclass
class RetrievedEntity:
    """An entity retrieved for a query."""
    entity_id: str
    entity: Entity
    score: float
    retrieval_method: str


class EntityRetriever:
    """Retrieves entities relevant to a query."""
    
    def __init__(
        self, 
        embedding_model: Any,
        graph: KnowledgeGraph
    ):
        self.embedding_model = embedding_model
        self.graph = graph
        self.entity_embeddings: np.ndarray | None = None
        self.entity_id_list: list[str] = []
        
    def build_index(self):
        """Build embedding index for all entities."""
        self.entity_id_list = list(self.graph.entities.keys())
        
        texts = []
        for eid in self.entity_id_list:
            entity = self.graph.entities[eid]
            text = f"{entity.name}. Type: {entity.entity_type}. {entity.description}"
            texts.append(text)
        
        embeddings = self.embedding_model.encode(texts)
        self.entity_embeddings = np.array(embeddings)
        
        # Normalize
        norms = np.linalg.norm(self.entity_embeddings, axis=1, keepdims=True)
        self.entity_embeddings = self.entity_embeddings / (norms + 1e-10)
    
    def retrieve(
        self, 
        query: str, 
        top_k: int = 10
    ) -> list[RetrievedEntity]:
        """Retrieve top-k entities for query."""
        if self.entity_embeddings is None:
            raise ValueError("Index not built")
        
        query_emb = self.embedding_model.encode([query])[0]
        query_emb = query_emb / (np.linalg.norm(query_emb) + 1e-10)
        
        scores = self.entity_embeddings @ query_emb
        top_indices = np.argsort(scores)[-top_k:][::-1]
        
        results = []
        for idx in top_indices:
            entity_id = self.entity_id_list[idx]
            results.append(RetrievedEntity(
                entity_id=entity_id,
                entity=self.graph.entities[entity_id],
                score=float(scores[idx]),
                retrieval_method="embedding"
            ))
        
        return results


class HybridEntityRetriever:
    """Combines embedding and keyword-based retrieval."""
    
    def __init__(
        self,
        embedding_model: Any,
        graph: KnowledgeGraph,
        bm25_weight: float = 0.3
    ):
        self.embedding_retriever = EntityRetriever(embedding_model, graph)
        self.graph = graph
        self.bm25_weight = bm25_weight
        self.bm25_index = None
        
    def build_index(self):
        """Build both embedding and BM25 indices."""
        self.embedding_retriever.build_index()
        self._build_bm25_index()
    
    def _build_bm25_index(self):
        """Build BM25 index for keyword matching."""
        # Simple TF-IDF based index
        from collections import Counter
        import math
        
        self.doc_freqs = Counter()
        self.term_freqs = {}
        self.doc_lengths = {}
        
        for eid in self.graph.entities:
            entity = self.graph.entities[eid]
            text = f"{entity.name} {entity.entity_type} {entity.description}"
            tokens = text.lower().split()
            
            self.term_freqs[eid] = Counter(tokens)
            self.doc_lengths[eid] = len(tokens)
            
            for token in set(tokens):
                self.doc_freqs[token] += 1
        
        self.avg_doc_length = np.mean(list(self.doc_lengths.values()))
        self.num_docs = len(self.graph.entities)
    
    def _bm25_score(
        self, 
        query_tokens: list[str], 
        entity_id: str,
        k1: float = 1.5,
        b: float = 0.75
    ) -> float:
        """Compute BM25 score for entity."""
        import math
        
        score = 0.0
        tf = self.term_freqs.get(entity_id, Counter())
        doc_len = self.doc_lengths.get(entity_id, 0)
        
        for token in query_tokens:
            if token not in tf:
                continue
            
            freq = tf[token]
            df = self.doc_freqs.get(token, 0)
            
            idf = math.log((self.num_docs - df + 0.5) / (df + 0.5) + 1)
            tf_component = (freq * (k1 + 1)) / (
                freq + k1 * (1 - b + b * doc_len / self.avg_doc_length)
            )
            
            score += idf * tf_component
        
        return score
    
    def retrieve(
        self, 
        query: str, 
        top_k: int = 10
    ) -> list[RetrievedEntity]:
        """Retrieve using hybrid scoring."""
        # Get embedding scores
        embedding_results = self.embedding_retriever.retrieve(query, top_k * 2)
        embedding_scores = {r.entity_id: r.score for r in embedding_results}
        
        # Get BM25 scores
        query_tokens = query.lower().split()
        bm25_scores = {}
        for eid in self.graph.entities:
            bm25_scores[eid] = self._bm25_score(query_tokens, eid)
        
        # Normalize scores
        max_emb = max(embedding_scores.values()) if embedding_scores else 1.0
        max_bm25 = max(bm25_scores.values()) if bm25_scores else 1.0
        
        # Combine scores
        combined_scores = {}
        all_ids = set(embedding_scores.keys()) | set(bm25_scores.keys())
        
        for eid in all_ids:
            emb_score = embedding_scores.get(eid, 0) / max_emb
            bm25_score = bm25_scores.get(eid, 0) / (max_bm25 + 1e-10)
            
            combined_scores[eid] = (
                (1 - self.bm25_weight) * emb_score + 
                self.bm25_weight * bm25_score
            )
        
        # Get top-k
        top_ids = heapq.nlargest(top_k, combined_scores, key=combined_scores.get)
        
        results = []
        for eid in top_ids:
            results.append(RetrievedEntity(
                entity_id=eid,
                entity=self.graph.entities[eid],
                score=combined_scores[eid],
                retrieval_method="hybrid"
            ))
        
        return results

218.4.2 Multi-Hop Graph Traversal

After identifying seed entities, graph traversal expands the context by following edges to discover related entities and paths. The traversal strategy must balance coverage with relevance, avoiding explosion of the subgraph while capturing necessary context.

Code
from dataclasses import dataclass, field
from collections import deque
from typing import Callable
import numpy as np


@dataclass
class TraversalPath:
    """A path through the knowledge graph."""
    nodes: list[str]  # Entity IDs
    edges: list[str]  # Relation types
    score: float = 1.0
    
    def __len__(self):
        return len(self.nodes)
    
    @property
    def head(self) -> str:
        return self.nodes[0]
    
    @property
    def tail(self) -> str:
        return self.nodes[-1]


class GraphTraverser:
    """Traverses knowledge graph to expand context."""
    
    def __init__(
        self,
        graph: KnowledgeGraph,
        max_hops: int = 2,
        max_neighbors_per_hop: int = 10,
        relevance_scorer: Callable[[str, str], float] | None = None
    ):
        self.graph = graph
        self.max_hops = max_hops
        self.max_neighbors = max_neighbors_per_hop
        self.relevance_scorer = relevance_scorer or (lambda q, e: 0.0)
    
    def bfs_traverse(
        self,
        seed_entities: list[str],
        query: str
    ) -> tuple[set[str], list[TraversalPath]]:
        """
        Breadth-first traversal from seed entities.
        
        Returns:
            Tuple of (visited entity IDs, paths found)
        """
        visited = set(seed_entities)
        paths = [TraversalPath(nodes=[eid], edges=[]) for eid in seed_entities]
        all_paths = paths.copy()
        
        for hop in range(self.max_hops):
            new_paths = []
            
            for path in paths:
                current_id = path.tail
                neighbors = self.graph.get_neighbors(current_id)
                
                # Score neighbors by relevance
                scored_neighbors = []
                for neighbor_id, relation, direction in neighbors:
                    if neighbor_id not in visited:
                        entity = self.graph.entities.get(neighbor_id)
                        if entity:
                            score = self.relevance_scorer(query, entity.name)
                            scored_neighbors.append(
                                (neighbor_id, relation, direction, score)
                            )
                
                # Take top neighbors
                scored_neighbors.sort(key=lambda x: x[3], reverse=True)
                top_neighbors = scored_neighbors[:self.max_neighbors]
                
                for neighbor_id, relation, direction, score in top_neighbors:
                    visited.add(neighbor_id)
                    new_path = TraversalPath(
                        nodes=path.nodes + [neighbor_id],
                        edges=path.edges + [relation],
                        score=path.score * (0.5 + 0.5 * score)
                    )
                    new_paths.append(new_path)
                    all_paths.append(new_path)
            
            paths = new_paths
        
        return visited, all_paths
    
    def beam_search_traverse(
        self,
        seed_entities: list[str],
        query: str,
        beam_width: int = 5
    ) -> list[TraversalPath]:
        """
        Beam search traversal prioritizing relevant paths.
        """
        # Initialize beam with seed entities
        beam = [
            TraversalPath(nodes=[eid], edges=[], score=1.0) 
            for eid in seed_entities
        ]
        
        for hop in range(self.max_hops):
            candidates = []
            
            for path in beam:
                current_id = path.tail
                neighbors = self.graph.get_neighbors(current_id)
                
                for neighbor_id, relation, direction in neighbors:
                    if neighbor_id not in path.nodes:  # Avoid cycles
                        entity = self.graph.entities.get(neighbor_id)
                        if entity:
                            score = self.relevance_scorer(query, entity.name)
                            new_path = TraversalPath(
                                nodes=path.nodes + [neighbor_id],
                                edges=path.edges + [relation],
                                score=path.score * (0.5 + 0.5 * score)
                            )
                            candidates.append(new_path)
            
            # Select top paths for beam
            candidates.sort(key=lambda p: p.score, reverse=True)
            beam = candidates[:beam_width]
            
            if not beam:
                break
        
        return beam
    
    def relation_constrained_traverse(
        self,
        seed_entities: list[str],
        relation_sequence: list[str]
    ) -> list[TraversalPath]:
        """
        Traverse following a specific sequence of relations.
        """
        current_paths = [
            TraversalPath(nodes=[eid], edges=[]) 
            for eid in seed_entities
        ]
        
        for required_relation in relation_sequence:
            new_paths = []
            
            for path in current_paths:
                current_id = path.tail
                neighbors = self.graph.get_neighbors(
                    current_id, 
                    relation_type=required_relation
                )
                
                for neighbor_id, relation, direction in neighbors:
                    new_path = TraversalPath(
                        nodes=path.nodes + [neighbor_id],
                        edges=path.edges + [relation],
                        score=path.score
                    )
                    new_paths.append(new_path)
            
            current_paths = new_paths
            
            if not current_paths:
                break
        
        return current_paths

218.4.3 Graph Neural Network Retrieval

Graph neural networks can score subgraph relevance to a query by learning to propagate query information through the graph structure and computing relevance scores at the subgraph level.

Code
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.data import Data, Batch
from typing import Any
import numpy as np


class QueryConditionedGNN(nn.Module):
    """GNN that conditions node representations on query."""
    
    def __init__(
        self,
        node_dim: int,
        query_dim: int,
        hidden_dim: int,
        num_layers: int = 3,
        num_heads: int = 4,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.query_proj = nn.Linear(query_dim, hidden_dim)
        self.node_proj = nn.Linear(node_dim, hidden_dim)
        
        # GAT layers
        self.gat_layers = nn.ModuleList()
        for i in range(num_layers):
            in_dim = hidden_dim if i == 0 else hidden_dim * num_heads
            self.gat_layers.append(
                GATConv(in_dim, hidden_dim, heads=num_heads, dropout=dropout)
            )
        
        # Query-node attention
        self.query_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * num_heads,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Final scoring
        self.score_mlp = nn.Sequential(
            nn.Linear(hidden_dim * num_heads * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
    
    def forward(
        self,
        node_features: torch.Tensor,  # (num_nodes, node_dim)
        edge_index: torch.Tensor,  # (2, num_edges)
        query_embedding: torch.Tensor,  # (batch_size, query_dim)
        batch: torch.Tensor  # Node to graph assignment
    ) -> torch.Tensor:
        """
        Compute relevance scores for subgraphs.
        
        Returns:
            Tensor of shape (batch_size,) with relevance scores
        """
        # Project inputs
        h = self.node_proj(node_features)  # (num_nodes, hidden_dim)
        q = self.query_proj(query_embedding)  # (batch_size, hidden_dim)
        
        # Expand query to match nodes per graph
        batch_size = query_embedding.size(0)
        q_expanded = q[batch]  # (num_nodes, hidden_dim)
        
        # Condition node features on query
        h = h + q_expanded
        
        # GNN message passing
        for gat in self.gat_layers:
            h = F.elu(gat(h, edge_index))
        
        # Pool to graph level
        graph_emb = global_mean_pool(h, batch)  # (batch_size, hidden_dim * heads)
        
        # Query-graph attention
        q_final = q.unsqueeze(1).expand(-1, 1, self.hidden_dim)
        q_final = q_final.repeat(1, 1, self.num_heads)  # Match dimensions
        
        attn_output, _ = self.query_attention(
            q_final,
            graph_emb.unsqueeze(1),
            graph_emb.unsqueeze(1)
        )
        attn_output = attn_output.squeeze(1)
        
        # Compute score
        combined = torch.cat([graph_emb, attn_output], dim=-1)
        scores = self.score_mlp(combined).squeeze(-1)
        
        return scores


class GNNRetriever:
    """Retriever using GNN for subgraph scoring."""
    
    def __init__(
        self,
        gnn_model: QueryConditionedGNN,
        embedding_model: Any,
        graph: KnowledgeGraph,
        device: str = "cuda"
    ):
        self.gnn = gnn_model.to(device)
        self.embedding_model = embedding_model
        self.graph = graph
        self.device = device
        
        self._build_graph_tensors()
    
    def _build_graph_tensors(self):
        """Build tensor representations of the graph."""
        self.entity_id_to_idx = {
            eid: idx for idx, eid in enumerate(self.graph.entities.keys())
        }
        self.idx_to_entity_id = {
            idx: eid for eid, idx in self.entity_id_to_idx.items()
        }
        
        # Build edge index
        edges = []
        for rel in self.graph.relations:
            head_id = self.graph._entity_id(rel.head)
            tail_id = self.graph._entity_id(rel.tail)
            
            if head_id in self.entity_id_to_idx and tail_id in self.entity_id_to_idx:
                head_idx = self.entity_id_to_idx[head_id]
                tail_idx = self.entity_id_to_idx[tail_id]
                edges.append([head_idx, tail_idx])
        
        self.edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        
        # Build node features (entity embeddings)
        texts = []
        for eid in self.entity_id_to_idx:
            entity = self.graph.entities[eid]
            text = f"{entity.name}. {entity.description}"
            texts.append(text)
        
        embeddings = self.embedding_model.encode(texts)
        self.node_features = torch.tensor(embeddings, dtype=torch.float32)
    
    def retrieve_subgraph(
        self,
        query: str,
        seed_entities: list[str],
        max_hops: int = 2,
        top_k: int = 5
    ) -> list[tuple[set[str], float]]:
        """
        Retrieve and score subgraphs around seed entities.
        
        Returns:
            List of (entity_id_set, score) tuples
        """
        # Generate candidate subgraphs
        traverser = GraphTraverser(
            self.graph,
            max_hops=max_hops,
            max_neighbors_per_hop=10
        )
        
        subgraph_candidates = []
        for seed_id in seed_entities:
            visited, paths = traverser.bfs_traverse([seed_id], query)
            subgraph_candidates.append(visited)
        
        # Also try combinations of seeds
        if len(seed_entities) > 1:
            combined = set()
            for seed_id in seed_entities:
                visited, _ = traverser.bfs_traverse([seed_id], query)
                combined.update(visited)
            subgraph_candidates.append(combined)
        
        # Score subgraphs
        query_embedding = torch.tensor(
            self.embedding_model.encode([query]),
            dtype=torch.float32
        ).to(self.device)
        
        scored_subgraphs = []
        
        self.gnn.eval()
        with torch.no_grad():
            for entity_ids in subgraph_candidates:
                # Build subgraph data
                node_indices = [
                    self.entity_id_to_idx[eid] 
                    for eid in entity_ids 
                    if eid in self.entity_id_to_idx
                ]
                
                if not node_indices:
                    continue
                
                # Remap indices for subgraph
                idx_map = {old: new for new, old in enumerate(node_indices)}
                
                subgraph_features = self.node_features[node_indices].to(self.device)
                
                # Build subgraph edges
                subgraph_edges = []
                for i in range(self.edge_index.size(1)):
                    src, dst = self.edge_index[:, i].tolist()
                    if src in idx_map and dst in idx_map:
                        subgraph_edges.append([idx_map[src], idx_map[dst]])
                
                if subgraph_edges:
                    subgraph_edge_index = torch.tensor(
                        subgraph_edges, dtype=torch.long
                    ).t().contiguous().to(self.device)
                else:
                    subgraph_edge_index = torch.empty(
                        (2, 0), dtype=torch.long
                    ).to(self.device)
                
                batch = torch.zeros(
                    len(node_indices), dtype=torch.long
                ).to(self.device)
                
                score = self.gnn(
                    subgraph_features,
                    subgraph_edge_index,
                    query_embedding,
                    batch
                )
                
                scored_subgraphs.append((entity_ids, float(score.cpu())))
        
        # Return top-k
        scored_subgraphs.sort(key=lambda x: x[1], reverse=True)
        return scored_subgraphs[:top_k]

218.4.4 Community Detection for Document Clustering

For large knowledge graphs, community detection algorithms identify densely connected clusters of entities that can serve as coherent retrieval units. The Louvain algorithm optimizes modularity, a measure of the quality of a partition:

\[Q = \frac{1}{2m} \sum_{ij} \left[ A_{ij} - \frac{k_i k_j}{2m} \right] \delta(c_i, c_j)\]

where \(k_i\) is the degree of node \(i\), \(c_i\) is the community assignment of node \(i\), \(m\) is the total number of edges, and \(\delta\) is the Kronecker delta.

Code
from collections import defaultdict
import random


class LouvainCommunityDetector:
    """Detects communities using the Louvain algorithm."""
    
    def __init__(self, graph: KnowledgeGraph, resolution: float = 1.0):
        self.graph = graph
        self.resolution = resolution
        self._build_adjacency()
    
    def _build_adjacency(self):
        """Build adjacency structure for modularity computation."""
        self.adjacency = defaultdict(lambda: defaultdict(float))
        self.degrees = defaultdict(float)
        self.total_weight = 0.0
        
        entity_ids = list(self.graph.entities.keys())
        
        for rel in self.graph.relations:
            head_id = self.graph._entity_id(rel.head)
            tail_id = self.graph._entity_id(rel.tail)
            
            if head_id in self.graph.entities and tail_id in self.graph.entities:
                self.adjacency[head_id][tail_id] += 1.0
                self.adjacency[tail_id][head_id] += 1.0
                self.degrees[head_id] += 1.0
                self.degrees[tail_id] += 1.0
                self.total_weight += 2.0
    
    def _modularity_gain(
        self,
        node: str,
        community: set[str],
        node_to_community: dict[str, int]
    ) -> float:
        """Compute modularity gain from moving node to community."""
        if self.total_weight == 0:
            return 0.0
        
        ki = self.degrees[node]
        ki_in = sum(
            self.adjacency[node].get(neighbor, 0)
            for neighbor in community
        )
        sigma_tot = sum(self.degrees[n] for n in community)
        
        gain = (
            ki_in / self.total_weight -
            self.resolution * ki * sigma_tot / (self.total_weight ** 2)
        )
        
        return gain
    
    def detect(self) -> dict[str, int]:
        """
        Detect communities using Louvain algorithm.
        
        Returns:
            Dictionary mapping entity IDs to community IDs
        """
        nodes = list(self.graph.entities.keys())
        
        # Initialize: each node in its own community
        node_to_community = {node: i for i, node in enumerate(nodes)}
        communities = {i: {node} for i, node in enumerate(nodes)}
        
        improved = True
        iteration = 0
        max_iterations = 100
        
        while improved and iteration < max_iterations:
            improved = False
            iteration += 1
            
            # Shuffle for randomization
            random.shuffle(nodes)
            
            for node in nodes:
                current_community_id = node_to_community[node]
                current_community = communities[current_community_id]
                
                # Remove node from current community
                current_community.discard(node)
                
                best_gain = 0.0
                best_community_id = current_community_id
                
                # Try neighboring communities
                neighbor_communities = set()
                for neighbor in self.adjacency[node]:
                    neighbor_communities.add(node_to_community[neighbor])
                
                for community_id in neighbor_communities:
                    community = communities[community_id]
                    gain = self._modularity_gain(
                        node, community, node_to_community
                    )
                    
                    if gain > best_gain:
                        best_gain = gain
                        best_community_id = community_id
                
                # Move to best community
                if best_community_id != current_community_id:
                    improved = True
                
                communities[best_community_id].add(node)
                node_to_community[node] = best_community_id
                
                # Clean up empty community
                if not current_community and current_community_id != best_community_id:
                    del communities[current_community_id]
        
        # Renumber communities
        community_ids = sorted(communities.keys())
        id_mapping = {old: new for new, old in enumerate(community_ids)}
        
        return {
            node: id_mapping[comm_id]
            for node, comm_id in node_to_community.items()
        }


class CommunitySummarizer:
    """Generates summaries for entity communities."""
    
    def __init__(self, llm_client: Any, graph: KnowledgeGraph):
        self.llm = llm_client
        self.graph = graph
    
    def summarize_community(
        self,
        entity_ids: set[str],
        max_entities: int = 20
    ) -> str:
        """Generate a summary for a community of entities."""
        entities = [
            self.graph.entities[eid]
            for eid in list(entity_ids)[:max_entities]
            if eid in self.graph.entities
        ]
        
        if not entities:
            return ""
        
        entity_descriptions = "\n".join(
            f"- {e.name} ({e.entity_type}): {e.description}"
            for e in entities
        )
        
        # Get internal relations
        internal_relations = []
        for rel in self.graph.relations:
            head_id = self.graph._entity_id(rel.head)
            tail_id = self.graph._entity_id(rel.tail)
            if head_id in entity_ids and tail_id in entity_ids:
                internal_relations.append(
                    f"{rel.head.name} --[{rel.relation_type}]--> {rel.tail.name}"
                )
        
        relations_text = "\n".join(internal_relations[:30])
        
        prompt = f"""Summarize the following group of related entities into a 
coherent paragraph that captures the main themes, relationships, and significance.

Entities:
{entity_descriptions}

Key Relationships:
{relations_text}

Write a concise summary (2-3 sentences) that would help someone understand 
what this group of entities represents and how they relate to each other."""

        summary = self.llm.generate(user_prompt=prompt)
        return summary
    
    def summarize_all_communities(
        self,
        communities: dict[str, int]
    ) -> dict[int, str]:
        """Generate summaries for all communities."""
        # Group entities by community
        community_entities = defaultdict(set)
        for entity_id, community_id in communities.items():
            community_entities[community_id].add(entity_id)
        
        summaries = {}
        for community_id, entity_ids in community_entities.items():
            if len(entity_ids) >= 2:  # Only summarize non-trivial communities
                summaries[community_id] = self.summarize_community(entity_ids)
        
        return summaries

218.5 Context Assembly and Generation

218.5.1 Hierarchical Context Construction

Retrieved graph elements must be assembled into a coherent context for the language model. Hierarchical construction organizes context from general to specific, starting with community summaries and proceeding to entity details and relationships.

Code
from dataclasses import dataclass, field
from typing import Any


@dataclass
class GraphContext:
    """Structured context extracted from knowledge graph."""
    community_summaries: list[str] = field(default_factory=list)
    entity_descriptions: list[str] = field(default_factory=list)
    relationships: list[str] = field(default_factory=list)
    paths: list[str] = field(default_factory=list)
    source_documents: list[str] = field(default_factory=list)
    
    def to_text(self, max_tokens: int = 4000) -> str:
        """Convert to text format for LLM context."""
        sections = []
        
        if self.community_summaries:
            sections.append("## Overview\n" + "\n\n".join(self.community_summaries))
        
        if self.entity_descriptions:
            sections.append("## Key Entities\n" + "\n".join(self.entity_descriptions))
        
        if self.relationships:
            sections.append("## Relationships\n" + "\n".join(self.relationships))
        
        if self.paths:
            sections.append("## Reasoning Paths\n" + "\n".join(self.paths))
        
        if self.source_documents:
            sections.append("## Source Documents\n" + "\n---\n".join(self.source_documents))
        
        full_text = "\n\n".join(sections)
        
        # Simple truncation (in practice, use proper tokenization)
        if len(full_text) > max_tokens * 4:  # Rough char to token ratio
            full_text = full_text[:max_tokens * 4] + "\n[Context truncated...]"
        
        return full_text


class ContextAssembler:
    """Assembles context from graph retrieval results."""
    
    def __init__(
        self,
        graph: KnowledgeGraph,
        community_summaries: dict[int, str] | None = None,
        communities: dict[str, int] | None = None
    ):
        self.graph = graph
        self.community_summaries = community_summaries or {}
        self.communities = communities or {}
    
    def assemble(
        self,
        retrieved_entities: list[RetrievedEntity],
        paths: list[TraversalPath],
        include_communities: bool = True,
        max_entities: int = 20,
        max_relationships: int = 30,
        max_paths: int = 10
    ) -> GraphContext:
        """Assemble context from retrieval results."""
        context = GraphContext()
        
        # Collect entity IDs
        entity_ids = set()
        for re in retrieved_entities[:max_entities]:
            entity_ids.add(re.entity_id)
        for path in paths[:max_paths]:
            entity_ids.update(path.nodes)
        
        # Add community summaries
        if include_communities and self.community_summaries:
            relevant_communities = set()
            for eid in entity_ids:
                if eid in self.communities:
                    relevant_communities.add(self.communities[eid])
            
            for comm_id in relevant_communities:
                if comm_id in self.community_summaries:
                    context.community_summaries.append(
                        self.community_summaries[comm_id]
                    )
        
        # Add entity descriptions
        for re in retrieved_entities[:max_entities]:
            entity = re.entity
            desc = f"**{entity.name}** ({entity.entity_type})"
            if entity.description:
                desc += f": {entity.description}"
            context.entity_descriptions.append(desc)
        
        # Add relationships
        relationship_count = 0
        for rel in self.graph.relations:
            head_id = self.graph._entity_id(rel.head)
            tail_id = self.graph._entity_id(rel.tail)
            
            if head_id in entity_ids or tail_id in entity_ids:
                rel_text = f"{rel.head.name} --[{rel.relation_type}]--> {rel.tail.name}"
                context.relationships.append(rel_text)
                relationship_count += 1
                
                if relationship_count >= max_relationships:
                    break
        
        # Add paths as reasoning chains
        for path in paths[:max_paths]:
            if len(path.nodes) > 1:
                path_parts = []
                for i, node_id in enumerate(path.nodes):
                    entity = self.graph.entities.get(node_id)
                    if entity:
                        path_parts.append(entity.name)
                        if i < len(path.edges):
                            path_parts.append(f"--[{path.edges[i]}]-->")
                
                path_text = " ".join(path_parts)
                context.paths.append(path_text)
        
        # Add source document snippets
        for eid in list(entity_ids)[:10]:
            if eid in self.graph.entity_to_documents:
                docs = list(self.graph.entity_to_documents[eid])[:2]
                entity = self.graph.entities.get(eid)
                if entity and entity.source_text:
                    context.source_documents.append(
                        f"[{entity.name}] {entity.source_text[:500]}"
                    )
        
        return context


class GraphRAGGenerator:
    """Generates responses using graph-retrieved context."""
    
    def __init__(
        self,
        llm_client: Any,
        entity_retriever: HybridEntityRetriever,
        graph: KnowledgeGraph,
        traverser: GraphTraverser,
        context_assembler: ContextAssembler
    ):
        self.llm = llm_client
        self.entity_retriever = entity_retriever
        self.graph = graph
        self.traverser = traverser
        self.context_assembler = context_assembler
    
    def generate(
        self,
        query: str,
        top_k_entities: int = 10,
        max_hops: int = 2,
        include_reasoning: bool = True
    ) -> dict[str, Any]:
        """
        Generate response using graph-based retrieval.
        
        Returns:
            Dictionary with response, context, and metadata
        """
        # Retrieve relevant entities
        retrieved_entities = self.entity_retriever.retrieve(
            query, 
            top_k=top_k_entities
        )
        
        # Traverse graph from retrieved entities
        seed_ids = [re.entity_id for re in retrieved_entities[:5]]
        visited, paths = self.traverser.bfs_traverse(seed_ids, query)
        
        # Assemble context
        context = self.context_assembler.assemble(
            retrieved_entities=retrieved_entities,
            paths=paths,
            max_entities=15,
            max_relationships=25,
            max_paths=8
        )
        
        # Build prompt
        context_text = context.to_text()
        
        system_prompt = """You are a helpful assistant that answers questions 
using the provided knowledge graph context. Base your answers on the information 
in the context. If the context doesn't contain enough information to fully 
answer the question, acknowledge this and provide what you can based on the 
available information.

When reasoning about relationships between entities, trace the connections 
explicitly using the paths provided."""

        user_prompt = f"""Context from Knowledge Graph:

{context_text}

Question: {query}

Please provide a comprehensive answer based on the context above."""

        if include_reasoning:
            user_prompt += """

First, briefly explain your reasoning by tracing relevant entity relationships, 
then provide your final answer."""

        response = self.llm.generate(
            system_prompt=system_prompt,
            user_prompt=user_prompt
        )
        
        return {
            "response": response,
            "context": context,
            "retrieved_entities": [
                {"name": re.entity.name, "score": re.score}
                for re in retrieved_entities
            ],
            "paths_traversed": len(paths),
            "entities_in_context": len(visited)
        }

218.5.2 Query Decomposition for Multi-Hop Questions

Complex queries often require decomposition into sub-queries that can be addressed independently and then synthesized. The decomposition strategy identifies the relational structure implicit in the query and generates targeted sub-queries.

Code
from dataclasses import dataclass
from typing import Any
import json


@dataclass
class DecomposedQuery:
    """A query decomposed into sub-queries."""
    original_query: str
    sub_queries: list[str]
    dependencies: list[list[int]]  # Which sub-queries depend on which
    aggregation_strategy: str  # "union", "intersection", "chain", "synthesis"


class QueryDecomposer:
    """Decomposes complex queries into sub-queries."""
    
    def __init__(self, llm_client: Any):
        self.llm = llm_client
    
    def decompose(self, query: str) -> DecomposedQuery:
        """Decompose a complex query into sub-queries."""
        prompt = f"""Analyze the following question and determine if it requires 
multi-step reasoning. If so, decompose it into simpler sub-questions that can 
be answered independently and then combined.

Question: {query}

Respond with a JSON object:
{{
    "needs_decomposition": true/false,
    "sub_queries": ["sub-question 1", "sub-question 2", ...],
    "dependencies": [[indices of sub-queries that sub-query 0 depends on], ...],
    "aggregation_strategy": "union" | "intersection" | "chain" | "synthesis",
    "reasoning": "brief explanation of decomposition strategy"
}}

Aggregation strategies:
- "union": Combine results from all sub-queries
- "intersection": Find common elements across sub-query results  
- "chain": Results from earlier sub-queries feed into later ones
- "synthesis": Combine insights to form a new conclusion

If the question is simple and doesn't need decomposition, return:
{{"needs_decomposition": false, "sub_queries": ["{query}"], "dependencies": [[]], "aggregation_strategy": "union"}}"""

        response = self.llm.generate(
            user_prompt=prompt,
            response_format="json"
        )
        
        parsed = json.loads(response)
        
        return DecomposedQuery(
            original_query=query,
            sub_queries=parsed["sub_queries"],
            dependencies=parsed["dependencies"],
            aggregation_strategy=parsed["aggregation_strategy"]
        )


class MultiHopGraphRAG:
    """Graph RAG with query decomposition and multi-hop reasoning."""
    
    def __init__(
        self,
        llm_client: Any,
        graph_rag: GraphRAGGenerator,
        decomposer: QueryDecomposer
    ):
        self.llm = llm_client
        self.graph_rag = graph_rag
        self.decomposer = decomposer
    
    def generate(self, query: str) -> dict[str, Any]:
        """Generate response with multi-hop reasoning."""
        # Decompose query
        decomposed = self.decomposer.decompose(query)
        
        if len(decomposed.sub_queries) == 1:
            # Simple query, no decomposition needed
            return self.graph_rag.generate(query)
        
        # Execute sub-queries in dependency order
        sub_results = [None] * len(decomposed.sub_queries)
        executed = set()
        
        while len(executed) < len(decomposed.sub_queries):
            for i, sub_query in enumerate(decomposed.sub_queries):
                if i in executed:
                    continue
                
                # Check if dependencies are satisfied
                deps = decomposed.dependencies[i]
                if all(d in executed for d in deps):
                    # Augment sub-query with prior results if chained
                    augmented_query = sub_query
                    if deps and decomposed.aggregation_strategy == "chain":
                        prior_context = "\n".join(
                            f"Prior finding: {sub_results[d]['response'][:500]}"
                            for d in deps
                            if sub_results[d]
                        )
                        augmented_query = f"{prior_context}\n\nCurrent question: {sub_query}"
                    
                    sub_results[i] = self.graph_rag.generate(augmented_query)
                    executed.add(i)
        
        # Aggregate results
        final_response = self._aggregate_results(
            query=query,
            decomposed=decomposed,
            sub_results=sub_results
        )
        
        return {
            "response": final_response,
            "decomposed_query": decomposed,
            "sub_results": sub_results
        }
    
    def _aggregate_results(
        self,
        query: str,
        decomposed: DecomposedQuery,
        sub_results: list[dict]
    ) -> str:
        """Aggregate sub-query results into final response."""
        sub_findings = "\n\n".join(
            f"Sub-question {i+1}: {decomposed.sub_queries[i]}\n"
            f"Finding: {result['response']}"
            for i, result in enumerate(sub_results)
            if result
        )
        
        synthesis_prompt = f"""Original question: {query}

The question was broken down into sub-questions, and here are the findings:

{sub_findings}

Aggregation strategy: {decomposed.aggregation_strategy}

Based on these findings, provide a comprehensive answer to the original question.
Synthesize the information coherently, resolving any conflicts and highlighting
key connections between the sub-findings."""

        return self.llm.generate(user_prompt=synthesis_prompt)

218.6 Production Architecture

218.6.1 Indexing Pipeline

Production graph RAG systems require robust indexing pipelines that process documents, construct graphs, compute embeddings, and maintain indices incrementally.

Code
from dataclasses import dataclass, field
from typing import Any, Iterator
import hashlib
from datetime import datetime
from abc import ABC, abstractmethod
import json


@dataclass
class Document:
    """A document to be indexed."""
    document_id: str
    text: str
    metadata: dict = field(default_factory=dict)
    timestamp: datetime = field(default_factory=datetime.now)


@dataclass
class IndexingResult:
    """Result of indexing a document."""
    document_id: str
    entities_extracted: int
    relations_extracted: int
    processing_time_ms: float
    success: bool
    error_message: str | None = None


class DocumentStore(ABC):
    """Abstract document storage."""
    
    @abstractmethod
    def store(self, document: Document) -> None:
        pass
    
    @abstractmethod
    def get(self, document_id: str) -> Document | None:
        pass
    
    @abstractmethod
    def list_ids(self) -> list[str]:
        pass


class GraphStore(ABC):
    """Abstract graph storage."""
    
    @abstractmethod
    def add_entity(self, entity: Entity, document_id: str) -> str:
        pass
    
    @abstractmethod
    def add_relation(self, relation: Relation) -> None:
        pass
    
    @abstractmethod
    def get_entity(self, entity_id: str) -> Entity | None:
        pass
    
    @abstractmethod
    def get_relations(
        self, 
        entity_id: str, 
        relation_type: str | None = None
    ) -> list[Relation]:
        pass


class VectorStore(ABC):
    """Abstract vector storage for embeddings."""
    
    @abstractmethod
    def add(
        self, 
        ids: list[str], 
        embeddings: list[list[float]], 
        metadata: list[dict]
    ) -> None:
        pass
    
    @abstractmethod
    def search(
        self, 
        query_embedding: list[float], 
        top_k: int
    ) -> list[tuple[str, float]]:
        pass


class IndexingPipeline:
    """Production indexing pipeline for graph RAG."""
    
    def __init__(
        self,
        document_store: DocumentStore,
        graph_store: GraphStore,
        vector_store: VectorStore,
        extractor: KnowledgeExtractor,
        embedding_model: Any,
        batch_size: int = 10
    ):
        self.document_store = document_store
        self.graph_store = graph_store
        self.vector_store = vector_store
        self.extractor = extractor
        self.embedding_model = embedding_model
        self.batch_size = batch_size
    
    def index_document(self, document: Document) -> IndexingResult:
        """Index a single document."""
        import time
        start_time = time.time()
        
        try:
            # Store document
            self.document_store.store(document)
            
            # Extract entities and relations
            entities, relations = self.extractor.extract(document.text)
            
            # Store entities and compute embeddings
            entity_ids = []
            entity_texts = []
            
            for entity in entities:
                entity_id = self.graph_store.add_entity(entity, document.document_id)
                entity_ids.append(entity_id)
                entity_texts.append(
                    f"{entity.name}. {entity.entity_type}. {entity.description}"
                )
            
            # Store relations
            for relation in relations:
                self.graph_store.add_relation(relation)
            
            # Compute and store embeddings
            if entity_texts:
                embeddings = self.embedding_model.encode(entity_texts)
                metadata = [
                    {"document_id": document.document_id, "entity_name": e.name}
                    for e in entities
                ]
                self.vector_store.add(entity_ids, embeddings.tolist(), metadata)
            
            processing_time = (time.time() - start_time) * 1000
            
            return IndexingResult(
                document_id=document.document_id,
                entities_extracted=len(entities),
                relations_extracted=len(relations),
                processing_time_ms=processing_time,
                success=True
            )
            
        except Exception as e:
            processing_time = (time.time() - start_time) * 1000
            return IndexingResult(
                document_id=document.document_id,
                entities_extracted=0,
                relations_extracted=0,
                processing_time_ms=processing_time,
                success=False,
                error_message=str(e)
            )
    
    def index_batch(
        self, 
        documents: list[Document]
    ) -> Iterator[IndexingResult]:
        """Index documents in batches."""
        for i in range(0, len(documents), self.batch_size):
            batch = documents[i:i + self.batch_size]
            for doc in batch:
                yield self.index_document(doc)
    
    def reindex_all(self) -> list[IndexingResult]:
        """Reindex all documents."""
        document_ids = self.document_store.list_ids()
        results = []
        
        for doc_id in document_ids:
            doc = self.document_store.get(doc_id)
            if doc:
                result = self.index_document(doc)
                results.append(result)
        
        return results

218.6.2 Caching and Performance Optimization

Production systems require caching at multiple levels to achieve acceptable latency. Query result caching, embedding caching, and subgraph caching each address different performance bottlenecks.

Code
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from abc import ABC, abstractmethod
import hashlib
import time
from collections import OrderedDict
import threading


T = TypeVar('T')


class Cache(ABC, Generic[T]):
    """Abstract cache interface."""
    
    @abstractmethod
    def get(self, key: str) -> T | None:
        pass
    
    @abstractmethod
    def set(self, key: str, value: T, ttl_seconds: int | None = None) -> None:
        pass
    
    @abstractmethod
    def delete(self, key: str) -> None:
        pass


class LRUCache(Cache[T]):
    """Thread-safe LRU cache implementation."""
    
    def __init__(self, max_size: int = 1000):
        self.max_size = max_size
        self.cache: OrderedDict[str, tuple[T, float | None]] = OrderedDict()
        self.lock = threading.Lock()
    
    def get(self, key: str) -> T | None:
        with self.lock:
            if key not in self.cache:
                return None
            
            value, expiry = self.cache[key]
            
            # Check expiry
            if expiry is not None and time.time() > expiry:
                del self.cache[key]
                return None
            
            # Move to end (most recently used)
            self.cache.move_to_end(key)
            return value
    
    def set(self, key: str, value: T, ttl_seconds: int | None = None) -> None:
        with self.lock:
            expiry = None
            if ttl_seconds is not None:
                expiry = time.time() + ttl_seconds
            
            if key in self.cache:
                self.cache.move_to_end(key)
            self.cache[key] = (value, expiry)
            
            # Evict oldest if over capacity
            while len(self.cache) > self.max_size:
                self.cache.popitem(last=False)
    
    def delete(self, key: str) -> None:
        with self.lock:
            if key in self.cache:
                del self.cache[key]


@dataclass
class CachedQueryResult:
    """Cached result for a query."""
    response: str
    context: GraphContext
    retrieved_entities: list[dict]
    timestamp: float


class CachedGraphRAG:
    """Graph RAG with multi-level caching."""
    
    def __init__(
        self,
        graph_rag: GraphRAGGenerator,
        embedding_model: Any,
        query_cache: Cache[CachedQueryResult] | None = None,
        embedding_cache: Cache[list[float]] | None = None,
        cache_ttl_seconds: int = 3600
    ):
        self.graph_rag = graph_rag
        self.embedding_model = embedding_model
        self.query_cache = query_cache or LRUCache(max_size=1000)
        self.embedding_cache = embedding_cache or LRUCache(max_size=10000)
        self.cache_ttl = cache_ttl_seconds
    
    def _query_cache_key(self, query: str) -> str:
        """Generate cache key for query."""
        return hashlib.sha256(query.lower().strip().encode()).hexdigest()
    
    def _embedding_cache_key(self, text: str) -> str:
        """Generate cache key for embedding."""
        return hashlib.sha256(text.encode()).hexdigest()
    
    def get_embedding(self, text: str) -> list[float]:
        """Get embedding with caching."""
        cache_key = self._embedding_cache_key(text)
        
        cached = self.embedding_cache.get(cache_key)
        if cached is not None:
            return cached
        
        embedding = self.embedding_model.encode([text])[0].tolist()
        self.embedding_cache.set(cache_key, embedding, self.cache_ttl)
        
        return embedding
    
    def generate(
        self,
        query: str,
        use_cache: bool = True,
        **kwargs
    ) -> dict[str, Any]:
        """Generate with query result caching."""
        cache_key = self._query_cache_key(query)
        
        if use_cache:
            cached = self.query_cache.get(cache_key)
            if cached is not None:
                return {
                    "response": cached.response,
                    "context": cached.context,
                    "retrieved_entities": cached.retrieved_entities,
                    "cached": True,
                    "cache_age_seconds": time.time() - cached.timestamp
                }
        
        # Generate fresh result
        result = self.graph_rag.generate(query, **kwargs)
        
        # Cache result
        cached_result = CachedQueryResult(
            response=result["response"],
            context=result["context"],
            retrieved_entities=result["retrieved_entities"],
            timestamp=time.time()
        )
        self.query_cache.set(cache_key, cached_result, self.cache_ttl)
        
        result["cached"] = False
        return result
    
    def invalidate_query(self, query: str) -> None:
        """Invalidate cached result for query."""
        cache_key = self._query_cache_key(query)
        self.query_cache.delete(cache_key)
    
    def warm_cache(self, queries: list[str]) -> None:
        """Pre-warm cache with common queries."""
        for query in queries:
            self.generate(query, use_cache=False)

218.6.3 Evaluation Framework

Rigorous evaluation requires metrics that capture both retrieval quality and generation accuracy. For graph RAG, additional metrics assess the quality of graph traversal and context assembly.

Code
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from collections import defaultdict


@dataclass
class EvaluationExample:
    """A single evaluation example."""
    query: str
    ground_truth_answer: str
    ground_truth_entities: list[str]  # Entity names that should be retrieved
    ground_truth_paths: list[list[str]] | None = None  # Expected reasoning paths


@dataclass
class RetrievalMetrics:
    """Metrics for retrieval evaluation."""
    precision: float
    recall: float
    f1: float
    mrr: float  # Mean Reciprocal Rank
    ndcg: float  # Normalized Discounted Cumulative Gain


@dataclass
class GenerationMetrics:
    """Metrics for generation evaluation."""
    answer_relevance: float  # 0-1 score from LLM judge
    factual_consistency: float  # 0-1 score for consistency with context
    completeness: float  # 0-1 score for answer completeness


@dataclass 
class GraphRAGMetrics:
    """Combined metrics for graph RAG evaluation."""
    retrieval: RetrievalMetrics
    generation: GenerationMetrics
    path_accuracy: float  # Fraction of expected paths found
    context_efficiency: float  # Relevant entities / total entities in context


class GraphRAGEvaluator:
    """Evaluates graph RAG system performance."""
    
    def __init__(self, llm_judge: Any):
        self.llm_judge = llm_judge
    
    def evaluate_retrieval(
        self,
        retrieved_entities: list[str],
        ground_truth_entities: list[str]
    ) -> RetrievalMetrics:
        """Evaluate retrieval quality."""
        retrieved_set = set(e.lower() for e in retrieved_entities)
        ground_truth_set = set(e.lower() for e in ground_truth_entities)
        
        if not retrieved_set:
            return RetrievalMetrics(0, 0, 0, 0, 0)
        
        # Precision and recall
        true_positives = len(retrieved_set & ground_truth_set)
        precision = true_positives / len(retrieved_set)
        recall = true_positives / len(ground_truth_set) if ground_truth_set else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        # MRR
        mrr = 0.0
        for i, entity in enumerate(retrieved_entities):
            if entity.lower() in ground_truth_set:
                mrr = 1.0 / (i + 1)
                break
        
        # NDCG
        relevance = [
            1.0 if e.lower() in ground_truth_set else 0.0 
            for e in retrieved_entities
        ]
        dcg = sum(rel / np.log2(i + 2) for i, rel in enumerate(relevance))
        
        ideal_relevance = sorted(relevance, reverse=True)
        idcg = sum(rel / np.log2(i + 2) for i, rel in enumerate(ideal_relevance))
        ndcg = dcg / idcg if idcg > 0 else 0
        
        return RetrievalMetrics(
            precision=precision,
            recall=recall,
            f1=f1,
            mrr=mrr,
            ndcg=ndcg
        )
    
    def evaluate_generation(
        self,
        query: str,
        generated_answer: str,
        ground_truth_answer: str,
        context: str
    ) -> GenerationMetrics:
        """Evaluate generation quality using LLM judge."""
        # Answer relevance
        relevance_prompt = f"""Rate how well the generated answer addresses the question.

Question: {query}
Generated Answer: {generated_answer}
Reference Answer: {ground_truth_answer}

Rate from 0 to 1 where:
0 = Completely irrelevant or incorrect
0.5 = Partially relevant but missing key information
1 = Fully relevant and accurate

Output only a number between 0 and 1."""

        relevance_score = float(self.llm_judge.generate(user_prompt=relevance_prompt))
        
        # Factual consistency
        consistency_prompt = f"""Check if the generated answer is consistent with the provided context.
Does the answer make claims that contradict or go beyond the context?

Context: {context[:2000]}
Generated Answer: {generated_answer}

Rate factual consistency from 0 to 1 where:
0 = Major contradictions or unsupported claims
0.5 = Some minor inconsistencies
1 = Fully consistent with context

Output only a number between 0 and 1."""

        consistency_score = float(self.llm_judge.generate(user_prompt=consistency_prompt))
        
        # Completeness
        completeness_prompt = f"""Evaluate the completeness of the generated answer compared to the reference.

Question: {query}
Generated Answer: {generated_answer}
Reference Answer: {ground_truth_answer}

Rate completeness from 0 to 1 where:
0 = Missing most key information
0.5 = Covers some but not all important points
1 = Covers all key information from reference

Output only a number between 0 and 1."""

        completeness_score = float(self.llm_judge.generate(user_prompt=completeness_prompt))
        
        return GenerationMetrics(
            answer_relevance=relevance_score,
            factual_consistency=consistency_score,
completeness=completeness_score
        )
    
    def evaluate_example(
        self,
        example: EvaluationExample,
        result: dict[str, Any]
    ) -> GraphRAGMetrics:
        """Evaluate a single example."""
        # Retrieval metrics
        retrieved_names = [e["name"] for e in result.get("retrieved_entities", [])]
        retrieval_metrics = self.evaluate_retrieval(
            retrieved_names,
            example.ground_truth_entities
        )
        
        # Generation metrics
        context_text = result.get("context", GraphContext()).to_text()
        generation_metrics = self.evaluate_generation(
            query=example.query,
            generated_answer=result.get("response", ""),
            ground_truth_answer=example.ground_truth_answer,
            context=context_text
        )
        
        # Path accuracy (if ground truth paths provided)
        path_accuracy = 0.0
        if example.ground_truth_paths:
            # This would require more sophisticated path matching
            path_accuracy = 0.5  # Placeholder
        
        # Context efficiency
        retrieved_set = set(e.lower() for e in retrieved_names)
        ground_truth_set = set(e.lower() for e in example.ground_truth_entities)
        relevant_in_context = len(retrieved_set & ground_truth_set)
        total_in_context = len(retrieved_set)
        context_efficiency = relevant_in_context / total_in_context if total_in_context > 0 else 0
        
        return GraphRAGMetrics(
            retrieval=retrieval_metrics,
            generation=generation_metrics,
            path_accuracy=path_accuracy,
            context_efficiency=context_efficiency
        )
    
    def evaluate_dataset(
        self,
        examples: list[EvaluationExample],
        graph_rag: Any
    ) -> dict[str, float]:
        """Evaluate on a dataset and compute aggregate metrics."""
        all_metrics = []
        
        for example in examples:
            result = graph_rag.generate(example.query)
            metrics = self.evaluate_example(example, result)
            all_metrics.append(metrics)
        
        # Aggregate
        aggregated = {
            "retrieval_precision": np.mean([m.retrieval.precision for m in all_metrics]),
            "retrieval_recall": np.mean([m.retrieval.recall for m in all_metrics]),
            "retrieval_f1": np.mean([m.retrieval.f1 for m in all_metrics]),
            "retrieval_mrr": np.mean([m.retrieval.mrr for m in all_metrics]),
            "retrieval_ndcg": np.mean([m.retrieval.ndcg for m in all_metrics]),
            "generation_relevance": np.mean([m.generation.answer_relevance for m in all_metrics]),
            "generation_consistency": np.mean([m.generation.factual_consistency for m in all_metrics]),
            "generation_completeness": np.mean([m.generation.completeness for m in all_metrics]),
            "path_accuracy": np.mean([m.path_accuracy for m in all_metrics]),
            "context_efficiency": np.mean([m.context_efficiency for m in all_metrics])
        }
        
        return aggregated

218.7 Deployment Considerations

218.7.1 Service Architecture

Production graph RAG systems typically decompose into separate services for indexing, retrieval, and generation, enabling independent scaling and optimization.

Code
from dataclasses import dataclass
from typing import Any
from abc import ABC, abstractmethod
import asyncio
from concurrent.futures import ThreadPoolExecutor


@dataclass
class ServiceConfig:
    """Configuration for graph RAG service."""
    max_concurrent_requests: int = 100
    request_timeout_seconds: float = 30.0
    indexing_batch_size: int = 10
    retrieval_top_k: int = 10
    max_context_tokens: int = 4000
    enable_caching: bool = True
    cache_ttl_seconds: int = 3600


class GraphRAGService:
    """Production service for graph RAG."""
    
    def __init__(
        self,
        config: ServiceConfig,
        indexing_pipeline: IndexingPipeline,
        graph_rag: CachedGraphRAG,
        multi_hop_rag: MultiHopGraphRAG
    ):
        self.config = config
        self.indexing_pipeline = indexing_pipeline
        self.graph_rag = graph_rag
        self.multi_hop_rag = multi_hop_rag
        self.executor = ThreadPoolExecutor(max_workers=config.max_concurrent_requests)
        self._semaphore = asyncio.Semaphore(config.max_concurrent_requests)
    
    async def query(
        self,
        query: str,
        use_multi_hop: bool = False,
        **kwargs
    ) -> dict[str, Any]:
        """Process a query asynchronously."""
        async with self._semaphore:
            loop = asyncio.get_event_loop()
            
            if use_multi_hop:
                result = await loop.run_in_executor(
                    self.executor,
                    lambda: self.multi_hop_rag.generate(query)
                )
            else:
                result = await loop.run_in_executor(
                    self.executor,
                    lambda: self.graph_rag.generate(query, **kwargs)
                )
            
            return result
    
    async def index_documents(
        self,
        documents: list[Document]
    ) -> list[IndexingResult]:
        """Index documents asynchronously."""
        loop = asyncio.get_event_loop()
        
        results = await loop.run_in_executor(
            self.executor,
            lambda: list(self.indexing_pipeline.index_batch(documents))
        )
        
        return results
    
    def health_check(self) -> dict[str, Any]:
        """Check service health."""
        return {
            "status": "healthy",
            "pending_requests": self.config.max_concurrent_requests - self._semaphore._value,
            "cache_enabled": self.config.enable_caching
        }

218.7.2 Observability and Monitoring

Production systems require comprehensive observability to diagnose issues and optimize performance. Key metrics include retrieval latency, generation latency, cache hit rates, and retrieval quality indicators.

Code
from dataclasses import dataclass, field
from typing import Any, Callable
import time
from collections import defaultdict
import threading
from contextlib import contextmanager


@dataclass
class MetricPoint:
    """A single metric observation."""
    name: str
    value: float
    timestamp: float
    labels: dict[str, str] = field(default_factory=dict)


class MetricsCollector:
    """Collects and exposes metrics."""
    
    def __init__(self):
        self.counters: dict[str, float] = defaultdict(float)
        self.gauges: dict[str, float] = defaultdict(float)
        self.histograms: dict[str, list[float]] = defaultdict(list)
        self.lock = threading.Lock()
    
    def increment_counter(
        self, 
        name: str, 
        value: float = 1.0,
        labels: dict[str, str] | None = None
    ):
        """Increment a counter metric."""
        key = self._make_key(name, labels)
        with self.lock:
            self.counters[key] += value
    
    def set_gauge(
        self, 
        name: str, 
        value: float,
        labels: dict[str, str] | None = None
    ):
        """Set a gauge metric."""
        key = self._make_key(name, labels)
        with self.lock:
            self.gauges[key] = value
    
    def observe_histogram(
        self, 
        name: str, 
        value: float,
        labels: dict[str, str] | None = None
    ):
        """Add observation to histogram."""
        key = self._make_key(name, labels)
        with self.lock:
            self.histograms[key].append(value)
            # Keep only last 1000 observations
            if len(self.histograms[key]) > 1000:
                self.histograms[key] = self.histograms[key][-1000:]
    
    def _make_key(self, name: str, labels: dict[str, str] | None) -> str:
        """Create metric key from name and labels."""
        if not labels:
            return name
        label_str = ",".join(f"{k}={v}" for k, v in sorted(labels.items()))
        return f"{name}{{{label_str}}}"
    
    @contextmanager
    def timer(self, name: str, labels: dict[str, str] | None = None):
        """Context manager for timing operations."""
        start = time.time()
        try:
            yield
        finally:
            duration = time.time() - start
            self.observe_histogram(name, duration, labels)
    
    def get_metrics(self) -> dict[str, Any]:
        """Get all metrics."""
        with self.lock:
            histogram_stats = {}
            for key, values in self.histograms.items():
                if values:
                    import numpy as np
                    histogram_stats[key] = {
                        "count": len(values),
                        "mean": np.mean(values),
                        "p50": np.percentile(values, 50),
                        "p95": np.percentile(values, 95),
                        "p99": np.percentile(values, 99)
                    }
            
            return {
                "counters": dict(self.counters),
                "gauges": dict(self.gauges),
                "histograms": histogram_stats
            }


class InstrumentedGraphRAG:
    """Graph RAG with metrics instrumentation."""
    
    def __init__(
        self,
        graph_rag: CachedGraphRAG,
        metrics: MetricsCollector
    ):
        self.graph_rag = graph_rag
        self.metrics = metrics
    
    def generate(self, query: str, **kwargs) -> dict[str, Any]:
        """Generate with metrics collection."""
        self.metrics.increment_counter("graphrag_requests_total")
        
        with self.metrics.timer("graphrag_latency_seconds"):
            result = self.graph_rag.generate(query, **kwargs)
        
        # Track cache hit/miss
        if result.get("cached"):
            self.metrics.increment_counter("graphrag_cache_hits_total")
        else:
            self.metrics.increment_counter("graphrag_cache_misses_total")
        
        # Track retrieval metrics
        num_entities = len(result.get("retrieved_entities", []))
        self.metrics.observe_histogram(
            "graphrag_entities_retrieved",
            num_entities
        )
        
        # Track context size
        context = result.get("context")
        if context:
            context_text = context.to_text()
            self.metrics.observe_histogram(
                "graphrag_context_chars",
                len(context_text)
            )
        
        return result

218.8 Advanced Topics

218.8.1 Temporal Knowledge Graphs

Many knowledge domains involve temporal dynamics where facts hold only during certain time periods. Temporal knowledge graphs augment triples with temporal scopes, requiring specialized representation and retrieval methods.

Code
from dataclasses import dataclass
from datetime import datetime
from typing import Any


@dataclass
class TemporalRelation:
    """A relation with temporal scope."""
    head: Entity
    relation_type: str
    tail: Entity
    valid_from: datetime | None = None
    valid_until: datetime | None = None
    confidence: float = 1.0
    
    def is_valid_at(self, timestamp: datetime) -> bool:
        """Check if relation is valid at given time."""
        if self.valid_from and timestamp < self.valid_from:
            return False
        if self.valid_until and timestamp > self.valid_until:
            return False
        return True


class TemporalKnowledgeGraph:
    """Knowledge graph with temporal reasoning."""
    
    def __init__(self):
        self.entities: dict[str, Entity] = {}
        self.temporal_relations: list[TemporalRelation] = []
    
    def add_temporal_relation(self, relation: TemporalRelation):
        """Add a temporally scoped relation."""
        self.temporal_relations.append(relation)
    
    def get_snapshot(self, timestamp: datetime) -> list[TemporalRelation]:
        """Get all relations valid at a specific time."""
        return [
            rel for rel in self.temporal_relations
            if rel.is_valid_at(timestamp)
        ]
    
    def get_entity_history(
        self, 
        entity_id: str,
        relation_type: str | None = None
    ) -> list[tuple[TemporalRelation, str]]:
        """Get temporal history of an entity's relations."""
        entity = self.entities.get(entity_id)
        if not entity:
            return []
        
        history = []
        for rel in self.temporal_relations:
            if rel.head == entity or rel.tail == entity:
                if relation_type is None or rel.relation_type == relation_type:
                    direction = "outgoing" if rel.head == entity else "incoming"
                    history.append((rel, direction))
        
        # Sort by temporal validity
        history.sort(key=lambda x: x[0].valid_from or datetime.min)
        return history


class TemporalGraphRAG:
    """Graph RAG with temporal awareness."""
    
    def __init__(
        self,
        llm_client: Any,
        temporal_graph: TemporalKnowledgeGraph,
        embedding_model: Any
    ):
        self.llm = llm_client
        self.graph = temporal_graph
        self.embedding_model = embedding_model
    
    def generate(
        self,
        query: str,
        reference_time: datetime | None = None
    ) -> dict[str, Any]:
        """Generate with temporal context."""
        # Parse temporal expressions from query
        parsed_time = self._parse_temporal_expression(query, reference_time)
        
        # Get temporally valid snapshot
        if parsed_time:
            valid_relations = self.graph.get_snapshot(parsed_time)
        else:
            valid_relations = self.graph.temporal_relations
        
        # Build context from valid relations
        context_parts = []
        for rel in valid_relations[:50]:  # Limit context size
            temporal_qualifier = ""
            if rel.valid_from or rel.valid_until:
                start = rel.valid_from.strftime("%Y-%m-%d") if rel.valid_from else "?"
                end = rel.valid_until.strftime("%Y-%m-%d") if rel.valid_until else "present"
                temporal_qualifier = f" (valid: {start} to {end})"
            
            context_parts.append(
                f"{rel.head.name} --[{rel.relation_type}]--> {rel.tail.name}{temporal_qualifier}"
            )
        
        context_text = "\n".join(context_parts)
        
        prompt = f"""Answer the following question using the temporally-scoped 
knowledge provided. Pay attention to the temporal validity of facts.

Temporal Knowledge:
{context_text}

Question: {query}

If the question involves a specific time period, ensure your answer reflects 
what was true during that period."""

        response = self.llm.generate(user_prompt=prompt)
        
        return {
            "response": response,
            "reference_time": parsed_time,
            "relations_used": len(valid_relations)
        }
    
    def _parse_temporal_expression(
        self,
        query: str,
        default: datetime | None
    ) -> datetime | None:
        """Parse temporal expressions from query."""
        # Simple heuristic parsing - production systems would use
        # more sophisticated temporal expression recognition
        query_lower = query.lower()
        
        if "currently" in query_lower or "now" in query_lower:
            return datetime.now()
        if "2023" in query:
            return datetime(2023, 6, 15)
        if "2022" in query:
            return datetime(2022, 6, 15)
        if "last year" in query_lower:
            now = datetime.now()
            return datetime(now.year - 1, 6, 15)
        
        return default

218.8.2 Hybrid Vector-Graph Retrieval

Combining dense vector retrieval with graph-based retrieval leverages the complementary strengths of both approaches. Vector retrieval excels at semantic similarity matching, while graph retrieval captures structural and relational patterns.

Code
from dataclasses import dataclass
from typing import Any
import numpy as np


@dataclass
class HybridRetrievalResult:
    """Result from hybrid retrieval."""
    entity_id: str
    entity: Entity
    vector_score: float
    graph_score: float
    combined_score: float


class HybridRetriever:
    """Combines vector and graph-based retrieval."""
    
    def __init__(
        self,
        entity_retriever: EntityRetriever,
        graph: KnowledgeGraph,
        graph_weight: float = 0.3,
        personalization_weight: float = 0.1
    ):
        self.entity_retriever = entity_retriever
        self.graph = graph
        self.graph_weight = graph_weight
        self.personalization_weight = personalization_weight
    
    def retrieve(
        self,
        query: str,
        seed_entities: list[str] | None = None,
        top_k: int = 10
    ) -> list[HybridRetrievalResult]:
        """Perform hybrid retrieval."""
        # Vector retrieval
        vector_results = self.entity_retriever.retrieve(query, top_k=top_k * 2)
        vector_scores = {r.entity_id: r.score for r in vector_results}
        
        # Graph-based scores using PageRank with personalization
        graph_scores = self._compute_personalized_pagerank(
            seed_entities or [r.entity_id for r in vector_results[:3]]
        )
        
        # Combine scores
        all_entity_ids = set(vector_scores.keys()) | set(graph_scores.keys())
        
        combined_results = []
        for entity_id in all_entity_ids:
            v_score = vector_scores.get(entity_id, 0)
            g_score = graph_scores.get(entity_id, 0)
            
            combined_score = (
                (1 - self.graph_weight) * v_score +
                self.graph_weight * g_score
            )
            
            entity = self.graph.entities.get(entity_id)
            if entity:
                combined_results.append(HybridRetrievalResult(
                    entity_id=entity_id,
                    entity=entity,
                    vector_score=v_score,
                    graph_score=g_score,
                    combined_score=combined_score
                ))
        
        # Sort by combined score
        combined_results.sort(key=lambda x: x.combined_score, reverse=True)
        return combined_results[:top_k]
    
    def _compute_personalized_pagerank(
        self,
        seed_entities: list[str],
        damping: float = 0.85,
        max_iterations: int = 100,
        tolerance: float = 1e-6
    ) -> dict[str, float]:
        """Compute personalized PageRank from seed entities."""
        entity_ids = list(self.graph.entities.keys())
        n = len(entity_ids)
        
        if n == 0:
            return {}
        
        id_to_idx = {eid: i for i, eid in enumerate(entity_ids)}
        
        # Build adjacency matrix
        adj = np.zeros((n, n))
        for rel in self.graph.relations:
            head_id = self.graph._entity_id(rel.head)
            tail_id = self.graph._entity_id(rel.tail)
            
            if head_id in id_to_idx and tail_id in id_to_idx:
                adj[id_to_idx[head_id], id_to_idx[tail_id]] = 1
                adj[id_to_idx[tail_id], id_to_idx[head_id]] = 1  # Undirected
        
        # Row normalize
        row_sums = adj.sum(axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1  # Avoid division by zero
        transition = adj / row_sums
        
        # Personalization vector
        personalization = np.zeros(n)
        for seed_id in seed_entities:
            if seed_id in id_to_idx:
                personalization[id_to_idx[seed_id]] = 1.0
        if personalization.sum() > 0:
            personalization /= personalization.sum()
        else:
            personalization = np.ones(n) / n
        
        # Power iteration
        scores = np.ones(n) / n
        for _ in range(max_iterations):
            new_scores = (
                damping * transition.T @ scores +
                (1 - damping) * personalization
            )
            
            if np.abs(new_scores - scores).sum() < tolerance:
                break
            scores = new_scores
        
        return {entity_ids[i]: float(scores[i]) for i in range(n)}

218.9 Summary

Graph-based retrieval-augmented generation extends the capabilities of language models by organizing knowledge into explicit relational structures. The graph representation enables multi-hop reasoning, relationship-aware retrieval, and structured context assembly that flat vector retrieval cannot achieve.

The mathematical foundations rest on graph theory, spectral methods, and graph neural networks. Knowledge graph construction requires entity extraction, relation identification, coreference resolution, and entity linking, each presenting distinct technical challenges. Retrieval over graphs combines embedding-based similarity search with graph traversal algorithms, balancing coverage and relevance through techniques such as beam search and personalized PageRank.

Production deployment demands attention to indexing pipelines, caching strategies, service architecture, and comprehensive observability. Evaluation must assess both retrieval quality through metrics such as precision, recall, and NDCG, and generation quality through measures of relevance, consistency, and completeness.

The field continues to evolve with active research in temporal knowledge graphs, hybrid retrieval methods, and more sophisticated graph neural architectures. For practitioners, the key is selecting the appropriate level of complexity for the application domain: simple entity extraction and two-hop traversal suffices for many use cases, while others demand the full machinery of multi-hop reasoning, query decomposition, and temporal awareness.