Spaces:
Running
Running
Fix ty warnings in lynxkite-graph-analytics.
Browse files- lynxkite-core/src/lynxkite/core/workspace.py +1 -1
- lynxkite-graph-analytics/src/lynxkite_graph_analytics/__init__.py +1 -1
- lynxkite-graph-analytics/src/lynxkite_graph_analytics/core.py +2 -2
- lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py +9 -4
- lynxkite-graph-analytics/src/lynxkite_graph_analytics/ml_ops.py +1 -1
- lynxkite-graph-analytics/src/lynxkite_graph_analytics/networkx_ops.py +36 -34
- lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch/pytorch_core.py +14 -10
- lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch/pytorch_ops.py +5 -5
- lynxkite-graph-analytics/tests/test_core.py +3 -3
- lynxkite-graph-analytics/tests/test_lynxkite_ops.py +10 -10
- lynxkite-graph-analytics/tests/test_pytorch_model_ops.py +5 -1
lynxkite-core/src/lynxkite/core/workspace.py
CHANGED
|
@@ -230,7 +230,7 @@ class Workspace(BaseConfig):
|
|
| 230 |
)
|
| 231 |
elif "title" in kwargs:
|
| 232 |
kwargs["data"] = WorkspaceNodeData(
|
| 233 |
-
title=kwargs["title"], op_id=kwargs["title"], params={}
|
| 234 |
)
|
| 235 |
kwargs.setdefault("type", "basic")
|
| 236 |
kwargs.setdefault("id", f"{kwargs['data'].title} {random_string}")
|
|
|
|
| 230 |
)
|
| 231 |
elif "title" in kwargs:
|
| 232 |
kwargs["data"] = WorkspaceNodeData(
|
| 233 |
+
title=kwargs["title"], op_id=kwargs["title"], params=kwargs.get("params", {})
|
| 234 |
)
|
| 235 |
kwargs.setdefault("type", "basic")
|
| 236 |
kwargs.setdefault("id", f"{kwargs['data'].title} {random_string}")
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/__init__.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
| 4 |
import pandas as pd
|
| 5 |
|
| 6 |
if os.environ.get("NX_CUGRAPH_AUTOCONFIG", "").strip().lower() == "true":
|
| 7 |
-
import cudf.pandas
|
| 8 |
|
| 9 |
cudf.pandas.install()
|
| 10 |
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
|
| 6 |
if os.environ.get("NX_CUGRAPH_AUTOCONFIG", "").strip().lower() == "true":
|
| 7 |
+
import cudf.pandas # ty: ignore[unresolved-import]
|
| 8 |
|
| 9 |
cudf.pandas.install()
|
| 10 |
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/core.py
CHANGED
|
@@ -174,13 +174,13 @@ def disambiguate_edges(ws: workspace.Workspace):
|
|
| 174 |
for edge in reversed(ws.edges):
|
| 175 |
dst_node = nodes[edge.target]
|
| 176 |
op = catalog.get(dst_node.data.op_id)
|
| 177 |
-
if op.get_input(edge.targetHandle).type == list[Bundle]:
|
| 178 |
# Takes multiple bundles as an input. No need to disambiguate.
|
| 179 |
continue
|
| 180 |
if (edge.target, edge.targetHandle) in seen:
|
| 181 |
i = ws.edges.index(edge)
|
| 182 |
del ws.edges[i]
|
| 183 |
-
if
|
| 184 |
del ws._crdt["edges"][i]
|
| 185 |
seen.add((edge.target, edge.targetHandle))
|
| 186 |
|
|
|
|
| 174 |
for edge in reversed(ws.edges):
|
| 175 |
dst_node = nodes[edge.target]
|
| 176 |
op = catalog.get(dst_node.data.op_id)
|
| 177 |
+
if not op or op.get_input(edge.targetHandle).type == list[Bundle]:
|
| 178 |
# Takes multiple bundles as an input. No need to disambiguate.
|
| 179 |
continue
|
| 180 |
if (edge.target, edge.targetHandle) in seen:
|
| 181 |
i = ws.edges.index(edge)
|
| 182 |
del ws.edges[i]
|
| 183 |
+
if ws._crdt:
|
| 184 |
del ws._crdt["edges"][i]
|
| 185 |
seen.add((edge.target, edge.targetHandle))
|
| 186 |
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py
CHANGED
|
@@ -8,7 +8,8 @@ from collections import deque
|
|
| 8 |
|
| 9 |
from . import core
|
| 10 |
import grandcypher
|
| 11 |
-
import matplotlib
|
|
|
|
| 12 |
import networkx as nx
|
| 13 |
import pandas as pd
|
| 14 |
import polars as pl
|
|
@@ -205,6 +206,7 @@ def _map_color(value):
|
|
| 205 |
else:
|
| 206 |
cmap = matplotlib.cm.get_cmap("Paired")
|
| 207 |
categories = pd.Index(value.unique())
|
|
|
|
| 208 |
colors = cmap.colors[: len(categories)]
|
| 209 |
return [
|
| 210 |
"#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255))
|
|
@@ -312,7 +314,7 @@ def view_tables(bundle: core.Bundle, *, _tables_open: str = "", limit: int = 100
|
|
| 312 |
view="graph_creation_view",
|
| 313 |
outputs=["output"],
|
| 314 |
)
|
| 315 |
-
def organize(bundles: list[core.Bundle], *, relations: str =
|
| 316 |
"""Merge multiple inputs and construct graphs from the tables.
|
| 317 |
|
| 318 |
To create a graph, import tables for edges and nodes, and combine them in this operation.
|
|
@@ -322,6 +324,9 @@ def organize(bundles: list[core.Bundle], *, relations: str = None) -> core.Bundl
|
|
| 322 |
bundle.dfs.update(b.dfs)
|
| 323 |
bundle.relations.extend(b.relations)
|
| 324 |
bundle.other.update(b.other)
|
| 325 |
-
if
|
| 326 |
-
bundle.relations = [
|
|
|
|
|
|
|
|
|
|
| 327 |
return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
|
|
|
|
| 8 |
|
| 9 |
from . import core
|
| 10 |
import grandcypher
|
| 11 |
+
import matplotlib.cm
|
| 12 |
+
import matplotlib.colors
|
| 13 |
import networkx as nx
|
| 14 |
import pandas as pd
|
| 15 |
import polars as pl
|
|
|
|
| 206 |
else:
|
| 207 |
cmap = matplotlib.cm.get_cmap("Paired")
|
| 208 |
categories = pd.Index(value.unique())
|
| 209 |
+
assert isinstance(cmap, matplotlib.colors.ListedColormap)
|
| 210 |
colors = cmap.colors[: len(categories)]
|
| 211 |
return [
|
| 212 |
"#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255))
|
|
|
|
| 314 |
view="graph_creation_view",
|
| 315 |
outputs=["output"],
|
| 316 |
)
|
| 317 |
+
def organize(bundles: list[core.Bundle], *, relations: str = ""):
|
| 318 |
"""Merge multiple inputs and construct graphs from the tables.
|
| 319 |
|
| 320 |
To create a graph, import tables for edges and nodes, and combine them in this operation.
|
|
|
|
| 324 |
bundle.dfs.update(b.dfs)
|
| 325 |
bundle.relations.extend(b.relations)
|
| 326 |
bundle.other.update(b.other)
|
| 327 |
+
if relations.strip():
|
| 328 |
+
bundle.relations = [
|
| 329 |
+
core.RelationDefinition(**r) # ty: ignore[missing-argument]
|
| 330 |
+
for r in json.loads(relations).values()
|
| 331 |
+
]
|
| 332 |
return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/ml_ops.py
CHANGED
|
@@ -216,7 +216,7 @@ def view_vectors(
|
|
| 216 |
metric: UMAPMetric = UMAPMetric.euclidean,
|
| 217 |
):
|
| 218 |
try:
|
| 219 |
-
from cuml.manifold.umap import UMAP
|
| 220 |
except ImportError:
|
| 221 |
from umap import UMAP
|
| 222 |
vec = np.stack(bundle.dfs[table_name][vector_column].to_numpy())
|
|
|
|
| 216 |
metric: UMAPMetric = UMAPMetric.euclidean,
|
| 217 |
):
|
| 218 |
try:
|
| 219 |
+
from cuml.manifold.umap import UMAP # ty: ignore[unresolved-import]
|
| 220 |
except ImportError:
|
| 221 |
from umap import UMAP
|
| 222 |
vec = np.stack(bundle.dfs[table_name][vector_column].to_numpy())
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/networkx_ops.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
"""Automatically wraps all NetworkX functions as LynxKite operations."""
|
| 2 |
|
| 3 |
-
import collections
|
| 4 |
-
import types
|
| 5 |
from lynxkite.core import ops
|
|
|
|
|
|
|
| 6 |
import functools
|
| 7 |
import inspect
|
| 8 |
import networkx as nx
|
| 9 |
-
import re
|
| 10 |
-
|
| 11 |
import pandas as pd
|
|
|
|
|
|
|
| 12 |
|
| 13 |
ENV = "LynxKite Graph Analytics"
|
| 14 |
|
|
@@ -17,20 +17,22 @@ class UnsupportedParameterType(Exception):
|
|
| 17 |
pass
|
| 18 |
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
-
def doc_to_type(name: str, type_hint: str) -> type:
|
| 25 |
type_hint = type_hint.lower()
|
| 26 |
type_hint = re.sub("[(][^)]+[)]", "", type_hint).strip().strip(".")
|
| 27 |
if " " in name or "http" in name:
|
| 28 |
-
return
|
| 29 |
if type_hint.endswith(", optional"):
|
| 30 |
w = doc_to_type(name, type_hint.removesuffix(", optional").strip())
|
| 31 |
-
if w is
|
| 32 |
-
return
|
| 33 |
-
|
|
|
|
| 34 |
if type_hint in [
|
| 35 |
"a digraph or multidigraph",
|
| 36 |
"a graph g",
|
|
@@ -54,15 +56,15 @@ def doc_to_type(name: str, type_hint: str) -> type:
|
|
| 54 |
]:
|
| 55 |
return nx.DiGraph
|
| 56 |
elif type_hint == "node":
|
| 57 |
-
return
|
| 58 |
elif type_hint == '"node (optional)"':
|
| 59 |
-
return
|
| 60 |
elif type_hint == '"edge"':
|
| 61 |
-
return
|
| 62 |
elif type_hint == '"edge (optional)"':
|
| 63 |
-
return
|
| 64 |
elif type_hint in ["class", "data type"]:
|
| 65 |
-
return
|
| 66 |
elif type_hint in ["string", "str", "node label"]:
|
| 67 |
return str
|
| 68 |
elif type_hint in ["string or none", "none or string", "string, or none"]:
|
|
@@ -72,27 +74,27 @@ def doc_to_type(name: str, type_hint: str) -> type:
|
|
| 72 |
elif type_hint in ["bool", "boolean"]:
|
| 73 |
return bool
|
| 74 |
elif type_hint == "tuple":
|
| 75 |
-
return
|
| 76 |
elif type_hint == "set":
|
| 77 |
-
return
|
| 78 |
elif type_hint == "list of floats":
|
| 79 |
-
return
|
| 80 |
elif type_hint == "list of floats or float":
|
| 81 |
return float
|
| 82 |
elif type_hint in ["dict", "dictionary"]:
|
| 83 |
-
return
|
| 84 |
elif type_hint == "scalar or dictionary":
|
| 85 |
return float
|
| 86 |
elif type_hint == "none or dict":
|
| 87 |
-
return
|
| 88 |
elif type_hint in ["function", "callable"]:
|
| 89 |
-
return
|
| 90 |
elif type_hint in [
|
| 91 |
"collection",
|
| 92 |
"container of nodes",
|
| 93 |
"list of nodes",
|
| 94 |
]:
|
| 95 |
-
return
|
| 96 |
elif type_hint in [
|
| 97 |
"container",
|
| 98 |
"generator",
|
|
@@ -104,13 +106,13 @@ def doc_to_type(name: str, type_hint: str) -> type:
|
|
| 104 |
"list or tuple",
|
| 105 |
"list",
|
| 106 |
]:
|
| 107 |
-
return
|
| 108 |
elif type_hint == "generator of sets":
|
| 109 |
-
return
|
| 110 |
elif type_hint == "dict or a set of 2 or 3 tuples":
|
| 111 |
-
return
|
| 112 |
elif type_hint == "set of 2 or 3 tuples":
|
| 113 |
-
return
|
| 114 |
elif type_hint == "none, string or function":
|
| 115 |
return str | None
|
| 116 |
elif type_hint == "string or function" and name == "weight":
|
|
@@ -135,8 +137,8 @@ def doc_to_type(name: str, type_hint: str) -> type:
|
|
| 135 |
elif name == "weight":
|
| 136 |
return str
|
| 137 |
elif type_hint == "object":
|
| 138 |
-
return
|
| 139 |
-
return
|
| 140 |
|
| 141 |
|
| 142 |
def types_from_doc(doc: str) -> dict[str, type]:
|
|
@@ -186,13 +188,13 @@ def wrapped(name: str, func):
|
|
| 186 |
return wrapper
|
| 187 |
|
| 188 |
|
| 189 |
-
def _get_params(func) ->
|
| 190 |
sig = inspect.signature(func)
|
| 191 |
# Get types from docstring.
|
| 192 |
types = types_from_doc(func.__doc__)
|
| 193 |
# Always hide these.
|
| 194 |
for k in ["backend", "backend_kwargs", "create_using"]:
|
| 195 |
-
types[k] =
|
| 196 |
# Add in types based on signature.
|
| 197 |
for k, param in sig.parameters.items():
|
| 198 |
if k in types:
|
|
@@ -203,10 +205,10 @@ def _get_params(func) -> dict | None:
|
|
| 203 |
types[k] = int
|
| 204 |
params = []
|
| 205 |
for name, param in sig.parameters.items():
|
| 206 |
-
_type = types.get(name,
|
| 207 |
-
if _type is
|
| 208 |
raise UnsupportedParameterType(name)
|
| 209 |
-
if _type is
|
| 210 |
continue
|
| 211 |
p = ops.Parameter.basic(
|
| 212 |
name=name,
|
|
|
|
| 1 |
"""Automatically wraps all NetworkX functions as LynxKite operations."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from lynxkite.core import ops
|
| 4 |
+
import collections.abc
|
| 5 |
+
import enum
|
| 6 |
import functools
|
| 7 |
import inspect
|
| 8 |
import networkx as nx
|
|
|
|
|
|
|
| 9 |
import pandas as pd
|
| 10 |
+
import re
|
| 11 |
+
import types
|
| 12 |
|
| 13 |
ENV = "LynxKite Graph Analytics"
|
| 14 |
|
|
|
|
| 17 |
pass
|
| 18 |
|
| 19 |
|
| 20 |
+
class Failure(str, enum.Enum):
|
| 21 |
+
UNSUPPORTED = "unsupported" # This parameter will be hidden.
|
| 22 |
+
SKIP = "skip" # We have to skip the whole function.
|
| 23 |
|
| 24 |
|
| 25 |
+
def doc_to_type(name: str, type_hint: str) -> type | types.UnionType | Failure:
|
| 26 |
type_hint = type_hint.lower()
|
| 27 |
type_hint = re.sub("[(][^)]+[)]", "", type_hint).strip().strip(".")
|
| 28 |
if " " in name or "http" in name:
|
| 29 |
+
return Failure.UNSUPPORTED # Not a parameter type.
|
| 30 |
if type_hint.endswith(", optional"):
|
| 31 |
w = doc_to_type(name, type_hint.removesuffix(", optional").strip())
|
| 32 |
+
if w is Failure.UNSUPPORTED or w is Failure.SKIP:
|
| 33 |
+
return Failure.SKIP
|
| 34 |
+
assert not isinstance(w, Failure)
|
| 35 |
+
return w | None
|
| 36 |
if type_hint in [
|
| 37 |
"a digraph or multidigraph",
|
| 38 |
"a graph g",
|
|
|
|
| 56 |
]:
|
| 57 |
return nx.DiGraph
|
| 58 |
elif type_hint == "node":
|
| 59 |
+
return Failure.UNSUPPORTED
|
| 60 |
elif type_hint == '"node (optional)"':
|
| 61 |
+
return Failure.SKIP
|
| 62 |
elif type_hint == '"edge"':
|
| 63 |
+
return Failure.UNSUPPORTED
|
| 64 |
elif type_hint == '"edge (optional)"':
|
| 65 |
+
return Failure.SKIP
|
| 66 |
elif type_hint in ["class", "data type"]:
|
| 67 |
+
return Failure.UNSUPPORTED
|
| 68 |
elif type_hint in ["string", "str", "node label"]:
|
| 69 |
return str
|
| 70 |
elif type_hint in ["string or none", "none or string", "string, or none"]:
|
|
|
|
| 74 |
elif type_hint in ["bool", "boolean"]:
|
| 75 |
return bool
|
| 76 |
elif type_hint == "tuple":
|
| 77 |
+
return Failure.UNSUPPORTED
|
| 78 |
elif type_hint == "set":
|
| 79 |
+
return Failure.UNSUPPORTED
|
| 80 |
elif type_hint == "list of floats":
|
| 81 |
+
return Failure.UNSUPPORTED
|
| 82 |
elif type_hint == "list of floats or float":
|
| 83 |
return float
|
| 84 |
elif type_hint in ["dict", "dictionary"]:
|
| 85 |
+
return Failure.UNSUPPORTED
|
| 86 |
elif type_hint == "scalar or dictionary":
|
| 87 |
return float
|
| 88 |
elif type_hint == "none or dict":
|
| 89 |
+
return Failure.SKIP
|
| 90 |
elif type_hint in ["function", "callable"]:
|
| 91 |
+
return Failure.UNSUPPORTED
|
| 92 |
elif type_hint in [
|
| 93 |
"collection",
|
| 94 |
"container of nodes",
|
| 95 |
"list of nodes",
|
| 96 |
]:
|
| 97 |
+
return Failure.UNSUPPORTED
|
| 98 |
elif type_hint in [
|
| 99 |
"container",
|
| 100 |
"generator",
|
|
|
|
| 106 |
"list or tuple",
|
| 107 |
"list",
|
| 108 |
]:
|
| 109 |
+
return Failure.UNSUPPORTED
|
| 110 |
elif type_hint == "generator of sets":
|
| 111 |
+
return Failure.UNSUPPORTED
|
| 112 |
elif type_hint == "dict or a set of 2 or 3 tuples":
|
| 113 |
+
return Failure.UNSUPPORTED
|
| 114 |
elif type_hint == "set of 2 or 3 tuples":
|
| 115 |
+
return Failure.UNSUPPORTED
|
| 116 |
elif type_hint == "none, string or function":
|
| 117 |
return str | None
|
| 118 |
elif type_hint == "string or function" and name == "weight":
|
|
|
|
| 137 |
elif name == "weight":
|
| 138 |
return str
|
| 139 |
elif type_hint == "object":
|
| 140 |
+
return Failure.UNSUPPORTED
|
| 141 |
+
return Failure.SKIP
|
| 142 |
|
| 143 |
|
| 144 |
def types_from_doc(doc: str) -> dict[str, type]:
|
|
|
|
| 188 |
return wrapper
|
| 189 |
|
| 190 |
|
| 191 |
+
def _get_params(func) -> list[ops.Parameter | ops.ParameterGroup]:
|
| 192 |
sig = inspect.signature(func)
|
| 193 |
# Get types from docstring.
|
| 194 |
types = types_from_doc(func.__doc__)
|
| 195 |
# Always hide these.
|
| 196 |
for k in ["backend", "backend_kwargs", "create_using"]:
|
| 197 |
+
types[k] = Failure.SKIP
|
| 198 |
# Add in types based on signature.
|
| 199 |
for k, param in sig.parameters.items():
|
| 200 |
if k in types:
|
|
|
|
| 205 |
types[k] = int
|
| 206 |
params = []
|
| 207 |
for name, param in sig.parameters.items():
|
| 208 |
+
_type = types.get(name, Failure.UNSUPPORTED)
|
| 209 |
+
if _type is Failure.UNSUPPORTED:
|
| 210 |
raise UnsupportedParameterType(name)
|
| 211 |
+
if _type is Failure.SKIP or _type in [nx.Graph, nx.DiGraph]:
|
| 212 |
continue
|
| 213 |
p = ops.Parameter.basic(
|
| 214 |
name=name,
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch/pytorch_core.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
import copy
|
| 4 |
import graphlib
|
| 5 |
import io
|
| 6 |
-
import
|
| 7 |
import pydantic
|
| 8 |
from lynxkite.core import ops, workspace
|
| 9 |
import torch
|
|
@@ -36,8 +36,12 @@ def reg(name, inputs=[], outputs=None, params=[], **kwargs):
|
|
| 36 |
return ops.register_passive_op(
|
| 37 |
ENV,
|
| 38 |
name,
|
| 39 |
-
inputs=[
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
params=params,
|
| 42 |
**kwargs,
|
| 43 |
)
|
|
@@ -75,16 +79,16 @@ class ModelMapping(pydantic.BaseModel):
|
|
| 75 |
map: dict[str, ColumnSpec]
|
| 76 |
|
| 77 |
|
| 78 |
-
def _torch_save(data) ->
|
| 79 |
"""Saves PyTorch data (modules, tensors) as a string."""
|
| 80 |
buffer = io.BytesIO()
|
| 81 |
torch.save(data, buffer)
|
| 82 |
return buffer.getvalue()
|
| 83 |
|
| 84 |
|
| 85 |
-
def _torch_load(data:
|
| 86 |
"""Loads PyTorch data (modules, tensors) from a string."""
|
| 87 |
-
buffer = io.BytesIO(data)
|
| 88 |
return torch.load(buffer)
|
| 89 |
|
| 90 |
|
|
@@ -98,7 +102,7 @@ class ModelConfig:
|
|
| 98 |
input_output_names: list[str]
|
| 99 |
loss: torch.nn.Module
|
| 100 |
source_workspace_json: str
|
| 101 |
-
optimizer_parameters: dict[str,
|
| 102 |
optimizer: torch.optim.Optimizer | None = None
|
| 103 |
source_workspace: str | None = None
|
| 104 |
trained: bool = False
|
|
@@ -116,7 +120,7 @@ class ModelConfig:
|
|
| 116 |
values = {k: v for k, v in zip(self.model_outputs, output)}
|
| 117 |
return values
|
| 118 |
|
| 119 |
-
def inference(self, inputs: dict[str, torch.Tensor]) -> dict[str,
|
| 120 |
"""Inference on a single batch."""
|
| 121 |
self.model.eval()
|
| 122 |
return self._forward(inputs)
|
|
@@ -422,7 +426,7 @@ class ModelBuilder:
|
|
| 422 |
op = self.catalog["Optimizer"]
|
| 423 |
cfg["optimizer_parameters"] = op.convert_params(self.nodes[self.optimizer].data.params)
|
| 424 |
cfg["source_workspace_json"] = self.ws.model_dump_json()
|
| 425 |
-
return ModelConfig(**cfg)
|
| 426 |
|
| 427 |
def get_names(self, *ids: list[str]) -> dict[str, str]:
|
| 428 |
"""Returns a mapping from internal IDs to human-readable names."""
|
|
@@ -448,7 +452,7 @@ class ModelBuilder:
|
|
| 448 |
|
| 449 |
|
| 450 |
def to_batch_tensors(
|
| 451 |
-
b: core.Bundle, batch_size: int, batch_index: int, m: ModelMapping
|
| 452 |
) -> dict[str, torch.Tensor]:
|
| 453 |
"""Extracts tensors from a bundle for a specific batch using a model mapping."""
|
| 454 |
tensors = {}
|
|
|
|
| 3 |
import copy
|
| 4 |
import graphlib
|
| 5 |
import io
|
| 6 |
+
import typing
|
| 7 |
import pydantic
|
| 8 |
from lynxkite.core import ops, workspace
|
| 9 |
import torch
|
|
|
|
| 36 |
return ops.register_passive_op(
|
| 37 |
ENV,
|
| 38 |
name,
|
| 39 |
+
inputs=[
|
| 40 |
+
ops.Input(name=name, position=ops.Position.BOTTOM, type="tensor") for name in inputs
|
| 41 |
+
],
|
| 42 |
+
outputs=[
|
| 43 |
+
ops.Output(name=name, position=ops.Position.TOP, type="tensor") for name in outputs
|
| 44 |
+
],
|
| 45 |
params=params,
|
| 46 |
**kwargs,
|
| 47 |
)
|
|
|
|
| 79 |
map: dict[str, ColumnSpec]
|
| 80 |
|
| 81 |
|
| 82 |
+
def _torch_save(data) -> bytes:
|
| 83 |
"""Saves PyTorch data (modules, tensors) as a string."""
|
| 84 |
buffer = io.BytesIO()
|
| 85 |
torch.save(data, buffer)
|
| 86 |
return buffer.getvalue()
|
| 87 |
|
| 88 |
|
| 89 |
+
def _torch_load(data: bytes) -> typing.Any:
|
| 90 |
"""Loads PyTorch data (modules, tensors) from a string."""
|
| 91 |
+
buffer = io.BytesIO(data)
|
| 92 |
return torch.load(buffer)
|
| 93 |
|
| 94 |
|
|
|
|
| 102 |
input_output_names: list[str]
|
| 103 |
loss: torch.nn.Module
|
| 104 |
source_workspace_json: str
|
| 105 |
+
optimizer_parameters: dict[str, typing.Any]
|
| 106 |
optimizer: torch.optim.Optimizer | None = None
|
| 107 |
source_workspace: str | None = None
|
| 108 |
trained: bool = False
|
|
|
|
| 120 |
values = {k: v for k, v in zip(self.model_outputs, output)}
|
| 121 |
return values
|
| 122 |
|
| 123 |
+
def inference(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 124 |
"""Inference on a single batch."""
|
| 125 |
self.model.eval()
|
| 126 |
return self._forward(inputs)
|
|
|
|
| 426 |
op = self.catalog["Optimizer"]
|
| 427 |
cfg["optimizer_parameters"] = op.convert_params(self.nodes[self.optimizer].data.params)
|
| 428 |
cfg["source_workspace_json"] = self.ws.model_dump_json()
|
| 429 |
+
return ModelConfig(**cfg) # ty: ignore[missing-argument]
|
| 430 |
|
| 431 |
def get_names(self, *ids: list[str]) -> dict[str, str]:
|
| 432 |
"""Returns a mapping from internal IDs to human-readable names."""
|
|
|
|
| 452 |
|
| 453 |
|
| 454 |
def to_batch_tensors(
|
| 455 |
+
b: core.Bundle, batch_size: int, batch_index: int, m: ModelMapping
|
| 456 |
) -> dict[str, torch.Tensor]:
|
| 457 |
"""Extracts tensors from a bundle for a specific batch using a model mapping."""
|
| 458 |
tensors = {}
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch/pytorch_ops.py
CHANGED
|
@@ -117,7 +117,7 @@ def neural_ode_mlp(
|
|
| 117 |
|
| 118 |
@op("Attention", outputs=["outputs", "weights"])
|
| 119 |
def attention(query, key, value, *, embed_dim=1024, num_heads=1, dropout=0.0):
|
| 120 |
-
return torch.nn.
|
| 121 |
|
| 122 |
|
| 123 |
@op("LayerNorm", outputs=["outputs", "weights"])
|
|
@@ -250,8 +250,8 @@ reg(
|
|
| 250 |
ops.register_passive_op(
|
| 251 |
ENV,
|
| 252 |
"Repeat",
|
| 253 |
-
inputs=[ops.Input(name="input", position=
|
| 254 |
-
outputs=[ops.Output(name="output", position=
|
| 255 |
params=[
|
| 256 |
ops.Parameter.basic("times", 1, int),
|
| 257 |
ops.Parameter.basic("same_weights", False, bool),
|
|
@@ -261,8 +261,8 @@ ops.register_passive_op(
|
|
| 261 |
ops.register_passive_op(
|
| 262 |
ENV,
|
| 263 |
"Recurrent chain",
|
| 264 |
-
inputs=[ops.Input(name="input", position=
|
| 265 |
-
outputs=[ops.Output(name="output", position=
|
| 266 |
params=[],
|
| 267 |
)
|
| 268 |
|
|
|
|
| 117 |
|
| 118 |
@op("Attention", outputs=["outputs", "weights"])
|
| 119 |
def attention(query, key, value, *, embed_dim=1024, num_heads=1, dropout=0.0):
|
| 120 |
+
return torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
|
| 121 |
|
| 122 |
|
| 123 |
@op("LayerNorm", outputs=["outputs", "weights"])
|
|
|
|
| 250 |
ops.register_passive_op(
|
| 251 |
ENV,
|
| 252 |
"Repeat",
|
| 253 |
+
inputs=[ops.Input(name="input", position=ops.Position.TOP, type="tensor")],
|
| 254 |
+
outputs=[ops.Output(name="output", position=ops.Position.BOTTOM, type="tensor")],
|
| 255 |
params=[
|
| 256 |
ops.Parameter.basic("times", 1, int),
|
| 257 |
ops.Parameter.basic("same_weights", False, bool),
|
|
|
|
| 261 |
ops.register_passive_op(
|
| 262 |
ENV,
|
| 263 |
"Recurrent chain",
|
| 264 |
+
inputs=[ops.Input(name="input", position=ops.Position.TOP, type="tensor")],
|
| 265 |
+
outputs=[ops.Output(name="output", position=ops.Position.BOTTOM, type="tensor")],
|
| 266 |
params=[],
|
| 267 |
)
|
| 268 |
|
lynxkite-graph-analytics/tests/test_core.py
CHANGED
|
@@ -20,19 +20,19 @@ async def test_multi_input_box():
|
|
| 20 |
ws.add_node(
|
| 21 |
id="1",
|
| 22 |
type="node_type",
|
| 23 |
-
|
| 24 |
position=workspace.Position(x=0, y=0),
|
| 25 |
)
|
| 26 |
ws.add_node(
|
| 27 |
id="2",
|
| 28 |
type="node_type",
|
| 29 |
-
|
| 30 |
position=workspace.Position(x=0, y=0),
|
| 31 |
)
|
| 32 |
ws.add_node(
|
| 33 |
id="3",
|
| 34 |
type="node_type",
|
| 35 |
-
|
| 36 |
position=workspace.Position(x=0, y=0),
|
| 37 |
)
|
| 38 |
ws.edges = [
|
|
|
|
| 20 |
ws.add_node(
|
| 21 |
id="1",
|
| 22 |
type="node_type",
|
| 23 |
+
title="Create Bundle",
|
| 24 |
position=workspace.Position(x=0, y=0),
|
| 25 |
)
|
| 26 |
ws.add_node(
|
| 27 |
id="2",
|
| 28 |
type="node_type",
|
| 29 |
+
title="Create Bundle",
|
| 30 |
position=workspace.Position(x=0, y=0),
|
| 31 |
)
|
| 32 |
ws.add_node(
|
| 33 |
id="3",
|
| 34 |
type="node_type",
|
| 35 |
+
title="Multi input op",
|
| 36 |
position=workspace.Position(x=0, y=0),
|
| 37 |
)
|
| 38 |
ws.edges = [
|
lynxkite-graph-analytics/tests/test_lynxkite_ops.py
CHANGED
|
@@ -12,7 +12,7 @@ async def test_execute_operation_not_in_catalog():
|
|
| 12 |
ws.add_node(
|
| 13 |
id="1",
|
| 14 |
type="node_type",
|
| 15 |
-
|
| 16 |
position=workspace.Position(x=0, y=0),
|
| 17 |
)
|
| 18 |
await execute(ws)
|
|
@@ -78,25 +78,25 @@ async def test_execute_operation_inputs_correct_cast():
|
|
| 78 |
ws.add_node(
|
| 79 |
id="1",
|
| 80 |
type="node_type",
|
| 81 |
-
|
| 82 |
position=workspace.Position(x=0, y=0),
|
| 83 |
)
|
| 84 |
ws.add_node(
|
| 85 |
id="2",
|
| 86 |
type="node_type",
|
| 87 |
-
|
| 88 |
position=workspace.Position(x=100, y=0),
|
| 89 |
)
|
| 90 |
ws.add_node(
|
| 91 |
id="3",
|
| 92 |
type="node_type",
|
| 93 |
-
|
| 94 |
position=workspace.Position(x=200, y=0),
|
| 95 |
)
|
| 96 |
ws.add_node(
|
| 97 |
id="4",
|
| 98 |
type="node_type",
|
| 99 |
-
|
| 100 |
position=workspace.Position(x=300, y=0),
|
| 101 |
)
|
| 102 |
ws.edges = [
|
|
@@ -136,19 +136,19 @@ async def test_multiple_inputs():
|
|
| 136 |
ws.add_node(
|
| 137 |
id="one",
|
| 138 |
type="cool",
|
| 139 |
-
|
| 140 |
position=workspace.Position(x=0, y=0),
|
| 141 |
)
|
| 142 |
ws.add_node(
|
| 143 |
id="two",
|
| 144 |
type="cool",
|
| 145 |
-
|
| 146 |
position=workspace.Position(x=100, y=0),
|
| 147 |
)
|
| 148 |
ws.add_node(
|
| 149 |
id="smaller",
|
| 150 |
type="cool",
|
| 151 |
-
|
| 152 |
position=workspace.Position(x=200, y=0),
|
| 153 |
)
|
| 154 |
ws.edges = [
|
|
@@ -188,8 +188,8 @@ async def test_optional_inputs():
|
|
| 188 |
return a + (b or 0)
|
| 189 |
|
| 190 |
assert maybe_add.__op__.inputs == [
|
| 191 |
-
ops.Input(name="a", type=int, position=
|
| 192 |
-
ops.Input(name="b", type=int | None, position=
|
| 193 |
]
|
| 194 |
ws = workspace.Workspace(env="test", nodes=[], edges=[])
|
| 195 |
a = ws.add_node(one)
|
|
|
|
| 12 |
ws.add_node(
|
| 13 |
id="1",
|
| 14 |
type="node_type",
|
| 15 |
+
title="Non existing op",
|
| 16 |
position=workspace.Position(x=0, y=0),
|
| 17 |
)
|
| 18 |
await execute(ws)
|
|
|
|
| 78 |
ws.add_node(
|
| 79 |
id="1",
|
| 80 |
type="node_type",
|
| 81 |
+
title="Create Bundle",
|
| 82 |
position=workspace.Position(x=0, y=0),
|
| 83 |
)
|
| 84 |
ws.add_node(
|
| 85 |
id="2",
|
| 86 |
type="node_type",
|
| 87 |
+
title="Bundle to Graph",
|
| 88 |
position=workspace.Position(x=100, y=0),
|
| 89 |
)
|
| 90 |
ws.add_node(
|
| 91 |
id="3",
|
| 92 |
type="node_type",
|
| 93 |
+
title="Graph to Bundle",
|
| 94 |
position=workspace.Position(x=200, y=0),
|
| 95 |
)
|
| 96 |
ws.add_node(
|
| 97 |
id="4",
|
| 98 |
type="node_type",
|
| 99 |
+
title="Dataframe to Bundle",
|
| 100 |
position=workspace.Position(x=300, y=0),
|
| 101 |
)
|
| 102 |
ws.edges = [
|
|
|
|
| 136 |
ws.add_node(
|
| 137 |
id="one",
|
| 138 |
type="cool",
|
| 139 |
+
title="One",
|
| 140 |
position=workspace.Position(x=0, y=0),
|
| 141 |
)
|
| 142 |
ws.add_node(
|
| 143 |
id="two",
|
| 144 |
type="cool",
|
| 145 |
+
title="Two",
|
| 146 |
position=workspace.Position(x=100, y=0),
|
| 147 |
)
|
| 148 |
ws.add_node(
|
| 149 |
id="smaller",
|
| 150 |
type="cool",
|
| 151 |
+
title="Smaller?",
|
| 152 |
position=workspace.Position(x=200, y=0),
|
| 153 |
)
|
| 154 |
ws.edges = [
|
|
|
|
| 188 |
return a + (b or 0)
|
| 189 |
|
| 190 |
assert maybe_add.__op__.inputs == [
|
| 191 |
+
ops.Input(name="a", type=int, position=ops.Position.LEFT),
|
| 192 |
+
ops.Input(name="b", type=int | None, position=ops.Position.LEFT),
|
| 193 |
]
|
| 194 |
ws = workspace.Workspace(env="test", nodes=[], edges=[])
|
| 195 |
a = ws.add_node(one)
|
lynxkite-graph-analytics/tests/test_pytorch_model_ops.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from lynxkite.core import workspace
|
| 2 |
from lynxkite_graph_analytics.pytorch import pytorch_core
|
| 3 |
import torch
|
|
@@ -11,7 +12,8 @@ def make_ws(env, nodes: dict[str, dict], edges: list[tuple[str, str]]):
|
|
| 11 |
del data["title"]
|
| 12 |
ws.add_node(
|
| 13 |
id=id,
|
| 14 |
-
|
|
|
|
| 15 |
)
|
| 16 |
ws.edges = [
|
| 17 |
workspace.WorkspaceEdge(
|
|
@@ -27,10 +29,12 @@ def make_ws(env, nodes: dict[str, dict], edges: list[tuple[str, str]]):
|
|
| 27 |
|
| 28 |
|
| 29 |
def summarize_layers(m: pytorch_core.ModelConfig) -> str:
|
|
|
|
| 30 |
return "".join(str(e)[:2] for e in m.model)
|
| 31 |
|
| 32 |
|
| 33 |
def summarize_connections(m: pytorch_core.ModelConfig) -> str:
|
|
|
|
| 34 |
return " ".join(
|
| 35 |
"".join(n[0] for n in c.param_names) + "->" + "".join(n[0] for n in c.return_names)
|
| 36 |
for c in m.model._children
|
|
|
|
| 1 |
+
import torch_geometric.nn as pyg_nn
|
| 2 |
from lynxkite.core import workspace
|
| 3 |
from lynxkite_graph_analytics.pytorch import pytorch_core
|
| 4 |
import torch
|
|
|
|
| 12 |
del data["title"]
|
| 13 |
ws.add_node(
|
| 14 |
id=id,
|
| 15 |
+
title=title,
|
| 16 |
+
params=data,
|
| 17 |
)
|
| 18 |
ws.edges = [
|
| 19 |
workspace.WorkspaceEdge(
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def summarize_layers(m: pytorch_core.ModelConfig) -> str:
|
| 32 |
+
assert isinstance(m.model, pyg_nn.Sequential)
|
| 33 |
return "".join(str(e)[:2] for e in m.model)
|
| 34 |
|
| 35 |
|
| 36 |
def summarize_connections(m: pytorch_core.ModelConfig) -> str:
|
| 37 |
+
assert isinstance(m.model, pyg_nn.Sequential)
|
| 38 |
return " ".join(
|
| 39 |
"".join(n[0] for n in c.param_names) + "->" + "".join(n[0] for n in c.return_names)
|
| 40 |
for c in m.model._children
|