| """Gymnasium-compatible wrapper around ``BioExperimentEnvironment``.
|
|
|
| Provides ``BioExperimentGymEnv`` which wraps the OpenEnv environment for
|
| local in-process RL training (no HTTP/WebSocket overhead).
|
|
|
| Observation and action spaces are represented as ``gymnasium.spaces.Dict``
|
| so that standard RL libraries (SB3, CleanRL, etc.) can ingest them.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| from typing import Any, Dict, Optional, Tuple
|
|
|
| import gymnasium as gym
|
| import numpy as np
|
| from gymnasium import spaces
|
|
|
| from models import ActionType, ExperimentAction, ExperimentObservation
|
| from server.hackathon_environment import BioExperimentEnvironment, MAX_STEPS
|
|
|
|
|
| ACTION_TYPE_LIST = list(ActionType)
|
| _N_ACTION_TYPES = len(ACTION_TYPE_LIST)
|
|
|
| _MAX_OUTPUTS = MAX_STEPS
|
| _MAX_HISTORY = MAX_STEPS
|
| _VEC_DIM = 64
|
|
|
|
|
| class BioExperimentGymEnv(gym.Env):
|
| """Gymnasium ``Env`` backed by the in-process simulator.
|
|
|
| Observations are flattened into a dictionary of NumPy arrays suitable
|
| for RL policy networks. Actions are integer-indexed action types with
|
| a continuous confidence scalar.
|
|
|
| For LLM-based agents or planners that prefer structured
|
| ``ExperimentAction`` objects, use the underlying
|
| ``BioExperimentEnvironment`` directly instead.
|
| """
|
|
|
| metadata = {"render_modes": ["human"]}
|
|
|
| def __init__(self, render_mode: Optional[str] = None):
|
| super().__init__()
|
| self._env = BioExperimentEnvironment()
|
| self.render_mode = render_mode
|
|
|
| self.action_space = spaces.Dict({
|
| "action_type": spaces.Discrete(_N_ACTION_TYPES),
|
| "confidence": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| })
|
|
|
| self.observation_space = spaces.Dict({
|
| "step_index": spaces.Discrete(MAX_STEPS + 1),
|
| "budget_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| "time_remaining_frac": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| "progress_flags": spaces.MultiBinary(18),
|
| "latest_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| "latest_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| "avg_quality": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| "avg_uncertainty": spaces.Box(0.0, 1.0, shape=(), dtype=np.float32),
|
| "n_violations": spaces.Discrete(20),
|
| "n_outputs": spaces.Discrete(_MAX_OUTPUTS + 1),
|
| "cumulative_reward": spaces.Box(-100.0, 100.0, shape=(), dtype=np.float32),
|
| })
|
|
|
| self._last_obs: Optional[ExperimentObservation] = None
|
|
|
|
|
|
|
| def reset(
|
| self,
|
| *,
|
| seed: Optional[int] = None,
|
| options: Optional[Dict[str, Any]] = None,
|
| ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| super().reset(seed=seed)
|
| obs = self._env.reset()
|
| self._last_obs = obs
|
| return self._vectorise(obs), self._info(obs)
|
|
|
| def step(
|
| self, action: Dict[str, Any]
|
| ) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
|
| action_idx = int(action["action_type"])
|
| confidence = float(action.get("confidence", 0.5))
|
|
|
| experiment_action = ExperimentAction(
|
| action_type=ACTION_TYPE_LIST[action_idx],
|
| confidence=confidence,
|
| )
|
| obs = self._env.step(experiment_action)
|
| self._last_obs = obs
|
|
|
| terminated = obs.done
|
| truncated = obs.step_index >= MAX_STEPS and not terminated
|
| reward = obs.reward
|
|
|
| return (
|
| self._vectorise(obs),
|
| reward,
|
| terminated,
|
| truncated,
|
| self._info(obs),
|
| )
|
|
|
| def render(self) -> Optional[str]:
|
| if self.render_mode != "human" or self._last_obs is None:
|
| return None
|
| obs = self._last_obs
|
| lines = [
|
| f"Step {obs.step_index}",
|
| f" Task: {obs.task.problem_statement[:80]}",
|
| f" Budget: ${obs.resource_usage.budget_remaining:,.0f} remaining",
|
| f" Time: {obs.resource_usage.time_remaining_days:.0f} days remaining",
|
| ]
|
| if obs.latest_output:
|
| lines.append(f" Latest: {obs.latest_output.summary}")
|
| if obs.rule_violations:
|
| lines.append(f" Violations: {obs.rule_violations}")
|
| text = "\n".join(lines)
|
| print(text)
|
| return text
|
|
|
|
|
|
|
| def _vectorise(self, obs: ExperimentObservation) -> Dict[str, Any]:
|
| progress = self._env._latent.progress if self._env._latent else None
|
| flags = np.zeros(18, dtype=np.int8)
|
| if progress:
|
| flag_names = [
|
| "samples_collected", "cohort_selected", "cells_cultured",
|
| "library_prepared", "perturbation_applied", "cells_sequenced",
|
| "qc_performed", "data_filtered", "data_normalized",
|
| "batches_integrated", "cells_clustered", "de_performed",
|
| "trajectories_inferred", "pathways_analyzed",
|
| "networks_inferred", "markers_discovered",
|
| "markers_validated", "conclusion_reached",
|
| ]
|
| for i, f in enumerate(flag_names):
|
| flags[i] = int(getattr(progress, f, False))
|
|
|
| unc = obs.uncertainty_summary
|
| lo = obs.latest_output
|
|
|
| return {
|
| "step_index": obs.step_index,
|
| "budget_remaining_frac": np.float32(
|
| obs.resource_usage.budget_remaining
|
| / max(obs.task.budget_limit, 1)
|
| ),
|
| "time_remaining_frac": np.float32(
|
| obs.resource_usage.time_remaining_days
|
| / max(obs.task.time_limit_days, 1)
|
| ),
|
| "progress_flags": flags,
|
| "latest_quality": np.float32(lo.quality_score if lo else 0.0),
|
| "latest_uncertainty": np.float32(lo.uncertainty if lo else 0.0),
|
| "avg_quality": np.float32(unc.get("avg_quality", 0.0)),
|
| "avg_uncertainty": np.float32(unc.get("avg_uncertainty", 0.0)),
|
| "n_violations": min(len(obs.rule_violations), 19),
|
| "n_outputs": min(len(obs.all_outputs), _MAX_OUTPUTS),
|
| "cumulative_reward": np.float32(
|
| obs.metadata.get("cumulative_reward", 0.0)
|
| if obs.metadata else 0.0
|
| ),
|
| }
|
|
|
| def _info(self, obs: ExperimentObservation) -> Dict[str, Any]:
|
| return {
|
| "structured_obs": obs,
|
| "episode_id": obs.metadata.get("episode_id") if obs.metadata else None,
|
| }
|
|
|