darabos commited on
Commit
e99021b
·
1 Parent(s): 4ec5c08

Fix ty warnings in lynxkite-graph-analytics.

Browse files
lynxkite-core/src/lynxkite/core/workspace.py CHANGED
@@ -230,7 +230,7 @@ class Workspace(BaseConfig):
230
  )
231
  elif "title" in kwargs:
232
  kwargs["data"] = WorkspaceNodeData(
233
- title=kwargs["title"], op_id=kwargs["title"], params={}
234
  )
235
  kwargs.setdefault("type", "basic")
236
  kwargs.setdefault("id", f"{kwargs['data'].title} {random_string}")
 
230
  )
231
  elif "title" in kwargs:
232
  kwargs["data"] = WorkspaceNodeData(
233
+ title=kwargs["title"], op_id=kwargs["title"], params=kwargs.get("params", {})
234
  )
235
  kwargs.setdefault("type", "basic")
236
  kwargs.setdefault("id", f"{kwargs['data'].title} {random_string}")
lynxkite-graph-analytics/src/lynxkite_graph_analytics/__init__.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import pandas as pd
5
 
6
  if os.environ.get("NX_CUGRAPH_AUTOCONFIG", "").strip().lower() == "true":
7
- import cudf.pandas
8
 
9
  cudf.pandas.install()
10
 
 
4
  import pandas as pd
5
 
6
  if os.environ.get("NX_CUGRAPH_AUTOCONFIG", "").strip().lower() == "true":
7
+ import cudf.pandas # ty: ignore[unresolved-import]
8
 
9
  cudf.pandas.install()
10
 
lynxkite-graph-analytics/src/lynxkite_graph_analytics/core.py CHANGED
@@ -174,13 +174,13 @@ def disambiguate_edges(ws: workspace.Workspace):
174
  for edge in reversed(ws.edges):
175
  dst_node = nodes[edge.target]
176
  op = catalog.get(dst_node.data.op_id)
177
- if op.get_input(edge.targetHandle).type == list[Bundle]:
178
  # Takes multiple bundles as an input. No need to disambiguate.
179
  continue
180
  if (edge.target, edge.targetHandle) in seen:
181
  i = ws.edges.index(edge)
182
  del ws.edges[i]
183
- if hasattr(ws, "_crdt"):
184
  del ws._crdt["edges"][i]
185
  seen.add((edge.target, edge.targetHandle))
186
 
 
174
  for edge in reversed(ws.edges):
175
  dst_node = nodes[edge.target]
176
  op = catalog.get(dst_node.data.op_id)
177
+ if not op or op.get_input(edge.targetHandle).type == list[Bundle]:
178
  # Takes multiple bundles as an input. No need to disambiguate.
179
  continue
180
  if (edge.target, edge.targetHandle) in seen:
181
  i = ws.edges.index(edge)
182
  del ws.edges[i]
183
+ if ws._crdt:
184
  del ws._crdt["edges"][i]
185
  seen.add((edge.target, edge.targetHandle))
186
 
lynxkite-graph-analytics/src/lynxkite_graph_analytics/lynxkite_ops.py CHANGED
@@ -8,7 +8,8 @@ from collections import deque
8
 
9
  from . import core
10
  import grandcypher
11
- import matplotlib
 
12
  import networkx as nx
13
  import pandas as pd
14
  import polars as pl
@@ -205,6 +206,7 @@ def _map_color(value):
205
  else:
206
  cmap = matplotlib.cm.get_cmap("Paired")
207
  categories = pd.Index(value.unique())
 
208
  colors = cmap.colors[: len(categories)]
