fix: remove heavy local embedding model

This commit is contained in:
swve 2024-01-28 20:39:42 +00:00
parent 1b57195a7a
commit 15a2c6ab4b
6 changed files with 13 additions and 860 deletions

View file

@ -21,8 +21,8 @@ class AIConfig(BaseModel):
enabled : bool = True
limits: AILimitsSettings = AILimitsSettings()
embeddings: Literal[
"text-embedding-ada-002", "all-MiniLM-L6-v2"
] = "all-MiniLM-L6-v2"
"text-embedding-ada-002",
] = "text-embedding-ada-002"
ai_model: Literal["gpt-3.5-turbo", "gpt-4-1106-preview"] = "gpt-3.5-turbo"
features: AIEnabledFeatures = AIEnabledFeatures()

View file

@ -2,7 +2,6 @@ from typing import Optional
from uuid import uuid4
from langchain.agents import AgentExecutor
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.prompts import MessagesPlaceholder
@ -45,7 +44,6 @@ def ask_ai(
texts = text_splitter.split_documents(documents)
embedding_models = {
"all-MiniLM-L6-v2": SentenceTransformerEmbeddings,
"text-embedding-ada-002": OpenAIEmbeddings,
}
@ -53,11 +51,11 @@ def ask_ai(
if embedding_model_name in embedding_models:
if embedding_model_name == "text-embedding-ada-002":
embedding_function = embedding_models[embedding_model_name](model=embedding_model_name, api_key=openai_api_key)
if embedding_model_name == "all-MiniLM-L6-v2":
embedding_function = embedding_models[embedding_model_name](model_name=embedding_model_name)
embedding_function = embedding_models[embedding_model_name](
model=embedding_model_name, api_key=openai_api_key
)
else:
embedding_function = embedding_models[embedding_model_name](model_name=embedding_model_name)
raise Exception("Embedding model not found")
# load it into Chroma and use it as a retriever
db = Chroma.from_documents(texts, embedding_function)
@ -75,7 +73,10 @@ def ask_ai(
memory_key = "history"
memory = AgentTokenBufferMemory(
memory_key=memory_key, llm=llm, chat_memory=message_history, max_token_limit=1000
memory_key=memory_key,
llm=llm,
chat_memory=message_history,
max_token_limit=1000,
)
system_message = SystemMessage(content=(message_for_the_prompt))

View file

@ -169,7 +169,7 @@ async def create_org(
limits_enabled=False,
max_asks=0,
),
embeddings="all-MiniLM-L6-v2",
embeddings="text-embedding-ada-002",
ai_model="gpt-3.5-turbo",
features=AIEnabledFeatures(
editor=False,
@ -531,8 +531,6 @@ async def get_org_join_mechanism(
return signup_mechanism
## 🔒 RBAC Utils ##