-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathembeddings_example.py
More file actions
46 lines (34 loc) · 1.38 KB
/
embeddings_example.py
File metadata and controls
46 lines (34 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""
embeddings_example.py: CLI for embeddings and cosine similarity with ai-sdk.
Usage: embeddings_example.py --values VAL1 VAL2 [VAL3 ...] [--model MODEL_ID]
"""
import os
import argparse
try:
from dotenv import load_dotenv # type: ignore
except Exception: # pragma: no cover
def load_dotenv() -> None: # type: ignore
return None
from ai_sdk import openai, embed_many, cosine_similarity # type: ignore[attr-defined]
load_dotenv()
def main():
parser = argparse.ArgumentParser(description="Embeddings CLI using ai-sdk.")
parser.add_argument(
"--values", nargs="+", required=True, help="List of values to embed."
)
parser.add_argument(
"--model",
default=os.getenv("AI_SDK_EMBED_MODEL", "text-embedding-3-small"),
help="Embedding model ID.",
)
args = parser.parse_args()
model = openai.embedding(args.model, api_key=os.getenv("OPENAI_API_KEY")) # type: ignore[attr-defined]
res = embed_many(model=model, values=args.values)
for i, val in enumerate(args.values):
print(f"Value '{val}' embedding length: {len(res.embeddings[i])}")
# Compute pairwise similarity for first two
if len(res.embeddings) > 1:
sim = cosine_similarity(res.embeddings[0], res.embeddings[1])
print(f"Similarity between '{args.values[0]}' and '{args.values[1]}': {sim}")
if __name__ == "__main__":
main()