-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathstart_local_embedding.py
More file actions
90 lines (76 loc) · 2.18 KB
/
start_local_embedding.py
File metadata and controls
90 lines (76 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
"""
启动本地 EmbeddingGemma 模型服务
兼容 OpenAI Embedding API 格式
"""
import os
import sys
import asyncio
from typing import List
from fastapi import FastAPI
from pydantic import BaseModel
from contextlib import asynccontextmanager
# Model path
MODEL_PATH = os.path.expanduser("~/.node-llama-cpp/models/embeddinggemma-300M-Q8_0.gguf")
if not os.path.exists(MODEL_PATH):
print(f"Error: Model not found at {MODEL_PATH}")
sys.exit(1)
print(f"Loading model: {MODEL_PATH}")
try:
from llama_cpp import Llama
llm = Llama(model_path=MODEL_PATH, embedding=True, verbose=False)
print("Model loaded successfully!")
except ImportError:
print("Error: llama-cpp-python not installed")
print("Run: pip install llama-cpp-python")
sys.exit(1)
# FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Local embedding service started on http://localhost:8088")
yield
print("Shutting down...")
app = FastAPI(title="Local Embedding Service", lifespan=lifespan)
class EmbeddingRequest(BaseModel):
model: str = "embedding-gemma"
input: List[str]
class EmbeddingResponse(BaseModel):
object: str = "list"
data: List[dict]
model: str
@app.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest):
"""OpenAI-compatible embedding endpoint"""
embeddings = []
for i, text in enumerate(request.input):
# llama-cpp embedding
embed = llm.embed(text)
embeddings.append({
"object": "embedding",
"embedding": embed,
"index": i
})
return EmbeddingResponse(
data=embeddings,
model=request.model
)
@app.get("/v1/models")
async def list_models():
"""List available models"""
return {
"object": "list",
"data": [
{
"id": "embedding-gemma",
"object": "model",
"owned_by": "local"
}
]
}
@app.get("/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
print("Starting local embedding server on port 8088...")
uvicorn.run(app, host="0.0.0.0", port=8088)