feat: init organizationconfig

This commit is contained in:
swve 2024-01-13 15:25:04 +01:00
parent 982ba037f5
commit de93d56945
6 changed files with 178 additions and 37 deletions

View file

@ -0,0 +1,68 @@
from json import JSONEncoder
import json
from typing import Literal, Optional
from click import Option
from pydantic import BaseModel
from sqlalchemy import JSON, BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel
# AI
class AILimitsSettings(BaseModel):
limits_enabled: bool = False
max_asks: int = 0
class AIEnabledFeatures(BaseModel):
editor: bool = False
activity_ask: bool = False
course_ask: bool = False
global_ai_ask: bool = False
class AIConfig(BaseModel):
limits: AILimitsSettings = AILimitsSettings()
embeddings: Literal[
"text-embedding-ada-002", "all-MiniLM-L6-v2"
] = "all-MiniLM-L6-v2"
ai_model: Literal["gpt-3.5-turbo", "gpt-4-1106-preview"] = "gpt-3.5-turbo"
features: AIEnabledFeatures = AIEnabledFeatures()
class OrgUserConfig(BaseModel):
signup_mechanism: Literal["open", "inviteOnly"] = "open"
# Limits
class LimitSettings(BaseModel):
limits_enabled: bool = False
max_users: int = 0
max_storage: int = 0
max_staff: int = 0
# General
class GeneralConfig(BaseModel):
color: str = ""
limits: LimitSettings = LimitSettings()
users: OrgUserConfig = OrgUserConfig()
active: bool = True
class OrganizationConfigBase(SQLModel):
GeneralConfig: GeneralConfig
AIConfig: AIConfig
class OrganizationConfig(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
org_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))
)
# TODO: fix this to use the correct type GeneralConfig
config: dict = Field(default={}, sa_column=Column(JSON))
creation_date: Optional[str]
update_date: Optional[str]

View file

@ -1,21 +0,0 @@
from typing import Optional
from sqlalchemy import BigInteger, Column, ForeignKey
from sqlmodel import Field, SQLModel
from enum import Enum
class HeaderTypeEnum(str, Enum):
LOGO_MENU_SETTINGS = "LOGO_MENU_SETTINGS"
MENU_LOGO_SETTINGS = "MENU_LOGO_SETTINGS"
class OrganizationSettings(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
org_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))
)
logo_image: Optional[str] = ""
header_type: HeaderTypeEnum = HeaderTypeEnum.LOGO_MENU_SETTINGS
color: str = ""
creation_date: str
update_date: str

View file

@ -1,5 +1,6 @@
from typing import Optional
from sqlmodel import Field, SQLModel
from src.db.organization_config import OrganizationConfig
class OrganizationBase(SQLModel):
@ -16,9 +17,11 @@ class Organization(OrganizationBase, table=True):
creation_date: str = ""
update_date: str = ""
class OrganizationUpdate(OrganizationBase):
pass
class OrganizationCreate(OrganizationBase):
pass
@ -26,5 +29,6 @@ class OrganizationCreate(OrganizationBase):
class OrganizationRead(OrganizationBase):
id: int
org_uuid: str
config: OrganizationConfig | dict
creation_date: str
update_date: str

View file

@ -7,7 +7,6 @@ from src.core.events.database import get_db_session
from src.db.users import PublicUser
from src.db.activities import Activity, ActivityRead
from src.security.auth import get_current_user
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from src.services.ai.base import ask_ai, get_chat_session_history
from src.services.ai.schemas.ai import (

View file

@ -3,16 +3,15 @@ from uuid import uuid4
from langchain.agents import AgentExecutor
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain_core.messages import BaseMessage
from langchain_community.vectorstores import Chroma
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.prompts import MessagesPlaceholder
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain_community.chat_message_histories import RedisChatMessageHistory
from langchain_core.messages import SystemMessage
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI
from langchain.agents.agent_toolkits import (
create_retriever_tool,
)
@ -54,7 +53,7 @@ def ask_ai(
)
tools = [tool]
llm = ChatOpenAI(temperature=0, api_key=openai_api_key)
llm = ChatOpenAI(temperature=0, api_key=openai_api_key, model_name="gpt-3.5-turbo")
memory_key = "history"

View file

@ -1,7 +1,20 @@
import json
import logging
from datetime import datetime
from logging import config
from typing import Literal
from uuid import uuid4
from sqlmodel import Session, select
from src.db.organization_config import (
AIConfig,
AIEnabledFeatures,
AILimitsSettings,
GeneralConfig,
LimitSettings,
OrgUserConfig,
OrganizationConfig,
OrganizationConfigBase,
)
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
@ -23,7 +36,7 @@ async def get_organization(
org_id: str,
db_session: Session,
current_user: PublicUser | AnonymousUser,
):
) -> OrganizationRead:
statement = select(Organization).where(Organization.id == org_id)
result = db_session.exec(statement)
@ -38,7 +51,18 @@ async def get_organization(
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
org = OrganizationRead.from_orm(org)
# Get org config
statement = select(OrganizationConfig).where(OrganizationConfig.org_id == org.id)
result = db_session.exec(statement)
org_config = result.first()
if org_config is None:
logging.error(f"Organization {org_id} has no config")
config = OrganizationConfig.from_orm(org_config) if org_config else {}
org = OrganizationRead(**org.dict(), config=config)
return org
@ -48,7 +72,7 @@ async def get_organization_by_slug(
org_slug: str,
db_session: Session,
current_user: PublicUser | AnonymousUser,
):
) -> OrganizationRead:
statement = select(Organization).where(Organization.slug == org_slug)
result = db_session.exec(statement)
@ -63,7 +87,18 @@ async def get_organization_by_slug(
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
org = OrganizationRead.from_orm(org)
# Get org config
statement = select(OrganizationConfig).where(OrganizationConfig.org_id == org.id)
result = db_session.exec(statement)
org_config = result.first()
if org_config is None:
logging.error(f"Organization {org_slug} has no config")
config = OrganizationConfig.from_orm(org_config) if org_config else {}
org = OrganizationRead(**org.dict(), config=config)
return org
@ -115,7 +150,64 @@ async def create_org(
db_session.commit()
db_session.refresh(user_org)
return OrganizationRead.from_orm(org)
org_config = OrganizationConfigBase(
GeneralConfig=GeneralConfig(
color="#000000",
limits=LimitSettings(
limits_enabled=False,
max_users=0,
max_storage=0,
max_staff=0,
),
users=OrgUserConfig(
signup_mechanism="open",
),
active=True,
),
AIConfig=AIConfig(
limits=AILimitsSettings(
limits_enabled=False,
max_asks=0,
),
embeddings="all-MiniLM-L6-v2",
ai_model="gpt-3.5-turbo",
features=AIEnabledFeatures(
editor=False,
activity_ask=False,
course_ask=False,
global_ai_ask=False,
),
),
)
org_config = json.loads(org_config.json())
# OrgSettings
org_settings = OrganizationConfig(
org_id=int(org.id if org.id else 0),
config=org_config,
creation_date=str(datetime.now()),
update_date=str(datetime.now()),
)
db_session.add(org_settings)
db_session.commit()
db_session.refresh(org_settings)
# Get org config
statement = select(OrganizationConfig).where(OrganizationConfig.org_id == org.id)
result = db_session.exec(statement)
org_config = result.first()
if org_config is None:
logging.error(f"Organization {org.id} has no config")
config = OrganizationConfig.from_orm(org_config)
org = OrganizationRead(**org.dict(), config=config)
return org
async def update_org(