diff --git a/alembic/versions/2025_10_03_1608-4324550d9c83_add_user_page.py b/alembic/versions/2025_10_03_1608-4324550d9c83_add_user_page.py new file mode 100644 index 0000000..2f2ca3d --- /dev/null +++ b/alembic/versions/2025_10_03_1608-4324550d9c83_add_user_page.py @@ -0,0 +1,36 @@ +"""add user page + +Revision ID: 4324550d9c83 +Revises: 31359fcda8a7 +Create Date: 2025-10-03 16:08:13.898990 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4324550d9c83' +down_revision: Union[str, Sequence[str], None] = '31359fcda8a7' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_session', schema=None) as batch_op: + batch_op.add_column(sa.Column('page_id', sa.String(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('user_session', schema=None) as batch_op: + batch_op.drop_column('page_id') + + # ### end Alembic commands ### diff --git a/src/api/v1/router.py b/src/api/v1/router.py index e145e9c..09c1122 100644 --- a/src/api/v1/router.py +++ b/src/api/v1/router.py @@ -16,6 +16,7 @@ async def insurance_chat(request: models.InsuranceChatRequest): """Handle insurance chat requests""" try: current_page = None + page_id = None if request.context and request.context.page: page_id = str(request.context.page).lower() current_page = await get_page_description(page_id) @@ -32,6 +33,7 @@ async def insurance_chat(request: models.InsuranceChatRequest): uid=str(request.userId) if request.userId else None, current_page=current_page, application=application, + page_id=page_id ) return models.InsuranceChatResponse( @@ -57,8 +59,15 @@ async def init_chat(request: models.InitializeChatRequest): name = None if request.context and request.context.name: name = request.context.name.first_name + first_visit = True + if request.context: + first_visit = request.context.isFirstVisit - result = await chat_service.initialize_chat(str(request.userId), application, name) + page_id = None + if request.context and request.context.page: + page_id = str(request.context.page).lower() + + result = await chat_service.initialize_chat(str(request.userId), application, name, first_visit, page_id) return models.InitializeChatResponse( session_id=result["session_id"], answer=result["answer"], @@ -69,14 +78,36 @@ async def init_chat(request: models.InitializeChatRequest): async def estimate(request: models.EstimationRequest): """Handle insurance estimation requests""" try: - if not request.applicants or not request.plans: + if not request.applicants: raise HTTPException( status_code=400, detail="Missing required applicants or plans" ) - + + print("estimation request: ", request) + + has_primary = False + has_spouse = False + has_dependents = False + for applicant in request.applicants: + if applicant.applicant == 1: + has_primary = True + elif applicant.applicant == 2: + has_spouse = True + elif applicant.applicant == 3: + has_dependents = True + + if has_primary and not has_spouse and not has_dependents: + coverage = 1 + elif has_primary and has_spouse and not has_dependents: + coverage = 2 + elif has_primary and not has_spouse and has_dependents: + coverage = 3 + else: + coverage = 4 + estimation_service = EstimationService() - estimation_response = await estimation_service.estimate_insurance(request.applicants, request.phq, request.plans) + estimation_response = await estimation_service.estimate_insurance(request.applicants, request.phq, coverage) return estimation_response diff --git a/src/database.py b/src/database.py index 8cbb5f1..0b6e9ef 100644 --- a/src/database.py +++ b/src/database.py @@ -40,6 +40,7 @@ class UserSession(Base): id = Column(BigInteger, primary_key=True, index=True) user_id = Column(String, index=True) session_id = Column(String) + page_id = Column(String, default=None, nullable=True) engine = create_engine(settings.DATABASE_URL) diff --git a/src/models.py b/src/models.py index d3e8089..3f56a35 100644 --- a/src/models.py +++ b/src/models.py @@ -17,7 +17,7 @@ class Applicant(BaseModel): class Plan(BaseModel): id: int - priceId: int + priceId: int | None = None class Medication(BaseModel): applicant: int @@ -61,7 +61,6 @@ class Address(BaseModel): class EstimationRequest(BaseModel): userId: str | int | None = Field(None, description="Unique identifier") applicants: List[Applicant] - plans: List[Plan] phq: PHQ income: float address: Address @@ -69,7 +68,8 @@ class EstimationRequest(BaseModel): class EstimationDetails(BaseModel): dtq: bool reason: str - price_id: int + # price_id: int = -1 + tier: str class EstimationResult(BaseModel): name: str @@ -83,7 +83,7 @@ class EstimationResult(BaseModel): class EstimationResponse(BaseModel): status: str details: EstimationDetails - results: List[EstimationResult] + # results: List[EstimationResult] class UserNameContext(BaseModel): first_name: str @@ -94,6 +94,7 @@ class InsuranceChatContext(BaseModel): application: dict | None = None applicationDTO: str | None = None name: UserNameContext | None = None + isFirstVisit: bool = True class InsuranceChatRequest(BaseModel): userId: str | int | None = None @@ -124,9 +125,12 @@ class PlansParam(BaseModel): class ApplicantParam(BaseModel): applicants: list[Applicant] +class PageParam(BaseModel): + page: str + class ChatHook(BaseModel): tool: str - params: PlansParam | ApplicantParam + params: PlansParam | ApplicantParam | PageParam class AIChatResponse(BaseModel): answer: str diff --git a/src/services/chat_service.py b/src/services/chat_service.py index e563ca3..7d83e16 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -4,7 +4,7 @@ from typing import Dict, Any, List, Optional import httpx -from src.models import ApplicantParam, ChatHook, PlansParam +from src.models import ApplicantParam, ChatHook, PlansParam, PageParam from .session_service import session_service from ..api.v1.models import Source, HistoryItem from ..config import settings @@ -86,7 +86,16 @@ class ChatService: "chat_session_id": session_id, "message": f"I'm sorry, I'm experiencing technical difficulties. Please try again later. Error: {str(e)}" } - + + async def add_message_to_history(self, session_id: str, message: list): + async with await self.get_client() as client: + response = await client.post( + "/messages/", + params={"chat_session_id": session_id,}, + json={"content": message} + ) + return response.json() + async def get_chat_history(self, session_id: str) -> List[HistoryItem]: """Get chat history for a session and format it properly""" async with await self.get_client() as client: @@ -145,7 +154,6 @@ class ChatService: return None - def _extract_sources_from_response(self, response_text: str) -> List[Source]: """Extract sources from RAG search results if available""" # This is a placeholder - in a real implementation, you would: @@ -159,10 +167,10 @@ class ChatService: async def get_user_session(self, uid: str) -> str | None: with Session() as session: - session = session.query(UserSession).filter(UserSession.user_id == uid).first() - if not session: + user_session = session.query(UserSession).filter(UserSession.user_id == uid).first() + if not user_session: return None - return session.session_id + return user_session.session_id async def create_user_session(self, uid: str, session_id: str): with Session() as session: @@ -173,51 +181,68 @@ class ChatService: session.add(user_session) session.commit() - async def initialize_chat(self, uid: str, application, name): - session_id = await self.get_user_session(uid) - if not session_id or not await session_service.validate_session(session_id): - session_id = await session_service.create_session(agent_id=settings.TALESTORM_AGENT_ID) - try: - await self.create_user_session(uid, session_id) - except: - pass + async def update_user_page(self, uid: str, page_id: str | None): + with Session() as session: + user_session = session.query(UserSession).filter(UserSession.user_id == uid).first() + old_page_id = user_session.page_id + user_session.page_id = page_id + session.add(user_session) + session.commit() + return old_page_id + async def get_user_page(self, uid: str): + with Session() as session: + user_session = session.query(UserSession).filter(UserSession.user_id == uid).first() + old_page_id = user_session.page_id + return old_page_id + + def get_welcome_message(self, application, name, first_visit): try: if not name: name = application["applicants"][0]["firstName"] except: - return { - "session_id": session_id, - "answer": INIT_MESSAGES[0], - } + return INIT_MESSAGES[0] - last_message = self.get_last_chat_message_date(session_id) - if not last_message: - return { - "session_id": session_id, - "answer": INIT_MESSAGES[0], - } + if first_visit: + return INIT_MESSAGES[0] try: applicant = application["applicants"][0] except: - return { - "session_id": session_id, - "answer": INIT_MESSAGES[2].format(name=name) - } + return INIT_MESSAGES[2].format(name=name) + if not applicant.get("gender") or not applicant.get("dob") or applicant.get("dob") == "-01-" or not applicant.get("weight") or applicant.get("heightFt") is None or applicant.get("heightIn") is None: - return { - "session_id": session_id, - "answer": INIT_MESSAGES[2].format(name=name) - } + return INIT_MESSAGES[2].format(name=name) + + return INIT_MESSAGES[1].format(name=name) + + + async def initialize_chat(self, uid: str, application, name, first_visit, page_id): + session_id = await self.get_user_session(uid) + if not session_id or not await session_service.validate_session(session_id): + session_id = await session_service.create_session(agent_id=settings.TALESTORM_AGENT_ID) + try: + await self.create_user_session(uid, session_id) + except: + pass + # await self.update_user_page(uid, page_id) + + welcome_msg = self.get_welcome_message(application, name, first_visit) + + msg_history_item = [ + {'parts': [{'content': 'Hi', 'part_kind': 'user-prompt'}],'kind': 'request'}, + {'parts': [{'content': welcome_msg, 'part_kind': 'text'}], 'kind': 'response'} + ] + + await self.add_message_to_history(session_id, msg_history_item) return { "session_id": session_id, - "answer": INIT_MESSAGES[1].format(name=name) + "answer": welcome_msg } - async def process_insurance_chat(self, message: str, session_id: Optional[str] = None, uid: Optional[str] = None, current_page: Optional[str] = None, application: Optional[dict] = None) -> Dict[str, Any]: + async def process_insurance_chat(self, message: str, session_id: Optional[str] = None, uid: Optional[str] = None, current_page: Optional[str] = None, application: Optional[dict] = None, page_id = None) -> Dict[str, Any]: """Process an insurance chat request""" try: if not session_id or not await session_service.validate_session(session_id): @@ -225,13 +250,17 @@ class ChatService: session_id = await self.get_user_session(uid) if not session_id or not await session_service.validate_session(session_id): session_id = await session_service.create_session(agent_id=settings.TALESTORM_AGENT_ID) - try: - await self.create_user_session(uid, session_id) - except: - pass + try: + await self.create_user_session(uid, session_id) + except: + pass else: session_id = await session_service.create_session(agent_id=settings.TALESTORM_AGENT_ID) + if page_id.lower() != 'welcome': + old_page_id = await self.update_user_page(uid, page_id) + else: + old_page_id = await self.get_user_page(uid) instructions = "" if uid: @@ -266,6 +295,7 @@ class ChatService: compare_plans = ai_response.get("compare_plans") show_plans = ai_response.get("show_plans") update_applicants = ai_response.get("update_applicants") + show_page = ai_response.get("show_page") hooks = [] if update_applicants: hooks.append(ChatHook( @@ -284,7 +314,13 @@ class ChatService: tool="show_plans", params=PlansParam(plans=show_plans) )) - + elif show_page and old_page_id: + hooks.append(ChatHook( + tool="show_page", + params=PageParam(page=old_page_id) + )) + + return { "session_id": session_id, "answer": ai_message, diff --git a/src/services/estimation_service_v2.py b/src/services/estimation_service_v2.py index 553c85f..a9f3552 100644 --- a/src/services/estimation_service_v2.py +++ b/src/services/estimation_service_v2.py @@ -458,7 +458,7 @@ class EstimationService: return None - async def estimate_insurance(self, applicants: list[Applicant], phq: PHQ, plans: list[Plan]): + async def estimate_insurance(self, applicants: list[Applicant], phq: PHQ, plan_coverage: int): estimation_results = [] is_review = False review_reasons = [] @@ -478,7 +478,6 @@ class EstimationService: base_tier = tier break - plan_coverage = self.get_plan_coverage(plans[0]) rx_spend = 0 for applicant_id, applicant in enumerate(applicants): applicant_review_reasons = [] @@ -559,8 +558,8 @@ class EstimationService: message=final_reason, ) ) - - plan_price_id = self.get_plan_price(plans[0], base_tier, plan_coverage) + + # plan_price_id = self.get_plan_price(plans[0], base_tier, plan_coverage) if is_dtq: reason = "\n".join(dtq_reasons) @@ -569,9 +568,10 @@ class EstimationService: details=EstimationDetails( dtq=is_dtq, reason=reason, - price_id=plan_price_id, + # price_id=plan_price_id, + tier=f"Tier {base_tier.value}", ), - results=estimation_results + # results=estimation_results ) if is_review: @@ -581,9 +581,10 @@ class EstimationService: details=EstimationDetails( dtq=is_dtq, reason=reason, - price_id=plan_price_id, + # price_id=plan_price_id, + tier=f"Tier {base_tier.value}", ), - results=estimation_results + # results=estimation_results ) new_tier, tier_reason = self.get_tier(plan_coverage, rx_spend) @@ -596,15 +597,16 @@ class EstimationService: details=EstimationDetails( dtq=True, reason=reason, - price_id=plan_price_id, + # price_id=plan_price_id, + tier=f"Tier {base_tier.value}", ), - results=estimation_results + # results=estimation_results ) if new_tier > base_tier: base_tier = new_tier - plan_price_id = self.get_plan_price(plans[0], base_tier, plan_coverage) + # plan_price_id = self.get_plan_price(plans[0], base_tier, plan_coverage) if base_tier is not None: reason = "\n".join(accept_reasons) @@ -613,9 +615,10 @@ class EstimationService: details=EstimationDetails( dtq=is_dtq, reason=reason, - price_id=plan_price_id, + # price_id=plan_price_id, + tier=f"Tier {base_tier.value}", ), - results=estimation_results + # results=estimation_results ) else: reason = "\n".join(dtq_reasons) @@ -624,8 +627,8 @@ class EstimationService: details=EstimationDetails( dtq=is_dtq, reason=reason, - price_id=plan_price_id, + # price_id=plan_price_id, + tier=f"Tier {base_tier.value}", ), - results=estimation_results + # results=estimation_results ) - \ No newline at end of file