File size: 1,479 Bytes
f51bb92
 
 
 
 
8f6647c
 
 
 
 
 
 
f51bb92
 
 
 
1e2550f
 
 
 
 
 
 
 
 
f51bb92
 
 
 
 
 
 
 
 
 
1e2550f
f51bb92
 
 
6581a76
f51bb92
 
 
 
 
 
 
8f6647c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from langchain_community.vectorstores import FAISS
from modules.vectorstore.base import VectorStoreBase
import os


class FAISS(FAISS):
    """To add length property to FAISS class"""

    def __len__(self):
        return self.index.ntotal


class FaissVectorStore(VectorStoreBase):
    def __init__(self, config):
        self.config = config
        self._init_vector_db()
        self.local_path = os.path.join(
            self.config["vectorstore"]["db_path"],
            "db_"
            + self.config["vectorstore"]["db_option"]
            + "_"
            + self.config["vectorstore"]["model"]
            + "_"
            + config["splitter_options"]["chunking_mode"],
        )

    def _init_vector_db(self):
        self.faiss = FAISS(
            embedding_function=None, index=0, index_to_docstore_id={}, docstore={}
        )

    def create_database(self, document_chunks, embedding_model):
        self.vectorstore = self.faiss.from_documents(
            documents=document_chunks, embedding=embedding_model
        )
        self.vectorstore.save_local(self.local_path)

    def load_database(self, embedding_model):
        self.vectorstore = self.faiss.load_local(
            self.local_path,
            embedding_model,
            allow_dangerous_deserialization=True,
        )
        return self.vectorstore

    def as_retriever(self):
        return self.vectorstore.as_retriever()

    def __len__(self):
        return len(self.vectorstore)