Skip to content

Commit

Permalink
fix: point id (#10)
Browse files Browse the repository at this point in the history
* fix: point id

* ci: try release master only

* chore: bump version
  • Loading branch information
Anush008 authored Feb 17, 2024
1 parent b2ac52d commit 8583ff2
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 69 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: Release

on: [push, workflow_dispatch]
on:
push:
branches:
- master
workflow_dispatch:

jobs:
build:
Expand Down Expand Up @@ -38,6 +42,10 @@ jobs:
restore-keys: |
${{ runner.os }}-maven-
- uses: actions/setup-node@v4
with:
node-version: 20

- name: "🔧 setup Bun"
uses: oven-sh/setup-bun@v1

Expand All @@ -54,4 +62,4 @@ jobs:
GIT_COMMITTER_NAME: "github-actions[bot]"
GIT_COMMITTER_EMAIL: "41898282+github-actions[bot]@users.noreply.github.com"
GIT_AUTHOR_NAME: ${{ steps.author_info.outputs.AUTHOR_NAME }}
GIT_AUTHOR_EMAIL: ${{ steps.author_info.outputs.AUTHOR_EMAIL }}
GIT_AUTHOR_EMAIL: ${{ steps.author_info.outputs.AUTHOR_EMAIL }}
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ For use with Java and Scala projects, the package can be found [here](https://ce
<dependency>
<groupId>io.qdrant</groupId>
<artifactId>spark</artifactId>
<version>1.12</version>
<version>1.13</version>
</dependency>
```

Expand All @@ -43,7 +43,7 @@ from pyspark.sql import SparkSession

spark = SparkSession.builder.config(
"spark.jars",
"spark-1.12-jar-with-dependencies.jar", # specify the downloaded JAR file
"spark-1.13-jar-with-dependencies.jar", # specify the downloaded JAR file
)
.master("local[*]")
.appName("qdrant")
Expand Down Expand Up @@ -73,7 +73,7 @@ To load data into Qdrant, a collection has to be created beforehand with the app
You can use the `qdrant-spark` connector as a library in Databricks to ingest data into Qdrant.
- Go to the `Libraries` section in your cluster dashboard.
- Select `Install New` to open the library installation modal.
- Search for `io.qdrant:spark:1.12` in the Maven packages and click `Install`.
- Search for `io.qdrant:spark:1.13` in the Maven packages and click `Install`.

<img width="1064" alt="Screenshot 2024-01-05 at 17 20 01 (1)" src="https://github.com/qdrant/qdrant-spark/assets/46051506/d95773e0-c5c6-4ff2-bf50-8055bb08fd1b">

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>io.qdrant</groupId>
<artifactId>spark</artifactId>
<version>1.12</version>
<version>1.13</version>
<name>qdrant-spark</name>
<url>https://github.com/qdrant/qdrant-spark</url>
<description>An Apache Spark connector for the Qdrant vector database</description>
Expand Down
78 changes: 78 additions & 0 deletions src/main/java/io/qdrant/spark/ObjectFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package io.qdrant.spark;

import java.util.HashMap;
import java.util.Map;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.ArrayType;

class ObjectFactory {
public static Object object(InternalRow record, StructField field, int fieldIndex) {
DataType dataType = field.dataType();

switch (dataType.typeName()) {
case "integer":
return record.getInt(fieldIndex);
case "float":
return record.getFloat(fieldIndex);
case "double":
return record.getDouble(fieldIndex);
case "long":
return record.getLong(fieldIndex);
case "boolean":
return record.getBoolean(fieldIndex);
case "string":
return record.getString(fieldIndex);
case "array":
ArrayType arrayType = (ArrayType) dataType;
ArrayData arrayData = record.getArray(fieldIndex);
return object(arrayData, arrayType.elementType());
case "struct":
StructType structType = (StructType) dataType;
InternalRow structData = record.getStruct(fieldIndex, structType.fields().length);
return object(structData, structType);
default:
return null;
}
}

public static Object object(ArrayData arrayData, DataType elementType) {

switch (elementType.typeName()) {
case "string": {
int length = arrayData.numElements();
String[] result = new String[length];
for (int i = 0; i < length; i++) {
result[i] = arrayData.getUTF8String(i).toString();
}
return result;
}

case "struct": {
StructType structType = (StructType) elementType;
int length = arrayData.numElements();
Object[] result = new Object[length];
for (int i = 0; i < length; i++) {
InternalRow structData = arrayData.getStruct(i, structType.fields().length);
result[i] = object(structData, structType);
}
return result;
}
default:
return arrayData.toObjectArray(elementType);
}
}

public static Object object(InternalRow structData, StructType structType) {
Map<String, Object> result = new HashMap<>();
for (int i = 0; i < structType.fields().length; i++) {
StructField structField = structType.fields()[i];
result.put(structField.name(), object(structData, structField, i));
}
return result;
}
}
80 changes: 19 additions & 61 deletions src/main/java/io/qdrant/spark/QdrantDataWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static io.qdrant.spark.ObjectFactory.object;

/**
* A DataWriter implementation that writes data to Qdrant, a vector search
* engine. This class takes
Expand Down Expand Up @@ -53,12 +51,24 @@ public void write(InternalRow record) {
for (StructField field : this.schema.fields()) {
int fieldIndex = this.schema.fieldIndex(field.name());
if (this.options.idField != null && field.name().equals(this.options.idField)) {
point.id = record.get(fieldIndex, field.dataType()).toString();

DataType dataType = field.dataType();
switch (dataType.typeName()) {
case "string":
point.id = record.getString(fieldIndex);
break;

case "integer":
point.id = record.getInt(fieldIndex);
break;
default:
throw new IllegalArgumentException("Point ID should be of type string or integer");
}
} else if (field.name().equals(this.options.embeddingField)) {
float[] vector = record.getArray(fieldIndex).toFloatArray();
point.vector = vector;
point.vector = record.getArray(fieldIndex).toFloatArray();

} else {
payload.put(field.name(), convertToJavaType(record, field, fieldIndex));
payload.put(field.name(), object(record, field, fieldIndex));
}
}

Expand Down Expand Up @@ -107,62 +117,10 @@ public void abort() {
@Override
public void close() {
}

private Object convertToJavaType(InternalRow record, StructField field, int fieldIndex) {
DataType dataType = field.dataType();

if (dataType == DataTypes.StringType) {
return record.getString(fieldIndex);
} else if (dataType == DataTypes.DateType || dataType == DataTypes.TimestampType) {
return record.getString(fieldIndex);
} else if (dataType instanceof ArrayType) {
ArrayType arrayType = (ArrayType) dataType;
ArrayData arrayData = record.getArray(fieldIndex);
return convertArrayToJavaType(arrayData, arrayType.elementType());
} else if (dataType instanceof StructType) {
StructType structType = (StructType) dataType;
InternalRow structData = record.getStruct(fieldIndex, structType.fields().length);
return convertStructToJavaType(structData, structType);
}

// Fall back to the generic get method
return record.get(fieldIndex, dataType);
}

private Object convertArrayToJavaType(ArrayData arrayData, DataType elementType) {
if (elementType == DataTypes.StringType) {
int length = arrayData.numElements();
String[] result = new String[length];
for (int i = 0; i < length; i++) {
result[i] = arrayData.getUTF8String(i).toString();
}
return result;
} else if (elementType instanceof StructType) {
StructType structType = (StructType) elementType;
int length = arrayData.numElements();
Object[] result = new Object[length];
for (int i = 0; i < length; i++) {
InternalRow structData = arrayData.getStruct(i, structType.fields().length);
result[i] = convertStructToJavaType(structData, structType);
}
return result;
} else {
return arrayData.toObjectArray(elementType);
}
}

private Object convertStructToJavaType(InternalRow structData, StructType structType) {
Map<String, Object> result = new HashMap<>();
for (int i = 0; i < structType.fields().length; i++) {
StructField structField = structType.fields()[i];
result.put(structField.name(), convertToJavaType(structData, structField, i));
}
return result;
}
}

class Point implements Serializable {
public String id;
public Object id;
public float[] vector;
public HashMap<String, Object> payload;
}
3 changes: 1 addition & 2 deletions src/test/java/io/qdrant/spark/TestQdrant.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ public void testGetTable() {
options.put("embedding_field", "embedding");
options.put("qdrant_url", "http://localhost:8080");
CaseInsensitiveStringMap dataSourceOptions = new CaseInsensitiveStringMap(options);
var reader = qdrant.getTable(schema, null, dataSourceOptions);
Assert.assertTrue(reader instanceof QdrantCluster);
Assert.assertTrue(qdrant.getTable(schema, null, dataSourceOptions) instanceof QdrantCluster);
}

@Test()
Expand Down

0 comments on commit 8583ff2

Please sign in to comment.