mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
feat: org wide ai features check
This commit is contained in:
parent
de93d56945
commit
077c26ce15
24 changed files with 573 additions and 163 deletions
|
|
@ -11,6 +11,7 @@ from langchain_core.messages import SystemMessage
|
|||
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
||||
AgentTokenBufferMemory,
|
||||
)
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain.agents.agent_toolkits import (
|
||||
create_retriever_tool,
|
||||
|
|
@ -31,6 +32,8 @@ def ask_ai(
|
|||
message_history,
|
||||
text_reference: str,
|
||||
message_for_the_prompt: str,
|
||||
embedding_model_name: str,
|
||||
openai_model_name: str,
|
||||
):
|
||||
# Get API Keys
|
||||
LH_CONFIG = get_learnhouse_config()
|
||||
|
|
@ -41,8 +44,20 @@ def ask_ai(
|
|||
documents = text_splitter.create_documents([text_reference])
|
||||
texts = text_splitter.split_documents(documents)
|
||||
|
||||
# create the open-source embedding function
|
||||
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
||||
embedding_models = {
|
||||
"all-MiniLM-L6-v2": SentenceTransformerEmbeddings,
|
||||
"text-embedding-ada-002": OpenAIEmbeddings,
|
||||
}
|
||||
|
||||
embedding_function = None
|
||||
|
||||
if embedding_model_name in embedding_models:
|
||||
if embedding_model_name == "text-embedding-ada-002":
|
||||
embedding_function = embedding_models[embedding_model_name](model=embedding_model_name, api_key=openai_api_key)
|
||||
if embedding_model_name == "all-MiniLM-L6-v2":
|
||||
embedding_function = embedding_models[embedding_model_name](model_name=embedding_model_name)
|
||||
else:
|
||||
embedding_function = embedding_models[embedding_model_name](model_name=embedding_model_name)
|
||||
|
||||
# load it into Chroma and use it as a retriever
|
||||
db = Chroma.from_documents(texts, embedding_function)
|
||||
|
|
@ -53,12 +68,14 @@ def ask_ai(
|
|||
)
|
||||
tools = [tool]
|
||||
|
||||
llm = ChatOpenAI(temperature=0, api_key=openai_api_key, model_name="gpt-3.5-turbo")
|
||||
llm = ChatOpenAI(
|
||||
temperature=0, api_key=openai_api_key, model_name=openai_model_name
|
||||
)
|
||||
|
||||
memory_key = "history"
|
||||
|
||||
memory = AgentTokenBufferMemory(
|
||||
memory_key=memory_key, llm=llm, chat_memory=message_history, max_tokens=1000
|
||||
memory_key=memory_key, llm=llm, chat_memory=message_history, max_token_limit=1000
|
||||
)
|
||||
|
||||
system_message = SystemMessage(content=(message_for_the_prompt))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue