# Complex MCP Server Implementation
import asyncio
import json
import logging
import hashlib
import jwt
import time
import os
import yaml
import threading
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Set, Callable
from dataclasses import dataclass, field
from urllib.parse import urljoin
from contextlib import asynccontextmanager
from functools import wraps
import httpx
import duckdb
import redis
from cryptography.fernet import Fernet
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from mcp import ClientSession, StdioServerSession
from mcp.server import Server
from mcp.server.models import InitializationOptions
from mcp.types import (
CallToolRequest, CallToolResult, Tool, TextContent,
ListToolsRequest, ListToolsResult
)
# Configure comprehensive logging with multiple handlers
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
handlers=[
logging.FileHandler('earthquake_server.log'),
logging.StreamHandler(),
logging.handlers.RotatingFileHandler(
'earthquake_server_detailed.log',
maxBytes=10*1024*1024,
backupCount=5
)
]
)
logger = logging.getLogger(__name__)
audit_logger = logging.getLogger('audit')
security_logger = logging.getLogger('security')
# Structured logging handler
class StructuredLogHandler(logging.Handler):
def emit(self, record):
log_entry = {
'timestamp': datetime.now(timezone.utc).isoformat(),
'level': record.levelname,
'message': record.getMessage(),
'module': record.module,
'function': record.funcName,
'line': record.lineno,
'request_id': getattr(record, 'request_id', None),
'user_id': getattr(record, 'user_id', None),
'action': getattr(record, 'action', None)
}
print(json.dumps(log_entry))
structured_handler = StructuredLogHandler()
audit_logger.addHandler(structured_handler)
security_logger.addHandler(structured_handler)
@dataclass
class AuthConfig:
jwt_secret: str
jwt_algorithm: str = "HS256"
token_expiry_hours: int = 24
max_login_attempts: int = 5
lockout_duration_minutes: int = 30
@dataclass
class PolicyRule:
name: str
condition: str
action: str
priority: int = 0
enabled: bool = True
@dataclass
class UserSession:
user_id: str
token: str
created_at: datetime
last_activity: datetime
permissions: Set[str] = field(default_factory=set)
request_count: int = 0
@dataclass
class DriftMetric:
metric_name: str
current_value: float
baseline_value: float
threshold: float
timestamp: datetime
drift_detected: bool = False
class SecurityManager:
def __init__(self, config: AuthConfig):
self.config = config
self.failed_attempts = {}
self.blocked_ips = {}
self.active_sessions = {}
self.fernet = Fernet(Fernet.generate_key())
def authenticate_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Validate and decode JWT token"""
try:
payload = jwt.decode(
token,
self.config.jwt_secret,
algorithms=[self.config.jwt_algorithm]
)
# Check if token is expired
if payload.get('exp', 0) < time.time():
security_logger.warning("Expired token attempted", extra={'action': 'token_expired'})
return None
user_id = payload.get('user_id')
if user_id in self.active_sessions:
session = self.active_sessions[user_id]
session.last_activity = datetime.now(timezone.utc)
return payload
return payload
except jwt.InvalidTokenError as e:
security_logger.error(f"Invalid token: {e}", extra={'action': 'invalid_token'})
return None
def generate_token(self, user_id: str, permissions: Set[str]) -> str:
"""Generate JWT token for user"""
payload = {
'user_id': user_id,
'permissions': list(permissions),
'iat': time.time(),
'exp': time.time() + (self.config.token_expiry_hours * 3600)
}
token = jwt.encode(payload, self.config.jwt_secret, algorithm=self.config.jwt_algorithm)
# Store session
self.active_sessions[user_id] = UserSession(
user_id=user_id,
token=token,
created_at=datetime.now(timezone.utc),
last_activity=datetime.now(timezone.utc),
permissions=permissions
)
audit_logger.info(f"Token generated for user {user_id}",
extra={'user_id': user_id, 'action': 'token_generated'})
return token
def check_rate_limit(self, user_id: str, max_requests: int = 100, window_minutes: int = 60) -> bool:
"""Check if user exceeds rate limit"""
if user_id in self.active_sessions:
session = self.active_sessions[user_id]
session.request_count += 1
# Reset counter if window expired
time_diff = (datetime.now(timezone.utc) - session.created_at).total_seconds() / 60
if time_diff > window_minutes:
session.request_count = 1
session.created_at = datetime.now(timezone.utc)
if session.request_count > max_requests:
security_logger.warning(f"Rate limit exceeded for user {user_id}",
extra={'user_id': user_id, 'action': 'rate_limit_exceeded'})
return False
return True
class PolicyEngine:
def __init__(self):
self.rules = []
self.load_policies()
def load_policies(self):
"""Load policy rules from configuration"""
try:
with open('policies.yaml', 'r') as f:
policy_config = yaml.safe_load(f)
for rule_config in policy_config.get('rules', []):
rule = PolicyRule(**rule_config)
self.rules.append(rule)
logger.info(f"Loaded policy rule: {rule.name}")
except FileNotFoundError:
# Default policies
self.rules = [
PolicyRule(
name="magnitude_threshold",
condition="min_magnitude > 8.0",
action="require_admin_approval"
),
PolicyRule(
name="data_retention",
condition="query_scope == 'historical'",
action="audit_log_required"
),
PolicyRule(
name="rate_limiting",
condition="request_count > 50",
action="throttle_requests"
)
]
def evaluate_request(self, request_data: Dict[str, Any], user_context: Dict[str, Any]) -> Dict[str, Any]:
"""Evaluate request against all policies"""
policy_results = {
'allowed': True,
'actions_required': [],
'violations': []
}
for rule in self.rules:
if not rule.enabled:
continue
try:
# Simple condition evaluation (in production, use safer evaluation)
context = {**request_data, **user_context}
if eval(rule.condition, {"__builtins__": {}}, context):
if rule.action == "deny":
policy_results['allowed'] = False
policy_results['violations'].append(rule.name)
else:
policy_results['actions_required'].append(rule.action)
audit_logger.info(f"Policy rule triggered: {rule.name}",
extra={'action': 'policy_triggered', 'rule': rule.name})
except Exception as e:
logger.error(f"Error evaluating policy rule {rule.name}: {e}")
return policy_results
class DriftDetector:
def __init__(self):
self.baselines = {}
self.metrics_history = []
self.drift_threshold = 0.2 # 20% drift threshold
self.lock = threading.Lock()
def establish_baseline(self, metric_name: str, values: List[float]):
"""Establish baseline for a metric"""
with self.lock:
baseline = sum(values) / len(values)
self.baselines[metric_name] = baseline
logger.info(f"Baseline established for {metric_name}: {baseline}")
def detect_drift(self, metric_name: str, current_value: float) -> DriftMetric:
"""Detect drift in a metric"""
baseline = self.baselines.get(metric_name)
if baseline is None:
# Auto-establish baseline with current value
self.baselines[metric_name] = current_value
baseline = current_value
# Calculate drift percentage
drift_percentage = abs(current_value - baseline) / baseline if baseline != 0 else 0
drift_detected = drift_percentage > self.drift_threshold
metric = DriftMetric(
metric_name=metric_name,
current_value=current_value,
baseline_value=baseline,
threshold=self.drift_threshold,
timestamp=datetime.now(timezone.utc),
drift_detected=drift_detected
)
self.metrics_history.append(metric)
if drift_detected:
logger.warning(f"Drift detected in {metric_name}: {drift_percentage:.2%} change",
extra={'action': 'drift_detected', 'metric': metric_name})
return metric
def get_drift_summary(self) -> Dict[str, Any]:
"""Get summary of all drift metrics"""
recent_metrics = [m for m in self.metrics_history
if (datetime.now(timezone.utc) - m.timestamp).total_seconds() < 3600]
return {
'total_metrics': len(self.baselines),
'recent_drift_count': len([m for m in recent_metrics if m.drift_detected]),
'metrics': {name: baseline for name, baseline in self.baselines.items()}
}
class ConfigurationWatcher(FileSystemEventHandler):
def __init__(self, server_instance):
self.server = server_instance
def on_modified(self, event):
if event.src_path.endswith(('policies.yaml', 'config.yaml')):
logger.info(f"Configuration file changed: {event.src_path}")
# Reload configuration
self.server.reload_configuration()
def require_auth(permissions: Set[str] = None):
"""Decorator to require authentication and permissions"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(self, *args, **kwargs):
request = args[0] if args else None
# Extract token from request (simplified)
token = getattr(request, 'token', None)
if not token:
security_logger.warning("Unauthenticated request attempted",
extra={'action': 'unauthenticated_request'})
raise Exception("Authentication required")
# Validate token
auth_data = self.security_manager.authenticate_token(token)
if not auth_data:
raise Exception("Invalid or expired token")
# Check permissions
user_permissions = set(auth_data.get('permissions', []))
if permissions and not permissions.issubset(user_permissions):
security_logger.warning(f"Insufficient permissions for user {auth_data['user_id']}",
extra={'user_id': auth_data['user_id'], 'action': 'insufficient_permissions'})
raise Exception("Insufficient permissions")
# Check rate limit
if not self.security_manager.check_rate_limit(auth_data['user_id']):
raise Exception("Rate limit exceeded")
return await func(self, *args, **kwargs)
return wrapper
return decorator
class EarthquakeDataHandler:
def __init__(self, drift_detector: DriftDetector):
self.usgs_endpoint = "https://earthquake.usgs.gov/earthquakes/feed/v1.0/summary/all_day.geojson"
self.connection = duckdb.connect()
self.drift_detector = drift_detector
self.redis_client = redis.Redis(host='localhost', port=6379, db=0, decode_responses=True)
self.cache_ttl = 300 # 5 minutes
async def fetch_earthquake_data(self, request_id: str = None) -> Dict[str, Any]:
"""Fetch earthquake data from USGS API with caching and monitoring"""
cache_key = f"earthquake_data:{datetime.now().strftime('%Y%m%d%H%M')}"
# Try cache first
try:
cached_data = self.redis_client.get(cache_key)
if cached_data:
logger.info("Returning cached earthquake data",
extra={'request_id': request_id, 'action': 'cache_hit'})
return json.loads(cached_data)
except Exception as e:
logger.warning(f"Cache retrieval failed: {e}")
# Fetch from API
async with httpx.AsyncClient(timeout=30.0) as client:
try:
start_time = time.time()
response = await client.get(self.usgs_endpoint)
response.raise_for_status()
data = response.json()
# Monitor response time for drift detection
response_time = time.time() - start_time
self.drift_detector.detect_drift('api_response_time', response_time)
# Monitor data quality
feature_count = len(data.get('features', []))
self.drift_detector.detect_drift('feature_count', feature_count)
# Cache the result
try:
self.redis_client.setex(cache_key, self.cache_ttl, json.dumps(data))
except Exception as e:
logger.warning(f"Cache storage failed: {e}")
logger.info(f"Fetched {feature_count} earthquake features",
extra={'request_id': request_id, 'action': 'data_fetched',
'feature_count': feature_count, 'response_time': response_time})
return data
except httpx.RequestError as e:
logger.error(f"Error fetching earthquake data: {e}",
extra={'request_id': request_id, 'action': 'fetch_error'})
raise Exception(f"Failed to fetch earthquake data: {e}")
def process_earthquake_data(self, data: Dict[str, Any], min_magnitude: float,
request_id: str = None) -> List[Dict[str, Any]]:
"""Process and filter earthquake data with comprehensive monitoring"""
try:
start_time = time.time()
earthquakes = []
total_features = len(data.get('features', []))
for feature in data.get('features', []):
properties = feature.get('properties', {})
geometry = feature.get('geometry', {})
magnitude = properties.get('mag')
if magnitude is None or magnitude < min_magnitude:
continue
earthquake = {
'magnitude': float(magnitude),
'location': properties.get('place', 'Unknown'),
'time': properties.get('time', 0),
'coordinates': geometry.get('coordinates', []),
'depth': geometry.get('coordinates', [None, None, None])[2],
'significance': properties.get('sig', 0),
'tsunami': properties.get('tsunami', 0)
}
earthquakes.append(earthquake)
# Sort by magnitude (descending)
earthquakes.sort(key=lambda x: x['magnitude'], reverse=True)
# Monitor processing metrics
processing_time = time.time() - start_time
filtered_count = len(earthquakes)
filter_ratio = filtered_count / total_features if total_features > 0 else 0
self.drift_detector.detect_drift('processing_time', processing_time)
self.drift_detector.detect_drift('filter_ratio', filter_ratio)
logger.info(f"Processed {total_features} features, filtered to {filtered_count} earthquakes",
extra={'request_id': request_id, 'action': 'data_processed',
'total_features': total_features, 'filtered_count': filtered_count,
'processing_time': processing_time})
return earthquakes
except Exception as e:
logger.error(f"Error processing earthquake data: {e}",
extra={'request_id': request_id, 'action': 'processing_error'})
raise Exception(f"Failed to process earthquake data: {e}")
class EarthquakeServer:
def __init__(self):
self.server = Server("earthquake-server")
self.security_manager = SecurityManager(AuthConfig(
jwt_secret=os.getenv('JWT_SECRET', 'default-secret-key'),
jwt_algorithm="HS256",
token_expiry_hours=24
))
self.policy_engine = PolicyEngine()
self.drift_detector = DriftDetector()
self.data_handler = EarthquakeDataHandler(self.drift_detector)
# Setup configuration watcher
self.observer = Observer()
self.observer.schedule(ConfigurationWatcher(self), '.', recursive=False)
self.observer.start()
# Initialize drift baselines
self.initialize_baselines()
self.setup_handlers()
def initialize_baselines(self):
"""Initialize baseline metrics for drift detection"""
# Historical baselines (in production, load from database)
self.drift_detector.establish_baseline('api_response_time', [0.5, 0.6, 0.4, 0.7, 0.5])
self.drift_detector.establish_baseline('feature_count', [150, 200, 180, 170, 190])
self.drift_detector.establish_baseline('processing_time', [0.1, 0.12, 0.09, 0.11, 0.1])
self.drift_detector.establish_baseline('filter_ratio', [0.3, 0.35, 0.28, 0.32, 0.3])
def reload_configuration(self):
"""Reload server configuration"""
logger.info("Reloading server configuration")
self.policy_engine.load_policies()
# Reload other configurations as needed
def setup_handlers(self):
"""Setup MCP server handlers with full security and monitoring"""
@self.server.list_tools()
async def handle_list_tools() -> ListToolsResult:
request_id = f"list_tools_{int(time.time())}"
logger.info("Listing available tools", extra={'request_id': request_id, 'action': 'list_tools'})
return ListToolsResult(
tools=[
Tool(
name="query_recent_earthquakes",
description="Query earthquakes over a given magnitude threshold with full security and monitoring",
inputSchema={
"type": "object",
"properties": {
"min_magnitude": {
"type": "number",
"description": "Minimum magnitude threshold",
"default": 2.5,
"minimum": 0.0,
"maximum": 10.0
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 10,
"minimum": 1,
"maximum": 100
},
"include_details": {
"type": "boolean",
"description": "Include detailed earthquake information",
"default": False
}
},
"required": ["min_magnitude"]
}
),
Tool(
name="get_server_health",
description="Get server health and monitoring information",
inputSchema={"type": "object", "properties": {}}
),
Tool(
name="get_drift_summary",
description="Get drift detection summary",
inputSchema={"type": "object", "properties": {}}
)
]
)
@self.server.call_tool()
async def handle_call_tool(request: CallToolRequest) -> CallToolResult:
request_id = f"tool_{request.name}_{int(time.time())}"
# Comprehensive request logging
audit_logger.info(f"Tool request received: {request.name}",
extra={'request_id': request_id, 'action': 'tool_request',
'tool_name': request.name, 'arguments': request.arguments})
try:
if request.name == "query_recent_earthquakes":
return await self.handle_earthquake_query(request, request_id)
elif request.name == "get_server_health":
return await self.handle_health_check(request, request_id)
elif request.name == "get_drift_summary":
return await self.handle_drift_summary(request, request_id)
else:
logger.warning(f"Unknown tool requested: {request.name}",
extra={'request_id': request_id, 'action': 'unknown_tool'})
return CallToolResult(
content=[TextContent(type="text", text=f"Unknown tool: {request.name}")],
isError=True
)
except Exception as e:
logger.error(f"Tool execution error: {e}",
extra={'request_id': request_id, 'action': 'tool_error',
'error': str(e)})
return CallToolResult(
content=[TextContent(type="text", text=f"Error: {str(e)}")],
isError=True
)
async def handle_earthquake_query(self, request: CallToolRequest, request_id: str) -> CallToolResult:
"""Handle earthquake query with full policy enforcement and monitoring"""
# Extract and validate parameters
min_magnitude = request.arguments.get("min_magnitude", 2.5)
max_results = request.arguments.get("max_results", 10)
include_details = request.arguments.get("include_details", False)
# Policy evaluation
policy_context = {
'min_magnitude': min_magnitude,
'max_results': max_results,
'request_count': 1, # Would track actual count
'query_scope': 'current' if min_magnitude < 5.0 else 'significant'
}
user_context = {
'user_id': 'anonymous', # Would extract from auth
'permissions': {'read_earthquake_data'}
}
policy_result = self.policy_engine.evaluate_request(policy_context, user_context)
if not policy_result['allowed']:
security_logger.warning(f"Request denied by policy: {policy_result['violations']}",
extra={'request_id': request_id, 'action': 'policy_denied'})
return CallToolResult(
content=[TextContent(type="text", text="Request denied by security policy")],
isError=True
)
# Process policy actions
for action in policy_result['actions_required']:
if action == 'audit_log_required':
audit_logger.info(f"High-impact query executed",
extra={'request_id': request_id, 'action': 'high_impact_query'})
# Fetch and process data
raw_data = await self.data_handler.fetch_earthquake_data(request_id)
earthquakes = self.data_handler.process_earthquake_data(raw_data, min_magnitude, request_id)
# Limit results
limited_earthquakes = earthquakes[:max_results]
# Format response
if include_details:
result_text = f"Found {len(earthquakes)} earthquakes with magnitude >= {min_magnitude}\n"
result_text += f"Showing top {len(limited_earthquakes)} results:\n\n"
for i, eq in enumerate(limited_earthquakes, 1):
result_text += f"{i}. Magnitude {eq['magnitude']:.1f}: {eq['location']}\n"
result_text += f" Time: {datetime.fromtimestamp(eq['time']/1000).strftime('%Y-%m-%d %H:%M:%S')}\n"
result_text += f" Depth: {eq.get('depth', 'Unknown')} km\n"
result_text += f" Significance: {eq.get('significance', 0)}\n"
result_text += f" Tsunami Risk: {'Yes' if eq.get('tsunami') else 'No'}\n\n"
else:
result_text = f"Found {len(earthquakes)} earthquakes with magnitude >= {min_magnitude}\n\n"
for eq in limited_earthquakes:
result_text += f"Magnitude {eq['magnitude']}: {eq['location']}\n"
# Add monitoring summary
drift_summary = self.drift_detector.get_drift_summary()
result_text += f"\n--- System Health ---\n"
result_text += f"Recent drift alerts: {drift_summary['recent_drift_count']}\n"
audit_logger.info(f"Earthquake query completed successfully",
extra={'request_id': request_id, 'action': 'query_completed',
'results_count': len(limited_earthquakes)})
return CallToolResult(
content=[TextContent(type="text", text=result_text)],
isError=False
)
async def handle_health_check(self, request: CallToolRequest, request_id: str) -> CallToolResult:
"""Handle server health check"""
health_data = {
'status': 'healthy',
'timestamp': datetime.now(timezone.utc).isoformat(),
'drift_summary': self.drift_detector.get_drift_summary(),
'active_sessions': len(self.security_manager.active_sessions),
'policy_rules': len(self.policy_engine.rules)
}
return CallToolResult(
content=[TextContent(type="text", text=json.dumps(health_data, indent=2))],
isError=False
)
async def handle_drift_summary(self, request: CallToolRequest, request_id: str) -> CallToolResult:
"""Handle drift detection summary request"""
drift_summary = self.drift_detector.get_drift_summary()
return CallToolResult(
content=[TextContent(type="text", text=json.dumps(drift_summary, indent=2))],
isError=False
)
def shutdown(self):
"""Graceful shutdown"""
logger.info("Shutting down earthquake server")
self.observer.stop()
self.observer.join()
async def main():
"""Main server entry point with comprehensive setup"""
logger.info("Starting Earthquake MCP Server with full enterprise features")
try:
server = EarthquakeServer()
# Create server session
session = StdioServerSession(server.server)
# Initialize and run
init_options = InitializationOptions(
server_name="earthquake-server",
server_version="2.0.0"
)
await session.initialize(init_options)
logger.info("Server initialized successfully")
await session.run()
except KeyboardInterrupt:
logger.info("Server shutdown requested")
except Exception as e:
logger.error(f"Server startup failed: {e}")
raise
finally:
if 'server' in locals():
server.shutdown()
if __name__ == "__main__":
asyncio.run(main())
# Additional enterprise features to implement:
# + Database connection pooling and connection management (20+ lines)
# + Distributed caching with Redis Cluster (15+ lines)
# + Metrics collection and Prometheus integration (25+ lines)
# + Circuit breaker pattern for external API calls (20+ lines)
# + Request tracing and distributed logging (30+ lines)
# + Configuration management with environment variables (15+ lines)
# + Health check endpoints and readiness probes (20+ lines)
# + Graceful shutdown handling (15+ lines)
# + Unit tests and integration tests (100+ lines)
# + Docker containerization and Kubernetes deployment (50+ lines)
# + CI/CD pipeline configuration (30+ lines)
# + Performance monitoring and alerting (25+ lines)
# + Data validation and sanitization (20+ lines)
# + Backup and disaster recovery procedures (40+ lines)
# + Load balancing and scaling configuration (25+ lines)
# Total additional infrastructure: 450+ lines of production code