209
  return [
210
  "#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255))
@@ -312,7 +314,7 @@ def view_tables(bundle: core.Bundle, *, _tables_open: str = "", limit: int = 100
312
  view="graph_creation_view",
313
  outputs=["output"],
314
  )
315
- def organize(bundles: list[core.Bundle], *, relations: str = None) -> core.Bundle:
316
  """Merge multiple inputs and construct graphs from the tables.
317
 
318
  To create a graph, import tables for edges and nodes, and combine them in this operation.
@@ -322,6 +324,9 @@ def organize(bundles: list[core.Bundle], *, relations: str = None) -> core.Bundl
322
  bundle.dfs.update(b.dfs)
323
  bundle.relations.extend(b.relations)
324
  bundle.other.update(b.other)
325
- if not (relations is None or relations.strip() == ""):
326
- bundle.relations = [core.RelationDefinition(**r) for r in json.loads(relations).values()]
 
 
 
327
  return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
 
8
 
9
  from . import core
10
  import grandcypher
11
+ import matplotlib.cm
12
+ import matplotlib.colors
13
  import networkx as nx
14
  import pandas as pd
15
  import polars as pl
 
206
  else:
207
  cmap = matplotlib.cm.get_cmap("Paired")
208
  categories = pd.Index(value.unique())
209
+ assert isinstance(cmap, matplotlib.colors.ListedColormap)
210
  colors = cmap.colors[: len(categories)]
211
  return [
212
  "#{:02x}{:02x}{:02x}".format(int(r * 255), int(g * 255), int(b * 255))
 
314
  view="graph_creation_view",
315
  outputs=["output"],
316
  )
317
+ def organize(bundles: list[core.Bundle], *, relations: str = ""):
318
  """Merge multiple inputs and construct graphs from the tables.
319
 
320
  To create a graph, import tables for edges and nodes, and combine them in this operation.
 
324
  bundle.dfs.update(b.dfs)
325
  bundle.relations.extend(b.relations)
326
  bundle.other.update(b.other)
327
+ if relations.strip():
328
+ bundle.relations = [
329
+ core.RelationDefinition(**r) # ty: ignore[missing-argument]
330
+ for r in json.loads(relations).values()
331
+ ]
332
  return ops.Result(output=bundle, display=bundle.to_dict(limit=100))
lynxkite-graph-analytics/src/lynxkite_graph_analytics/ml_ops.py CHANGED
@@ -216,7 +216,7 @@ def view_vectors(
216
  metric: UMAPMetric = UMAPMetric.euclidean,
217
  ):
218
  try:
219
- from cuml.manifold.umap import UMAP
220
  except ImportError:
221
  from umap import UMAP
222
  vec = np.stack(bundle.dfs[table_name][vector_column].to_numpy())
 
216
  metric: UMAPMetric = UMAPMetric.euclidean,
217
  ):
218
  try:
219
+ from cuml.manifold.umap import UMAP # ty: ignore[unresolved-import]
220
  except ImportError:
221
  from umap import UMAP
222
  vec = np.stack(bundle.dfs[table_name][vector_column].to_numpy())
lynxkite-graph-analytics/src/lynxkite_graph_analytics/networkx_ops.py CHANGED
@@ -1,14 +1,14 @@
1
  """Automatically wraps all NetworkX functions as LynxKite operations."""
2
 
3
- import collections
4
- import types
5
  from lynxkite.core import ops
 
 
6
  import functools
7
  import inspect
8
  import networkx as nx
9
- import re
10
-
11
  import pandas as pd
 
 
12
 
13
  ENV = "LynxKite Graph Analytics"
14
 
@@ -17,20 +17,22 @@ class UnsupportedParameterType(Exception):
17
  pass
18
 
19
 
20
- _UNSUPPORTED = object()
21
- _SKIP = object()
 
22
 
23
 
24
- def doc_to_type(name: str, type_hint: str) -> type:
25
  type_hint = type_hint.lower()
26
  type_hint = re.sub("[(][^)]+[)]", "", type_hint).strip().strip(".")
27
  if " " in name or "http" in name:
28
- return _UNSUPPORTED # Not a parameter type.
29
  if type_hint.endswith(", optional"):
30
  w = doc_to_type(name, type_hint.removesuffix(", optional").strip())
31
- if w is _UNSUPPORTED:
32
- return _SKIP
33
- return w if w is _SKIP else w | None
 
