GamedAreS's picture
Added necessary files
82168a6
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Sepsis ICU RL Environment.
Defines the action and observation spaces for a 72-hour ICU sepsis simulation.
Reward shaping follows Surviving Sepsis Campaign (SSC) guidelines.
"""
from typing import List, Optional
from openenv.core.env_server.types import Action, Observation
from pydantic import Field
class SepsisAction(Action):
"""
Treatment decision for one hourly ICU timestep.
Each dimension is a discrete integer 0-3:
fluid_bolus : 0=none | 1=250 mL | 2=500 mL | 3=1000 mL
vasopressor : 0=none | 1=low-NE | 2=high-NE | 3=vasopressin add-on
antibiotic : 0=none | 1=broad-spectrum start | 2=de-escalate | 3=escalate
ventilation : 0=none | 1=O2 supplementation | 2=NIV | 3=intubate
"""
fluid_bolus: int = Field(..., ge=0, le=3, description="IV fluid bolus level (0-3)")
vasopressor: int = Field(..., ge=0, le=3, description="Vasopressor dose level (0-3)")
antibiotic: int = Field(..., ge=0, le=3, description="Antibiotic decision (0-3)")
ventilation: int = Field(..., ge=0, le=3, description="Ventilation support level (0-3)")
class SepsisObservation(Observation):
"""
Clinical snapshot returned after each environment step.
state_data is a 15-dim normalised vector (all values in [0, 1]):
[0] heart_rate [1] systolic_bp [2] diastolic_bp
[3] respiratory_rate [4] spo2 [5] temperature
[6] lactate [7] wbc [8] creatinine
[9] bilirubin [10] platelets [11] fio2
[12] sofa_score [13] shock_index [14] antibiotic_active
"""
# Core observation vector
state_data: List[float] = Field(
default_factory=list,
description="15-dim normalised vector of vitals and labs",
)
# Scalar clinical indicators
sofa_score: float = Field(default=0.0, description="SOFA score 0-24 (raw)")
in_septic_shock: bool = Field(
default=False, description="True if MAP < 65 and vasopressor required"
)
# Episode metadata
step_count: int = Field(default=0, description="Current hour within the episode")
clinical_summary: Optional[str] = Field(
default=None, description="Human-readable summary of current clinical state"
)
# Cumulative treatment accumulators
cumulative_fluid_mL: float = Field(
default=0.0, description="Total IV fluid given this episode (mL)"
)
cumulative_vasopressor_units: float = Field(
default=0.0, description="Cumulative vasopressor dose (µg/kg/min · hours)"
)
# Safety flags
fluid_overload_risk: bool = Field(
default=False, description="True if cumulative fluid > 6 L and RR elevated"
)
aki_risk: bool = Field(
default=False, description="True if acute kidney injury criteria met"
)
ards_risk: bool = Field(
default=False, description="True if ARDS criteria met alongside fluid overload"
)