Files
rag-mcp/tools/relations.py
kappa 2858e0a344 Initial commit: RAG MCP Server with relationship graph
Features:
- Vector search with Pinecone + Vertex AI embeddings
- Document relationships (link, unlink, related, graph)
- Auto-link with LLM analysis
- Intelligent merge with Gemini

Modular structure:
- clients/: Pinecone, Vertex AI
- tools/: core, relations, stats
- utils/: validation, logging

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 11:05:45 +09:00

418 lines
14 KiB
Python

"""Relation management tools and utilities."""
from typing import Optional
from collections import deque
from clients import get_index
from utils.logging import get_logger
from utils.validation import validate_document_id
from config import MAX_GRAPH_NODES
logger = get_logger(__name__)
# ============================================================
# Relation utility functions
# ============================================================
def parse_relation(rel_str: str) -> tuple:
"""
Parse 'id:type' format to (id, type).
Args:
rel_str: Relation string in format "id:type"
Returns:
(id, type) tuple
"""
if ':' in rel_str:
parts = rel_str.split(':', 1)
return (parts[0], parts[1])
return (rel_str, 'related')
def format_relation(doc_id: str, rel_type: str) -> str:
"""
Format (id, type) to 'id:type' string.
Args:
doc_id: Document ID
rel_type: Relation type
Returns:
Formatted relation string
"""
return f"{doc_id}:{rel_type}"
def get_reverse_relation(rel_type: str) -> str:
"""
Get reverse relation type.
Args:
rel_type: Relation type
Returns:
Reverse relation type
"""
reverse_map = {
'depends_on': 'required_by',
'required_by': 'depends_on',
'part_of': 'contains',
'contains': 'part_of',
'see_also': 'see_also',
'related': 'related',
'blocks': 'blocked_by',
'blocked_by': 'blocks',
'extends': 'extended_by',
'extended_by': 'extends',
'updates': 'updated_by',
'updated_by': 'updates',
}
return reverse_map.get(rel_type, f"reverse_{rel_type}")
def _add_reverse_relation(target_id: str, source_id: str, rel_type: str) -> bool:
"""
Internal: Add reverse relation to target document.
Args:
target_id: Target document ID
source_id: Source document ID
rel_type: Relation type (forward)
Returns:
True if successful, False otherwise
"""
try:
index = get_index()
result = index.fetch(ids=[target_id])
if target_id not in result["vectors"]:
logger.warning(f"Target document not found for reverse relation: {target_id}")
return False
metadata = result["vectors"][target_id]["metadata"]
relations = metadata.get("relations", [])
reverse_rel = format_relation(source_id, get_reverse_relation(rel_type))
if reverse_rel not in relations:
relations.append(reverse_rel)
metadata["relations"] = relations
vector = result["vectors"][target_id]["values"]
index.upsert(vectors=[{
"id": target_id,
"values": vector,
"metadata": metadata
}])
logger.debug(f"Added reverse relation: {target_id} <- {source_id}")
return True
return True
except Exception as e:
logger.error(f"Failed to add reverse relation: {str(e)}")
return False
def _remove_reverse_relation(target_id: str, source_id: str) -> bool:
"""
Internal: Remove reverse relation from target document.
Args:
target_id: Target document ID
source_id: Source document ID
Returns:
True if successful, False otherwise
"""
try:
index = get_index()
result = index.fetch(ids=[target_id])
if target_id not in result["vectors"]:
logger.warning(f"Target document not found for reverse relation removal: {target_id}")
return False
metadata = result["vectors"][target_id]["metadata"]
relations = metadata.get("relations", [])
relations = [r for r in relations if not r.startswith(f"{source_id}:")]
metadata["relations"] = relations
vector = result["vectors"][target_id]["values"]
index.upsert(vectors=[{
"id": target_id,
"values": vector,
"metadata": metadata
}])
logger.debug(f"Removed reverse relation: {target_id} <- {source_id}")
return True
except Exception as e:
logger.error(f"Failed to remove reverse relation: {str(e)}")
return False
# ============================================================
# MCP Tools
# ============================================================
def rag_link(from_id: str, to_id: str, relation_type: str = "related") -> str:
"""
Create relation between two documents (bidirectional).
Args:
from_id: Source document ID
to_id: Target document ID
relation_type: Relation type (depends_on, part_of, see_also, blocks, extends, updates, related)
Returns:
Success message or error message
"""
# Validate inputs
is_valid, error_msg = validate_document_id(from_id)
if not is_valid:
logger.warning(f"Source ID validation failed: {error_msg}")
return f"Error: {error_msg}"
is_valid, error_msg = validate_document_id(to_id)
if not is_valid:
logger.warning(f"Target ID validation failed: {error_msg}")
return f"Error: {error_msg}"
try:
index = get_index()
result = index.fetch(ids=[from_id])
if from_id not in result["vectors"]:
logger.warning(f"Source document not found: {from_id}")
return f"Error: Source document not found: {from_id}"
from_metadata = result["vectors"][from_id]["metadata"]
from_relations = from_metadata.get("relations", [])
new_rel = format_relation(to_id, relation_type)
if new_rel in from_relations:
logger.info(f"Relation already exists: {from_id} -> {to_id}")
return f"Relation already exists: {from_id} --[{relation_type}]--> {to_id}"
from_relations.append(new_rel)
from_metadata["relations"] = from_relations
from_vector = result["vectors"][from_id]["values"]
index.upsert(vectors=[{
"id": from_id,
"values": from_vector,
"metadata": from_metadata
}])
_add_reverse_relation(to_id, from_id, relation_type)
reverse_type = get_reverse_relation(relation_type)
logger.info(f"Linked: {from_id} --[{relation_type}]--> {to_id}")
return f"Linked: {from_id} --[{relation_type}]--> {to_id}\nReverse: {to_id} --[{reverse_type}]--> {from_id}"
except Exception as e:
logger.error(f"rag_link failed: {str(e)}", exc_info=True)
return f"Error: {str(e)}"
def rag_unlink(from_id: str, to_id: str) -> str:
"""
Remove relation between two documents (bidirectional).
Args:
from_id: Source document ID
to_id: Target document ID
Returns:
Success message or error message
"""
# Validate inputs
is_valid, error_msg = validate_document_id(from_id)
if not is_valid:
logger.warning(f"Source ID validation failed: {error_msg}")
return f"Error: {error_msg}"
is_valid, error_msg = validate_document_id(to_id)
if not is_valid:
logger.warning(f"Target ID validation failed: {error_msg}")
return f"Error: {error_msg}"
try:
index = get_index()
result = index.fetch(ids=[from_id])
if from_id not in result["vectors"]:
logger.warning(f"Document not found: {from_id}")
return f"Error: Document not found: {from_id}"
from_metadata = result["vectors"][from_id]["metadata"]
from_relations = from_metadata.get("relations", [])
original_count = len(from_relations)
from_relations = [r for r in from_relations if not r.startswith(f"{to_id}:")]
if len(from_relations) == original_count:
logger.info(f"No relation found: {from_id} -> {to_id}")
return f"No relation found from {from_id} to {to_id}"
from_metadata["relations"] = from_relations
from_vector = result["vectors"][from_id]["values"]
index.upsert(vectors=[{
"id": from_id,
"values": from_vector,
"metadata": from_metadata
}])
_remove_reverse_relation(to_id, from_id)
logger.info(f"Unlinked: {from_id} <--> {to_id}")
return f"Unlinked: {from_id} <--> {to_id}"
except Exception as e:
logger.error(f"rag_unlink failed: {str(e)}", exc_info=True)
return f"Error: {str(e)}"
def rag_related(id: str, relation_type: Optional[str] = None, include_content: bool = False) -> str:
"""
Query related documents.
Args:
id: Document ID to query
relation_type: Filter by specific relation type (default: all)
include_content: Include document content (default: False)
Returns:
Formatted relation list or error message
"""
# Validate input
is_valid, error_msg = validate_document_id(id)
if not is_valid:
logger.warning(f"Document ID validation failed: {error_msg}")
return f"Error: {error_msg}"
try:
index = get_index()
result = index.fetch(ids=[id])
if id not in result["vectors"]:
logger.warning(f"Document not found: {id}")
return f"Error: Document not found: {id}"
metadata = result["vectors"][id]["metadata"]
relations = metadata.get("relations", [])
if not relations:
logger.info(f"No relations found for: {id}")
return f"No relations found for: {id}"
if relation_type:
relations = [r for r in relations if r.endswith(f":{relation_type}")]
if not relations:
logger.info(f"No '{relation_type}' relations found for: {id}")
return f"No '{relation_type}' relations found for: {id}"
output = [f"Relations for {id}:"]
if include_content:
related_ids = [parse_relation(r)[0] for r in relations]
related_docs = index.fetch(ids=related_ids)
for rel_str in relations:
target_id, rel_type = parse_relation(rel_str)
entry = f"\n --[{rel_type}]--> {target_id}"
if target_id in related_docs["vectors"]:
text = related_docs["vectors"][target_id]["metadata"].get("text", "")
tag = related_docs["vectors"][target_id]["metadata"].get("tag", "")
preview = text[:200] + "..." if len(text) > 200 else text
entry += f"\n Tag: {tag}"
entry += f"\n {preview}"
output.append(entry)
else:
for rel_str in relations:
target_id, rel_type = parse_relation(rel_str)
output.append(f" --[{rel_type}]--> {target_id}")
logger.info(f"Found {len(relations)} relations for: {id}")
return "\n".join(output)
except Exception as e:
logger.error(f"rag_related failed: {str(e)}", exc_info=True)
return f"Error: {str(e)}"
def rag_graph(id: str, depth: int = 1) -> str:
"""
Explore relation graph from a document using BFS.
Args:
id: Starting document ID
depth: Search depth (default: 1, max: 3)
Returns:
Formatted graph or error message
"""
# Validate input
is_valid, error_msg = validate_document_id(id)
if not is_valid:
logger.warning(f"Document ID validation failed: {error_msg}")
return f"Error: {error_msg}"
try:
index = get_index()
depth = min(depth, 3)
visited = set()
nodes = {}
edges = []
queue = deque([(id, 0)]) # BFS queue with (doc_id, current_depth)
while queue and len(visited) < MAX_GRAPH_NODES:
doc_id, current_depth = queue.popleft()
if doc_id in visited or current_depth > depth:
continue
visited.add(doc_id)
result = index.fetch(ids=[doc_id])
if doc_id not in result["vectors"]:
continue
metadata = result["vectors"][doc_id]["metadata"]
text = metadata.get("text", "")
tag = metadata.get("tag", "")
relations = metadata.get("relations", [])
nodes[doc_id] = {
"tag": tag,
"preview": text[:100] + "..." if len(text) > 100 else text,
"depth": current_depth
}
for rel_str in relations:
target_id, rel_type = parse_relation(rel_str)
edges.append({
"from": doc_id,
"to": target_id,
"type": rel_type
})
if target_id not in visited and len(visited) < MAX_GRAPH_NODES:
queue.append((target_id, current_depth + 1))
if not nodes:
logger.warning(f"Document not found: {id}")
return f"Error: Document not found: {id}"
output = [f"=== Graph from {id} (depth: {depth}) ==="]
if len(visited) >= MAX_GRAPH_NODES:
output.append(f"\n[WARNING] Graph limited to {MAX_GRAPH_NODES} nodes\n")
output.append("Nodes:")
for node_id, info in sorted(nodes.items(), key=lambda x: x[1]["depth"]):
indent = " " * info["depth"]
output.append(f"{indent}[{info['tag']}] {node_id}")
output.append(f"{indent} {info['preview']}")
output.append(f"\nEdges ({len(edges)}):")
for edge in edges:
output.append(f" {edge['from'][:8]}... --[{edge['type']}]--> {edge['to'][:8]}...")
logger.info(f"Graph traversed: {len(nodes)} nodes, {len(edges)} edges")
return "\n".join(output)
except Exception as e:
logger.error(f"rag_graph failed: {str(e)}", exc_info=True)
return f"Error: {str(e)}"