Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions docs/optimizer_config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,155 @@
},
"title": "TokenizerConfig",
"type": "object"
},
"VllmEmbeddingConfig": {
"additionalProperties": false,
"description": "Configuration for vLLM-based embeddings.",
"properties": {
"default_prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Default prompt for the model. This is used when no task specific prompt is not provided.",
"title": "Default Prompt"
},
"classification_prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Prompt for classifier.",
"title": "Classification Prompt"
},
"cluster_prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Prompt for clustering.",
"title": "Cluster Prompt"
},
"sts_prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Prompt for finding most similar sentences.",
"title": "Sts Prompt"
},
"query_prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Prompt for query.",
"title": "Query Prompt"
},
"passage_prompt": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Prompt for passage.",
"title": "Passage Prompt"
},
"use_cache": {
"default": true,
"description": "Whether to use embeddings caching.",
"title": "Use Cache",
"type": "boolean"
},
"model_name": {
"default": "BAAI/bge-base-en-v1.5",
"description": "Name of the HuggingFace model to load via vLLM.",
"title": "Model Name",
"type": "string"
},
"batch_size": {
"default": 32,
"description": "Number of texts to encode per vLLM encode() call.",
"title": "Batch Size",
"type": "integer"
},
"max_model_len": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"description": "Maximum sequence length. Reduces VRAM usage for long-context models.",
"title": "Max Model Len"
},
"gpu_memory_utilization": {
"default": 0.9,
"description": "Fraction of GPU memory vLLM is allowed to use (0.0 to 1.0).",
"maximum": 1.0,
"minimum": 0.0,
"title": "Gpu Memory Utilization",
"type": "number"
},
"dtype": {
"default": "auto",
"description": "Data type for model weights: 'auto', 'float16', 'bfloat16', 'float32'.",
"title": "Dtype",
"type": "string"
},
"trust_remote_code": {
"default": false,
"description": "Whether to trust remote code when loading the model.",
"title": "Trust Remote Code",
"type": "boolean"
},
"extra_init_kwargs": {
"additionalProperties": true,
"description": "Extra keyword arguments passed to the vLLM LLM() constructor.",
"title": "Extra Init Kwargs",
"type": "object"
},
"extra_encode_kwargs": {
"additionalProperties": true,
"description": "Extra keyword arguments passed to llm.encode() at inference time (e.g. custom SamplingParams).",
"title": "Extra Encode Kwargs",
"type": "object"
}
},
"title": "VllmEmbeddingConfig",
"type": "object"
}
},
"description": "Configuration for the optimization process.\n\nOne can use it to customize optimization beyond choosing different preset.\nInstantiate it and pass to :py:meth:`autointent.Pipeline.from_optimization_config`.",
Expand Down Expand Up @@ -1019,6 +1168,9 @@
{
"$ref": "#/$defs/HashingVectorizerEmbeddingConfig"
},
{
"$ref": "#/$defs/VllmEmbeddingConfig"
},
{
"$ref": "#/$defs/BaseEmbedderConfig"
}
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ opensearch = [
openai = [
"openai (>=2,<3)",
]
vllm = [
"vllm>=0.20.0",
]

[tool.uv]
conflicts = [
Expand Down
2 changes: 2 additions & 0 deletions src/autointent/_wrappers/embedder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from .hashing_vectorizer import HashingVectorizerEmbeddingBackend
from .openai import OpenaiEmbeddingBackend
from .sentence_transformers import SentenceTransformerEmbeddingBackend
from .vllm import VllmEmbeddingBackend

__all__ = [
"BaseEmbeddingBackend",
"Embedder",
"HashingVectorizerEmbeddingBackend",
"OpenaiEmbeddingBackend",
"SentenceTransformerEmbeddingBackend",
"VllmEmbeddingBackend",
]
8 changes: 6 additions & 2 deletions src/autointent/_wrappers/embedder/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
HashingVectorizerEmbeddingConfig,
OpenaiEmbeddingConfig,
SentenceTransformerEmbeddingConfig,
VllmEmbeddingConfig,
)

from .hashing_vectorizer import HashingVectorizerEmbeddingBackend
from .openai import OpenaiEmbeddingBackend
from .sentence_transformers import SentenceTransformerEmbeddingBackend
from .vllm import VllmEmbeddingBackend

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -64,8 +66,8 @@ def _init_backend(self) -> BaseEmbeddingBackend:
return OpenaiEmbeddingBackend(self.config)
if isinstance(self.config, HashingVectorizerEmbeddingConfig):
return HashingVectorizerEmbeddingBackend(self.config)
# Check if it's exactly the abstract base config (not a subclass)

if isinstance(self.config, VllmEmbeddingConfig):
return VllmEmbeddingBackend(self.config)
msg = f"Cannot instantiate abstract EmbedderConfig: {self.config.__repr__()}"
raise TypeError(msg)

Expand Down Expand Up @@ -161,6 +163,8 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
instance._backend = OpenaiEmbeddingBackend.load(backend_path)
elif isinstance(config, HashingVectorizerEmbeddingConfig):
instance._backend = HashingVectorizerEmbeddingBackend.load(backend_path)
elif isinstance(config, VllmEmbeddingConfig):
instance._backend = VllmEmbeddingBackend.load(backend_path)
else:
msg = f"Cannot load abstract EmbedderConfig: {config.__repr__()}"
raise TypeError(msg)
Expand Down
Loading
Loading