import streamlit as st import os from io import StringIO from llama_index.llms import HuggingFaceInferenceAPI from llama_index.embeddings import HuggingFaceInferenceAPIEmbedding from llama_index import ServiceContext, VectorStoreIndex from llama_index.schema import Document import uuid from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter from typing import List from pydantic import BaseModel import json inference_api_key = st.secrets["INFRERENCE_API_TOKEN"] # embed_model_name = st.text_input( # 'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce") # llm_model_name = st.text_input( # 'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2") class PriceModel(BaseModel): """Data model for price""" price: str embed_model_name = "jinaai/jina-embedding-s-en-v1" llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2" llm = HuggingFaceInferenceAPI( model_name=llm_model_name, token=inference_api_key) embed_model = HuggingFaceInferenceAPIEmbedding( model_name=embed_model_name, token=inference_api_key, model_kwargs={"device": ""}, encode_kwargs={"normalize_embeddings": True}, ) service_context = ServiceContext.from_defaults( embed_model=embed_model, llm=llm) query = st.text_input( 'Query', "What is the price of the product?" ) html_file = st.file_uploader("Upload a html file", type=["html"]) if html_file is not None: stringio = StringIO(html_file.getvalue().decode("utf-8")) string_data = stringio.read() with st.expander("Uploaded HTML"): st.code(string_data, language='html') document_id = str(uuid.uuid4()) document = Document(text=string_data) document.metadata["id"] = document_id documents = [document] filters = MetadataFilters( filters=[ExactMatchFilter(key="id", value=document_id)]) index = VectorStoreIndex.from_documents( documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context) query_engine = index.as_query_engine( filters=filters, service_context=service_context, response_mode="tree_summarize", output_cls=PriceModel) response = query_engine.query(query) st.write(f'Price: {response.price}') # if st.button('Start Pipeline'): # if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None: # st.write('Running Pipeline') # llm = HuggingFaceInferenceAPI( # model_name=llm_model_name, token=inference_api_key) # embed_model = HuggingFaceInferenceAPIEmbedding( # model_name=embed_model_name, # token=inference_api_key, # model_kwargs={"device": ""}, # encode_kwargs={"normalize_embeddings": True}, # ) # service_context = ServiceContext.from_defaults( # embed_model=embed_model, llm=llm) # stringio = StringIO(html_file.getvalue().decode("utf-8")) # string_data = stringio.read() # with st.expander("Uploaded HTML"): # st.write(string_data) # document_id = str(uuid.uuid4()) # document = Document(text=string_data) # document.metadata["id"] = document_id # documents = [document] # filters = MetadataFilters( # filters=[ExactMatchFilter(key="id", value=document_id)]) # index = VectorStoreIndex.from_documents( # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context) # retriever = index.as_retriever() # ranked_nodes = retriever.retrieve( # query) # with st.expander("Ranked Nodes"): # for node in ranked_nodes: # st.write(node.node.get_content(), "-> Score:", node.score) # query_engine = index.as_query_engine( # filters=filters, service_context=service_context) # response = query_engine.query(query) # st.write(response.response) # st.write(response.source_nodes) # else: # st.error('Please fill in all the fields') # else: # st.write('Press start to begin') # # if html_file is not None: # # stringio = StringIO(html_file.getvalue().decode("utf-8")) # # string_data = stringio.read() # # with st.expander("Uploaded HTML"): # # st.write(string_data) # # document_id = str(uuid.uuid4()) # # document = Document(text=string_data) # # document.metadata["id"] = document_id # # documents = [document] # # filters = MetadataFilters( # # filters=[ExactMatchFilter(key="id", value=document_id)]) # # index = VectorStoreIndex.from_documents( # # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context) # # retriever = index.as_retriever() # # ranked_nodes = retriever.retrieve( # # "Get me all the information about the product") # # with st.expander("Ranked Nodes"): # # for node in ranked_nodes: # # st.write(node.node.get_content(), "-> Score:", node.score) # # query_engine = index.as_query_engine( # # filters=filters, service_context=service_context) # # response = query_engine.query( # # "Get me all the information about the product") # # st.write(response)