This codelab illustrates how to build a text-video retrieval engine from scratch using Milvus and Towhee.

What is Text-Video Retrieval?

In simple words, text-video retrieval is: given a text query and a pool of candidate videos, select the video which corresponds to the text query.

What are Milvus & Towhee?

We'll go through video retrieval procedures and evaluate the performance. Moreover, we managed to make the core functionality as simple as few lines of code, with which you can start hacking your own video retrieval engine.

Install Dependencies

First please make sure you have installed required python packages:

$ python -m pip install -q pymilvus towhee towhee.models pillow ipython gradio

Prepare the data

MSR-VTT (Microsoft Research Video to Text) is a dataset for the open domain video captioning, which consists of 10,000 video clips.

Download the MSR-VTT-1kA test set from google drive and unzip it, which contains just 1k videos. And the video captions text sentence information is in ./MSRVTT_JSFUSION_test.csv.

The data is organized as follows:

First to download the dataset and unzip it:

$ curl -L https://github.com/towhee-io/examples/releases/download/data/text_video_search.zip -O
$ unzip -q -o text_video_search.zip

Let's take a quick look:

import pandas as pd
import os

raw_video_path = './test_1k_compress' # 1k test video path.
test_csv_path = './MSRVTT_JSFUSION_test.csv' # 1k video caption csv.

test_sample_csv_path = './MSRVTT_JSFUSION_test_sample.csv'

sample_num = 1000 # you can change this sample_num to be smaller, so that this notebook will be faster.
test_df = pd.read_csv(test_csv_path)
print('length of all test set is {}'.format(len(test_df)))
sample_df = test_df.sample(sample_num, random_state=42)

sample_df['video_path'] = sample_df.apply(lambda x:os.path.join(raw_video_path, x['video_id']) + '.mp4', axis=1)

sample_df.to_csv(test_sample_csv_path)
print('random sample {} examples'.format(sample_num))

df = pd.read_csv(test_sample_csv_path)

df[['video_id', 'video_path', 'sentence']].head()

Define some helper function to convert video to gif so that we can have a look at these video-text pairs.

from IPython import display
from pathlib import Path
import towhee
from PIL import Image

def display_gif(video_path_list, text_list):
    html = ''
    for video_path, text in zip(video_path_list, text_list):
        html_line = '<img src="{}"> {} <br/>'.format(video_path, text)
        html += html_line
    return display.HTML(html)

    
def convert_video2gif(video_path, output_gif_path, num_samples=16):
    frames = (
        towhee.glob(video_path)
              .video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': num_samples})
              .to_list()[0]
    )
    imgs = [Image.fromarray(frame) for frame in frames]
    imgs[0].save(fp=output_gif_path, format='GIF', append_images=imgs[1:], save_all=True, loop=0)


def display_gifs_from_video(video_path_list, text_list, tmpdirname = './tmp_gifs'):
    Path(tmpdirname).mkdir(exist_ok=True)
    gif_path_list = []
    for video_path in video_path_list:
        video_name = str(Path(video_path).name).split('.')[0]
        gif_path = Path(tmpdirname) / (video_name + '.gif')
        convert_video2gif(video_path, gif_path)
        gif_path_list.append(gif_path)
    return display_gif(gif_path_list, text_list)

Take a look at the ground-truth video-text pairs.

sample_show_df = sample_df[:3]
video_path_list = sample_show_df['video_path'].to_list()
text_list = sample_show_df['sentence'].to_list()
tmpdirname = './tmp_gifs'
display_gifs_from_video(video_path_list, text_list, tmpdirname=tmpdirname)

a girl wearing red top and black trouser is putting a sweater on a dog

young people sit around the edges of a room clapping and raising their arms while others dance in the center during a party

cartoon people are eating at a restaurant

Create a Milvus Collection

Before getting started, please make sure you have installed milvus. Let's first create a text_video_retrieval collection that uses the L2 distance metric and an IVF_FLAT index.

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

connections.connect(host='127.0.0.1', port='19530')

def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='video retrieval')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2', #IP
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

collection = create_milvus_collection('text_video_retrieval', 512)

Load Video Embeddings into Milvus