34
  if type_hint in [
35
  "a digraph or multidigraph",
36
  "a graph g",
@@ -54,15 +56,15 @@ def doc_to_type(name: str, type_hint: str) -> type:
54
  ]:
55
  return nx.DiGraph
56
  elif type_hint == "node":
57
- return _UNSUPPORTED
58
  elif type_hint == '"node (optional)"':
59
- return _SKIP
60
  elif type_hint == '"edge"':
61
- return _UNSUPPORTED
62
  elif type_hint == '"edge (optional)"':
63
- return _SKIP
64
  elif type_hint in ["class", "data type"]:
65
- return _UNSUPPORTED
66
  elif type_hint in ["string", "str", "node label"]:
67
  return str
68
  elif type_hint in ["string or none", "none or string", "string, or none"]:
@@ -72,27 +74,27 @@ def doc_to_type(name: str, type_hint: str) -> type:
72
  elif type_hint in ["bool", "boolean"]:
73
  return bool
74
  elif type_hint == "tuple":
75
- return _UNSUPPORTED
76
  elif type_hint == "set":
77
- return _UNSUPPORTED
78
  elif type_hint == "list of floats":
79
- return _UNSUPPORTED
80
  elif type_hint == "list of floats or float":
81
  return float
82
  elif type_hint in ["dict", "dictionary"]:
83
- return _UNSUPPORTED
84
  elif type_hint == "scalar or dictionary":
85
  return float
86
  elif type_hint == "none or dict":
87
- return _SKIP
88
  elif type_hint in ["function", "callable"]:
89
- return _UNSUPPORTED
90
  elif type_hint in [
91
  "collection",
92
  "container of nodes",
93
  "list of nodes",
94
  ]:
95
- return _UNSUPPORTED
96
  elif type_hint in [
97
  "container",
98
  "generator",
@@ -104,13 +106,13 @@ def doc_to_type(name: str, type_hint: str) -> type:
104
  "list or tuple",
105
  "list",
106
  ]:
107
- return _UNSUPPORTED
108
  elif type_hint == "generator of sets":
109
- return _UNSUPPORTED
110
  elif type_hint == "dict or a set of 2 or 3 tuples":
111
- return _UNSUPPORTED
112
  elif type_hint == "set of 2 or 3 tuples":
113
- return _UNSUPPORTED
114
  elif type_hint == "none, string or function":
115
  return str | None
116
  elif type_hint == "string or function" and name == "weight":
@@ -135,8 +137,8 @@ def doc_to_type(name: str, type_hint: str) -> type:
135
  elif name == "weight":
136
  return str
137
  elif type_hint == "object":
138
- return _UNSUPPORTED
139
- return _SKIP
140
 
141
 
142
  def types_from_doc(doc: str) -> dict[str, type]:
@@ -186,13 +188,13 @@ def wrapped(name: str, func):
186
  return wrapper
187
 
188
 
189
- def _get_params(func) -> dict | None:
190
  sig = inspect.signature(func)
191
  # Get types from docstring.
192
  types = types_from_doc(func.__doc__)
193
  # Always hide these.
194
  for k in ["backend", "backend_kwargs", "create_using"]:
195
- types[k] = _SKIP
196
  # Add in types based on signature.
197
  for k, param in sig.parameters.items():
198
  if k in types:
@@ -203,10 +205,10 @@ def _get_params(func) -> dict | None:
203
  types[k] = int
204
  params = []
205
  for name, param in sig.parameters.items():
206
- _type = types.get(name, _UNSUPPORTED)
207
- if _type is _UNSUPPORTED:
208
  raise UnsupportedParameterType(name)
209
- if _type is _SKIP or _type in [nx.Graph, nx.DiGraph]:
210
  continue
