mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
fix: remove heavy local embedding model
This commit is contained in:
parent
1b57195a7a
commit
15a2c6ab4b
6 changed files with 13 additions and 860 deletions
|
|
@ -2,7 +2,6 @@ from typing import Optional
|
|||
from uuid import uuid4
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.prompts import MessagesPlaceholder
|
||||
|
|
@ -45,7 +44,6 @@ def ask_ai(
|
|||
texts = text_splitter.split_documents(documents)
|
||||
|
||||
embedding_models = {
|
||||
"all-MiniLM-L6-v2": SentenceTransformerEmbeddings,
|
||||
"text-embedding-ada-002": OpenAIEmbeddings,
|
||||
}
|
||||
|
||||
|
|
@ -53,11 +51,11 @@ def ask_ai(
|
|||
|
||||
if embedding_model_name in embedding_models:
|
||||
if embedding_model_name == "text-embedding-ada-002":
|
||||
embedding_function = embedding_models[embedding_model_name](model=embedding_model_name, api_key=openai_api_key)
|
||||
if embedding_model_name == "all-MiniLM-L6-v2":
|
||||
embedding_function = embedding_models[embedding_model_name](model_name=embedding_model_name)
|
||||
embedding_function = embedding_models[embedding_model_name](
|
||||
model=embedding_model_name, api_key=openai_api_key
|
||||
)
|
||||
else:
|
||||
embedding_function = embedding_models[embedding_model_name](model_name=embedding_model_name)
|
||||
raise Exception("Embedding model not found")
|
||||
|
||||
# load it into Chroma and use it as a retriever
|
||||
db = Chroma.from_documents(texts, embedding_function)
|
||||
|
|
@ -75,7 +73,10 @@ def ask_ai(
|
|||
memory_key = "history"
|
||||
|
||||
memory = AgentTokenBufferMemory(
|
||||
memory_key=memory_key, llm=llm, chat_memory=message_history, max_token_limit=1000
|
||||
memory_key=memory_key,
|
||||
llm=llm,
|
||||
chat_memory=message_history,
|
||||
max_token_limit=1000,
|
||||
)
|
||||
|
||||
system_message = SystemMessage(content=(message_for_the_prompt))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue