"""Relation management tools and utilities.""" from typing import Optional from collections import deque from clients import get_index from utils.logging import get_logger from utils.validation import validate_document_id from config import MAX_GRAPH_NODES logger = get_logger(__name__) # ============================================================ # Relation utility functions # ============================================================ def parse_relation(rel_str: str) -> tuple: """ Parse 'id:type' format to (id, type). Args: rel_str: Relation string in format "id:type" Returns: (id, type) tuple """ if ':' in rel_str: parts = rel_str.split(':', 1) return (parts[0], parts[1]) return (rel_str, 'related') def format_relation(doc_id: str, rel_type: str) -> str: """ Format (id, type) to 'id:type' string. Args: doc_id: Document ID rel_type: Relation type Returns: Formatted relation string """ return f"{doc_id}:{rel_type}" def get_reverse_relation(rel_type: str) -> str: """ Get reverse relation type. Args: rel_type: Relation type Returns: Reverse relation type """ reverse_map = { 'depends_on': 'required_by', 'required_by': 'depends_on', 'part_of': 'contains', 'contains': 'part_of', 'see_also': 'see_also', 'related': 'related', 'blocks': 'blocked_by', 'blocked_by': 'blocks', 'extends': 'extended_by', 'extended_by': 'extends', 'updates': 'updated_by', 'updated_by': 'updates', } return reverse_map.get(rel_type, f"reverse_{rel_type}") def _add_reverse_relation(target_id: str, source_id: str, rel_type: str) -> bool: """ Internal: Add reverse relation to target document. Args: target_id: Target document ID source_id: Source document ID rel_type: Relation type (forward) Returns: True if successful, False otherwise """ try: index = get_index() result = index.fetch(ids=[target_id]) if target_id not in result["vectors"]: logger.warning(f"Target document not found for reverse relation: {target_id}") return False metadata = result["vectors"][target_id]["metadata"] relations = metadata.get("relations", []) reverse_rel = format_relation(source_id, get_reverse_relation(rel_type)) if reverse_rel not in relations: relations.append(reverse_rel) metadata["relations"] = relations vector = result["vectors"][target_id]["values"] index.upsert(vectors=[{ "id": target_id, "values": vector, "metadata": metadata }]) logger.debug(f"Added reverse relation: {target_id} <- {source_id}") return True return True except Exception as e: logger.error(f"Failed to add reverse relation: {str(e)}") return False def _remove_reverse_relation(target_id: str, source_id: str) -> bool: """ Internal: Remove reverse relation from target document. Args: target_id: Target document ID source_id: Source document ID Returns: True if successful, False otherwise """ try: index = get_index() result = index.fetch(ids=[target_id]) if target_id not in result["vectors"]: logger.warning(f"Target document not found for reverse relation removal: {target_id}") return False metadata = result["vectors"][target_id]["metadata"] relations = metadata.get("relations", []) relations = [r for r in relations if not r.startswith(f"{source_id}:")] metadata["relations"] = relations vector = result["vectors"][target_id]["values"] index.upsert(vectors=[{ "id": target_id, "values": vector, "metadata": metadata }]) logger.debug(f"Removed reverse relation: {target_id} <- {source_id}") return True except Exception as e: logger.error(f"Failed to remove reverse relation: {str(e)}") return False # ============================================================ # MCP Tools # ============================================================ def rag_link(from_id: str, to_id: str, relation_type: str = "related") -> str: """ Create relation between two documents (bidirectional). Args: from_id: Source document ID to_id: Target document ID relation_type: Relation type (depends_on, part_of, see_also, blocks, extends, updates, related) Returns: Success message or error message """ # Validate inputs is_valid, error_msg = validate_document_id(from_id) if not is_valid: logger.warning(f"Source ID validation failed: {error_msg}") return f"Error: {error_msg}" is_valid, error_msg = validate_document_id(to_id) if not is_valid: logger.warning(f"Target ID validation failed: {error_msg}") return f"Error: {error_msg}" try: index = get_index() result = index.fetch(ids=[from_id]) if from_id not in result["vectors"]: logger.warning(f"Source document not found: {from_id}") return f"Error: Source document not found: {from_id}" from_metadata = result["vectors"][from_id]["metadata"] from_relations = from_metadata.get("relations", []) new_rel = format_relation(to_id, relation_type) if new_rel in from_relations: logger.info(f"Relation already exists: {from_id} -> {to_id}") return f"Relation already exists: {from_id} --[{relation_type}]--> {to_id}" from_relations.append(new_rel) from_metadata["relations"] = from_relations from_vector = result["vectors"][from_id]["values"] index.upsert(vectors=[{ "id": from_id, "values": from_vector, "metadata": from_metadata }]) _add_reverse_relation(to_id, from_id, relation_type) reverse_type = get_reverse_relation(relation_type) logger.info(f"Linked: {from_id} --[{relation_type}]--> {to_id}") return f"Linked: {from_id} --[{relation_type}]--> {to_id}\nReverse: {to_id} --[{reverse_type}]--> {from_id}" except Exception as e: logger.error(f"rag_link failed: {str(e)}", exc_info=True) return f"Error: {str(e)}" def rag_unlink(from_id: str, to_id: str) -> str: """ Remove relation between two documents (bidirectional). Args: from_id: Source document ID to_id: Target document ID Returns: Success message or error message """ # Validate inputs is_valid, error_msg = validate_document_id(from_id) if not is_valid: logger.warning(f"Source ID validation failed: {error_msg}") return f"Error: {error_msg}" is_valid, error_msg = validate_document_id(to_id) if not is_valid: logger.warning(f"Target ID validation failed: {error_msg}") return f"Error: {error_msg}" try: index = get_index() result = index.fetch(ids=[from_id]) if from_id not in result["vectors"]: logger.warning(f"Document not found: {from_id}") return f"Error: Document not found: {from_id}" from_metadata = result["vectors"][from_id]["metadata"] from_relations = from_metadata.get("relations", []) original_count = len(from_relations) from_relations = [r for r in from_relations if not r.startswith(f"{to_id}:")] if len(from_relations) == original_count: logger.info(f"No relation found: {from_id} -> {to_id}") return f"No relation found from {from_id} to {to_id}" from_metadata["relations"] = from_relations from_vector = result["vectors"][from_id]["values"] index.upsert(vectors=[{ "id": from_id, "values": from_vector, "metadata": from_metadata }]) _remove_reverse_relation(to_id, from_id) logger.info(f"Unlinked: {from_id} <--> {to_id}") return f"Unlinked: {from_id} <--> {to_id}" except Exception as e: logger.error(f"rag_unlink failed: {str(e)}", exc_info=True) return f"Error: {str(e)}" def rag_related(id: str, relation_type: Optional[str] = None, include_content: bool = False) -> str: """ Query related documents. Args: id: Document ID to query relation_type: Filter by specific relation type (default: all) include_content: Include document content (default: False) Returns: Formatted relation list or error message """ # Validate input is_valid, error_msg = validate_document_id(id) if not is_valid: logger.warning(f"Document ID validation failed: {error_msg}") return f"Error: {error_msg}" try: index = get_index() result = index.fetch(ids=[id]) if id not in result["vectors"]: logger.warning(f"Document not found: {id}") return f"Error: Document not found: {id}" metadata = result["vectors"][id]["metadata"] relations = metadata.get("relations", []) if not relations: logger.info(f"No relations found for: {id}") return f"No relations found for: {id}" if relation_type: relations = [r for r in relations if r.endswith(f":{relation_type}")] if not relations: logger.info(f"No '{relation_type}' relations found for: {id}") return f"No '{relation_type}' relations found for: {id}" output = [f"Relations for {id}:"] if include_content: related_ids = [parse_relation(r)[0] for r in relations] related_docs = index.fetch(ids=related_ids) for rel_str in relations: target_id, rel_type = parse_relation(rel_str) entry = f"\n --[{rel_type}]--> {target_id}" if target_id in related_docs["vectors"]: text = related_docs["vectors"][target_id]["metadata"].get("text", "") tag = related_docs["vectors"][target_id]["metadata"].get("tag", "") preview = text[:200] + "..." if len(text) > 200 else text entry += f"\n Tag: {tag}" entry += f"\n {preview}" output.append(entry) else: for rel_str in relations: target_id, rel_type = parse_relation(rel_str) output.append(f" --[{rel_type}]--> {target_id}") logger.info(f"Found {len(relations)} relations for: {id}") return "\n".join(output) except Exception as e: logger.error(f"rag_related failed: {str(e)}", exc_info=True) return f"Error: {str(e)}" def rag_graph(id: str, depth: int = 1) -> str: """ Explore relation graph from a document using BFS. Args: id: Starting document ID depth: Search depth (default: 1, max: 3) Returns: Formatted graph or error message """ # Validate input is_valid, error_msg = validate_document_id(id) if not is_valid: logger.warning(f"Document ID validation failed: {error_msg}") return f"Error: {error_msg}" try: index = get_index() depth = min(depth, 3) visited = set() nodes = {} edges = [] queue = deque([(id, 0)]) # BFS queue with (doc_id, current_depth) while queue and len(visited) < MAX_GRAPH_NODES: doc_id, current_depth = queue.popleft() if doc_id in visited or current_depth > depth: continue visited.add(doc_id) result = index.fetch(ids=[doc_id]) if doc_id not in result["vectors"]: continue metadata = result["vectors"][doc_id]["metadata"] text = metadata.get("text", "") tag = metadata.get("tag", "") relations = metadata.get("relations", []) nodes[doc_id] = { "tag": tag, "preview": text[:100] + "..." if len(text) > 100 else text, "depth": current_depth } for rel_str in relations: target_id, rel_type = parse_relation(rel_str) edges.append({ "from": doc_id, "to": target_id, "type": rel_type }) if target_id not in visited and len(visited) < MAX_GRAPH_NODES: queue.append((target_id, current_depth + 1)) if not nodes: logger.warning(f"Document not found: {id}") return f"Error: Document not found: {id}" output = [f"=== Graph from {id} (depth: {depth}) ==="] if len(visited) >= MAX_GRAPH_NODES: output.append(f"\n[WARNING] Graph limited to {MAX_GRAPH_NODES} nodes\n") output.append("Nodes:") for node_id, info in sorted(nodes.items(), key=lambda x: x[1]["depth"]): indent = " " * info["depth"] output.append(f"{indent}[{info['tag']}] {node_id}") output.append(f"{indent} {info['preview']}") output.append(f"\nEdges ({len(edges)}):") for edge in edges: output.append(f" {edge['from'][:8]}... --[{edge['type']}]--> {edge['to'][:8]}...") logger.info(f"Graph traversed: {len(nodes)} nodes, {len(edges)} edges") return "\n".join(output) except Exception as e: logger.error(f"rag_graph failed: {str(e)}", exc_info=True) return f"Error: {str(e)}"