107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
|
|
import json
|
|
import httpx
|
|
from pydantic import BaseModel
|
|
from src.database import Drug, Session
|
|
from src.drug_price_parser import DrugPriceParser, DrugPriceResponse
|
|
from src.config import settings
|
|
from src.services.session_service import session_service
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
|
|
class DrugFull(BaseModel):
|
|
name: str
|
|
dosage: float
|
|
dosage_unit: str
|
|
unit_price: float
|
|
description: str | None = None
|
|
|
|
|
|
async def convert_drug_result(drug: DrugPriceResponse) -> list[DrugFull]:
|
|
base_url = settings.TALESTORM_API_BASE_URL
|
|
api_key = settings.TALESTORM_API_KEY
|
|
|
|
client = httpx.AsyncClient(
|
|
base_url=base_url,
|
|
headers={"X-API-Key": api_key},
|
|
timeout=httpx.Timeout(60.0, connect=10.0) # 30s total timeout, 10s connect timeout
|
|
)
|
|
session_id = await session_service.create_session(agent_id=settings.TALESTORM_DRUG_AGENT_ID)
|
|
drug_json = drug.model_dump_json()
|
|
response = await client.post(
|
|
"/chat/",
|
|
json={"chat_session_id": session_id, "user_message": f"{drug_json}"},
|
|
)
|
|
response_json = response.json()['message']
|
|
response_dict = json.loads(response_json)
|
|
return [DrugFull.model_validate(r) for r in response_dict["result"]]
|
|
|
|
|
|
async def download_drug(drug_name: str):
|
|
with Session() as session:
|
|
drugs = session.query(Drug).filter(Drug.name == drug_name).all()
|
|
if drugs:
|
|
return
|
|
|
|
parser = DrugPriceParser()
|
|
result = parser.get_drug_prices(drug_name)
|
|
drugs = await convert_drug_result(result)
|
|
|
|
await store_drug(drugs)
|
|
|
|
|
|
async def get_drug(drug_name: str) -> list[DrugFull]:
|
|
parser = DrugPriceParser()
|
|
result = parser.get_drug_prices(drug_name)
|
|
drugs = await convert_drug_result(result)
|
|
return drugs
|
|
|
|
|
|
async def store_drug(drugs: list[DrugFull]):
|
|
with Session() as session:
|
|
for drug in drugs:
|
|
try:
|
|
session.add(Drug(
|
|
name=drug.name,
|
|
dosage=drug.dosage,
|
|
dosage_unit=drug.dosage_unit,
|
|
unit_price=drug.unit_price,
|
|
description=drug.description
|
|
))
|
|
session.commit()
|
|
except IntegrityError as e:
|
|
session.rollback()
|
|
pass
|
|
|
|
|
|
async def fetch_drug_with_dosage(drug_name: str, dosage: float) -> DrugFull | None:
|
|
try:
|
|
with Session() as session:
|
|
drug = session.query(Drug).filter(Drug.name == drug_name, Drug.dosage == dosage).first()
|
|
if drug:
|
|
return DrugFull.model_validate(drug)
|
|
|
|
drug = session.query(Drug).filter(Drug.name == drug_name).first()
|
|
if drug:
|
|
return DrugFull.model_validate(drug)
|
|
except:
|
|
pass
|
|
|
|
drugs = await get_drug(drug_name)
|
|
print(f"Drug {drug_name} found {drugs}")
|
|
try:
|
|
await store_drug(drugs)
|
|
except Exception as e:
|
|
print(f"Error storing drug {drug_name}: {e}")
|
|
pass
|
|
|
|
drug = None
|
|
for c_drug in drugs:
|
|
if c_drug.dosage == dosage:
|
|
drug = c_drug
|
|
|
|
if drug:
|
|
return DrugFull.model_validate(drug)
|
|
|
|
raise Exception(f"Drug {drug_name} with dosage {dosage} not found")
|
|
|