We first extract embeddings from images with CLIP4Clip model and insert the embeddings into Milvus for indexing. Towhee provides a method-chaining style API so that users can assemble a data processing pipeline with operators.

Before you start running the clip4clip operator, you should make sure you have make git and git-lfs installed. For git(If you have installed, just skip it):

$ sudo apt-get install git

For git-lfs(You must install it for downloading checkpoint weights on backend):

$ sudo apt-get install git-lfs
$ git lfs install

CLIP4Clip is a video-text retrieval model based on CLIP (ViT-B). The towhee clip4clip operator with pretrained weights can easily extract video embedding and text embedding by a few codes.

import os
import towhee

device = 'cuda:2'
# device = 'cpu'

# For the first time you run this line, 
# it will take some time 
# because towhee will download operator with weights on backend.
dc = (
    towhee.read_csv(test_sample_csv_path)
      .runas_op['video_id', 'id'](func=lambda x: int(x[-4:]))
      .video_decode.ffmpeg['video_path', 'frames'](sample_type='uniform_temporal_subsample', args={'num_samples': 12})
      .runas_op['frames', 'frames'](func=lambda x: [y for y in x])
      .video_text_embedding.clip4clip['frames', 'vec'](model_name='clip_vit_b32', modality='video', device=device)
      .to_milvus['id', 'vec'](collection=collection, batch=30)
)
print('Total number of inserted data is {}.'.format(collection.num_entities))

Total number of inserted data is 1000.

Here is detailed explanation for each line of the code:

Search in Milvus with Towhee

dc = (
    towhee.read_csv(test_sample_csv_path).unstream()
      .video_text_embedding.clip4clip['sentence','text_vec'](model_name='clip_vit_b32', modality='text', device=device)
      .milvus_search['text_vec', 'top10_raw_res'](collection=collection, limit=10)
      .runas_op['video_id', 'ground_truth'](func=lambda x : [int(x[-4:])])
      .runas_op['top10_raw_res', 'top1'](func=lambda res: [x.id for i, x in enumerate(res) if i < 1])
      .runas_op['top10_raw_res', 'top5'](func=lambda res: [x.id for i, x in enumerate(res) if i < 5])
      .runas_op['top10_raw_res', 'top10'](func=lambda res: [x.id for i, x in enumerate(res) if i < 10])
)

dc.select['video_id', 'sentence', 'ground_truth', 'top10_raw_res', 'top1', 'top5', 'top10']().show()

We have finished the core functionality of the text-video retrieval engine. However, we don't know whether it achieves a reasonable performance. We need to evaluate the retrieval engine against the ground truth so that we know if there is any room to improve it.

In this section, we'll evaluate the strength of our text-video retrieval using recall@topk: Recall@topk is the proportion of relevant items found in the top-k recommendations. Suppose that we computed recall at 10 and found it is 40% in our top-10 recommendation system. This means that 40% of the total number of the relevant items appear in the top-k results.

benchmark = (
    dc.with_metrics(['mean_hit_ratio',])
        .evaluate['ground_truth', 'top1'](name='recall_at_1')
        .evaluate['ground_truth', 'top5'](name='recall_at_5')
        .evaluate['ground_truth', 'top10'](name='recall_at_10')
        .report()
)

This result is almost identical to the recall metrics represented in the paper. You can find more detail about metrics in paperwithcode.

We've learnt how to build a reverse video search engine. Now it's time to add some interface and release a showcase. Towhee provides towhee.api() to wrap the data processing pipeline as a function with .as_function(). So we can build a quick demo with this milvus_search_function with Gradio.

import gradio

show_num = 3
with towhee.api() as api:
    milvus_search_function = (
         api.clip4clip(model_name='clip_vit_b32', modality='text', device=device)
            .milvus_search(collection=collection, limit=show_num)
            .runas_op(func=lambda res: [os.path.join(raw_video_path, 'video' + str(x.id) + '.mp4') for x in res])
            .as_function()
    )

interface = gradio.Interface(milvus_search_function, 
                             inputs=[gradio.Textbox()],
                             outputs=[gradio.Video(format='mp4') for _ in range(show_num)]
                            )

interface.launch(inline=True, share=True)