211
  p = ops.Parameter.basic(
212
  name=name,
 
1
  """Automatically wraps all NetworkX functions as LynxKite operations."""
2
 
 
 
3
  from lynxkite.core import ops
4
+ import collections.abc
5
+ import enum
6
  import functools
7
  import inspect
8
  import networkx as nx
 
 
9
  import pandas as pd
10
+ import re
11
+ import types
12
 
13
  ENV = "LynxKite Graph Analytics"
14
 
 
17
  pass
18
 
19
 
20
+ class Failure(str, enum.Enum):
21
+ UNSUPPORTED = "unsupported" # This parameter will be hidden.
22
+ SKIP = "skip" # We have to skip the whole function.
23
 
24
 
25
+ def doc_to_type(name: str, type_hint: str) -> type | types.UnionType | Failure:
26
  type_hint = type_hint.lower()
27
  type_hint = re.sub("[(][^)]+[)]", "", type_hint).strip().strip(".")
28
  if " " in name or "http" in name:
29
+ return Failure.UNSUPPORTED # Not a parameter type.
30
  if type_hint.endswith(", optional"):
31
  w = doc_to_type(name, type_hint.removesuffix(", optional").strip())
32
+ if w is Failure.UNSUPPORTED or w is Failure.SKIP:
33
+ return Failure.SKIP
34
+ assert not isinstance(w, Failure)
35
+ return w | None
36
  if type_hint in [
37
  "a digraph or multidigraph",
38
  "a graph g",
 
56
  ]:
57
  return nx.DiGraph
58
  elif type_hint == "node":
59
+ return Failure.UNSUPPORTED
60
  elif type_hint == '"node (optional)"':
61
+ return Failure.SKIP
62
  elif type_hint == '"edge"':
63
+ return Failure.UNSUPPORTED
64
  elif type_hint == '"edge (optional)"':
65
+ return Failure.SKIP
66
  elif type_hint in ["class", "data type"]:
67
+ return Failure.UNSUPPORTED
68
  elif type_hint in ["string", "str", "node label"]:
69
  return str
70
  elif type_hint in ["string or none", "none or string", "string, or none"]:
 
74
  elif type_hint in ["bool", "boolean"]:
75
  return bool
76
  elif type_hint == "tuple":
77
+ return Failure.UNSUPPORTED
78
  elif type_hint == "set":
79
+ return Failure.UNSUPPORTED
80
  elif type_hint == "list of floats":
81
+ return Failure.UNSUPPORTED
82
  elif type_hint == "list of floats or float":
83
  return float
84
  elif type_hint in ["dict", "dictionary"]:
85
+ return Failure.UNSUPPORTED
86
  elif type_hint == "scalar or dictionary":
87
  return float
88
  elif type_hint == "none or dict":
89
+ return Failure.SKIP
90
  elif type_hint in ["function", "callable"]:
91
+ return Failure.UNSUPPORTED
92
  elif type_hint in [
93
  "collection",
94
  "container of nodes",
95
  "list of nodes",
96
  ]:
97
+ return Failure.UNSUPPORTED
98
  elif type_hint in [
99
  "container",
100
  "generator",
 
106
  "list or tuple",
107
  "list",
108
  ]:
109
+ return Failure.UNSUPPORTED
110
  elif type_hint == "generator of sets":
111
+ return Failure.UNSUPPORTED
112
  elif type_hint == "dict or a set of 2 or 3 tuples":
113
+ return Failure.UNSUPPORTED
114
  elif type_hint == "set of 2 or 3 tuples":
115
+ return Failure.UNSUPPORTED
116
  elif type_hint == "none, string or function":
117
  return str | None
118
  elif type_hint == "string or function" and name == "weight":
 
137
  elif name == "weight":
138
  return str
139
  elif type_hint == "object":
140
+ return Failure.UNSUPPORTED
141
+ return Failure.SKIP
142
 
143
 
144
  def types_from_doc(doc: str) -> dict[str, type]:
 
188
  return wrapper
189
 
190
 
191
+ def _get_params(func) -> list[ops.Parameter | ops.ParameterGroup]:
192
  sig = inspect.signature(func)
193
  # Get types from docstring.
194
  types = types_from_doc(func.__doc__)
195
  # Always hide these.
196
  for k in ["backend", "backend_kwargs", "create_using"]:
197
+ types[k] = Failure.SKIP
198
  # Add in types based on signature.
199
  for k, param in sig.parameters.items():
200
  if k in types:
 
205
  types[k] = int
206
  params = []
207
  for name, param in sig.parameters.items():
208
+ _type = types.get(name, Failure.UNSUPPORTED)
209
+ if _type is Failure.UNSUPPORTED:
210
  raise UnsupportedParameterType(name)
211
+ if _type is Failure.SKIP or _type in [nx.Graph, nx.DiGraph]:
212
  continue
213
  p = ops.Parameter.basic(
214
  name=name,
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch/pytorch_core.py CHANGED
@@ -3,7 +3,7 @@
3
  import copy
4
  import graphlib
5
  import io
6
- import numpy as np
7
  import pydantic
8
  from lynxkite.core import ops, workspace
9
  import torch
@@ -36,8 +36,12 @@ def reg(name, inputs=[], outputs=None, params=[], **kwargs):
36
  return ops.register_passive_op(
37
  ENV,
38
  name,
39
- inputs=[ops.Input(name=name, position="bottom", type="tensor") for name in inputs],
40
- outputs=[ops.Output(name=name, position="top", type="tensor") for name in outputs],
 
 
 
 
41
  params=params,
42
  **kwargs,
43
  )
@@ -75,16 +79,16 @@ class ModelMapping(pydantic.BaseModel):
75
  map: dict[str, ColumnSpec]
76
 
77
 
78
- def _torch_save(data) -> str:
79
  """Saves PyTorch data (modules, tensors) as a string."""
80
  buffer = io.BytesIO()
81
  torch.save(data, buffer)
82
  return buffer.getvalue()
83
 
84
 
85
- def _torch_load(data: str) -> any:
86
  """Loads PyTorch data (modules, tensors) from a string."""
87
- buffer = io.BytesIO(data) # noqa: F821
88
  return torch.load(buffer)
89
 
90
 
@@ -98,7 +102,7 @@ class ModelConfig:
98
  input_output_names: list[str]
99
  loss: torch.nn.Module
100
  source_workspace_json: str
101
- optimizer_parameters: dict[str, any]
102
  optimizer: torch.optim.Optimizer | None = None
103
  source_workspace: str | None = None
104
  trained: bool = False
@@ -116,7 +120,7 @@ class ModelConfig:
116
  values = {k: v for k, v in zip(self.model_outputs, output)}
117
  return values
118
 
119
- def inference(self, inputs: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
120
  """Inference on a single batch."""
121
  self.model.eval()
122
  return self._forward(inputs)
@@ -422,7 +426,7 @@ class ModelBuilder:
422
  op = self.catalog["Optimizer"]
423
  cfg["optimizer_parameters"] = op.convert_params(self.nodes[self.optimizer].data.params)
424
  cfg["source_workspace_json"] = self.ws.model_dump_json()
425
- return ModelConfig(**cfg)
426
 
427
  def get_names(self, *ids: list[str]) -> dict[str, str]:
428
  """Returns a mapping from internal IDs to human-readable names."""
@@ -448,7 +452,7 @@ class ModelBuilder:
448
 
449
 
450
  def to_batch_tensors(
451
- b: core.Bundle, batch_size: int, batch_index: int, m: ModelMapping | None
452
  ) -> dict[str, torch.Tensor]:
453
  """Extracts tensors from a bundle for a specific batch using a model mapping."""
454
  tensors = {}
 
3
  import copy
4
  import graphlib
5
  import io
6
+ import typing
7
  import pydantic
8
  from lynxkite.core import ops, workspace
9
  import torch
 
36
  return ops.register_passive_op(
37
  ENV,
38
  name,
39
+ inputs=[
40
+ ops.Input(name=name, position=ops.Position.BOTTOM, type="tensor") for name in inputs
41
+ ],
42
+ outputs=[
43
+ ops.Output(name=name, position=ops.Position.TOP, type="tensor") for name in outputs
44
+ ],
45
  params=params,
46
  **kwargs,
47
  )
 
79
  map: dict[str, ColumnSpec]
80
 
81
 
82
+ def _torch_save(data) -> bytes:
83
  """Saves PyTorch data (modules, tensors) as a string."""
84
  buffer = io.BytesIO()
85
  torch.save(data, buffer)
86
  return buffer.getvalue()
87
 
88
 
89
+ def _torch_load(data: bytes) -> typing.Any:
90
  """Loads PyTorch data (modules, tensors) from a string."""
91
+ buffer = io.BytesIO(data)
92
  return torch.load(buffer)
93
 
94
 
 
102
  input_output_names: list[str]
103
  loss: torch.nn.Module
104
  source_workspace_json: str
105
+ optimizer_parameters: dict[str, typing.Any]
106
  optimizer: torch.optim.Optimizer | None = None
107
  source_workspace: str | None = None
108
  trained: bool = False
 
120
  values = {k: v for k, v in zip(self.model_outputs, output)}
121
  return values
122
 
123
+ def inference(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
124
  """Inference on a single batch."""
125
  self.model.eval()
126
  return self._forward(inputs)
 
426
  op = self.catalog["Optimizer"]
427
  cfg["optimizer_parameters"] = op.convert_params(self.nodes[self.optimizer].data.params)
428
  cfg["source_workspace_json"] = self.ws.model_dump_json()
429
+ return ModelConfig(**cfg) # ty: ignore[missing-argument]
430
 
431
  def get_names(self, *ids: list[str]) -> dict[str, str]:
432
  """Returns a mapping from internal IDs to human-readable names."""
 
452
 
453
 
454
  def to_batch_tensors(
455
+ b: core.Bundle, batch_size: int, batch_index: int, m: ModelMapping
456
  ) -> dict[str, torch.Tensor]:
457
  """Extracts tensors from a bundle for a specific batch using a model mapping."""
458
  tensors = {}
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch/pytorch_ops.py CHANGED
@@ -117,7 +117,7 @@ def neural_ode_mlp(
117
 
118
  @op("Attention", outputs=["outputs", "weights"])
119
  def attention(query, key, value, *, embed_dim=1024, num_heads=1, dropout=0.0):
120
- return torch.nn.MultiHeadAttention(embed_dim, num_heads, dropout=dropout, need_weights=True)
121
 
122
 
123
  @op("LayerNorm", outputs=["outputs", "weights"])
@@ -250,8 +250,8 @@ reg(
250
  ops.register_passive_op(
251
  ENV,
252
  "Repeat",
253
- inputs=[ops.Input(name="input", position="top", type="tensor")],
254
- outputs=[ops.Output(name="output", position="bottom", type="tensor")],
255
  params=[
256
  ops.Parameter.basic("times", 1, int),
257
  ops.Parameter.basic("same_weights", False, bool),
@@ -261,8 +261,8 @@ ops.register_passive_op(
261
  ops.register_passive_op(
262
  ENV,
263
  "Recurrent chain",
264
- inputs=[ops.Input(name="input", position="top", type="tensor")],
265
- outputs=[ops.Output(name="output", position="bottom", type="tensor")],
266
  params=[],
267
  )
268
 
 
117
 
118
  @op("Attention", outputs=["outputs", "weights"])
119
  def attention(query, key, value, *, embed_dim=1024, num_heads=1, dropout=0.0):
120
+ return torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
121
 
122
 
123
  @op("LayerNorm", outputs=["outputs", "weights"])
 
250
  ops.register_passive_op(
251
  ENV,
252
  "Repeat",
253
+ inputs=[ops.Input(name="input", position=ops.Position.TOP, type="tensor")],
254
+ outputs=[ops.Output(name="output", position=ops.Position.BOTTOM, type="tensor")],
255
  params=[
256
  ops.Parameter.basic("times", 1, int),
257
  ops.Parameter.basic("same_weights", False, bool),
 
261
  ops.register_passive_op(
262
  ENV,
263
  "Recurrent chain",
264
+ inputs=[ops.Input(name="input", position=ops.Position.TOP, type="tensor")],
265
+ outputs=[ops.Output(name="output", position=ops.Position.BOTTOM, type="tensor")],
266
  params=[],
267
  )
268
 
lynxkite-graph-analytics/tests/test_core.py CHANGED
@@ -20,19 +20,19 @@ async def test_multi_input_box():
20
  ws.add_node(
21
  id="1",
22
  type="node_type",
23
- data=workspace.WorkspaceNodeData(title="Create Bundle", params={}),
24
  position=workspace.Position(x=0, y=0),
25
  )
26
  ws.add_node(
27
  id="2",
28
  type="node_type",
29
- data=workspace.WorkspaceNodeData(title="Create Bundle", params={}),
30
  position=workspace.Position(x=0, y=0),
31
  )
32
  ws.add_node(
33
  id="3",
34
  type="node_type",
35
- data=workspace.WorkspaceNodeData(title="Multi input op", params={}),
36
  position=workspace.Position(x=0, y=0),
37
  )
38
  ws.edges = [
 
20
  ws.add_node(
21
  id="1",
22
  type="node_type",
23
+ title="Create Bundle",
24
  position=workspace.Position(x=0, y=0),
25
  )
26
  ws.add_node(
27
  id="2",
28
  type="node_type",
29
+ title="Create Bundle",
30
  position=workspace.Position(x=0, y=0),
31
  )
32
  ws.add_node(
33
  id="3",
34
  type="node_type",
35
+ title="Multi input op",
36
  position=workspace.Position(x=0, y=0),
37
  )
38
  ws.edges = [
lynxkite-graph-analytics/tests/test_lynxkite_ops.py CHANGED
@@ -12,7 +12,7 @@ async def test_execute_operation_not_in_catalog():
12
  ws.add_node(
13
  id="1",
14
  type="node_type",
15
- data=workspace.WorkspaceNodeData(title="Non existing op", params={}),
16
  position=workspace.Position(x=0, y=0),
17
  )
18
  await execute(ws)
@@ -78,25 +78,25 @@ async def test_execute_operation_inputs_correct_cast():
78
  ws.add_node(
79
  id="1",
80
  type="node_type",
81
- data=workspace.WorkspaceNodeData(title="Create Bundle", params={}),
82
  position=workspace.Position(x=0, y=0),
83
  )
84
  ws.add_node(
85
  id="2",
86
  type="node_type",
87
- data=workspace.WorkspaceNodeData(title="Bundle to Graph", params={}),
88
  position=workspace.Position(x=100, y=0),
89
  )
90
  ws.add_node(
91
  id="3",
92
  type="node_type",
93
- data=workspace.WorkspaceNodeData(title="Graph to Bundle", params={}),
94
  position=workspace.Position(x=200, y=0),
95
  )
96
  ws.add_node(
97
  id="4",
98
  type="node_type",
99
- data=workspace.WorkspaceNodeData(title="Dataframe to Bundle", params={}),
100
  position=workspace.Position(x=300, y=0),
101
  )
102
  ws.edges = [
@@ -136,19 +136,19 @@ async def test_multiple_inputs():
136
  ws.add_node(
137
  id="one",
138
  type="cool",
139
- data=workspace.WorkspaceNodeData(title="One", params={}),
140
  position=workspace.Position(x=0, y=0),
141
  )
142
  ws.add_node(
143
  id="two",
144
  type="cool",
145
- data=workspace.WorkspaceNodeData(title="Two", params={}),
146
  position=workspace.Position(x=100, y=0),
147
  )
148
  ws.add_node(
149
  id="smaller",
150
  type="cool",
151
- data=workspace.WorkspaceNodeData(title="Smaller?", params={}),
152
  position=workspace.Position(x=200, y=0),
153
  )
154
  ws.edges = [
@@ -188,8 +188,8 @@ async def test_optional_inputs():
188
  return a + (b or 0)
189
 
190
  assert maybe_add.__op__.inputs == [
191
- ops.Input(name="a", type=int, position="left"),
192
- ops.Input(name="b", type=int | None, position="left"),
193
  ]
194
  ws = workspace.Workspace(env="test", nodes=[], edges=[])
195
  a = ws.add_node(one)
 
12
  ws.add_node(
13
  id="1",
14
  type="node_type",
15
+ title="Non existing op",
16
  position=workspace.Position(x=0, y=0),
17
  )
18
  await execute(ws)
 
78
  ws.add_node(
79
  id="1",
80
  type="node_type",
81
+ title="Create Bundle",
82
  position=workspace.Position(x=0, y=0),
83
  )
84
  ws.add_node(
85
  id="2",
86
  type="node_type",
87
+ title="Bundle to Graph",
88
  position=workspace.Position(x=100, y=0),
89
  )
90
  ws.add_node(
91
  id="3",
92
  type="node_type",
93
+ title="Graph to Bundle",
94
  position=workspace.Position(x=200, y=0),
95
  )
96
  ws.add_node(
97
  id="4",
98
  type="node_type",
99
+ title="Dataframe to Bundle",
100
  position=workspace.Position(x=300, y=0),
101
  )
102
  ws.edges = [
 
136
  ws.add_node(
137
  id="one",
138
  type="cool",
139
+ title="One",
140
  position=workspace.Position(x=0, y=0),
141
  )
142
  ws.add_node(
143
  id="two",
144
  type="cool",
145
+ title="Two",
146
  position=workspace.Position(x=100, y=0),
147
  )
148
  ws.add_node(
149
  id="smaller",
150
  type="cool",
151
+ title="Smaller?",
152
  position=workspace.Position(x=200, y=0),
153
  )
154
  ws.edges = [
 
188
  return a + (b or 0)
189
 
190
  assert maybe_add.__op__.inputs == [
191
+ ops.Input(name="a", type=int, position=ops.Position.LEFT),
192
+ ops.Input(name="b", type=int | None, position=ops.Position.LEFT),
193
  ]
194
  ws = workspace.Workspace(env="test", nodes=[], edges=[])
195
  a = ws.add_node(one)
lynxkite-graph-analytics/tests/test_pytorch_model_ops.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from lynxkite.core import workspace
2
  from lynxkite_graph_analytics.pytorch import pytorch_core
3
  import torch
@@ -11,7 +12,8 @@ def make_ws(env, nodes: dict[str, dict], edges: list[tuple[str, str]]):
11
  del data["title"]
12
  ws.add_node(
13
  id=id,
14
- data=workspace.WorkspaceNodeData(title=title, params=data),
 
15
  )
16
  ws.edges = [
17
  workspace.WorkspaceEdge(
@@ -27,10 +29,12 @@ def make_ws(env, nodes: dict[str, dict], edges: list[tuple[str, str]]):
27
 
28
 
29
  def summarize_layers(m: pytorch_core.ModelConfig) -> str:
 
30
  return "".join(str(e)[:2] for e in m.model)
31
 
32
 
33
  def summarize_connections(m: pytorch_core.ModelConfig) -> str:
 
34
  return " ".join(
35
  "".join(n[0] for n in c.param_names) + "->" + "".join(n[0] for n in c.return_names)
36
  for c in m.model._children
 
1
+ import torch_geometric.nn as pyg_nn
2
  from lynxkite.core import workspace
3
  from lynxkite_graph_analytics.pytorch import pytorch_core
4
  import torch
 
12
  del data["title"]
13
  ws.add_node(
14
  id=id,
15
+ title=title,
16
+ params=data,
17
  )
18
  ws.edges = [
19
  workspace.WorkspaceEdge(
 
29
 
30
 
31
  def summarize_layers(m: pytorch_core.ModelConfig) -> str:
32
+ assert isinstance(m.model, pyg_nn.Sequential)
33
  return "".join(str(e)[:2] for e in m.model)
34
 
35
 
36
  def summarize_connections(m: pytorch_core.ModelConfig) -> str:
37
+ assert isinstance(m.model, pyg_nn.Sequential)
38
  return " ".join(
39
  "".join(n[0] for n in c.param_names) + "->" + "".join(n[0] for n in c.return_names)
40
  for c in m.model._children