Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion multimind/agents/agent_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def load_agent(
max_history=memory_config.get("max_history", 100)
)

# Create agen
# Create agent
agent = Agent(
model=model,
memory=memory,
Expand Down
59 changes: 50 additions & 9 deletions multimind/agents/prompt_correction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Callable, Any, Dict, List
import logging

logger = logging.getLogger(__name__)


class PromptCorrectionLayer:
"""
Observability and self-healing layer for LLM/agent pipelines.
Expand All @@ -19,23 +22,61 @@ def add_correction_hook(self, hook: Callable[[str, Dict], str]):
def add_adapter_update_hook(self, hook: Callable[[str, str], None]):
self.adapter_update_hooks.append(hook)

def _compute_issue_score(self, prompt: str, output: str, trace: Dict) -> float:
"""
Heuristic scoring function for potential issues / hallucinations.
Returns a score in [0, 1], where higher means more suspicious.
"""
text = output.lower()
score = 0.0

# Strong indicators
strong_markers = [
"[error]", "hallucination", "not based on real data",
"fabricated answer", "made this up"
]
if any(marker in text for marker in strong_markers):
score += 0.7

# Weaker indicators based on uncertainty phrases
weak_markers = [
"i am not sure", "i'm not sure", "i do not know",
"i don't know", "cannot verify", "not certain"
]
if any(marker in text for marker in weak_markers):
score += 0.2

# If trace provides an explicit model_score / confidence, incorporate it.
# Expecting trace.get("confidence") in [0, 1] where low is suspicious.
confidence = trace.get("confidence")
if isinstance(confidence, (int, float)):
confidence_clamped = max(0.0, min(1.0, float(confidence)))
score += (1.0 - confidence_clamped) * 0.3

return min(score, 1.0)

def monitor(self, prompt: str, output: str, trace: Dict = None) -> str:
"""
Monitor output for errors/hallucinations and apply corrections if needed.
Uses a heuristic score instead of a single string check.
"""
trace = trace or {}
try:
# Example: simple hallucination check (can be replaced with real logic)
if "[error]" in output or "hallucination" in output.lower():
self.logger.warning(f"Detected issue in output: {output}")
issue_score = self._compute_issue_score(prompt, output, trace)
threshold = trace.get("hallucination_threshold", 0.6)
if issue_score >= threshold:
self.logger.warning(
"Detected potential hallucination (score=%.2f, threshold=%.2f): %s",
issue_score,
threshold,
output,
)
for hook in self.error_hooks:
hook(prompt, Exception("Detected hallucination"), trace)
# Apply correction hooks to the *output* and return corrected output.
# (Correction hooks are expected to take a string and trace, and return a string.)
corrected_output = output
for hook in self.correction_hooks:
corrected_output = hook(corrected_output, trace)
self.logger.info(f"Corrected output: {corrected_output}")
self.logger.info("Corrected output: %s", corrected_output)
return corrected_output
return output
except Exception as e:
Expand All @@ -54,15 +95,15 @@ def update_adapter(self, adapter_key: str, new_adapter_path: str):
if __name__ == "__main__":
pcl = PromptCorrectionLayer()
def error_logger(prompt, exc, trace):
print(f"Error detected for prompt '{prompt}': {exc}")
logger.error("Error detected for prompt '%s': %s", prompt, exc)
def simple_correction(prompt, trace):
return prompt + " [CORRECTED]"
def adapter_updater(adapter_key, new_path):
print(f"Adapter {adapter_key} updated to {new_path}")
logger.info("Adapter %s updated to %s", adapter_key, new_path)
pcl.add_error_hook(error_logger)
pcl.add_correction_hook(simple_correction)
pcl.add_adapter_update_hook(adapter_updater)
# Simulate monitoring
corrected_output = pcl.monitor("What is the capital of France?", "[error] hallucination detected", {"step": 1})
print("Corrected output after correction:", corrected_output)
logger.info("Corrected output after correction: %s", corrected_output)
pcl.update_adapter("user123", "lora_adapter_v2")
22 changes: 18 additions & 4 deletions multimind/api/multi_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

import logging
from fastapi import FastAPI, HTTPException
import os
from fastapi import FastAPI, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Union
import asyncio
Expand All @@ -16,6 +17,19 @@
app = FastAPI(title="Multi-Model API")
logger = logging.getLogger(__name__)

API_KEYS = os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else []


def verify_api_key(api_key: Optional[str] = Header(None, alias="X-API-Key")) -> bool:
"""Verify the API key from request header."""
if not API_KEYS:
return True
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
if api_key not in API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API key")
return True

# Reuse a single factory across requests to avoid re-loading env / re-allocating caches.
_MODEL_FACTORY = ModelFactory()

Expand Down Expand Up @@ -79,7 +93,7 @@ class EmbeddingsRequest(BaseModel):
model_weights: Optional[Dict[str, float]] = None

@app.post("/generate")
async def generate(request: GenerateRequest):
async def generate(request: GenerateRequest, authenticated: bool = Depends(verify_api_key)):
"""Generate text using the multi-model wrapper."""
try:
multi_model = await _get_multi_model(
Expand All @@ -99,7 +113,7 @@ async def generate(request: GenerateRequest):
raise HTTPException(status_code=500, detail="Internal server error")

@app.post("/chat")
async def chat(request: ChatRequest):
async def chat(request: ChatRequest, authenticated: bool = Depends(verify_api_key)):
"""Generate chat completion using the multi-model wrapper."""
try:
multi_model = await _get_multi_model(
Expand All @@ -119,7 +133,7 @@ async def chat(request: ChatRequest):
raise HTTPException(status_code=500, detail="Internal server error")

@app.post("/embeddings")
async def embeddings(request: EmbeddingsRequest):
async def embeddings(request: EmbeddingsRequest, authenticated: bool = Depends(verify_api_key)):
"""Generate embeddings using the multi-model wrapper."""
try:
multi_model = await _get_multi_model(
Expand Down
69 changes: 46 additions & 23 deletions multimind/api/unified_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Unified API endpoint for multi-modal processing with MoE support.
"""

from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from typing import Dict, List, Any, Optional, Union
import asyncio
Expand All @@ -19,9 +19,41 @@

app = FastAPI(title="Unified Multi-Modal API")

API_KEYS = os.getenv("API_KEYS", "").split(",") if os.getenv("API_KEYS") else []


def verify_api_key(api_key: Optional[str] = Header(None, alias="X-API-Key")) -> bool:
"""Verify the API key from request header."""
if not API_KEYS:
return True
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
if api_key not in API_KEYS:
raise HTTPException(status_code=401, detail="Invalid API key")
return True

# Reuse a single factory across requests to avoid re-creating model caches.
_MODEL_FACTORY = ModelFactory()

_ROUTER = None
_WORKFLOW_REGISTRY = None


def _get_router():
global _ROUTER
if _ROUTER is None:
from ..router.multi_modal_router import MultiModalRouter
_ROUTER = MultiModalRouter()
return _ROUTER


def _get_workflow_registry():
global _WORKFLOW_REGISTRY
if _WORKFLOW_REGISTRY is None:
from .mcp.registry import WorkflowRegistry
_WORKFLOW_REGISTRY = WorkflowRegistry()
return _WORKFLOW_REGISTRY


class _TextExpertAdapter(Expert):
"""Expert wrapper around a model instance for text."""
Expand Down Expand Up @@ -169,16 +201,13 @@ def _build_experts(modalities: List[str], router: Any) -> Dict[str, Expert]:
return experts

@app.post("/v1/process", response_model=UnifiedResponse)
async def process_request(request: UnifiedRequest):
async def process_request(request: UnifiedRequest, authenticated: bool = Depends(verify_api_key)):
"""Process multi-modal request using either MoE or router."""
try:
# Import here to avoid circular imports
from ..router.multi_modal_router import MultiModalRouter, MultiModalRequest
from .mcp.registry import WorkflowRegistry

# Initialize components
router = MultiModalRouter()
workflow_registry = WorkflowRegistry()
from ..router.multi_modal_router import MultiModalRequest

router = _get_router()
workflow_registry = _get_workflow_registry()

# Convert inputs to router format (support multiple inputs per modality)
content: Dict[str, Any] = {}
Expand Down Expand Up @@ -288,32 +317,26 @@ async def process_request(request: UnifiedRequest):
raise HTTPException(status_code=500, detail="Internal server error")

@app.get("/v1/models")
async def list_models():
async def list_models(authenticated: bool = Depends(verify_api_key)):
"""List available models and their capabilities."""
# Import here to avoid circular imports
from ..router.multi_modal_router import MultiModalRouter
router = MultiModalRouter()

router = _get_router()

models = {}
for modality, model_dict in router.modality_registry.items():
models[modality] = list(model_dict.keys())
return {"models": models}

@app.get("/v1/workflows")
async def list_workflows():
async def list_workflows(authenticated: bool = Depends(verify_api_key)):
"""List available MCP workflows."""
# Import here to avoid circular imports
from .mcp.registry import WorkflowRegistry
workflow_registry = WorkflowRegistry()
workflow_registry = _get_workflow_registry()
return {"workflows": workflow_registry.list_workflows()}

@app.get("/v1/metrics")
async def get_metrics():
async def get_metrics(authenticated: bool = Depends(verify_api_key)):
"""Get performance metrics for models."""
# Import here to avoid circular imports
from ..router.multi_modal_router import MultiModalRouter
router = MultiModalRouter()

router = _get_router()

return {
"costs": router.cost_tracker.costs,
"performance": router.performance_metrics.metrics
Expand Down
8 changes: 6 additions & 2 deletions multimind/client/federated_router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Callable, Dict, Any
import logging
import time

logger = logging.getLogger(__name__)


class FederatedRouter:
"""
Routes between local (on-device) and cloud model clients based on context (input size, latency, privacy, etc.).
Expand Down Expand Up @@ -41,5 +45,5 @@ def generate(self, prompt, **kwargs):
local = DummyClient()
cloud = DummyClient()
router = FederatedRouter(local, cloud)
print(router.generate("short prompt"))
print(router.generate("This is a very long prompt that should go to the cloud..." * 20))
logger.info("%s", router.generate("short prompt"))
logger.info("%s", router.generate("This is a very long prompt that should go to the cloud..." * 20))
22 changes: 18 additions & 4 deletions multimind/compliance/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def update_epsilon(self, epsilon: float):
from datetime import datetime
import json
import asyncio
import hashlib
from pathlib import Path
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -291,11 +292,12 @@ async def check_and_heal(self, compliance_state: Dict[str, Any]) -> Dict[str, An

def _get_state_metadata(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""Get metadata for a compliance state."""
state_bytes = json.dumps(state, sort_keys=True, default=str).encode("utf-8")
return {
"status": state.get("status", "unknown"),
"timestamp": datetime.now().isoformat(),
"version": state.get("version", "1.0"),
"checksum": hash(str(state))
"checksum": hashlib.sha256(state_bytes).hexdigest()
}

def _create_rollback_point(self, state: Dict[str, Any]):
Expand Down Expand Up @@ -497,8 +499,13 @@ async def _extract_watermark(self, model: Any) -> str:

async def _generate_fingerprint(self, model: Any) -> str:
"""Generate fingerprint for the model."""
# Placeholder implementation: Replace with actual fingerprint generation logic
return f"fingerprint_{hash(str(model))}"
# Deterministic cryptographic fingerprint for model identity.
model_payload = {
"type": type(model).__name__,
"repr": repr(model),
}
model_bytes = json.dumps(model_payload, sort_keys=True, default=str).encode("utf-8")
return f"fingerprint_{hashlib.sha256(model_bytes).hexdigest()}"

async def verify_watermark(self, model) -> Dict[str, Any]:
"""Enhanced watermark verification with tamper detection."""
Expand Down Expand Up @@ -533,10 +540,17 @@ async def track_fingerprint(self, model: Any) -> Dict[str, Any]:
"""Track and return fingerprint information for a model."""
fingerprint = await self._generate_fingerprint(model)
await self.fingerprint_tracker.track(fingerprint)
model_id = hashlib.sha256(
json.dumps(
{"type": type(model).__name__, "repr": repr(model)},
sort_keys=True,
default=str
).encode("utf-8")
).hexdigest()[:16]
return {
"fingerprint": fingerprint,
"timestamp": datetime.now().isoformat(),
"model_id": str(hash(str(model)))
"model_id": model_id
}

class AdaptivePrivacy:
Expand Down
2 changes: 1 addition & 1 deletion multimind/compliance/advanced_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ComplianceShardConfig(BaseModel):
jurisdiction: str
epsilon: float = 1.0
rules: List[Dict[str, Any]]
metadata: Dict[str, Any] = {}
metadata: Dict[str, Any] = Field(default_factory=dict)
compliance_level: ComplianceLevel = ComplianceLevel.STANDARD
encryption_enabled: bool = True
metrics_tracking: bool = True
Expand Down
8 changes: 4 additions & 4 deletions multimind/compliance/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,16 @@ async def cleanup_old_events(self) -> int:

async def export_events(
self,
format: str = "json",
export_format: str = "json",
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None
) -> str:
"""Export audit events in specified format."""
events = await self.get_events(start_time=start_time, end_time=end_time)

if format == "json":
if export_format == "json":
return json.dumps([e.dict() for e in events], default=str)
elif format == "csv":
elif export_format == "csv":
import csv
import io
if not events:
Expand All @@ -156,7 +156,7 @@ async def export_events(
writer.writerow(row)
return output.getvalue()
else:
raise ValueError(f"Unsupported export format: {format}")
raise ValueError(f"Unsupported export format: {export_format}")

async def get_compliance_report(
self,
Expand Down
Loading
Loading