diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java index 7c7597c0c..9df1bc287 100644 --- a/src/it/java/io/weaviate/integration/AggregationITest.java +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -14,6 +14,7 @@ import io.weaviate.ConcurrentTest; import io.weaviate.client6.v1.api.WeaviateClient; import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.Vectorizers; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.aggregate.AggregateResponseGroup; import io.weaviate.client6.v1.api.collections.aggregate.AggregateResponseGrouped; @@ -21,8 +22,6 @@ import io.weaviate.client6.v1.api.collections.aggregate.GroupBy; import io.weaviate.client6.v1.api.collections.aggregate.GroupedBy; import io.weaviate.client6.v1.api.collections.aggregate.IntegerAggregation; -import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; -import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.containers.Container; public class AggregationITest extends ConcurrentTest { @@ -36,7 +35,7 @@ public static void beforeAll() throws IOException { .properties( Property.text("category"), Property.integer("price")) - .vector(Hnsw.of(new NoneVectorizer()))); + .vectors(Vectorizers.none())); var things = client.collections.use(COLLECTION); for (var category : List.of("Shoes", "Hat", "Jacket")) { diff --git a/src/it/java/io/weaviate/integration/CollectionsITest.java b/src/it/java/io/weaviate/integration/CollectionsITest.java index dcde8a399..8037deb84 100644 --- a/src/it/java/io/weaviate/integration/CollectionsITest.java +++ b/src/it/java/io/weaviate/integration/CollectionsITest.java @@ -12,7 +12,8 @@ import io.weaviate.client6.v1.api.collections.InvertedIndex; import io.weaviate.client6.v1.api.collections.Property; import io.weaviate.client6.v1.api.collections.Replication; -import io.weaviate.client6.v1.api.collections.VectorIndex; +import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.api.collections.Vectorizers; import io.weaviate.client6.v1.api.collections.config.Shard; import io.weaviate.client6.v1.api.collections.config.ShardStatus; import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; @@ -28,18 +29,18 @@ public void testCreateGetDelete() throws IOException { client.collections.create(collectionName, col -> col .properties(Property.text("username"), Property.integer("age")) - .vector(Hnsw.of(new NoneVectorizer()))); + .vectors(Vectorizers.none())); var thingsCollection = client.collections.getConfig(collectionName); Assertions.assertThat(thingsCollection).get() .hasFieldOrPropertyWithValue("collectionName", collectionName) - .extracting(CollectionConfig::vectors, InstanceOfAssertFactories.map(String.class, VectorIndex.class)) + .extracting(CollectionConfig::vectors, InstanceOfAssertFactories.map(String.class, Vectorizer.class)) .as("default vector").extractingByKey("default") .satisfies(defaultVector -> { - Assertions.assertThat(defaultVector).extracting(VectorIndex::vectorizer) + Assertions.assertThat(defaultVector) .as("has none vectorizer").isInstanceOf(NoneVectorizer.class); - Assertions.assertThat(defaultVector).extracting(VectorIndex::config) + Assertions.assertThat(defaultVector).extracting(Vectorizer::vectorIndex) .isInstanceOf(Hnsw.class); }); diff --git a/src/it/java/io/weaviate/integration/DataITest.java b/src/it/java/io/weaviate/integration/DataITest.java index e4d1338f9..827236163 100644 --- a/src/it/java/io/weaviate/integration/DataITest.java +++ b/src/it/java/io/weaviate/integration/DataITest.java @@ -11,6 +11,7 @@ import io.weaviate.ConcurrentTest; import io.weaviate.client6.v1.api.WeaviateClient; import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.Vectorizers; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.data.BatchReference; @@ -20,8 +21,6 @@ import io.weaviate.client6.v1.api.collections.query.QueryMetadata; import io.weaviate.client6.v1.api.collections.query.QueryReference; import io.weaviate.client6.v1.api.collections.query.Where; -import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; -import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.containers.Container; public class DataITest extends ConcurrentTest { @@ -107,7 +106,7 @@ private static void createTestCollections() throws IOException { Property.integer("age")) .references( Property.reference("hasAwards", awardsGrammy, awardsOscar)) - .vectors(named -> named.vector(VECTOR_INDEX, Hnsw.of(new NoneVectorizer())))); + .vectors(Vectorizers.none(VECTOR_INDEX))); } @Test @@ -223,7 +222,7 @@ public void testUpdate() throws IOException { collection -> collection .properties(Property.text("title"), Property.integer("year")) .references(Property.reference("writtenBy", nsAuthors)) - .vector(Hnsw.of(new NoneVectorizer()))); + .vectors(Vectorizers.none())); var authors = client.collections.use(nsAuthors); var walter = authors.data.insert(Map.of("name", "walter scott")); diff --git a/src/it/java/io/weaviate/integration/SearchITest.java b/src/it/java/io/weaviate/integration/SearchITest.java index 4d69a82fa..3f67240a5 100644 --- a/src/it/java/io/weaviate/integration/SearchITest.java +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -18,6 +18,7 @@ import io.weaviate.ConcurrentTest; import io.weaviate.client6.v1.api.WeaviateClient; import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.Vectorizers; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateMetadata; import io.weaviate.client6.v1.api.collections.WeaviateObject; @@ -27,10 +28,6 @@ import io.weaviate.client6.v1.api.collections.query.QueryMetadata; import io.weaviate.client6.v1.api.collections.query.QueryResponseGroup; import io.weaviate.client6.v1.api.collections.query.Where; -import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; -import io.weaviate.client6.v1.api.collections.vectorizers.Img2VecNeuralVectorizer; -import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; -import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecContextionaryVectorizer; import io.weaviate.containers.Container; import io.weaviate.containers.Container.ContainerGroup; import io.weaviate.containers.Contextionary; @@ -133,7 +130,7 @@ private static Map populateTest(int n) throws IOException { private static void createTestCollection() throws IOException { client.collections.create(COLLECTION, cfg -> cfg .properties(Property.text("category")) - .vector(VECTOR_INDEX, Hnsw.of(new NoneVectorizer()))); + .vectors(Vectorizers.none(VECTOR_INDEX))); } @Test @@ -142,7 +139,7 @@ public void testNearText() throws IOException { client.collections.create(nsSongs, col -> col .properties(Property.text("title")) - .vector(Hnsw.of(Text2VecContextionaryVectorizer.of()))); + .vectors(Vectorizers.text2vecContextionary())); var songs = client.collections.use(nsSongs); var submarine = songs.data.insert(Map.of("title", "Yellow Submarine")); @@ -164,13 +161,13 @@ public void testNearText() throws IOException { @Test public void testNearText_groupBy() throws IOException { - var vectorIndex = Hnsw.of(Text2VecContextionaryVectorizer.of()); + var vectorizer = Vectorizers.text2vecContextionary(); var nsArtists = ns("Artists"); client.collections.create(nsArtists, col -> col .properties(Property.text("name")) - .vector(vectorIndex)); + .vectors(vectorizer)); var artists = client.collections.use(nsArtists); var beatles = artists.data.insert(Map.of("name", "Beatles")); @@ -181,7 +178,7 @@ public void testNearText_groupBy() throws IOException { col -> col .properties(Property.text("title")) .references(Property.reference("performedBy", nsArtists)) - .vector(vectorIndex)); + .vectors(vectorizer)); var songs = client.collections.use(nsSongs); songs.data.insert(Map.of("title", "Yellow Submarine"), @@ -208,9 +205,8 @@ public void testNearImage() throws IOException { .properties( Property.text("breed"), Property.blob("img")) - .vector(Hnsw.of( - Img2VecNeuralVectorizer.of( - i2v -> i2v.imageFields("img"))))); + .vectors(Vectorizers.img2vecNeural( + i2v -> i2v.imageFields("img")))); var cats = client.collections.use(nsCats); cats.data.insert(Map.of( @@ -325,7 +321,7 @@ public void testNearObject() throws IOException { client.collections.create(nsAnimals, collection -> collection .properties(Property.text("kind")) - .vector(Hnsw.of(Text2VecContextionaryVectorizer.of()))); + .vectors(Vectorizers.text2vecContextionary())); var animals = client.collections.use(nsAnimals); @@ -354,7 +350,7 @@ public void testHybrid() throws IOException { client.collections.create(nsHobbies, collection -> collection .properties(Property.text("name"), Property.text("description")) - .vector(Hnsw.of(Text2VecContextionaryVectorizer.of()))); + .vectors(Vectorizers.text2vecContextionary())); var hobbies = client.collections.use(nsHobbies); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java index 39820b34f..8381623be 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionConfig.java @@ -27,7 +27,7 @@ public record CollectionConfig( @SerializedName("description") String description, @SerializedName("properties") List properties, List references, - @SerializedName("vectorConfig") Map vectors, + @SerializedName("vectorConfig") Map vectors, @SerializedName("multiTenancyConfig") MultiTenancy multiTenancy, @SerializedName("shardingConfig") Sharding sharding, @SerializedName("replicationConfig") Replication replication, @@ -88,7 +88,7 @@ public static class Builder implements ObjectBuilder { private String description; private Map properties = new HashMap<>(); private Map references = new HashMap<>(); - private Map vectors = new HashMap<>(); + private Map vectors = new HashMap<>(); private MultiTenancy multiTenancy; private Sharding sharding; private Replication replication; @@ -131,23 +131,14 @@ private List referenceList() { return this.references.values().stream().toList(); } - public Builder vector(VectorIndex vector) { - this.vectors.put(VectorIndex.DEFAULT_VECTOR_NAME, vector); - return this; - } - - public Builder vector(String name, VectorIndex vector) { - this.vectors.put(name, vector); - return this; - } - - public Builder vectors(Map vectors) { + public final Builder vectors(Map vectors) { this.vectors.putAll(vectors); return this; } - public Builder vectors(Function>> fn) { - this.vectors = fn.apply(new VectorsBuilder()).build(); + @SafeVarargs + public final Builder vectors(Map.Entry... vectors) { + this.vectors.putAll(Map.ofEntries(vectors)); return this; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java b/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java index 05f535ad9..4da02656c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java @@ -19,6 +19,7 @@ public interface VectorIndex { static final String DEFAULT_VECTOR_NAME = "default"; + static final VectorIndex DEFAULT_VECTOR_INDEX = Hnsw.of(); public enum Kind implements JsonEnum { HNSW("hnsw"), @@ -48,8 +49,6 @@ default String type() { return _kind().jsonValue(); } - Vectorizer vectorizer(); - Object config(); public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { @@ -79,7 +78,6 @@ public TypeAdapter create(Gson gson, TypeToken type) { init(gson); } - final var vectorizerAdapter = gson.getDelegateAdapter(this, TypeToken.get(Vectorizer.class)); final var writeAdapter = gson.getDelegateAdapter(this, TypeToken.get(rawType)); return (TypeAdapter) new TypeAdapter() { @@ -89,13 +87,11 @@ public void write(JsonWriter out, VectorIndex value) throws IOException { out.name("vectorIndexType"); out.value(value._kind().jsonValue()); - var config = writeAdapter.toJsonTree((T) value.config()); - config.getAsJsonObject().remove("vectorizer"); out.name("vectorIndexConfig"); + var config = writeAdapter.toJsonTree((T) value.config()); + config.getAsJsonObject().remove("name"); Streams.write(config, out); - out.name("vectorizer"); - vectorizerAdapter.write(out, value.vectorizer()); out.endObject(); } @@ -117,7 +113,6 @@ public VectorIndex read(JsonReader in) throws IOException { } var config = jsonObject.get("vectorIndexConfig").getAsJsonObject(); - config.add("vectorizer", jsonObject.get("vectorizer")); return adapter.fromJsonTree(config); } }.nullSafe(); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java index 7ca6568ab..b5b6c68bb 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java @@ -3,13 +3,16 @@ import java.io.IOException; import java.util.EnumMap; import java.util.Map; +import java.util.function.Function; import com.google.gson.Gson; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; import com.google.gson.TypeAdapter; import com.google.gson.TypeAdapterFactory; +import com.google.gson.internal.Streams; import com.google.gson.reflect.TypeToken; import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonToken; import com.google.gson.stream.JsonWriter; import io.weaviate.client6.v1.api.collections.vectorizers.Img2VecNeuralVectorizer; @@ -17,6 +20,7 @@ import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecContextionaryVectorizer; import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecWeaviateVectorizer; +import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.json.JsonEnum; public interface Vectorizer { @@ -48,6 +52,8 @@ public static Kind valueOfJson(String jsonValue) { Object _self(); + VectorIndex vectorIndex(); + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { INSTANCE; @@ -84,28 +90,48 @@ public TypeAdapter create(Gson gson, TypeToken type) { public void write(JsonWriter out, Vectorizer value) throws IOException { TypeAdapter adapter = (TypeAdapter) delegateAdapters.get(value._kind()); - out.beginObject(); - out.name(value._kind().jsonValue()); - adapter.write(out, (T) value._self()); - out.endObject(); + // Serialize vectorizer config as { "vectorizer-kind": { ... } } + // and remove "vectorIndex" object which every vectorizer has. + var vectorizer = new JsonObject(); + var config = adapter.toJsonTree((T) value._self()); + + // This will create { "vectorIndexType": "", "vectorIndexConfig": { ... } } + // to which we just need to add "vectorizer": { ... } key. + var vectorIndex = config.getAsJsonObject().remove("vectorIndex"); + + vectorizer.add(value._kind().jsonValue(), config); + vectorIndex.getAsJsonObject().add("vectorizer", vectorizer); + + Streams.write(vectorIndex, out); } @Override public Vectorizer read(JsonReader in) throws IOException { - in.beginObject(); - var vectorizerName = in.nextName(); + var jsonObject = JsonParser.parseReader(in).getAsJsonObject(); + + // VectorIndex.CustomTypeAdapterFactory expects keys + // ["vectorIndexType", "vectorIndexConfig"]. + var vectorIndex = new JsonObject(); + vectorIndex.add("vectorIndexType", jsonObject.get("vectorIndexType")); + vectorIndex.add("vectorIndexConfig", jsonObject.get("vectorIndexConfig")); + + var vectorizerObject = jsonObject.get("vectorizer").getAsJsonObject(); + var vectorizerName = vectorizerObject.keySet().iterator().next(); + + Vectorizer.Kind kind; try { - var kind = Vectorizer.Kind.valueOfJson(vectorizerName); - var adapter = delegateAdapters.get(kind); - return adapter.read(in); + kind = Vectorizer.Kind.valueOfJson(vectorizerName); } catch (IllegalArgumentException e) { return null; - } finally { - if (in.peek() == JsonToken.BEGIN_OBJECT) { - in.beginObject(); - } - in.endObject(); } + + var adapter = delegateAdapters.get(kind); + var concreteVectorizer = vectorizerObject.get(vectorizerName).getAsJsonObject(); + + // Each individual vectorizer has a `VectorIndex vectorIndex` field. + concreteVectorizer.add("vectorIndex", vectorIndex); + + return adapter.fromJsonTree(concreteVectorizer); } }.nullSafe(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizers.java b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizers.java new file mode 100644 index 000000000..9081b2674 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizers.java @@ -0,0 +1,105 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.Map; +import java.util.function.Function; + +import io.weaviate.client6.v1.api.collections.vectorizers.Img2VecNeuralVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Multi2VecClipVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecContextionaryVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecWeaviateVectorizer; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +/** Static methods for creating instances of {@link Vectorizer}. */ +public final class Vectorizers { + + public static Map.Entry none() { + return none(VectorIndex.DEFAULT_VECTOR_NAME); + } + + public static Map.Entry none( + Function> fn) { + return none(VectorIndex.DEFAULT_VECTOR_NAME, fn); + } + + public static Map.Entry none(String vectorName) { + return Map.entry(vectorName, NoneVectorizer.of()); + } + + public static Map.Entry none(String vectorName, + Function> fn) { + return Map.entry(vectorName, NoneVectorizer.of(fn)); + } + + public static Map.Entry img2vecNeural() { + return img2vecNeural(VectorIndex.DEFAULT_VECTOR_NAME); + } + + public static Map.Entry img2vecNeural( + Function> fn) { + return img2vecNeural(VectorIndex.DEFAULT_VECTOR_NAME, fn); + } + + public static Map.Entry img2vecNeural(String vectorName) { + return Map.entry(vectorName, Img2VecNeuralVectorizer.of()); + } + + public static Map.Entry img2vecNeural(String vectorName, + Function> fn) { + return Map.entry(vectorName, Img2VecNeuralVectorizer.of(fn)); + } + + public static Map.Entry multi2vecClip() { + return multi2vecClip(VectorIndex.DEFAULT_VECTOR_NAME); + } + + public static Map.Entry multi2vecClip( + Function> fn) { + return multi2vecClip(VectorIndex.DEFAULT_VECTOR_NAME, fn); + } + + public static Map.Entry multi2vecClip(String vectorName) { + return Map.entry(vectorName, Multi2VecClipVectorizer.of()); + } + + public static Map.Entry multi2vecClip(String vectorName, + Function> fn) { + return Map.entry(vectorName, Multi2VecClipVectorizer.of(fn)); + } + + public static Map.Entry text2vecContextionary() { + return text2vecContextionary(VectorIndex.DEFAULT_VECTOR_NAME); + } + + public static Map.Entry text2vecContextionary( + Function> fn) { + return text2vecContextionary(VectorIndex.DEFAULT_VECTOR_NAME, fn); + } + + public static Map.Entry text2vecContextionary(String vectorName) { + return Map.entry(vectorName, Text2VecContextionaryVectorizer.of()); + } + + public static Map.Entry text2vecContextionary(String vectorName, + Function> fn) { + return Map.entry(vectorName, Text2VecContextionaryVectorizer.of(fn)); + } + + public static Map.Entry text2VecWeaviate() { + return text2VecWeaviate(VectorIndex.DEFAULT_VECTOR_NAME); + } + + public static Map.Entry text2VecWeaviate( + Function> fn) { + return text2VecWeaviate(VectorIndex.DEFAULT_VECTOR_NAME, fn); + } + + public static Map.Entry text2VecWeaviate(String vectorName) { + return Map.entry(vectorName, Text2VecWeaviateVectorizer.of()); + } + + public static Map.Entry text2VecWeaviate(String vectorName, + Function> fn) { + return Map.entry(vectorName, Text2VecWeaviateVectorizer.of(fn)); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java index 3b233f01f..41b8b5aea 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/UpdateCollectionRequest.java @@ -12,7 +12,7 @@ import io.weaviate.client6.v1.api.collections.InvertedIndex; import io.weaviate.client6.v1.api.collections.Replication; import io.weaviate.client6.v1.api.collections.Reranker; -import io.weaviate.client6.v1.api.collections.VectorIndex; +import io.weaviate.client6.v1.api.collections.Vectorizer; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.json.JSON; import io.weaviate.client6.v1.internal.rest.Endpoint; @@ -98,9 +98,14 @@ public Builder generativeModule(Generative generativeModule) { return this; } + public final Builder vectors(Map vectors) { + this.newCollection.vectors(vectors); + return this; + } + @SafeVarargs - public final Builder vectors(Map.Entry... vectors) { - this.newCollection.vectors(Map.ofEntries(vectors)); + public final Builder vectors(Map.Entry... vectors) { + this.newCollection.vectors(vectors); return this; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/BaseVectorIndex.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/BaseVectorIndex.java deleted file mode 100644 index 49ed116c2..000000000 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/BaseVectorIndex.java +++ /dev/null @@ -1,19 +0,0 @@ -package io.weaviate.client6.v1.api.collections.vectorindex; - -import io.weaviate.client6.v1.api.collections.VectorIndex; -import io.weaviate.client6.v1.api.collections.Vectorizer; -import lombok.EqualsAndHashCode; - -@EqualsAndHashCode -abstract class BaseVectorIndex implements VectorIndex { - protected final Vectorizer vectorizer; - - @Override - public Vectorizer vectorizer() { - return this.vectorizer; - } - - public BaseVectorIndex(Vectorizer vectorizer) { - this.vectorizer = vectorizer; - } -} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java index 90ca1990c..92069553d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java @@ -5,16 +5,10 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.VectorIndex; -import io.weaviate.client6.v1.api.collections.Vectorizer; import io.weaviate.client6.v1.internal.ObjectBuilder; -import lombok.EqualsAndHashCode; -import lombok.ToString; -@EqualsAndHashCode(callSuper = true) -@ToString -public class Flat extends BaseVectorIndex { - @SerializedName("vectorCacheMaxObjects") - Long vectorCacheMaxObjects; +public record Flat(@SerializedName("vectorCacheMaxObjects") Long vectorCacheMaxObjects) + implements VectorIndex { @Override public VectorIndex.Kind _kind() { @@ -26,29 +20,22 @@ public Object config() { return this; } - public static Flat of(Vectorizer vectorizer) { - return of(vectorizer, ObjectBuilder.identity()); + public static Flat of() { + return of(ObjectBuilder.identity()); } - public static Flat of(Vectorizer vectorizer, Function> fn) { - return fn.apply(new Builder(vectorizer)).build(); + public static Flat of(Function> fn) { + return fn.apply(new Builder()).build(); } public Flat(Builder builder) { - super(builder.vectorizer); - this.vectorCacheMaxObjects = builder.vectorCacheMaxObjects; + this(builder.vectorCacheMaxObjects); } public static class Builder implements ObjectBuilder { - // Required parameters. - private final Vectorizer vectorizer; private Long vectorCacheMaxObjects; - protected Builder(Vectorizer vectorizer) { - this.vectorizer = vectorizer; - } - public Builder vectorCacheMaxObjects(long vectorCacheMaxObjects) { this.vectorCacheMaxObjects = vectorCacheMaxObjects; return this; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java index 4538ad9b4..a06f1e652 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java @@ -5,39 +5,22 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.VectorIndex; -import io.weaviate.client6.v1.api.collections.Vectorizer; import io.weaviate.client6.v1.internal.ObjectBuilder; -import lombok.EqualsAndHashCode; -import lombok.ToString; - -@EqualsAndHashCode(callSuper = true) -@ToString -public class Hnsw extends BaseVectorIndex { - @SerializedName("distance") - private final Distance distance; - @SerializedName("ef") - private final Integer ef; - @SerializedName("efConstruction") - private final Integer efConstruction; - @SerializedName("maxConnections") - private final Integer maxConnections; - @SerializedName("vectorCacheMaxObjects") - private final Long vectorCacheMaxObjects; - @SerializedName("cleanupIntervalSeconds") - private final Integer cleanupIntervalSeconds; - @SerializedName("filterStrategy") - private final FilterStrategy filterStrategy; - - @SerializedName("dynamicEfMin") - private final Integer dynamicEfMin; - @SerializedName("dynamicEfMax") - private final Integer dynamicEfMax; - @SerializedName("dynamicEfFactor") - private final Integer dynamicEfFactor; - @SerializedName("flatSearchCutoff") - private final Integer flatSearchCutoff; - @SerializedName("skip") - Boolean skipVectorization; + +public record Hnsw( + @SerializedName("distance") Distance distance, + @SerializedName("ef") Integer ef, + @SerializedName("efConstruction") Integer efConstruction, + @SerializedName("maxConnections") Integer maxConnections, + @SerializedName("vectorCacheMaxObjects") Long vectorCacheMaxObjects, + @SerializedName("cleanupIntervalSeconds") Integer cleanupIntervalSeconds, + @SerializedName("filterStrategy") FilterStrategy filterStrategy, + + @SerializedName("dynamicEfMin") Integer dynamicEfMin, + @SerializedName("dynamicEfMax") Integer dynamicEfMax, + @SerializedName("dynamicEfFactor") Integer dynamicEfFactor, + @SerializedName("flatSearchCutoff") Integer flatSearchCutoff, + @SerializedName("skip") Boolean skipVectorization) implements VectorIndex { @Override public VectorIndex.Kind _kind() { @@ -49,39 +32,31 @@ public Object config() { return this; } - @Override - public Vectorizer vectorizer() { - return this.vectorizer; - } - - public static Hnsw of(Vectorizer vectorizer) { - return of(vectorizer, ObjectBuilder.identity()); + public static Hnsw of() { + return of(ObjectBuilder.identity()); } - public static Hnsw of(Vectorizer vectorizer, Function> fn) { - return fn.apply(new Builder(vectorizer)).build(); + public static Hnsw of(Function> fn) { + return fn.apply(new Builder()).build(); } public Hnsw(Builder builder) { - super(builder.vectorizer); - this.distance = builder.distance; - this.ef = builder.ef; - this.efConstruction = builder.efConstruction; - this.maxConnections = builder.maxConnections; - this.vectorCacheMaxObjects = builder.vectorCacheMaxObjects; - this.cleanupIntervalSeconds = builder.cleanupIntervalSeconds; - this.filterStrategy = builder.filterStrategy; - this.dynamicEfMin = builder.dynamicEfMin; - this.dynamicEfMax = builder.dynamicEfMax; - this.dynamicEfFactor = builder.dynamicEfFactor; - this.flatSearchCutoff = builder.flatSearchCutoff; - this.skipVectorization = builder.skipVectorization; + this( + builder.distance, + builder.ef, + builder.efConstruction, + builder.maxConnections, + builder.vectorCacheMaxObjects, + builder.cleanupIntervalSeconds, + builder.filterStrategy, + builder.dynamicEfMin, + builder.dynamicEfMax, + builder.dynamicEfFactor, + builder.flatSearchCutoff, + builder.skipVectorization); } public static class Builder implements ObjectBuilder { - // Required parameters. - private final Vectorizer vectorizer; - private Distance distance; private Integer ef; private Integer efConstruction; @@ -96,10 +71,6 @@ public static class Builder implements ObjectBuilder { private Integer flatSearchCutoff; private Boolean skipVectorization; - public Builder(Vectorizer vectorizer) { - this.vectorizer = vectorizer; - } - public Builder distance(Distance distance) { this.distance = distance; return this; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Img2VecNeuralVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Img2VecNeuralVectorizer.java index 2d9ff6beb..7f5a28a8e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Img2VecNeuralVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Img2VecNeuralVectorizer.java @@ -7,11 +7,13 @@ import com.google.gson.annotations.SerializedName; +import io.weaviate.client6.v1.api.collections.VectorIndex; import io.weaviate.client6.v1.api.collections.Vectorizer; import io.weaviate.client6.v1.internal.ObjectBuilder; public record Img2VecNeuralVectorizer( - @SerializedName("imageFields") List imageFields) implements Vectorizer { + @SerializedName("imageFields") List imageFields, + VectorIndex vectorIndex) implements Vectorizer { @Override public Vectorizer.Kind _kind() { @@ -32,10 +34,11 @@ public static Img2VecNeuralVectorizer of(Function { + private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX; private List imageFields = new ArrayList<>(); public Builder imageFields(List fields) { @@ -47,6 +50,11 @@ public Builder imageFields(String... fields) { return imageFields(Arrays.asList(fields)); } + public Builder vectorIndex(VectorIndex vectorIndex) { + this.vectorIndex = vectorIndex; + return this; + } + @Override public Img2VecNeuralVectorizer build() { return new Img2VecNeuralVectorizer(this); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecClipVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecClipVectorizer.java index 945984cc4..440ce7a31 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecClipVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecClipVectorizer.java @@ -8,6 +8,7 @@ import com.google.gson.annotations.SerializedName; +import io.weaviate.client6.v1.api.collections.VectorIndex; import io.weaviate.client6.v1.api.collections.Vectorizer; import io.weaviate.client6.v1.internal.ObjectBuilder; @@ -16,7 +17,8 @@ public record Multi2VecClipVectorizer( @SerializedName("inferenceUrl") String inferenceUrl, @SerializedName("imageFields") List imageFields, @SerializedName("textFields") List textFields, - @SerializedName("weights") Weights weights) implements Vectorizer { + @SerializedName("weights") Weights weights, + VectorIndex vectorIndex) implements Vectorizer { private static record Weights( @SerializedName("imageWeights") List imageWeights, @@ -49,10 +51,12 @@ public Multi2VecClipVectorizer(Builder builder) { builder.textFields.keySet().stream().toList(), new Weights( builder.imageFields.values().stream().toList(), - builder.textFields.values().stream().toList())); + builder.textFields.values().stream().toList()), + builder.vectorIndex); } public static class Builder implements ObjectBuilder { + private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX; private boolean vectorizeCollectionName = false; private String inferenceUrl; private Map imageFields = new HashMap<>(); @@ -96,6 +100,11 @@ public Builder vectorizeCollectionName(boolean enable) { return this; } + public Builder vectorIndex(VectorIndex vectorIndex) { + this.vectorIndex = vectorIndex; + return this; + } + @Override public Multi2VecClipVectorizer build() { return new Multi2VecClipVectorizer(this); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/NoneVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/NoneVectorizer.java index 6449ba89b..c75f1c8dd 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/NoneVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/NoneVectorizer.java @@ -1,15 +1,13 @@ package io.weaviate.client6.v1.api.collections.vectorizers; -import java.io.IOException; - -import com.google.gson.TypeAdapter; -import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonToken; -import com.google.gson.stream.JsonWriter; +import java.util.function.Function; +import io.weaviate.client6.v1.api.collections.VectorIndex; import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; +import io.weaviate.client6.v1.internal.ObjectBuilder; -public record NoneVectorizer() implements Vectorizer { +public record NoneVectorizer(VectorIndex vectorIndex) implements Vectorizer { @Override public Kind _kind() { return Vectorizer.Kind.NONE; @@ -20,26 +18,29 @@ public Object _self() { return this; } - public static final TypeAdapter TYPE_ADAPTER = new TypeAdapter() { + public static NoneVectorizer of() { + return of(ObjectBuilder.identity()); + } - @Override - public void write(JsonWriter out, NoneVectorizer value) throws IOException { - out.beginObject(); - out.name(value._kind().jsonValue()); - out.beginObject(); - out.endObject(); - out.endObject(); + public static NoneVectorizer of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public NoneVectorizer(Builder builder) { + this(builder.vectorIndex); + } + + public static class Builder implements ObjectBuilder { + private VectorIndex vectorIndex = Hnsw.of(); + + public Builder vectorIndex(VectorIndex vectorIndex) { + this.vectorIndex = vectorIndex; + return this; } @Override - public NoneVectorizer read(JsonReader in) throws IOException { - // NoneVectorizer expects no parameters, so we just skip to the closing bracket. - in.beginObject(); - while (in.peek() != JsonToken.END_OBJECT) { - in.skipValue(); - } - in.endObject(); - return new NoneVectorizer(); + public NoneVectorizer build() { + return new NoneVectorizer(this); } - }.nullSafe(); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecContextionaryVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecContextionaryVectorizer.java index 7bbfc6c9c..aa2550e30 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecContextionaryVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecContextionaryVectorizer.java @@ -4,11 +4,13 @@ import com.google.gson.annotations.SerializedName; +import io.weaviate.client6.v1.api.collections.VectorIndex; import io.weaviate.client6.v1.api.collections.Vectorizer; import io.weaviate.client6.v1.internal.ObjectBuilder; public record Text2VecContextionaryVectorizer( - @SerializedName("vectorizeClassName") boolean vectorizeCollectionName) implements Vectorizer { + @SerializedName("vectorizeClassName") boolean vectorizeCollectionName, + VectorIndex vectorIndex) implements Vectorizer { @Override public Vectorizer.Kind _kind() { @@ -30,10 +32,11 @@ public static Text2VecContextionaryVectorizer of( } public Text2VecContextionaryVectorizer(Builder builder) { - this(builder.vectorizeCollectionName); + this(builder.vectorizeCollectionName, builder.vectorIndex); } public static class Builder implements ObjectBuilder { + private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX; private boolean vectorizeCollectionName = false; public Builder vectorizeCollectionName(boolean enable) { @@ -41,6 +44,11 @@ public Builder vectorizeCollectionName(boolean enable) { return this; } + public Builder vectorIndex(VectorIndex vectorIndex) { + this.vectorIndex = vectorIndex; + return this; + } + public Text2VecContextionaryVectorizer build() { return new Text2VecContextionaryVectorizer(this); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecWeaviateVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecWeaviateVectorizer.java index 134a6513a..5d50ade0b 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecWeaviateVectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecWeaviateVectorizer.java @@ -4,6 +4,7 @@ import com.google.gson.annotations.SerializedName; +import io.weaviate.client6.v1.api.collections.VectorIndex; import io.weaviate.client6.v1.api.collections.Vectorizer; import io.weaviate.client6.v1.internal.ObjectBuilder; @@ -11,7 +12,8 @@ public record Text2VecWeaviateVectorizer( @SerializedName("vectorizeClassName") boolean vectorizeCollectionName, @SerializedName("baseUrl") String inferenceUrl, @SerializedName("dimensions") Integer dimensions, - @SerializedName("model") String model) implements Vectorizer { + @SerializedName("model") String model, + VectorIndex vectorIndex) implements Vectorizer { @Override public Vectorizer.Kind _kind() { @@ -35,13 +37,15 @@ public Text2VecWeaviateVectorizer(Builder builder) { this(builder.vectorizeCollectionName, builder.inferenceUrl, builder.dimensions, - builder.model); + builder.model, + builder.vectorIndex); } public static final String SNOWFLAKE_ARCTIC_EMBED_L_20 = "Snowflake/snowflake-arctic-embed-l-v2.0"; public static final String SNOWFLAKE_ARCTIC_EMBED_M_15 = "Snowflake/snowflake-arctic-embed-m-v1.5"; public static class Builder implements ObjectBuilder { + private VectorIndex vectorIndex = VectorIndex.DEFAULT_VECTOR_INDEX; private boolean vectorizeCollectionName = false; private String inferenceUrl; private Integer dimensions; @@ -67,6 +71,11 @@ public Builder model(String model) { return this; } + public Builder vectorIndex(VectorIndex vectorIndex) { + this.vectorIndex = vectorIndex; + return this; + } + public Text2VecWeaviateVectorizer build() { return new Text2VecWeaviateVectorizer(this); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java index c58fa7072..ca52d4c2d 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java +++ b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java @@ -27,9 +27,6 @@ public final class JSON { io.weaviate.client6.v1.api.collections.Generative.CustomTypeAdapterFactory.INSTANCE); // TypeAdapters ----------------------------------------------------------- - gsonBuilder.registerTypeAdapter( - io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer.class, - io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer.TYPE_ADAPTER); gsonBuilder.registerTypeAdapter( io.weaviate.client6.v1.api.collections.data.Reference.class, io.weaviate.client6.v1.api.collections.data.Reference.TYPE_ADAPTER); diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index e76516bcd..82a205ae9 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -19,8 +19,8 @@ import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Property; import io.weaviate.client6.v1.api.collections.Reranker; -import io.weaviate.client6.v1.api.collections.VectorIndex; import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.api.collections.Vectorizers; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.data.BatchReference; @@ -44,17 +44,29 @@ public static Object[][] testCases() { // Vectorizer.CustomTypeAdapterFactory { Vectorizer.class, - new NoneVectorizer(), - "{\"none\": {}}", + NoneVectorizer.of(), + """ + { + "vectorIndexType": "hnsw", + "vectorIndexConfig": {}, + "vectorizer": {"none": {}} + } + """, }, { Vectorizer.class, Img2VecNeuralVectorizer.of(i2v -> i2v.imageFields("jpeg", "png")), """ - {"img2vec-neural": { - "imageFields": ["jpeg", "png"] - }} - """, + { + "vectorIndexType": "hnsw", + "vectorIndexConfig": {}, + "vectorizer": { + "img2vec-neural": { + "imageFields": ["jpeg", "png"] + } + } + } + """, }, { Vectorizer.class, @@ -64,27 +76,39 @@ public static Object[][] testCases() { .textField("txt", 2f) .vectorizeCollectionName(true)), """ - {"multi2vec-clip": { - "inferenceUrl": "http://example.com", - "vectorizeClassName": true, - "imageFields": ["img"], - "textFields": ["txt"], - "weights": { - "imageWeights": [1.0], - "textWeights": [2.0] + { + "vectorIndexType": "hnsw", + "vectorIndexConfig": {}, + "vectorizer": { + "multi2vec-clip": { + "inferenceUrl": "http://example.com", + "vectorizeClassName": true, + "imageFields": ["img"], + "textFields": ["txt"], + "weights": { + "imageWeights": [1.0], + "textWeights": [2.0] + } + } } - }} - """, + } + """, }, { Vectorizer.class, Text2VecContextionaryVectorizer.of(t2v -> t2v .vectorizeCollectionName(true)), """ - {"text2vec-contextionary": { - "vectorizeClassName": true - }} - """, + { + "vectorIndexType": "hnsw", + "vectorIndexConfig": {}, + "vectorizer": { + "text2vec-contextionary": { + "vectorizeClassName": true + } + } + } + """, }, { Vectorizer.class, @@ -94,20 +118,27 @@ public static Object[][] testCases() { .model("very-good-model") .vectorizeCollectionName(true)), """ - {"text2vec-weaviate": { - "baseUrl": "http://example.com", - "vectorizeClassName": true, - "dimensions": 4, - "model": "very-good-model" - }} - """, + { + "vectorIndexType": "hnsw", + "vectorIndexConfig": {}, + "vectorizer": { + "text2vec-weaviate": { + "baseUrl": "http://example.com", + "vectorizeClassName": true, + "dimensions": 4, + "model": "very-good-model" + } + } + } + """, }, // VectorIndex.CustomTypeAdapterFactory { - VectorIndex.class, - Flat.of(new NoneVectorizer(), flat -> flat - .vectorCacheMaxObjects(100)), + Vectorizer.class, + NoneVectorizer.of(none -> none + .vectorIndex(Flat.of(flat -> flat + .vectorCacheMaxObjects(100)))), """ { "vectorIndexType": "flat", @@ -117,20 +148,21 @@ public static Object[][] testCases() { """, }, { - VectorIndex.class, - Hnsw.of(new NoneVectorizer(), hnsw -> hnsw - .distance(Distance.DOT) - .ef(1) - .efConstruction(2) - .maxConnections(3) - .vectorCacheMaxObjects(4) - .cleanupIntervalSeconds(5) - .dynamicEfMin(6) - .dynamicEfMax(7) - .dynamicEfFactor(8) - .flatSearchCutoff(9) - .skipVectorization(true) - .filterStrategy(Hnsw.FilterStrategy.ACORN)), + Vectorizer.class, + NoneVectorizer.of(none -> none + .vectorIndex(Hnsw.of(hnsw -> hnsw + .distance(Distance.DOT) + .ef(1) + .efConstruction(2) + .maxConnections(3) + .vectorCacheMaxObjects(4) + .cleanupIntervalSeconds(5) + .dynamicEfMin(6) + .dynamicEfMax(7) + .dynamicEfFactor(8) + .flatSearchCutoff(9) + .skipVectorization(true) + .filterStrategy(Hnsw.FilterStrategy.ACORN)))), """ { "vectorIndexType": "hnsw", @@ -197,9 +229,9 @@ public static Object[][] testCases() { Property.integer("size")) .references( Property.reference("owner", "Person", "Company")) - .vectors(named -> named - .vector("v-shape", Hnsw.of(Img2VecNeuralVectorizer.of( - i2v -> i2v.imageFields("img")))))), + .vectors( + Vectorizers.img2vecNeural("v-shape", + i2v -> i2v.imageFields("img")))), """ { "class": "Things",