From de93d569453a58e56c8822bdffdddd9af6a8a58f Mon Sep 17 00:00:00 2001 From: swve Date: Sat, 13 Jan 2024 15:25:04 +0100 Subject: [PATCH] feat: init organizationconfig --- apps/api/src/db/organization_config.py | 68 +++++++++++++++ apps/api/src/db/organization_settings.py | 21 ----- apps/api/src/db/organizations.py | 12 ++- apps/api/src/services/ai/ai.py | 1 - apps/api/src/services/ai/base.py | 9 +- apps/api/src/services/orgs/orgs.py | 104 +++++++++++++++++++++-- 6 files changed, 178 insertions(+), 37 deletions(-) create mode 100644 apps/api/src/db/organization_config.py delete mode 100644 apps/api/src/db/organization_settings.py diff --git a/apps/api/src/db/organization_config.py b/apps/api/src/db/organization_config.py new file mode 100644 index 00000000..3608666b --- /dev/null +++ b/apps/api/src/db/organization_config.py @@ -0,0 +1,68 @@ +from json import JSONEncoder +import json +from typing import Literal, Optional +from click import Option +from pydantic import BaseModel +from sqlalchemy import JSON, BigInteger, Column, ForeignKey +from sqlmodel import Field, SQLModel + + +# AI +class AILimitsSettings(BaseModel): + limits_enabled: bool = False + max_asks: int = 0 + + +class AIEnabledFeatures(BaseModel): + editor: bool = False + activity_ask: bool = False + course_ask: bool = False + global_ai_ask: bool = False + + +class AIConfig(BaseModel): + limits: AILimitsSettings = AILimitsSettings() + embeddings: Literal[ + "text-embedding-ada-002", "all-MiniLM-L6-v2" + ] = "all-MiniLM-L6-v2" + ai_model: Literal["gpt-3.5-turbo", "gpt-4-1106-preview"] = "gpt-3.5-turbo" + features: AIEnabledFeatures = AIEnabledFeatures() + + +class OrgUserConfig(BaseModel): + signup_mechanism: Literal["open", "inviteOnly"] = "open" + + +# Limits +class LimitSettings(BaseModel): + limits_enabled: bool = False + max_users: int = 0 + max_storage: int = 0 + max_staff: int = 0 + + +# General +class GeneralConfig(BaseModel): + color: str = "" + limits: LimitSettings = LimitSettings() + users: OrgUserConfig = OrgUserConfig() + active: bool = True + + +class OrganizationConfigBase(SQLModel): + GeneralConfig: GeneralConfig + AIConfig: AIConfig + +class OrganizationConfig(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + org_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")) + ) + # TODO: fix this to use the correct type GeneralConfig + config: dict = Field(default={}, sa_column=Column(JSON)) + creation_date: Optional[str] + update_date: Optional[str] + + + + diff --git a/apps/api/src/db/organization_settings.py b/apps/api/src/db/organization_settings.py deleted file mode 100644 index babdef08..00000000 --- a/apps/api/src/db/organization_settings.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Optional -from sqlalchemy import BigInteger, Column, ForeignKey -from sqlmodel import Field, SQLModel -from enum import Enum - - -class HeaderTypeEnum(str, Enum): - LOGO_MENU_SETTINGS = "LOGO_MENU_SETTINGS" - MENU_LOGO_SETTINGS = "MENU_LOGO_SETTINGS" - - -class OrganizationSettings(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - org_id: int = Field( - sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")) - ) - logo_image: Optional[str] = "" - header_type: HeaderTypeEnum = HeaderTypeEnum.LOGO_MENU_SETTINGS - color: str = "" - creation_date: str - update_date: str diff --git a/apps/api/src/db/organizations.py b/apps/api/src/db/organizations.py index c16d3809..57b5bbf8 100644 --- a/apps/api/src/db/organizations.py +++ b/apps/api/src/db/organizations.py @@ -1,13 +1,14 @@ from typing import Optional from sqlmodel import Field, SQLModel +from src.db.organization_config import OrganizationConfig class OrganizationBase(SQLModel): name: str - description: Optional[str] + description: Optional[str] slug: str email: str - logo_image: Optional[str] + logo_image: Optional[str] class Organization(OrganizationBase, table=True): @@ -16,9 +17,11 @@ class Organization(OrganizationBase, table=True): creation_date: str = "" update_date: str = "" + class OrganizationUpdate(OrganizationBase): pass + class OrganizationCreate(OrganizationBase): pass @@ -26,5 +29,6 @@ class OrganizationCreate(OrganizationBase): class OrganizationRead(OrganizationBase): id: int org_uuid: str - creation_date: str - update_date: str + config: OrganizationConfig | dict + creation_date: str + update_date: str diff --git a/apps/api/src/services/ai/ai.py b/apps/api/src/services/ai/ai.py index e0e21e3e..3ad79a56 100644 --- a/apps/api/src/services/ai/ai.py +++ b/apps/api/src/services/ai/ai.py @@ -7,7 +7,6 @@ from src.core.events.database import get_db_session from src.db.users import PublicUser from src.db.activities import Activity, ActivityRead from src.security.auth import get_current_user -from langchain.memory.chat_message_histories import RedisChatMessageHistory from src.services.ai.base import ask_ai, get_chat_session_history from src.services.ai.schemas.ai import ( diff --git a/apps/api/src/services/ai/base.py b/apps/api/src/services/ai/base.py index f2356fa2..798f2c06 100644 --- a/apps/api/src/services/ai/base.py +++ b/apps/api/src/services/ai/base.py @@ -3,16 +3,15 @@ from uuid import uuid4 from langchain.agents import AgentExecutor from langchain.text_splitter import CharacterTextSplitter from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings -from langchain.vectorstores import Chroma -from langchain_core.messages import BaseMessage +from langchain_community.vectorstores import Chroma from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent from langchain.prompts import MessagesPlaceholder -from langchain.memory.chat_message_histories import RedisChatMessageHistory +from langchain_community.chat_message_histories import RedisChatMessageHistory from langchain_core.messages import SystemMessage from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( AgentTokenBufferMemory, ) -from langchain.chat_models import ChatOpenAI +from langchain_community.chat_models import ChatOpenAI from langchain.agents.agent_toolkits import ( create_retriever_tool, ) @@ -54,7 +53,7 @@ def ask_ai( ) tools = [tool] - llm = ChatOpenAI(temperature=0, api_key=openai_api_key) + llm = ChatOpenAI(temperature=0, api_key=openai_api_key, model_name="gpt-3.5-turbo") memory_key = "history" diff --git a/apps/api/src/services/orgs/orgs.py b/apps/api/src/services/orgs/orgs.py index b3735390..3b673dd5 100644 --- a/apps/api/src/services/orgs/orgs.py +++ b/apps/api/src/services/orgs/orgs.py @@ -1,7 +1,20 @@ +import json +import logging from datetime import datetime +from logging import config from typing import Literal from uuid import uuid4 from sqlmodel import Session, select +from src.db.organization_config import ( + AIConfig, + AIEnabledFeatures, + AILimitsSettings, + GeneralConfig, + LimitSettings, + OrgUserConfig, + OrganizationConfig, + OrganizationConfigBase, +) from src.security.rbac.rbac import ( authorization_verify_based_on_roles_and_authorship, authorization_verify_if_user_is_anon, @@ -23,7 +36,7 @@ async def get_organization( org_id: str, db_session: Session, current_user: PublicUser | AnonymousUser, -): +) -> OrganizationRead: statement = select(Organization).where(Organization.id == org_id) result = db_session.exec(statement) @@ -38,7 +51,18 @@ async def get_organization( # RBAC check await rbac_check(request, org.org_uuid, current_user, "read", db_session) - org = OrganizationRead.from_orm(org) + # Get org config + statement = select(OrganizationConfig).where(OrganizationConfig.org_id == org.id) + result = db_session.exec(statement) + + org_config = result.first() + + if org_config is None: + logging.error(f"Organization {org_id} has no config") + + config = OrganizationConfig.from_orm(org_config) if org_config else {} + + org = OrganizationRead(**org.dict(), config=config) return org @@ -48,7 +72,7 @@ async def get_organization_by_slug( org_slug: str, db_session: Session, current_user: PublicUser | AnonymousUser, -): +) -> OrganizationRead: statement = select(Organization).where(Organization.slug == org_slug) result = db_session.exec(statement) @@ -63,7 +87,18 @@ async def get_organization_by_slug( # RBAC check await rbac_check(request, org.org_uuid, current_user, "read", db_session) - org = OrganizationRead.from_orm(org) + # Get org config + statement = select(OrganizationConfig).where(OrganizationConfig.org_id == org.id) + result = db_session.exec(statement) + + org_config = result.first() + + if org_config is None: + logging.error(f"Organization {org_slug} has no config") + + config = OrganizationConfig.from_orm(org_config) if org_config else {} + + org = OrganizationRead(**org.dict(), config=config) return org @@ -87,7 +122,7 @@ async def create_org( org = Organization.from_orm(org_object) - if isinstance(current_user,AnonymousUser): + if isinstance(current_user, AnonymousUser): raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="You should be logged in to be able to achieve this action", @@ -115,7 +150,64 @@ async def create_org( db_session.commit() db_session.refresh(user_org) - return OrganizationRead.from_orm(org) + org_config = OrganizationConfigBase( + GeneralConfig=GeneralConfig( + color="#000000", + limits=LimitSettings( + limits_enabled=False, + max_users=0, + max_storage=0, + max_staff=0, + ), + users=OrgUserConfig( + signup_mechanism="open", + ), + active=True, + ), + AIConfig=AIConfig( + limits=AILimitsSettings( + limits_enabled=False, + max_asks=0, + ), + embeddings="all-MiniLM-L6-v2", + ai_model="gpt-3.5-turbo", + features=AIEnabledFeatures( + editor=False, + activity_ask=False, + course_ask=False, + global_ai_ask=False, + ), + ), + ) + + org_config = json.loads(org_config.json()) + + # OrgSettings + org_settings = OrganizationConfig( + org_id=int(org.id if org.id else 0), + config=org_config, + creation_date=str(datetime.now()), + update_date=str(datetime.now()), + ) + + db_session.add(org_settings) + db_session.commit() + db_session.refresh(org_settings) + + # Get org config + statement = select(OrganizationConfig).where(OrganizationConfig.org_id == org.id) + result = db_session.exec(statement) + + org_config = result.first() + + if org_config is None: + logging.error(f"Organization {org.id} has no config") + + config = OrganizationConfig.from_orm(org_config) + + org = OrganizationRead(**org.dict(), config=config) + + return org async def update_org(