Skip to content

Commit

Permalink
CLIP feature module (#246)
Browse files Browse the repository at this point in the history
* Added feature modules for OpenAI CLIP

* Minor refactoring of utility classes

* Refactored common image preprocessing logic into helper class

Co-authored-by: Silvan Heller <[email protected]>
Co-authored-by: Florian Spiess <[email protected]>
Former-commit-id: c2315f0
  • Loading branch information
3 people authored Jan 21, 2022
1 parent 2bd8d51 commit b504855
Show file tree
Hide file tree
Showing 46 changed files with 698 additions and 94 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ allprojects {
group = 'org.vitrivr'

/* Our current version, on dev branch this should always be release+1-SNAPSHOT */
version = '3.6.2'
version = '3.6.3'

apply plugin: 'java-library'
apply plugin: 'maven-publish'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.vitrivr.cineast.core.db.DBSelector;
import org.vitrivr.cineast.core.db.dao.reader.TagReader;
import org.vitrivr.cineast.core.features.SegmentTags;
import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.math.MathHelper;
import org.vitrivr.cineast.standalone.config.Config;
import org.vitrivr.cineast.standalone.config.RetrievalRuntimeConfig;
import org.vitrivr.cineast.standalone.util.ContinuousRetrievalLogic;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.math.MathHelper;

public class Location implements ReadableFloatVector {

Expand Down
19 changes: 19 additions & 0 deletions cineast-core/src/main/java/org/vitrivr/cineast/core/data/Pair.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.vitrivr.cineast.core.data;

import java.util.Objects;

public class Pair<K, V> {

public K first;
Expand All @@ -10,4 +12,21 @@ public Pair(K first, V second) {
this.second = second;
}

@Override
public String toString() {
return "Pair(" + this.first + ", " + this.second + ")";
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Pair<?, ?> pair = (Pair<?, ?>) o;
return Objects.equals(first, pair.first) && Objects.equals(second, pair.second);
}

@Override
public int hashCode() {
return Objects.hash(first, second);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import org.vitrivr.cineast.core.data.entities.MediaObjectDescriptor;
import org.vitrivr.cineast.core.data.entities.MediaSegmentDescriptor;
import org.vitrivr.cineast.core.data.segments.SegmentContainer;
import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.math.MathHelper;

/**
* An {@link AbstractQueryTermContainer} is the implementation of a {@link SegmentContainer} which is used in the online-phase (during retrieval).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,14 @@ public CottontailWrapper(DatabaseConfig config, boolean keepOpen) {
}
this.channel = builder.build();
this.client = new SimpleClient(this.channel);

boolean pingSuccessful = this.client.ping();
watch.stop();
LOGGER.info("Connected to Cottontail in {} ms at {}:{}", watch.getTime(TimeUnit.MILLISECONDS),
config.getHost(), config.getPort());
if (pingSuccessful) {
LOGGER.info("Connected to Cottontail in {} ms at {}:{}", watch.getTime(TimeUnit.MILLISECONDS), config.getHost(), config.getPort());
} else {
LOGGER.warn("Could not connect to Cottontail at {}:{}", config.getHost(), config.getPort());
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.bytedeco.javacpp.avutil;
import org.bytedeco.javacpp.swscale;
import org.vitrivr.cineast.core.data.raw.images.MultiImage;
import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.math.MathHelper;

class VideoOutputStreamContainer extends AbstractAVStreamContainer {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import org.vitrivr.cineast.core.data.score.ScoreElement;
import org.vitrivr.cineast.core.data.segments.SegmentContainer;
import org.vitrivr.cineast.core.features.abstracts.StagedFeatureModule;
import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.math.MathHelper;
import org.vitrivr.cineast.core.util.audio.HPCP;
import org.vitrivr.cineast.core.util.dsp.fft.FFTUtil;
import org.vitrivr.cineast.core.util.dsp.fft.STFT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.vitrivr.cineast.core.data.score.ScoreElement;
import org.vitrivr.cineast.core.data.segments.SegmentContainer;
import org.vitrivr.cineast.core.features.abstracts.StagedFeatureModule;
import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.math.MathHelper;
import org.vitrivr.cineast.core.util.audio.HPCP;
import org.vitrivr.cineast.core.util.dsp.fft.FFTUtil;
import org.vitrivr.cineast.core.util.dsp.fft.STFT;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package org.vitrivr.cineast.core.features;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat16;
import org.vitrivr.cineast.core.config.QueryConfig;
import org.vitrivr.cineast.core.config.ReadableQueryConfig;
import org.vitrivr.cineast.core.data.FloatVectorImpl;
import org.vitrivr.cineast.core.data.frames.VideoFrame;
import org.vitrivr.cineast.core.data.score.ScoreElement;
import org.vitrivr.cineast.core.data.segments.SegmentContainer;
import org.vitrivr.cineast.core.features.abstracts.AbstractFeatureModule;
import org.vitrivr.cineast.core.util.images.ImagePreprocessingHelper;

import java.awt.image.BufferedImage;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class CLIPImage extends AbstractFeatureModule {

private static final Logger LOGGER = LogManager.getLogger();

private static final int EMBEDDING_SIZE = 512;
private static final String TABLE_NAME = "features_clip";
private static final ReadableQueryConfig.Distance DISTANCE = ReadableQueryConfig.Distance.cosine;

private static final int IMAGE_SIZE = 224;

private static final String RESOURCE_PATH = "resources/CLIP/";
private static final String EMBEDDING_MODEL = "clip-image-vit-32-tf";

private static final String EMBEDDING_INPUT = "input";
private static final String EMBEDDING_OUTPUT = "output";

private static final float[] MEAN = new float[]{0.48145466f, 0.4578275f, 0.40821073f};
private static final float[] STD = new float[]{0.26862954f, 0.26130258f, 0.27577711f};

private SavedModelBundle model;

public CLIPImage() {
super(TABLE_NAME, 1f, EMBEDDING_SIZE);
model = SavedModelBundle.load(RESOURCE_PATH + EMBEDDING_MODEL);
}

@Override
public void processSegment(SegmentContainer shot) {

if (shot.getMostRepresentativeFrame() == VideoFrame.EMPTY_VIDEO_FRAME) {
return;
}

float[] embeddingArray = embedImage(shot.getMostRepresentativeFrame().getImage().getBufferedImage());
this.persist(shot.getId(), new FloatVectorImpl(embeddingArray));

}

@Override
protected ReadableQueryConfig setQueryConfig(ReadableQueryConfig qc) {
return QueryConfig.clone(qc).setDistance(DISTANCE);
}

@Override
public List<ScoreElement> getSimilar(SegmentContainer sc, ReadableQueryConfig qc) {

if (sc.getMostRepresentativeFrame() == VideoFrame.EMPTY_VIDEO_FRAME) {
return Collections.emptyList();
}

QueryConfig queryConfig = QueryConfig.clone(qc);
queryConfig.setDistance(DISTANCE);

float[] embeddingArray = embedImage(sc.getMostRepresentativeFrame().getImage().getBufferedImage());

return getSimilar(embeddingArray, queryConfig);
}

private float[] embedImage(BufferedImage img) {

float[] rgb = prepareImage(img);

try (TFloat16 imageTensor = TFloat16.tensorOf(Shape.of(1, 3, IMAGE_SIZE, IMAGE_SIZE), DataBuffers.of(rgb))) {
HashMap<String, Tensor> inputMap = new HashMap<>();
inputMap.put(EMBEDDING_INPUT, imageTensor);

Map<String, Tensor> resultMap = model.call(inputMap);

try (TFloat16 encoding = (TFloat16) resultMap.get(EMBEDDING_OUTPUT)) {

float[] embeddingArray = new float[EMBEDDING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
encoding.read(floatBuffer);

return embeddingArray;

}
}
}

private static float[] prepareImage(BufferedImage img) {
return ImagePreprocessingHelper.imageToCHWArray(
ImagePreprocessingHelper.squaredScaleCenterCrop(img, IMAGE_SIZE),
MEAN, STD);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package org.vitrivr.cineast.core.features;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.LongNdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TInt64;
import org.vitrivr.cineast.core.config.QueryConfig;
import org.vitrivr.cineast.core.config.ReadableQueryConfig;
import org.vitrivr.cineast.core.data.CorrespondenceFunction;
import org.vitrivr.cineast.core.data.distance.DistanceElement;
import org.vitrivr.cineast.core.data.distance.SegmentDistanceElement;
import org.vitrivr.cineast.core.data.providers.primitive.FloatArrayTypeProvider;
import org.vitrivr.cineast.core.data.providers.primitive.PrimitiveTypeProvider;
import org.vitrivr.cineast.core.data.providers.primitive.StringTypeProvider;
import org.vitrivr.cineast.core.data.score.ScoreElement;
import org.vitrivr.cineast.core.data.segments.SegmentContainer;
import org.vitrivr.cineast.core.db.DBSelector;
import org.vitrivr.cineast.core.db.DBSelectorSupplier;
import org.vitrivr.cineast.core.db.setup.EntityCreator;
import org.vitrivr.cineast.core.features.retriever.Retriever;
import org.vitrivr.cineast.core.util.text.ClipTokenizer;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import static org.vitrivr.cineast.core.util.CineastConstants.FEATURE_COLUMN_QUALIFIER;
import static org.vitrivr.cineast.core.util.CineastConstants.GENERIC_ID_COLUMN_QUALIFIER;

public class CLIPText implements Retriever {

private static final Logger LOGGER = LogManager.getLogger();

private static final int EMBEDDING_SIZE = 512;
private static final String TABLE_NAME = "features_clip";
private static final ReadableQueryConfig.Distance DISTANCE = ReadableQueryConfig.Distance.cosine;

private static final String RESOURCE_PATH = "resources/CLIP/";
private static final String EMBEDDING_MODEL = "clip-text-vit-32-tf";

private static final String EMBEDDING_INPUT = "input";
private static final String EMBEDDING_OUTPUT = "output";

private static final CorrespondenceFunction CORRESPONDENCE = CorrespondenceFunction.linear(1f);

private static SavedModelBundle model;

private DBSelector selector;
private ClipTokenizer ct = new ClipTokenizer();

private static void init() {
if (model == null) {
model = SavedModelBundle.load(RESOURCE_PATH + EMBEDDING_MODEL);
}
}

public CLIPText() {
init();
}

@Override
public void initalizePersistentLayer(Supplier<EntityCreator> supply) {
supply.get().createFeatureEntity(TABLE_NAME, true, EMBEDDING_SIZE);
}

@Override
public void dropPersistentLayer(Supplier<EntityCreator> supply) {
supply.get().dropEntity(TABLE_NAME);
}

@Override
public void init(DBSelectorSupplier selectorSupply) {
this.selector = selectorSupply.get();
this.selector.open(TABLE_NAME);
}

@Override
public List<ScoreElement> getSimilar(SegmentContainer sc, ReadableQueryConfig qc) {

String text = sc.getText();

if (text == null || text.isBlank()) {
return Collections.emptyList();
}

return getSimilar(new FloatArrayTypeProvider(embedText(text)), qc);
}

private float[] embedText(String text) {

long[] tokens = ct.clipTokenize(text);

LongNdArray arr = NdArrays.ofLongs(Shape.of(1, tokens.length));
for (int i = 0; i < tokens.length; i++) {
arr.setLong(tokens[i], 0, i);
}

try (TInt64 textTensor = TInt64.tensorOf(arr)) {

HashMap<String, Tensor> inputMap = new HashMap<>();
inputMap.put(EMBEDDING_INPUT, textTensor);

Map<String, Tensor> resultMap = model.call(inputMap);

try (TFloat16 embedding = (TFloat16) resultMap.get(EMBEDDING_OUTPUT)) {

float[] embeddingArray = new float[EMBEDDING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
embedding.read(floatBuffer);
return embeddingArray;

}
}
}

@Override
public List<ScoreElement> getSimilar(String segmentId, ReadableQueryConfig qc) {
List<PrimitiveTypeProvider> list = this.selector.getFeatureVectorsGeneric(GENERIC_ID_COLUMN_QUALIFIER, new StringTypeProvider(segmentId), FEATURE_COLUMN_QUALIFIER);
if (list.isEmpty()) {
LOGGER.warn("No feature vector for shotId {} found, returning empty result-list", segmentId);
return Collections.emptyList();
}
return getSimilar(list.get(0), qc);
}

private List<ScoreElement> getSimilar(PrimitiveTypeProvider queryProvider, ReadableQueryConfig qc) {
ReadableQueryConfig qcc = QueryConfig.clone(qc).setDistance(DISTANCE);
List<SegmentDistanceElement> distances = this.selector.getNearestNeighboursGeneric(qc.getResultsPerModule(), queryProvider, FEATURE_COLUMN_QUALIFIER, SegmentDistanceElement.class, qcc);
CorrespondenceFunction function = qcc.getCorrespondenceFunction().orElse(CORRESPONDENCE);
return DistanceElement.toScore(distances, function);
}

@Override
public void finish() {
if (this.selector != null) {
this.selector.close();
this.selector = null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import org.vitrivr.cineast.core.data.score.ScoreElement;
import org.vitrivr.cineast.core.data.segments.SegmentContainer;
import org.vitrivr.cineast.core.features.abstracts.AbstractFeatureModule;
import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.math.MathHelper;

/**
* see Efficient Use of MPEG-7 Edge Histogram Descriptor by Won '02 see http://stackoverflow.com/questions/909542/opencv-edge-extraction
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.vitrivr.cineast.core.features;

import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.math.MathHelper;

/**
* A Extraction and Retrieval module that uses HOG descriptors and a 256 word codebook based on Mirflickr 25K to obtain a histograms of codewords. These histograms ares used as feature-vectors.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.vitrivr.cineast.core.features;

import org.vitrivr.cineast.core.util.MathHelper;
import org.vitrivr.cineast.core.util.math.MathHelper;

/**
* A Extraction and Retrieval module that uses HOG descriptors and a 512 word codebook based on Mirflickr 25K to obtain a histograms of codewords. These histograms ares used as feature-vectors.
Expand Down
Loading

0 comments on commit b504855

Please sign in to comment.