synth-net / src /services /search_service.py
github-actions
Sync from GitHub (CI)
6ca4b94
raw
history blame
7.19 kB
import httpx
import logging
import polars as pl
import numpy as np
from fastapi.applications import FastAPI
from typing import Literal, Optional
from sentence_transformers import util
from sklearn.metrics.pairwise import cosine_similarity
logger = logging.getLogger(__name__)
from src.config import config
from src.utils.logging import context_logger
async def encode(texts: list[str], mode: Literal["item", "scale"] = "item"):
async with httpx.AsyncClient() as client:
response = await client.post(
"http://localhost:8001/encode",
json={"texts": texts, "mode": mode},
timeout=30.0
)
response.raise_for_status()
result = np.array(response.json()['embeddings'])
return result
def align_embeddings(item_embeddings, keying):
item_embeddings_positive = item_embeddings[[x == "positive" for x in keying]]
item_embeddings_negative = item_embeddings[[x == "negative" for x in keying]]
if item_embeddings_positive.size == 0 or item_embeddings_negative.size == 0:
return {
'item_centroid_positive': np.nan,
'item_centroid_negative': np.nan,
'item_embeddings_aligned': np.nan,
'item_centroid_aligned': np.nan
}
item_centroid_positive = item_embeddings_positive.mean(axis=0)
item_centroid_negative = item_embeddings_negative.mean(axis=0)
cosine_similarities = util.cos_sim(item_embeddings, item_centroid_positive).numpy().squeeze()
synthetic_is_negative = cosine_similarities < 0
polarity_axis = item_centroid_positive - item_centroid_negative
axis_magnitude = np.sqrt(np.sum(polarity_axis**2))
if not np.isfinite(axis_magnitude) or axis_magnitude <= 0 or not any(synthetic_is_negative):
return {
'item_centroid_positive': np.nan,
'item_centroid_negative': np.nan,
'item_embeddings_aligned': np.nan,
'item_centroid_aligned': np.nan
}
polarity_unit_vector = polarity_axis / axis_magnitude
reflection_plane_center = (item_centroid_positive + item_centroid_negative) / 2
signed_distances_to_plane = np.dot(
item_embeddings - reflection_plane_center,
polarity_unit_vector
)
items_to_align = np.array([x == "negative" for x in keying]) & synthetic_is_negative
reflection_distances = np.where(items_to_align, signed_distances_to_plane, 0)
item_embeddings_aligned = item_embeddings - 2 * np.outer(
reflection_distances,
polarity_unit_vector
)
item_centroid_aligned = item_embeddings_aligned.mean(axis=0)
return {
'item_centroid_positive': item_centroid_positive,
'item_centroid_negative': item_centroid_negative,
'item_embeddings_aligned': item_embeddings_aligned,
'item_centroid_aligned': item_centroid_aligned
}
async def semantic_item_search(queries: list[dict], app: FastAPI) -> np.ndarray:
query_items = [q['text'] for q in queries]
query_keys = [q['reversed'] for q in queries]
with context_logger(f"Sending encoding requests for {len(query_items)} queries"):
query_embeddings = await encode(texts=query_items, mode="item")
with context_logger(f"Aligning item embeddings based on keying"):
keying = ["negative" if x else "positive" for x in query_keys]
query_embeddings_aligned = align_embeddings(query_embeddings, keying)
query_centroid = query_embeddings_aligned['item_centroid_aligned']
if np.any(np.isnan(query_centroid)):
logger.info(f"Query embedding alignment failed, calculating centroid without alignment")
query_centroid = query_embeddings.mean(axis=0)
with context_logger("Calculating cosine similarity"):
similarities = cosine_similarity(
X=app.state.data['item_centroids'],
Y=query_centroid.reshape(1, -1)
).ravel()
return similarities
async def semantic_scale_search(queries: list[dict], app: FastAPI) -> np.ndarray:
query = [q['text'] for q in queries]
with context_logger(f"Sending encoding requests for {len(query)} queries."):
query_embeddings = await encode(texts=query, mode="scale")
query_embeddings = query_embeddings.squeeze()
with context_logger("Calculating cosine similarity"):
similarities = cosine_similarity(
X=app.state.data['scale_centroids'],
Y=query_embeddings.reshape(1, -1)
).ravel()
return similarities
async def compute_search_results(similarities: np.ndarray, app: FastAPI) -> pl.DataFrame:
search_results = (
app.state.data['meta'].clone()
.with_columns(
pl.Series("similarity", similarities).round(3)
)
.group_by("meta_doi")
.agg([
pl.col('scale_name'),
pl.col('is_instrument'),
pl.col([
'meta_instrument_name',
'warn_item_count_deviation',
'warn_scale_count_deviation',
'warn_item_text_deviation',
'warn_keying_correction'
]).first(),
pl.col('similarity'),
])
.with_columns(
pl.concat_list([
pl.when(pl.col('warn_item_count_deviation'))
.then(pl.lit('ITEM_COUNT_DEVIATION')),
pl.when(pl.col('warn_scale_count_deviation'))
.then(pl.lit('SCALE_COUNT_DEVIATION')),
pl.when(pl.col('warn_item_text_deviation'))
.then(pl.lit('ITEM_TEXT_DEVIATION')),
pl.when(pl.col('warn_keying_correction'))
.then(pl.lit('KEYING_CORRECTION')),
]).list.drop_nulls().alias('warning_codes')
)
.with_columns(
pl.col('warning_codes').list.len().alias('warning_count'),
max_similarity = pl.col("similarity").list.max(),
max_abs_similarity = pl.col("similarity").list.max().abs(),
)
.drop([
'warn_item_count_deviation',
'warn_scale_count_deviation',
'warn_item_text_deviation',
'warn_keying_correction'
])
)
return search_results
async def filter_search(df: pl.DataFrame, filter_string: str) -> pl.DataFrame:
if filter_string:
in_instrument_name = df['meta_instrument_name'].str.to_lowercase().str.contains(filter_string)
in_scale_names = (
df['scale_name']
.list.drop_nulls() # Remove null values from each list
.list.join(" ") # Join list elements with space separator
.str.to_lowercase()
.str.contains(filter_string)
)
return df.filter(in_instrument_name | in_scale_names)
return df
async def refine_search(
df: pl.DataFrame,
sort_col: str,
sort_descending: bool,
page_index: int,
page_size: int
) -> pl.DataFrame:
sorted_result = df.sort(by=sort_col, descending=sort_descending)
start_index = page_index * page_size
end_index = start_index + page_size
page_results = sorted_result[start_index:end_index]
return page_results