from ragatouille import RAGPretrainedModel from modules.vectorstore.base import VectorStoreBase import os class ColbertVectorStore(VectorStoreBase): def __init__(self, config): self.config = config self._init_vector_db() def _init_vector_db(self): self.colbert = RAGPretrainedModel.from_pretrained( "colbert-ir/colbertv2.0", index_root=os.path.join( self.config["vectorstore"]["db_path"], "db_" + self.config["vectorstore"]["db_option"], ), ) def create_database(self, documents, document_names, document_metadata): index_path = self.colbert.index( index_name="new_idx", collection=documents, document_ids=document_names, document_metadatas=document_metadata, ) def load_database(self): path = os.path.join( self.config["vectorstore"]["db_path"], "db_" + self.config["vectorstore"]["db_option"], ) self.vectorstore = RAGPretrainedModel.from_index( f"{path}/colbert/indexes/new_idx" ) return self.vectorstore def as_retriever(self): return self.vectorstore.as_retriever()