feat: init organizationconfig

This commit is contained in:
swve 2024-01-13 15:25:04 +01:00
parent 982ba037f5
commit de93d56945
6 changed files with 178 additions and 37 deletions

View file

@ -7,7 +7,6 @@ from src.core.events.database import get_db_session
from src.db.users import PublicUser
from src.db.activities import Activity, ActivityRead
from src.security.auth import get_current_user
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from src.services.ai.base import ask_ai, get_chat_session_history
from src.services.ai.schemas.ai import (

View file

@ -3,16 +3,15 @@ from uuid import uuid4
from langchain.agents import AgentExecutor
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain_core.messages import BaseMessage
from langchain_community.vectorstores import Chroma
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.prompts import MessagesPlaceholder
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain_community.chat_message_histories import RedisChatMessageHistory
from langchain_core.messages import SystemMessage
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI
from langchain.agents.agent_toolkits import (
create_retriever_tool,
)
@ -54,7 +53,7 @@ def ask_ai(
)
tools = [tool]
llm = ChatOpenAI(temperature=0, api_key=openai_api_key)
llm = ChatOpenAI(temperature=0, api_key=openai_api_key, model_name="gpt-3.5-turbo")
memory_key = "history"

View file

@ -1,7 +1,20 @@
import json
import logging
from datetime import datetime
from logging import config
from typing import Literal
from uuid import uuid4
from sqlmodel import Session, select
from src.db.organization_config import (
AIConfig,
AIEnabledFeatures,
AILimitsSettings,
GeneralConfig,
LimitSettings,
OrgUserConfig,
OrganizationConfig,
OrganizationConfigBase,
)
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
@ -23,7 +36,7 @@ async def get_organization(
org_id: str,
db_session: Session,
current_user: PublicUser | AnonymousUser,
):
) -> OrganizationRead:
statement = select(Organization).where(Organization.id == org_id)
result = db_session.exec(statement)
@ -38,7 +51,18 @@ async def get_organization(
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
org = OrganizationRead.from_orm(org)
# Get org config
statement = select(OrganizationConfig).where(OrganizationConfig.org_id == org.id)
result = db_session.exec(statement)
org_config = result.first()
if org_config is None:
logging.error(f"Organization {org_id} has no config")
config = OrganizationConfig.from_orm(org_config) if org_config else {}
org = OrganizationRead(**org.dict(), config=config)
return org
@ -48,7 +72,7 @@ async def get_organization_by_slug(
org_slug: str,
db_session: Session,
current_user: PublicUser | AnonymousUser,
):
) -> OrganizationRead:
statement = select(Organization).where(Organization.slug == org_slug)
result = db_session.exec(statement)
@ -63,7 +87,18 @@ async def get_organization_by_slug(
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
org = OrganizationRead.from_orm(org)
# Get org config
statement = select(OrganizationConfig).where(OrganizationConfig.org_id == org.id)
result = db_session.exec(statement)
org_config = result.first()
if org_config is None:
logging.error(f"Organization {org_slug} has no config")
config = OrganizationConfig.from_orm(org_config) if org_config else {}
org = OrganizationRead(**org.dict(), config=config)
return org
@ -87,7 +122,7 @@ async def create_org(
org = Organization.from_orm(org_object)
if isinstance(current_user,AnonymousUser):
if isinstance(current_user, AnonymousUser):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="You should be logged in to be able to achieve this action",
@ -115,7 +150,64 @@ async def create_org(
db_session.commit()
db_session.refresh(user_org)
return OrganizationRead.from_orm(org)
org_config = OrganizationConfigBase(
GeneralConfig=GeneralConfig(
color="#000000",
limits=LimitSettings(
limits_enabled=False,
max_users=0,
max_storage=0,
max_staff=0,
),
users=OrgUserConfig(
signup_mechanism="open",
),
active=True,
),
AIConfig=AIConfig(
limits=AILimitsSettings(
limits_enabled=False,
max_asks=0,
),
embeddings="all-MiniLM-L6-v2",
ai_model="gpt-3.5-turbo",
features=AIEnabledFeatures(
editor=False,
activity_ask=False,
course_ask=False,
global_ai_ask=False,
),
),
)
org_config = json.loads(org_config.json())
# OrgSettings
org_settings = OrganizationConfig(
org_id=int(org.id if org.id else 0),
config=org_config,
creation_date=str(datetime.now()),
update_date=str(datetime.now()),
)
db_session.add(org_settings)
db_session.commit()
db_session.refresh(org_settings)
# Get org config
statement = select(OrganizationConfig).where(OrganizationConfig.org_id == org.id)
result = db_session.exec(statement)
org_config = result.first()
if org_config is None:
logging.error(f"Organization {org.id} has no config")
config = OrganizationConfig.from_orm(org_config)
org = OrganizationRead(**org.dict(), config=config)
return org
async def update_org(