feat: org wide ai features check

This commit is contained in:
swve 2024-01-14 11:58:09 +01:00
parent de93d56945
commit 077c26ce15
24 changed files with 573 additions and 163 deletions

View file

@ -11,6 +11,7 @@ from langchain_core.messages import SystemMessage
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)
from langchain_openai import OpenAIEmbeddings
from langchain_community.chat_models import ChatOpenAI
from langchain.agents.agent_toolkits import (
create_retriever_tool,
@ -31,6 +32,8 @@ def ask_ai(
message_history,
text_reference: str,
message_for_the_prompt: str,
embedding_model_name: str,
openai_model_name: str,
):
# Get API Keys
LH_CONFIG = get_learnhouse_config()
@ -41,8 +44,20 @@ def ask_ai(
documents = text_splitter.create_documents([text_reference])
texts = text_splitter.split_documents(documents)
# create the open-source embedding function
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
embedding_models = {
"all-MiniLM-L6-v2": SentenceTransformerEmbeddings,
"text-embedding-ada-002": OpenAIEmbeddings,
}
embedding_function = None
if embedding_model_name in embedding_models:
if embedding_model_name == "text-embedding-ada-002":
embedding_function = embedding_models[embedding_model_name](model=embedding_model_name, api_key=openai_api_key)
if embedding_model_name == "all-MiniLM-L6-v2":
embedding_function = embedding_models[embedding_model_name](model_name=embedding_model_name)
else:
embedding_function = embedding_models[embedding_model_name](model_name=embedding_model_name)
# load it into Chroma and use it as a retriever
db = Chroma.from_documents(texts, embedding_function)
@ -53,12 +68,14 @@ def ask_ai(
)
tools = [tool]
llm = ChatOpenAI(temperature=0, api_key=openai_api_key, model_name="gpt-3.5-turbo")
llm = ChatOpenAI(
temperature=0, api_key=openai_api_key, model_name=openai_model_name
)
memory_key = "history"
memory = AgentTokenBufferMemory(
memory_key=memory_key, llm=llm, chat_memory=message_history, max_tokens=1000
memory_key=memory_key, llm=llm, chat_memory=message_history, max_token_limit=1000
)
system_message = SystemMessage(content=(message_for_the_prompt))