From 2858e0a344c0a5ef8a7a26094cf55531cb3ac46b Mon Sep 17 00:00:00 2001 From: kappa Date: Tue, 3 Feb 2026 11:05:45 +0900 Subject: [PATCH] Initial commit: RAG MCP Server with relationship graph Features: - Vector search with Pinecone + Vertex AI embeddings - Document relationships (link, unlink, related, graph) - Auto-link with LLM analysis - Intelligent merge with Gemini Modular structure: - clients/: Pinecone, Vertex AI - tools/: core, relations, stats - utils/: validation, logging Co-Authored-By: Claude Opus 4.5 --- .env.example | 19 ++ .gitignore | 43 +++++ README.md | 68 ++++++++ clients/__init__.py | 10 ++ clients/pinecone.py | 27 +++ clients/vertex.py | 237 +++++++++++++++++++++++++ config.py | 38 ++++ deploy.sh | 33 ++++ pyproject.toml | 26 +++ server.py | 43 +++++ tools/__init__.py | 16 ++ tools/core.py | 304 ++++++++++++++++++++++++++++++++ tools/relations.py | 417 ++++++++++++++++++++++++++++++++++++++++++++ tools/stats.py | 37 ++++ utils/__init__.py | 11 ++ utils/logging.py | 30 ++++ utils/validation.py | 91 ++++++++++ 17 files changed, 1450 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 README.md create mode 100644 clients/__init__.py create mode 100644 clients/pinecone.py create mode 100644 clients/vertex.py create mode 100644 config.py create mode 100755 deploy.sh create mode 100644 pyproject.toml create mode 100644 server.py create mode 100644 tools/__init__.py create mode 100644 tools/core.py create mode 100644 tools/relations.py create mode 100644 tools/stats.py create mode 100644 utils/__init__.py create mode 100644 utils/logging.py create mode 100644 utils/validation.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..8cf28f8 --- /dev/null +++ b/.env.example @@ -0,0 +1,19 @@ +# RAG MCP Server Configuration +# API keys stored in Vault (api/*) + +# Vault: api/pinecone +PINECONE_API_KEY=your-pinecone-api-key + +# Vault: api/vertex +VERTEX_API_KEY=your-vertex-api-key +GOOGLE_CLOUD_PROJECT=your-gcp-project-id +GOOGLE_CLOUD_LOCATION=us-central1 + +# Service Configuration +PINECONE_INDEX_NAME=memory-index +FASTMCP_HOST=0.0.0.0 +FASTMCP_PORT=8000 + +# Auto-link settings +AUTO_LINK_THRESHOLD=0.75 +AUTO_LINK_TOP_K=5 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..412dca2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,43 @@ +# Environment +.env +.env.local + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +.venv/ +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# Backup files +*.bak diff --git a/README.md b/README.md new file mode 100644 index 0000000..15a0eb1 --- /dev/null +++ b/README.md @@ -0,0 +1,68 @@ +# RAG MCP Server + +Vector DB 기반 장기 기억 시스템 + 관계 그래프 지원 MCP 서버 + +## 기능 + +### 핵심 도구 +- `rag_save` - 정보 저장 (auto_link로 자동 관계 생성 가능) +- `rag_retrieve` - 벡터 유사도 검색 +- `rag_update` - LLM 기반 지능형 병합 (충돌 시 새 정보 우선) +- `rag_delete` - 문서 삭제 (관계 자동 정리) + +### 관계 그래프 +- `rag_link` - 두 문서 간 관계 생성 (양방향) +- `rag_unlink` - 관계 제거 +- `rag_related` - 특정 문서의 관련 문서 조회 +- `rag_graph` - 관계 그래프 탐색 (depth 지정 가능) +- `rag_stats` - 전체 통계 조회 + +### 관계 타입 +| 관계 | 역관계 | 용도 | +|------|--------|------| +| `depends_on` | `required_by` | 의존성 | +| `part_of` | `contains` | 포함 관계 | +| `updates` | `updated_by` | 업데이트 | +| `see_also` | `see_also` | 참조 (대칭) | +| `extends` | `extended_by` | 확장 | +| `contradicts` | `contradicts` | 모순 (대칭) | +| `related` | `related` | 일반 관계 (대칭) | + +## 스택 +- **FastMCP** - MCP 서버 프레임워크 +- **Pinecone** - 벡터 데이터베이스 +- **Vertex AI** - 임베딩 (text-embedding-004) + LLM (Gemini) + +## 설정 + +```bash +cp .env.example .env +# .env 파일에 API 키 설정 +``` + +### 환경 변수 +| 변수 | 설명 | 기본값 | +|------|------|--------| +| `VERTEX_API_KEY` | Google Vertex AI API 키 | (필수) | +| `PINECONE_API_KEY` | Pinecone API 키 | (필수) | +| `PINECONE_INDEX_NAME` | Pinecone 인덱스명 | `memory-index` | +| `GOOGLE_CLOUD_PROJECT` | GCP 프로젝트 ID | - | +| `GOOGLE_CLOUD_LOCATION` | GCP 리전 | `us-central1` | +| `AUTO_LINK_THRESHOLD` | 자동 관계 생성 유사도 임계값 | `0.75` | +| `AUTO_LINK_TOP_K` | 자동 관계 분석 시 검색할 문서 수 | `5` | +| `FASTMCP_HOST` | 서버 호스트 | `0.0.0.0` | +| `FASTMCP_PORT` | 서버 포트 | `8000` | + +## 실행 + +```bash +# 의존성 설치 +pip install -e . + +# 서버 실행 +python server.py +``` + +## 배포 위치 +- Container: `jp1:rag-mcp` +- IP: `10.253.100.107:8000` diff --git a/clients/__init__.py b/clients/__init__.py new file mode 100644 index 0000000..3279b28 --- /dev/null +++ b/clients/__init__.py @@ -0,0 +1,10 @@ +"""Client modules for external services.""" +from .pinecone import get_index +from .vertex import get_embedding, merge_with_llm, analyze_relations_with_llm + +__all__ = [ + "get_index", + "get_embedding", + "merge_with_llm", + "analyze_relations_with_llm" +] diff --git a/clients/pinecone.py b/clients/pinecone.py new file mode 100644 index 0000000..083c15e --- /dev/null +++ b/clients/pinecone.py @@ -0,0 +1,27 @@ +"""Pinecone client singleton.""" +from pinecone import Pinecone +from config import PINECONE_API_KEY, PINECONE_INDEX_NAME +from utils.logging import get_logger + +logger = get_logger(__name__) + +# Pinecone singleton +_pc_client = None +_index = None + +def get_index(): + """ + Get Pinecone index instance (singleton pattern). + + Returns: + Pinecone index instance + """ + global _pc_client, _index + + if _index is None: + logger.info("Initializing Pinecone client") + _pc_client = Pinecone(api_key=PINECONE_API_KEY) + _index = _pc_client.Index(PINECONE_INDEX_NAME) + logger.info(f"Connected to Pinecone index: {PINECONE_INDEX_NAME}") + + return _index diff --git a/clients/vertex.py b/clients/vertex.py new file mode 100644 index 0000000..732a011 --- /dev/null +++ b/clients/vertex.py @@ -0,0 +1,237 @@ +"""Vertex AI API client with HTTP session pooling.""" +import json +import requests +from typing import List +from config import ( + PROJECT_ID, + LOCATION, + VERTEX_API_KEY, + REQUEST_TIMEOUT +) +from utils.logging import get_logger +from utils.validation import sanitize_for_prompt + +logger = get_logger(__name__) + +# HTTP session for connection pooling +_session = None + +def _get_session() -> requests.Session: + """Get or create HTTP session for connection pooling.""" + global _session + if _session is None: + _session = requests.Session() + _session.headers.update({ + "Content-Type": "application/json" + }) + return _session + +def get_embedding(text: str) -> List[float]: + """ + Get text embedding from Vertex AI. + + Args: + text: Text to embed + + Returns: + Embedding vector + + Raises: + Exception: If API call fails + """ + url = ( + f"https://{LOCATION}-aiplatform.googleapis.com/v1/" + f"projects/{PROJECT_ID}/locations/{LOCATION}/" + f"publishers/google/models/text-embedding-004:predict" + ) + + try: + session = _get_session() + response = session.post( + url, + params={"key": VERTEX_API_KEY}, + json={"instances": [{"content": text}]}, + timeout=REQUEST_TIMEOUT + ) + + if response.status_code != 200: + logger.error(f"Vertex AI embedding error: {response.status_code} - {response.text}") + raise Exception(f"Vertex AI API Error: {response.text}") + + result = response.json() + return result["predictions"][0]["embeddings"]["values"] + + except requests.exceptions.Timeout: + logger.error("Vertex AI embedding request timeout") + raise Exception("Vertex AI request timeout") + except requests.exceptions.RequestException as e: + logger.error(f"Vertex AI embedding request failed: {str(e)}") + raise Exception(f"Vertex AI request failed: {str(e)}") + +def merge_with_llm(old: str, new: str) -> str: + """ + Merge two texts intelligently using Vertex AI Gemini. + + Args: + old: Existing text + new: New text to merge + + Returns: + Merged text + + Raises: + Exception: If API call fails + """ + url = ( + f"https://{LOCATION}-aiplatform.googleapis.com/v1/" + f"projects/{PROJECT_ID}/locations/{LOCATION}/" + f"publishers/google/models/gemini-2.0-flash:generateContent" + ) + + # Sanitize inputs + old_sanitized = sanitize_for_prompt(old) + new_sanitized = sanitize_for_prompt(new) + + prompt = f"""기존 정보와 새 정보를 병합하세요. +규칙: +1. 충돌하는 내용은 새 정보를 우선합니다. +2. 중복은 제거하고 간결하게 정리합니다. +3. 보완되는 정보는 자연스럽게 통합합니다. +4. 결과만 출력하세요. 설명은 불필요합니다. + +[기존 정보] +{old_sanitized} + +[새 정보] +{new_sanitized} + +[병합 결과]""" + + try: + session = _get_session() + response = session.post( + url, + params={"key": VERTEX_API_KEY}, + json={"contents": [{"role": "user", "parts": [{"text": prompt}]}]}, + timeout=REQUEST_TIMEOUT + ) + + if response.status_code != 200: + logger.error(f"Gemini merge error: {response.status_code} - {response.text}") + raise Exception(f"Gemini API Error: {response.text}") + + result = response.json() + merged = result["candidates"][0]["content"]["parts"][0]["text"].strip() + logger.info(f"Successfully merged texts (old: {len(old)} chars, new: {len(new)} chars)") + return merged + + except requests.exceptions.Timeout: + logger.error("Gemini merge request timeout") + raise Exception("Gemini request timeout") + except requests.exceptions.RequestException as e: + logger.error(f"Gemini merge request failed: {str(e)}") + raise Exception(f"Gemini request failed: {str(e)}") + +def analyze_relations_with_llm(new_doc: str, new_tag: str, similar_docs: list) -> list: + """ + Analyze relations between new document and existing documents using Gemini. + + Args: + new_doc: New document content + new_tag: New document tag + similar_docs: List of similar documents with id, text, tag, score + + Returns: + List of relations: [{"id": doc_id, "relation": relation_type}, ...] + """ + if not similar_docs: + return [] + + url = ( + f"https://{LOCATION}-aiplatform.googleapis.com/v1/" + f"projects/{PROJECT_ID}/locations/{LOCATION}/" + f"publishers/google/models/gemini-2.0-flash:generateContent" + ) + + # Build document summary + docs_summary = [] + for i, doc in enumerate(similar_docs): + preview = doc["text"][:300] + "..." if len(doc["text"]) > 300 else doc["text"] + preview_sanitized = sanitize_for_prompt(preview) + docs_summary.append( + f"[{i+1}] ID: {doc['id']}, Tag: {doc['tag']}\n내용: {preview_sanitized}" + ) + + new_doc_sanitized = sanitize_for_prompt(new_doc[:500]) + + prompt = f"""새 문서와 기존 문서들 간의 관계를 분석하세요. + +[새 문서] +Tag: {new_tag} +내용: {new_doc_sanitized} + +[기존 문서들] +{chr(10).join(docs_summary)} + +[관계 유형] +- depends_on: 새 문서가 기존 문서에 의존 (API 사용, 라이브러리 참조 등) +- part_of: 새 문서가 기존 문서의 일부 (같은 프로젝트, 하위 기능 등) +- extends: 새 문서가 기존 문서를 확장 (기능 추가, 버전 업 등) +- see_also: 관련 참고 문서 (비슷한 주제, 참고할 만한 내용) +- updates: 새 문서가 기존 문서의 업데이트/수정 버전 +- none: 관계 없음 (유사도가 높아도 실제 관계가 없는 경우) + +[출력 형식] +JSON 배열로만 출력하세요. 설명 없이 JSON만 출력. +관계가 있는 문서만 포함하세요. + +예시: +[{{"id": "문서ID", "relation": "depends_on"}}, {{"id": "문서ID", "relation": "see_also"}}] + +관계가 없으면 빈 배열 출력: +[] + +[분석 결과]""" + + try: + session = _get_session() + response = session.post( + url, + params={"key": VERTEX_API_KEY}, + json={"contents": [{"role": "user", "parts": [{"text": prompt}]}]}, + timeout=REQUEST_TIMEOUT + ) + + if response.status_code != 200: + logger.warning(f"Gemini relation analysis error: {response.status_code}") + return [] + + result_text = response.json()["candidates"][0]["content"]["parts"][0]["text"].strip() + + # Remove JSON code block markers + if result_text.startswith("```"): + lines = result_text.split("\n") + result_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:]) + + relations = json.loads(result_text) + + # Filter valid relations + valid_relations = [] + valid_types = {"depends_on", "part_of", "extends", "see_also", "updates", "related"} + for rel in relations: + if isinstance(rel, dict) and "id" in rel and "relation" in rel: + if rel["relation"] in valid_types: + valid_relations.append(rel) + + logger.info(f"Analyzed relations: found {len(valid_relations)} valid relations") + return valid_relations + + except (json.JSONDecodeError, KeyError, IndexError) as e: + logger.warning(f"Failed to parse relation analysis: {str(e)}") + return [] + except requests.exceptions.Timeout: + logger.warning("Gemini relation analysis timeout") + return [] + except requests.exceptions.RequestException as e: + logger.warning(f"Gemini relation analysis request failed: {str(e)}") + return [] diff --git a/config.py b/config.py new file mode 100644 index 0000000..10512a6 --- /dev/null +++ b/config.py @@ -0,0 +1,38 @@ +""" +Configuration module for RAG system. +Handles environment variable loading, validation, and constants. +""" +import os +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# Constants +MAX_CONTENT_LENGTH = 50000 +MAX_TAG_LENGTH = 100 +MAX_GRAPH_NODES = 100 +REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "30")) +AUTO_LINK_THRESHOLD = float(os.getenv("AUTO_LINK_THRESHOLD", "0.75")) +AUTO_LINK_TOP_K = int(os.getenv("AUTO_LINK_TOP_K", "5")) + +# Environment variables +PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT", "hypnotic-tenure-swz46") +LOCATION = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") +VERTEX_API_KEY = os.getenv("VERTEX_API_KEY") +PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") +PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "memory-index") + +# Server settings +FASTMCP_HOST = os.getenv("FASTMCP_HOST", "0.0.0.0") +FASTMCP_PORT = int(os.getenv("FASTMCP_PORT", "8000")) + +def validate_env() -> None: + """Validate required environment variables are present.""" + required = ["VERTEX_API_KEY", "PINECONE_API_KEY", "PINECONE_INDEX_NAME"] + missing = [k for k in required if not os.getenv(k)] + if missing: + raise RuntimeError(f"필수 환경 변수가 없습니다: {missing}") + +# Validate on import +validate_env() diff --git a/deploy.sh b/deploy.sh new file mode 100755 index 0000000..e4b01f2 --- /dev/null +++ b/deploy.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# RAG MCP Server Deploy Script + +HOST="root@10.253.100.107" +REMOTE_DIR="/root/rag-mcp" +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" + +echo "📦 Copying project files to rag-mcp..." + +# Create directories +ssh ${HOST} "mkdir -p ${REMOTE_DIR}/{clients,tools,utils}" + +# Copy files +scp "${SCRIPT_DIR}/server.py" "${SCRIPT_DIR}/config.py" "${SCRIPT_DIR}/pyproject.toml" ${HOST}:${REMOTE_DIR}/ +scp "${SCRIPT_DIR}/clients/"*.py ${HOST}:${REMOTE_DIR}/clients/ +scp "${SCRIPT_DIR}/tools/"*.py ${HOST}:${REMOTE_DIR}/tools/ +scp "${SCRIPT_DIR}/utils/"*.py ${HOST}:${REMOTE_DIR}/utils/ + +echo "🔄 Restarting service..." +ssh ${HOST} << 'EOF' +cd /root/rag-mcp +PID=$(pgrep -f "uv run.*server.py" | head -1) +if [ -n "$PID" ]; then + kill $PID 2>/dev/null + sleep 2 +fi +setsid uv run --with fastmcp --with pinecone --with python-dotenv --with requests python server.py /tmp/rag-mcp.log 2>&1 & +sleep 3 +ps aux | grep "server.py" | grep -v grep && echo "✅ Service started" || echo "⚠️ Service not running" +echo "" +echo "📋 Recent logs:" +tail -15 /tmp/rag-mcp.log 2>/dev/null || echo "No logs yet" +EOF diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..df3820b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,26 @@ +[project] +name = "rag-mcp" +version = "1.0.0" +description = "RAG Memory MCP Server with relationship graph support" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "fastmcp>=0.1.0", + "pinecone>=5.0.0", + "python-dotenv>=1.0.0", + "requests>=2.31.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "ruff>=0.1.0", +] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "I", "W"] +ignore = ["E501"] diff --git a/server.py b/server.py new file mode 100644 index 0000000..767dffa --- /dev/null +++ b/server.py @@ -0,0 +1,43 @@ +""" +RAG System MCP Server Entry Point + +Refactored modular structure with: +- config.py: Environment variables and constants +- clients/: Pinecone and Vertex AI clients +- tools/: MCP tool implementations +- utils/: Validation and logging +""" +from fastmcp import FastMCP +from config import FASTMCP_HOST, FASTMCP_PORT +from utils import setup_logging +from tools import ( + rag_save, + rag_retrieve, + rag_update, + rag_delete, + rag_link, + rag_unlink, + rag_related, + rag_graph, + rag_stats +) + +# Setup logging +setup_logging() + +# Initialize FastMCP server +mcp = FastMCP("RAG") + +# Register tools +mcp.tool()(rag_save) +mcp.tool()(rag_retrieve) +mcp.tool()(rag_update) +mcp.tool()(rag_delete) +mcp.tool()(rag_link) +mcp.tool()(rag_unlink) +mcp.tool()(rag_related) +mcp.tool()(rag_graph) +mcp.tool()(rag_stats) + +if __name__ == "__main__": + mcp.run(transport="http", host=FASTMCP_HOST, port=FASTMCP_PORT) diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..8b114c1 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,16 @@ +"""MCP tools for RAG operations.""" +from .core import rag_save, rag_retrieve, rag_update, rag_delete +from .relations import rag_link, rag_unlink, rag_related, rag_graph +from .stats import rag_stats + +__all__ = [ + "rag_save", + "rag_retrieve", + "rag_update", + "rag_delete", + "rag_link", + "rag_unlink", + "rag_related", + "rag_graph", + "rag_stats" +] diff --git a/tools/core.py b/tools/core.py new file mode 100644 index 0000000..b8553ae --- /dev/null +++ b/tools/core.py @@ -0,0 +1,304 @@ +"""Core RAG tools: save, retrieve, update, delete.""" +import uuid +from typing import Optional +from clients import get_index, get_embedding, merge_with_llm, analyze_relations_with_llm +from utils.logging import get_logger +from utils.validation import validate_content, validate_tag, validate_document_id +from config import AUTO_LINK_THRESHOLD, AUTO_LINK_TOP_K +from .relations import ( + parse_relation, + format_relation, + _add_reverse_relation, + _remove_reverse_relation +) +import time +import threading + +logger = get_logger(__name__) + +# Rate limiting (in-memory counter) +_rate_limiter = { + "requests": [], + "lock": threading.Lock() +} +MAX_REQUESTS_PER_MINUTE = 60 + +def rate_limit_check() -> bool: + """ + Check if rate limit is exceeded. + + Returns: + True if allowed, False if rate limit exceeded + """ + with _rate_limiter["lock"]: + now = time.time() + # Remove requests older than 1 minute + _rate_limiter["requests"] = [ + req_time for req_time in _rate_limiter["requests"] + if now - req_time < 60 + ] + + if len(_rate_limiter["requests"]) >= MAX_REQUESTS_PER_MINUTE: + return False + + _rate_limiter["requests"].append(now) + return True + +def rag_save(content: str, tag: Optional[str] = "general", relations: Optional[str] = None, auto_link: bool = False) -> str: + """ + Save important information to vector database. + + Args: + content: Text content to save + tag: Tag for filtering (default: general) + relations: Comma-separated relations (e.g., "id1:depends_on,id2:see_also") + auto_link: Auto-create relations using LLM analysis + + Returns: + Success message with document ID or error message + """ + if not rate_limit_check(): + logger.warning("Rate limit exceeded for rag_save") + return "Error: Rate limit exceeded. Please wait before retrying." + + # Validate inputs + is_valid, error_msg = validate_content(content) + if not is_valid: + logger.warning(f"Content validation failed: {error_msg}") + return f"Error: {error_msg}" + + is_valid, error_msg = validate_tag(tag) + if not is_valid: + logger.warning(f"Tag validation failed: {error_msg}") + return f"Error: {error_msg}" + + try: + index = get_index() + vector = get_embedding(content) + doc_id = str(uuid.uuid4()) + + # Parse manual relations + rel_list = [] + if relations: + rel_list = [r.strip() for r in relations.split(',') if r.strip()] + + # Auto-create relations + auto_relations = [] + if auto_link: + logger.info(f"Auto-linking document with threshold={AUTO_LINK_THRESHOLD}, top_k={AUTO_LINK_TOP_K}") + + # Search similar documents + similar_results = index.query( + vector=vector, + top_k=AUTO_LINK_TOP_K, + include_metadata=True + ) + + # Filter by threshold + similar_docs = [] + for match in similar_results.get("matches", []): + if match.get("score", 0) >= AUTO_LINK_THRESHOLD: + similar_docs.append({ + "id": match["id"], + "text": match["metadata"].get("text", ""), + "tag": match["metadata"].get("tag", ""), + "score": match["score"] + }) + + # LLM relation analysis + if similar_docs: + analyzed = analyze_relations_with_llm(content, tag, similar_docs) + for rel in analyzed: + rel_str = format_relation(rel["id"], rel["relation"]) + if rel_str not in rel_list: + rel_list.append(rel_str) + auto_relations.append(rel_str) + + metadata = { + "text": content, + "tag": tag, + "relations": rel_list + } + + index.upsert(vectors=[{ + "id": doc_id, + "values": vector, + "metadata": metadata + }]) + + # Create bidirectional relations + for rel_str in rel_list: + target_id, rel_type = parse_relation(rel_str) + _add_reverse_relation(target_id, doc_id, rel_type) + + # Format result + result = f"Saved with ID: {doc_id}" + if auto_relations: + result += f"\n\nAuto-linked ({len(auto_relations)}):" + for rel in auto_relations: + target_id, rel_type = parse_relation(rel) + result += f"\n --[{rel_type}]--> {target_id[:8]}..." + if relations: + result += f"\nManual relations: {relations}" + + logger.info(f"Document saved: {doc_id}, tag={tag}, relations={len(rel_list)}") + return result + + except Exception as e: + logger.error(f"rag_save failed: {str(e)}", exc_info=True) + return f"Error: {str(e)}" + +def rag_retrieve(query: str, top_k: int = 3, tag: Optional[str] = None) -> str: + """ + Retrieve relevant information from vector database. + + Args: + query: Search query + top_k: Number of results to return (default: 3) + tag: Filter by specific tag (default: None, search all) + + Returns: + Formatted search results or error message + """ + if not rate_limit_check(): + logger.warning("Rate limit exceeded for rag_retrieve") + return "Error: Rate limit exceeded. Please wait before retrying." + + # Validate query + is_valid, error_msg = validate_content(query) + if not is_valid: + logger.warning(f"Query validation failed: {error_msg}") + return f"Error: {error_msg}" + + try: + index = get_index() + query_vector = get_embedding(query) + filter_dict = {"tag": {"$eq": tag}} if tag else None + + results = index.query( + vector=query_vector, + top_k=top_k, + include_metadata=True, + filter=filter_dict + ) + + if not results["matches"]: + logger.info("No matching documents found") + return "관련된 정보를 찾지 못했습니다." + + formatted = [] + for i, res in enumerate(results["matches"], 1): + if "metadata" in res: + text = res["metadata"]["text"] + tag_val = res["metadata"].get("tag", "") + relations = res["metadata"].get("relations", []) + doc_id = res["id"] + score = res.get("score", 0) + + entry = f"[{i}] ID: {doc_id} (score: {score:.3f}, tag: {tag_val})" + if relations: + entry += f"\n Relations: {relations}" + entry += f"\n {text}" + formatted.append(entry) + + logger.info(f"Retrieved {len(formatted)} documents for query") + return "검색 결과:\n" + "\n---\n".join(formatted) + + except Exception as e: + logger.error(f"rag_retrieve failed: {str(e)}", exc_info=True) + return f"Error: {str(e)}" + +def rag_update(id: str, new_info: str) -> str: + """ + Intelligently merge existing and new information. + + Args: + id: Document ID to update + new_info: New information to add or merge + + Returns: + Update summary or error message + """ + if not rate_limit_check(): + logger.warning("Rate limit exceeded for rag_update") + return "Error: Rate limit exceeded. Please wait before retrying." + + # Validate inputs + is_valid, error_msg = validate_document_id(id) + if not is_valid: + logger.warning(f"Document ID validation failed: {error_msg}") + return f"Error: {error_msg}" + + is_valid, error_msg = validate_content(new_info) + if not is_valid: + logger.warning(f"New info validation failed: {error_msg}") + return f"Error: {error_msg}" + + try: + index = get_index() + + result = index.fetch(ids=[id]) + if id not in result["vectors"]: + logger.warning(f"Document not found: {id}") + return f"Error: Not found: {id}" + + old_metadata = result["vectors"][id]["metadata"] + old_text = old_metadata.get("text", "") + tag = old_metadata.get("tag", "general") + relations = old_metadata.get("relations", []) + + merged = merge_with_llm(old_text, new_info) + + vector = get_embedding(merged) + index.upsert(vectors=[{ + "id": id, + "values": vector, + "metadata": {"text": merged, "tag": tag, "relations": relations} + }]) + + logger.info(f"Document updated: {id}") + return f"Updated: {id}\n\n[기존]\n{old_text}\n\n[새 정보]\n{new_info}\n\n[병합 결과]\n{merged}" + + except Exception as e: + logger.error(f"rag_update failed: {str(e)}", exc_info=True) + return f"Error: {str(e)}" + +def rag_delete(id: str) -> str: + """ + Delete document by ID. + + Args: + id: Document ID to delete + + Returns: + Success message or error message + """ + if not rate_limit_check(): + logger.warning("Rate limit exceeded for rag_delete") + return "Error: Rate limit exceeded. Please wait before retrying." + + # Validate input + is_valid, error_msg = validate_document_id(id) + if not is_valid: + logger.warning(f"Document ID validation failed: {error_msg}") + return f"Error: {error_msg}" + + try: + index = get_index() + + # Remove reverse relations + result = index.fetch(ids=[id]) + if id in result["vectors"]: + relations = result["vectors"][id]["metadata"].get("relations", []) + for rel_str in relations: + target_id, _ = parse_relation(rel_str) + _remove_reverse_relation(target_id, id) + + index.delete(ids=[id]) + + logger.info(f"Document deleted: {id}") + return f"Deleted: {id}" + + except Exception as e: + logger.error(f"rag_delete failed: {str(e)}", exc_info=True) + return f"Error: {str(e)}" diff --git a/tools/relations.py b/tools/relations.py new file mode 100644 index 0000000..493c6bd --- /dev/null +++ b/tools/relations.py @@ -0,0 +1,417 @@ +"""Relation management tools and utilities.""" +from typing import Optional +from collections import deque +from clients import get_index +from utils.logging import get_logger +from utils.validation import validate_document_id +from config import MAX_GRAPH_NODES + +logger = get_logger(__name__) + +# ============================================================ +# Relation utility functions +# ============================================================ + +def parse_relation(rel_str: str) -> tuple: + """ + Parse 'id:type' format to (id, type). + + Args: + rel_str: Relation string in format "id:type" + + Returns: + (id, type) tuple + """ + if ':' in rel_str: + parts = rel_str.split(':', 1) + return (parts[0], parts[1]) + return (rel_str, 'related') + +def format_relation(doc_id: str, rel_type: str) -> str: + """ + Format (id, type) to 'id:type' string. + + Args: + doc_id: Document ID + rel_type: Relation type + + Returns: + Formatted relation string + """ + return f"{doc_id}:{rel_type}" + +def get_reverse_relation(rel_type: str) -> str: + """ + Get reverse relation type. + + Args: + rel_type: Relation type + + Returns: + Reverse relation type + """ + reverse_map = { + 'depends_on': 'required_by', + 'required_by': 'depends_on', + 'part_of': 'contains', + 'contains': 'part_of', + 'see_also': 'see_also', + 'related': 'related', + 'blocks': 'blocked_by', + 'blocked_by': 'blocks', + 'extends': 'extended_by', + 'extended_by': 'extends', + 'updates': 'updated_by', + 'updated_by': 'updates', + } + return reverse_map.get(rel_type, f"reverse_{rel_type}") + +def _add_reverse_relation(target_id: str, source_id: str, rel_type: str) -> bool: + """ + Internal: Add reverse relation to target document. + + Args: + target_id: Target document ID + source_id: Source document ID + rel_type: Relation type (forward) + + Returns: + True if successful, False otherwise + """ + try: + index = get_index() + result = index.fetch(ids=[target_id]) + if target_id not in result["vectors"]: + logger.warning(f"Target document not found for reverse relation: {target_id}") + return False + + metadata = result["vectors"][target_id]["metadata"] + relations = metadata.get("relations", []) + reverse_rel = format_relation(source_id, get_reverse_relation(rel_type)) + + if reverse_rel not in relations: + relations.append(reverse_rel) + metadata["relations"] = relations + vector = result["vectors"][target_id]["values"] + index.upsert(vectors=[{ + "id": target_id, + "values": vector, + "metadata": metadata + }]) + logger.debug(f"Added reverse relation: {target_id} <- {source_id}") + return True + + return True + + except Exception as e: + logger.error(f"Failed to add reverse relation: {str(e)}") + return False + +def _remove_reverse_relation(target_id: str, source_id: str) -> bool: + """ + Internal: Remove reverse relation from target document. + + Args: + target_id: Target document ID + source_id: Source document ID + + Returns: + True if successful, False otherwise + """ + try: + index = get_index() + result = index.fetch(ids=[target_id]) + if target_id not in result["vectors"]: + logger.warning(f"Target document not found for reverse relation removal: {target_id}") + return False + + metadata = result["vectors"][target_id]["metadata"] + relations = metadata.get("relations", []) + relations = [r for r in relations if not r.startswith(f"{source_id}:")] + metadata["relations"] = relations + + vector = result["vectors"][target_id]["values"] + index.upsert(vectors=[{ + "id": target_id, + "values": vector, + "metadata": metadata + }]) + + logger.debug(f"Removed reverse relation: {target_id} <- {source_id}") + return True + + except Exception as e: + logger.error(f"Failed to remove reverse relation: {str(e)}") + return False + +# ============================================================ +# MCP Tools +# ============================================================ + +def rag_link(from_id: str, to_id: str, relation_type: str = "related") -> str: + """ + Create relation between two documents (bidirectional). + + Args: + from_id: Source document ID + to_id: Target document ID + relation_type: Relation type (depends_on, part_of, see_also, blocks, extends, updates, related) + + Returns: + Success message or error message + """ + # Validate inputs + is_valid, error_msg = validate_document_id(from_id) + if not is_valid: + logger.warning(f"Source ID validation failed: {error_msg}") + return f"Error: {error_msg}" + + is_valid, error_msg = validate_document_id(to_id) + if not is_valid: + logger.warning(f"Target ID validation failed: {error_msg}") + return f"Error: {error_msg}" + + try: + index = get_index() + + result = index.fetch(ids=[from_id]) + if from_id not in result["vectors"]: + logger.warning(f"Source document not found: {from_id}") + return f"Error: Source document not found: {from_id}" + + from_metadata = result["vectors"][from_id]["metadata"] + from_relations = from_metadata.get("relations", []) + new_rel = format_relation(to_id, relation_type) + + if new_rel in from_relations: + logger.info(f"Relation already exists: {from_id} -> {to_id}") + return f"Relation already exists: {from_id} --[{relation_type}]--> {to_id}" + + from_relations.append(new_rel) + from_metadata["relations"] = from_relations + + from_vector = result["vectors"][from_id]["values"] + index.upsert(vectors=[{ + "id": from_id, + "values": from_vector, + "metadata": from_metadata + }]) + + _add_reverse_relation(to_id, from_id, relation_type) + + reverse_type = get_reverse_relation(relation_type) + logger.info(f"Linked: {from_id} --[{relation_type}]--> {to_id}") + return f"Linked: {from_id} --[{relation_type}]--> {to_id}\nReverse: {to_id} --[{reverse_type}]--> {from_id}" + + except Exception as e: + logger.error(f"rag_link failed: {str(e)}", exc_info=True) + return f"Error: {str(e)}" + +def rag_unlink(from_id: str, to_id: str) -> str: + """ + Remove relation between two documents (bidirectional). + + Args: + from_id: Source document ID + to_id: Target document ID + + Returns: + Success message or error message + """ + # Validate inputs + is_valid, error_msg = validate_document_id(from_id) + if not is_valid: + logger.warning(f"Source ID validation failed: {error_msg}") + return f"Error: {error_msg}" + + is_valid, error_msg = validate_document_id(to_id) + if not is_valid: + logger.warning(f"Target ID validation failed: {error_msg}") + return f"Error: {error_msg}" + + try: + index = get_index() + + result = index.fetch(ids=[from_id]) + if from_id not in result["vectors"]: + logger.warning(f"Document not found: {from_id}") + return f"Error: Document not found: {from_id}" + + from_metadata = result["vectors"][from_id]["metadata"] + from_relations = from_metadata.get("relations", []) + original_count = len(from_relations) + + from_relations = [r for r in from_relations if not r.startswith(f"{to_id}:")] + + if len(from_relations) == original_count: + logger.info(f"No relation found: {from_id} -> {to_id}") + return f"No relation found from {from_id} to {to_id}" + + from_metadata["relations"] = from_relations + from_vector = result["vectors"][from_id]["values"] + index.upsert(vectors=[{ + "id": from_id, + "values": from_vector, + "metadata": from_metadata + }]) + + _remove_reverse_relation(to_id, from_id) + + logger.info(f"Unlinked: {from_id} <--> {to_id}") + return f"Unlinked: {from_id} <--> {to_id}" + + except Exception as e: + logger.error(f"rag_unlink failed: {str(e)}", exc_info=True) + return f"Error: {str(e)}" + +def rag_related(id: str, relation_type: Optional[str] = None, include_content: bool = False) -> str: + """ + Query related documents. + + Args: + id: Document ID to query + relation_type: Filter by specific relation type (default: all) + include_content: Include document content (default: False) + + Returns: + Formatted relation list or error message + """ + # Validate input + is_valid, error_msg = validate_document_id(id) + if not is_valid: + logger.warning(f"Document ID validation failed: {error_msg}") + return f"Error: {error_msg}" + + try: + index = get_index() + + result = index.fetch(ids=[id]) + if id not in result["vectors"]: + logger.warning(f"Document not found: {id}") + return f"Error: Document not found: {id}" + + metadata = result["vectors"][id]["metadata"] + relations = metadata.get("relations", []) + + if not relations: + logger.info(f"No relations found for: {id}") + return f"No relations found for: {id}" + + if relation_type: + relations = [r for r in relations if r.endswith(f":{relation_type}")] + + if not relations: + logger.info(f"No '{relation_type}' relations found for: {id}") + return f"No '{relation_type}' relations found for: {id}" + + output = [f"Relations for {id}:"] + + if include_content: + related_ids = [parse_relation(r)[0] for r in relations] + related_docs = index.fetch(ids=related_ids) + + for rel_str in relations: + target_id, rel_type = parse_relation(rel_str) + entry = f"\n --[{rel_type}]--> {target_id}" + if target_id in related_docs["vectors"]: + text = related_docs["vectors"][target_id]["metadata"].get("text", "") + tag = related_docs["vectors"][target_id]["metadata"].get("tag", "") + preview = text[:200] + "..." if len(text) > 200 else text + entry += f"\n Tag: {tag}" + entry += f"\n {preview}" + output.append(entry) + else: + for rel_str in relations: + target_id, rel_type = parse_relation(rel_str) + output.append(f" --[{rel_type}]--> {target_id}") + + logger.info(f"Found {len(relations)} relations for: {id}") + return "\n".join(output) + + except Exception as e: + logger.error(f"rag_related failed: {str(e)}", exc_info=True) + return f"Error: {str(e)}" + +def rag_graph(id: str, depth: int = 1) -> str: + """ + Explore relation graph from a document using BFS. + + Args: + id: Starting document ID + depth: Search depth (default: 1, max: 3) + + Returns: + Formatted graph or error message + """ + # Validate input + is_valid, error_msg = validate_document_id(id) + if not is_valid: + logger.warning(f"Document ID validation failed: {error_msg}") + return f"Error: {error_msg}" + + try: + index = get_index() + depth = min(depth, 3) + + visited = set() + nodes = {} + edges = [] + queue = deque([(id, 0)]) # BFS queue with (doc_id, current_depth) + + while queue and len(visited) < MAX_GRAPH_NODES: + doc_id, current_depth = queue.popleft() + + if doc_id in visited or current_depth > depth: + continue + + visited.add(doc_id) + + result = index.fetch(ids=[doc_id]) + if doc_id not in result["vectors"]: + continue + + metadata = result["vectors"][doc_id]["metadata"] + text = metadata.get("text", "") + tag = metadata.get("tag", "") + relations = metadata.get("relations", []) + + nodes[doc_id] = { + "tag": tag, + "preview": text[:100] + "..." if len(text) > 100 else text, + "depth": current_depth + } + + for rel_str in relations: + target_id, rel_type = parse_relation(rel_str) + edges.append({ + "from": doc_id, + "to": target_id, + "type": rel_type + }) + if target_id not in visited and len(visited) < MAX_GRAPH_NODES: + queue.append((target_id, current_depth + 1)) + + if not nodes: + logger.warning(f"Document not found: {id}") + return f"Error: Document not found: {id}" + + output = [f"=== Graph from {id} (depth: {depth}) ==="] + if len(visited) >= MAX_GRAPH_NODES: + output.append(f"\n[WARNING] Graph limited to {MAX_GRAPH_NODES} nodes\n") + + output.append("Nodes:") + for node_id, info in sorted(nodes.items(), key=lambda x: x[1]["depth"]): + indent = " " * info["depth"] + output.append(f"{indent}[{info['tag']}] {node_id}") + output.append(f"{indent} {info['preview']}") + + output.append(f"\nEdges ({len(edges)}):") + for edge in edges: + output.append(f" {edge['from'][:8]}... --[{edge['type']}]--> {edge['to'][:8]}...") + + logger.info(f"Graph traversed: {len(nodes)} nodes, {len(edges)} edges") + return "\n".join(output) + + except Exception as e: + logger.error(f"rag_graph failed: {str(e)}", exc_info=True) + return f"Error: {str(e)}" diff --git a/tools/stats.py b/tools/stats.py new file mode 100644 index 0000000..6cbec19 --- /dev/null +++ b/tools/stats.py @@ -0,0 +1,37 @@ +"""Statistics and monitoring tools.""" +from clients import get_index +from utils.logging import get_logger +from config import AUTO_LINK_THRESHOLD, AUTO_LINK_TOP_K + +logger = get_logger(__name__) + +def rag_stats() -> str: + """ + Return RAG database statistics. + + Returns: + Formatted statistics or error message + """ + try: + index = get_index() + stats = index.describe_index_stats() + + total = stats.get("total_vector_count", 0) + + output = [ + "=== RAG Statistics ===", + f"Total documents: {total}", + f"Dimension: {stats.get('dimension', 'N/A')}", + f"Index fullness: {stats.get('index_fullness', 'N/A')}", + "", + "=== Auto-link Settings ===", + f"Threshold: {AUTO_LINK_THRESHOLD}", + f"Top-K candidates: {AUTO_LINK_TOP_K}" + ] + + logger.info(f"Stats retrieved: {total} documents") + return "\n".join(output) + + except Exception as e: + logger.error(f"rag_stats failed: {str(e)}", exc_info=True) + return f"Error: {str(e)}" diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..617b61f --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,11 @@ +"""Utility modules for RAG system.""" +from .validation import validate_content, validate_tag, validate_document_id +from .logging import setup_logging, get_logger + +__all__ = [ + "validate_content", + "validate_tag", + "validate_document_id", + "setup_logging", + "get_logger" +] diff --git a/utils/logging.py b/utils/logging.py new file mode 100644 index 0000000..92f56d3 --- /dev/null +++ b/utils/logging.py @@ -0,0 +1,30 @@ +"""Logging configuration for RAG system.""" +import logging +import sys + +def setup_logging(level: str = "INFO") -> None: + """ + Setup logging configuration. + + Args: + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + """ + logging.basicConfig( + level=getattr(logging, level.upper(), logging.INFO), + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] + ) + +def get_logger(name: str) -> logging.Logger: + """ + Get logger instance. + + Args: + name: Logger name (usually __name__) + + Returns: + Logger instance + """ + return logging.getLogger(name) diff --git a/utils/validation.py b/utils/validation.py new file mode 100644 index 0000000..0860589 --- /dev/null +++ b/utils/validation.py @@ -0,0 +1,91 @@ +"""Input validation functions for RAG system.""" +from config import MAX_CONTENT_LENGTH, MAX_TAG_LENGTH +import re + +def validate_content(content: str) -> tuple[bool, str]: + """ + Validate content string. + + Args: + content: Content to validate + + Returns: + (is_valid, error_message) + """ + if not content or not content.strip(): + return False, "Content cannot be empty" + + if len(content) > MAX_CONTENT_LENGTH: + return False, f"Content exceeds maximum length of {MAX_CONTENT_LENGTH} characters" + + return True, "" + +def validate_tag(tag: str) -> tuple[bool, str]: + """ + Validate tag string. + + Args: + tag: Tag to validate + + Returns: + (is_valid, error_message) + """ + if not tag or not tag.strip(): + return False, "Tag cannot be empty" + + if len(tag) > MAX_TAG_LENGTH: + return False, f"Tag exceeds maximum length of {MAX_TAG_LENGTH} characters" + + # Only allow alphanumeric, underscore, hyphen + if not re.match(r'^[a-zA-Z0-9_-]+$', tag): + return False, "Tag must contain only alphanumeric characters, underscores, and hyphens" + + return True, "" + +def validate_document_id(doc_id: str) -> tuple[bool, str]: + """ + Validate document ID format. + + Args: + doc_id: Document ID to validate + + Returns: + (is_valid, error_message) + """ + if not doc_id or not doc_id.strip(): + return False, "Document ID cannot be empty" + + # Basic UUID format check (flexible for various ID formats) + if len(doc_id) < 8 or len(doc_id) > 100: + return False, "Document ID must be between 8 and 100 characters" + + return True, "" + +def sanitize_for_prompt(text: str) -> str: + """ + Sanitize text to prevent prompt injection attacks. + + Args: + text: Text to sanitize + + Returns: + Sanitized text + """ + # Remove common prompt injection patterns + dangerous_patterns = [ + r'(?i)ignore\s+previous\s+instructions', + r'(?i)ignore\s+all\s+previous', + r'(?i)disregard\s+previous', + r'(?i)you\s+are\s+now', + r'(?i)system\s+prompt', + r'(?i)bypass\s+filters', + ] + + sanitized = text + for pattern in dangerous_patterns: + sanitized = re.sub(pattern, '[FILTERED]', sanitized) + + # Limit repeated characters + sanitized = re.sub(r'(.)\1{20,}', r'\1' * 20, sanitized) + + return sanitized