diff --git a/alembic/versions/2025_10_03_1608-4324550d9c83_add_user_page.py b/alembic/versions/2025_10_03_1608-4324550d9c83_add_user_page.py deleted file mode 100644 index 2f2ca3d..0000000 --- a/alembic/versions/2025_10_03_1608-4324550d9c83_add_user_page.py +++ /dev/null @@ -1,36 +0,0 @@ -"""add user page - -Revision ID: 4324550d9c83 -Revises: 31359fcda8a7 -Create Date: 2025-10-03 16:08:13.898990 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = '4324550d9c83' -down_revision: Union[str, Sequence[str], None] = '31359fcda8a7' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('user_session', schema=None) as batch_op: - batch_op.add_column(sa.Column('page_id', sa.String(), nullable=True)) - - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('user_session', schema=None) as batch_op: - batch_op.drop_column('page_id') - - # ### end Alembic commands ### diff --git a/src/api/v1/router.py b/src/api/v1/router.py index 09c1122..e145e9c 100644 --- a/src/api/v1/router.py +++ b/src/api/v1/router.py @@ -16,7 +16,6 @@ async def insurance_chat(request: models.InsuranceChatRequest): """Handle insurance chat requests""" try: current_page = None - page_id = None if request.context and request.context.page: page_id = str(request.context.page).lower() current_page = await get_page_description(page_id) @@ -33,7 +32,6 @@ async def insurance_chat(request: models.InsuranceChatRequest): uid=str(request.userId) if request.userId else None, current_page=current_page, application=application, - page_id=page_id ) return models.InsuranceChatResponse( @@ -59,15 +57,8 @@ async def init_chat(request: models.InitializeChatRequest): name = None if request.context and request.context.name: name = request.context.name.first_name - first_visit = True - if request.context: - first_visit = request.context.isFirstVisit - page_id = None - if request.context and request.context.page: - page_id = str(request.context.page).lower() - - result = await chat_service.initialize_chat(str(request.userId), application, name, first_visit, page_id) + result = await chat_service.initialize_chat(str(request.userId), application, name) return models.InitializeChatResponse( session_id=result["session_id"], answer=result["answer"], @@ -78,36 +69,14 @@ async def init_chat(request: models.InitializeChatRequest): async def estimate(request: models.EstimationRequest): """Handle insurance estimation requests""" try: - if not request.applicants: + if not request.applicants or not request.plans: raise HTTPException( status_code=400, detail="Missing required applicants or plans" ) - - print("estimation request: ", request) - - has_primary = False - has_spouse = False - has_dependents = False - for applicant in request.applicants: - if applicant.applicant == 1: - has_primary = True - elif applicant.applicant == 2: - has_spouse = True - elif applicant.applicant == 3: - has_dependents = True - - if has_primary and not has_spouse and not has_dependents: - coverage = 1 - elif has_primary and has_spouse and not has_dependents: - coverage = 2 - elif has_primary and not has_spouse and has_dependents: - coverage = 3 - else: - coverage = 4 - + estimation_service = EstimationService() - estimation_response = await estimation_service.estimate_insurance(request.applicants, request.phq, coverage) + estimation_response = await estimation_service.estimate_insurance(request.applicants, request.phq, request.plans) return estimation_response diff --git a/src/database.py b/src/database.py index 0b6e9ef..8cbb5f1 100644 --- a/src/database.py +++ b/src/database.py @@ -40,7 +40,6 @@ class UserSession(Base): id = Column(BigInteger, primary_key=True, index=True) user_id = Column(String, index=True) session_id = Column(String) - page_id = Column(String, default=None, nullable=True) engine = create_engine(settings.DATABASE_URL) diff --git a/src/models.py b/src/models.py index 3f56a35..d3e8089 100644 --- a/src/models.py +++ b/src/models.py @@ -17,7 +17,7 @@ class Applicant(BaseModel): class Plan(BaseModel): id: int - priceId: int | None = None + priceId: int class Medication(BaseModel): applicant: int @@ -61,6 +61,7 @@ class Address(BaseModel): class EstimationRequest(BaseModel): userId: str | int | None = Field(None, description="Unique identifier") applicants: List[Applicant] + plans: List[Plan] phq: PHQ income: float address: Address @@ -68,8 +69,7 @@ class EstimationRequest(BaseModel): class EstimationDetails(BaseModel): dtq: bool reason: str - # price_id: int = -1 - tier: str + price_id: int class EstimationResult(BaseModel): name: str @@ -83,7 +83,7 @@ class EstimationResult(BaseModel): class EstimationResponse(BaseModel): status: str details: EstimationDetails - # results: List[EstimationResult] + results: List[EstimationResult] class UserNameContext(BaseModel): first_name: str @@ -94,7 +94,6 @@ class InsuranceChatContext(BaseModel): application: dict | None = None applicationDTO: str | None = None name: UserNameContext | None = None - isFirstVisit: bool = True class InsuranceChatRequest(BaseModel): userId: str | int | None = None @@ -125,12 +124,9 @@ class PlansParam(BaseModel): class ApplicantParam(BaseModel): applicants: list[Applicant] -class PageParam(BaseModel): - page: str - class ChatHook(BaseModel): tool: str - params: PlansParam | ApplicantParam | PageParam + params: PlansParam | ApplicantParam class AIChatResponse(BaseModel): answer: str diff --git a/src/services/chat_service.py b/src/services/chat_service.py index 7d83e16..e563ca3 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -4,7 +4,7 @@ from typing import Dict, Any, List, Optional import httpx -from src.models import ApplicantParam, ChatHook, PlansParam, PageParam +from src.models import ApplicantParam, ChatHook, PlansParam from .session_service import session_service from ..api.v1.models import Source, HistoryItem from ..config import settings @@ -86,16 +86,7 @@ class ChatService: "chat_session_id": session_id, "message": f"I'm sorry, I'm experiencing technical difficulties. Please try again later. Error: {str(e)}" } - - async def add_message_to_history(self, session_id: str, message: list): - async with await self.get_client() as client: - response = await client.post( - "/messages/", - params={"chat_session_id": session_id,}, - json={"content": message} - ) - return response.json() - + async def get_chat_history(self, session_id: str) -> List[HistoryItem]: """Get chat history for a session and format it properly""" async with await self.get_client() as client: @@ -154,6 +145,7 @@ class ChatService: return None + def _extract_sources_from_response(self, response_text: str) -> List[Source]: """Extract sources from RAG search results if available""" # This is a placeholder - in a real implementation, you would: @@ -167,10 +159,10 @@ class ChatService: async def get_user_session(self, uid: str) -> str | None: with Session() as session: - user_session = session.query(UserSession).filter(UserSession.user_id == uid).first() - if not user_session: + session = session.query(UserSession).filter(UserSession.user_id == uid).first() + if not session: return None - return user_session.session_id + return session.session_id async def create_user_session(self, uid: str, session_id: str): with Session() as session: @@ -181,68 +173,51 @@ class ChatService: session.add(user_session) session.commit() - async def update_user_page(self, uid: str, page_id: str | None): - with Session() as session: - user_session = session.query(UserSession).filter(UserSession.user_id == uid).first() - old_page_id = user_session.page_id - user_session.page_id = page_id - session.add(user_session) - session.commit() - return old_page_id + async def initialize_chat(self, uid: str, application, name): + session_id = await self.get_user_session(uid) + if not session_id or not await session_service.validate_session(session_id): + session_id = await session_service.create_session(agent_id=settings.TALESTORM_AGENT_ID) + try: + await self.create_user_session(uid, session_id) + except: + pass - async def get_user_page(self, uid: str): - with Session() as session: - user_session = session.query(UserSession).filter(UserSession.user_id == uid).first() - old_page_id = user_session.page_id - return old_page_id - - def get_welcome_message(self, application, name, first_visit): try: if not name: name = application["applicants"][0]["firstName"] except: - return INIT_MESSAGES[0] + return { + "session_id": session_id, + "answer": INIT_MESSAGES[0], + } - if first_visit: - return INIT_MESSAGES[0] + last_message = self.get_last_chat_message_date(session_id) + if not last_message: + return { + "session_id": session_id, + "answer": INIT_MESSAGES[0], + } try: applicant = application["applicants"][0] except: - return INIT_MESSAGES[2].format(name=name) - + return { + "session_id": session_id, + "answer": INIT_MESSAGES[2].format(name=name) + } if not applicant.get("gender") or not applicant.get("dob") or applicant.get("dob") == "-01-" or not applicant.get("weight") or applicant.get("heightFt") is None or applicant.get("heightIn") is None: - return INIT_MESSAGES[2].format(name=name) - - return INIT_MESSAGES[1].format(name=name) - - - async def initialize_chat(self, uid: str, application, name, first_visit, page_id): - session_id = await self.get_user_session(uid) - if not session_id or not await session_service.validate_session(session_id): - session_id = await session_service.create_session(agent_id=settings.TALESTORM_AGENT_ID) - try: - await self.create_user_session(uid, session_id) - except: - pass - # await self.update_user_page(uid, page_id) - - welcome_msg = self.get_welcome_message(application, name, first_visit) - - msg_history_item = [ - {'parts': [{'content': 'Hi', 'part_kind': 'user-prompt'}],'kind': 'request'}, - {'parts': [{'content': welcome_msg, 'part_kind': 'text'}], 'kind': 'response'} - ] - - await self.add_message_to_history(session_id, msg_history_item) + return { + "session_id": session_id, + "answer": INIT_MESSAGES[2].format(name=name) + } return { "session_id": session_id, - "answer": welcome_msg + "answer": INIT_MESSAGES[1].format(name=name) } - async def process_insurance_chat(self, message: str, session_id: Optional[str] = None, uid: Optional[str] = None, current_page: Optional[str] = None, application: Optional[dict] = None, page_id = None) -> Dict[str, Any]: + async def process_insurance_chat(self, message: str, session_id: Optional[str] = None, uid: Optional[str] = None, current_page: Optional[str] = None, application: Optional[dict] = None) -> Dict[str, Any]: """Process an insurance chat request""" try: if not session_id or not await session_service.validate_session(session_id): @@ -250,17 +225,13 @@ class ChatService: session_id = await self.get_user_session(uid) if not session_id or not await session_service.validate_session(session_id): session_id = await session_service.create_session(agent_id=settings.TALESTORM_AGENT_ID) - try: - await self.create_user_session(uid, session_id) - except: - pass + try: + await self.create_user_session(uid, session_id) + except: + pass else: session_id = await session_service.create_session(agent_id=settings.TALESTORM_AGENT_ID) - if page_id.lower() != 'welcome': - old_page_id = await self.update_user_page(uid, page_id) - else: - old_page_id = await self.get_user_page(uid) instructions = "" if uid: @@ -295,7 +266,6 @@ class ChatService: compare_plans = ai_response.get("compare_plans") show_plans = ai_response.get("show_plans") update_applicants = ai_response.get("update_applicants") - show_page = ai_response.get("show_page") hooks = [] if update_applicants: hooks.append(ChatHook( @@ -314,13 +284,7 @@ class ChatService: tool="show_plans", params=PlansParam(plans=show_plans) )) - elif show_page and old_page_id: - hooks.append(ChatHook( - tool="show_page", - params=PageParam(page=old_page_id) - )) - - + return { "session_id": session_id, "answer": ai_message, diff --git a/src/services/estimation_service_v2.py b/src/services/estimation_service_v2.py index a9f3552..553c85f 100644 --- a/src/services/estimation_service_v2.py +++ b/src/services/estimation_service_v2.py @@ -458,7 +458,7 @@ class EstimationService: return None - async def estimate_insurance(self, applicants: list[Applicant], phq: PHQ, plan_coverage: int): + async def estimate_insurance(self, applicants: list[Applicant], phq: PHQ, plans: list[Plan]): estimation_results = [] is_review = False review_reasons = [] @@ -478,6 +478,7 @@ class EstimationService: base_tier = tier break + plan_coverage = self.get_plan_coverage(plans[0]) rx_spend = 0 for applicant_id, applicant in enumerate(applicants): applicant_review_reasons = [] @@ -558,8 +559,8 @@ class EstimationService: message=final_reason, ) ) - - # plan_price_id = self.get_plan_price(plans[0], base_tier, plan_coverage) + + plan_price_id = self.get_plan_price(plans[0], base_tier, plan_coverage) if is_dtq: reason = "\n".join(dtq_reasons) @@ -568,10 +569,9 @@ class EstimationService: details=EstimationDetails( dtq=is_dtq, reason=reason, - # price_id=plan_price_id, - tier=f"Tier {base_tier.value}", + price_id=plan_price_id, ), - # results=estimation_results + results=estimation_results ) if is_review: @@ -581,10 +581,9 @@ class EstimationService: details=EstimationDetails( dtq=is_dtq, reason=reason, - # price_id=plan_price_id, - tier=f"Tier {base_tier.value}", + price_id=plan_price_id, ), - # results=estimation_results + results=estimation_results ) new_tier, tier_reason = self.get_tier(plan_coverage, rx_spend) @@ -597,16 +596,15 @@ class EstimationService: details=EstimationDetails( dtq=True, reason=reason, - # price_id=plan_price_id, - tier=f"Tier {base_tier.value}", + price_id=plan_price_id, ), - # results=estimation_results + results=estimation_results ) if new_tier > base_tier: base_tier = new_tier - # plan_price_id = self.get_plan_price(plans[0], base_tier, plan_coverage) + plan_price_id = self.get_plan_price(plans[0], base_tier, plan_coverage) if base_tier is not None: reason = "\n".join(accept_reasons) @@ -615,10 +613,9 @@ class EstimationService: details=EstimationDetails( dtq=is_dtq, reason=reason, - # price_id=plan_price_id, - tier=f"Tier {base_tier.value}", + price_id=plan_price_id, ), - # results=estimation_results + results=estimation_results ) else: reason = "\n".join(dtq_reasons) @@ -627,8 +624,8 @@ class EstimationService: details=EstimationDetails( dtq=is_dtq, reason=reason, - # price_id=plan_price_id, - tier=f"Tier {base_tier.value}", + price_id=plan_price_id, ), - # results=estimation_results + results=estimation_results ) + \ No newline at end of file