Skip to content

Commit 9833e9d

Browse files
committed
address comments.
1 parent 1d19262 commit 9833e9d

File tree

6 files changed

+52
-98
lines changed

6 files changed

+52
-98
lines changed

python/flink_agents/api/event.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#################################################################################
18-
import json
1918
from abc import ABC
2019
from typing import Any
2120
from uuid import UUID, uuid4
@@ -36,13 +35,15 @@ class Event(BaseModel, ABC, extra="allow"):
3635

3736
@model_validator(mode='after')
3837
def validate_extra(self) -> 'Event':
39-
"""Make sure all extra properties are serializable."""
40-
for value in self.__pydantic_extra__.values():
41-
if isinstance(value, BaseModel):
42-
continue
43-
json.dumps(value)
38+
"""Ensure init fields is serializable."""
39+
self.model_dump_json()
4440
return self
4541

42+
def __setattr__(self, name: str, value: Any) -> None:
43+
super().__setattr__(name, value)
44+
# Ensure added property can be serialized.
45+
self.model_dump_json()
46+
4647

4748
class InputEvent(Event):
4849
"""Event generated by the framework, carrying an input data that

python/flink_agents/api/tests/test_event.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,25 @@
1818
from typing import Type
1919

2020
import pytest
21+
from pydantic import ValidationError
22+
from pydantic_core import PydanticSerializationError
2123

22-
from flink_agents.api.event import Event
24+
from flink_agents.api.event import Event, InputEvent, OutputEvent
2325

2426

25-
def test_event_serializable() -> None: #noqa D103
26-
event = Event(a="1")
27-
event.model_dump_json()
27+
def test_event_serializable_valid() -> None: #noqa D103
28+
Event(a=1, b=InputEvent(input=1), c=OutputEvent(output='111'))
2829

29-
with pytest.raises(TypeError):
30-
Event(a=Type[Event])
30+
def test_event_serializable_invalid() -> None: #noqa D103
31+
with pytest.raises(ValidationError):
32+
Event(a=1, b=Type[InputEvent])
33+
34+
def test_event_add_field() -> None: #noqa D103
35+
event = Event(a=1)
36+
event.c = Event()
37+
38+
def test_event_add_field_invalid() -> None: #noqa D103
39+
event = Event(a=1)
40+
with pytest.raises(PydanticSerializationError):
41+
event.c = Type[InputEvent]
3142

32-
event = Event(a=Event())
33-
event.model_dump_json()

python/flink_agents/plan/action.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#################################################################################
18-
import inspect
19-
from typing import Callable, List, Type
18+
from typing import List, Type
2019

2120
from pydantic import BaseModel
2221

@@ -41,6 +40,7 @@ class Action(BaseModel):
4140
"""
4241

4342
name: str
43+
#TODO: Raise a warning when the action has a return value, as it will be ignored.
4444
exec: Function
4545
listen_event_types: List[Type[Event]]
4646

@@ -52,19 +52,5 @@ def __init__(
5252
) -> None:
5353
"""Action will check function signature when init."""
5454
super().__init__(name=name, exec=exec, listen_event_types=listen_event_types)
55-
exec.check_signature(self.check_signature)
56-
57-
@classmethod
58-
def check_signature(cls, func: Callable) -> None:
59-
"""" Checker for action function signature."""
60-
#TODO: update check logic after import State and RunnerContext.
61-
params = inspect.signature(func).parameters
62-
if len(params) != 1:
63-
err_msg = "Action function must have exactly 1 parameter"
64-
raise TypeError(err_msg)
65-
for i, param in enumerate(params.values()):
66-
if i == 0:
67-
if not issubclass(param.annotation, Event):
68-
err_msg = "Action function first parameter must be Event"
69-
raise TypeError(err_msg)
55+
exec.check_signature([Event])
7056

python/flink_agents/plan/function.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Function(BaseModel, ABC):
2828
"""Base interface for user defined functions, includes python and java."""
2929

3030
@abstractmethod
31-
def check_signature(self, checker: Callable) -> None:
31+
def check_signature(self, *args: Tuple[Any, ...]) -> None:
3232
"""Check function signature is legal or not."""
3333

3434
@abstractmethod
@@ -76,9 +76,16 @@ def from_callable(func: Callable) -> Function:
7676
__func=func,
7777
)
7878

79-
def check_signature(self, checker: Callable) -> None:
80-
"""Apply external check logic to function signature."""
81-
checker(self.__get_func())
79+
def check_signature(self, *args: Tuple[Any, ...]) -> None:
80+
"""Check function signature."""
81+
params = inspect.signature(self.__get_func()).parameters
82+
annotations = [param.annotation for param in params.values()]
83+
err_msg = f"Expect {self.qualname} have signature {args}, but got {annotations}."
84+
if len(params) != args.__len__():
85+
raise TypeError(err_msg)
86+
for i, annotation in enumerate(annotations):
87+
if not issubclass(annotation, *args[i]):
88+
raise TypeError(err_msg)
8289

8390
def __call__(self, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any:
8491
"""Execute the stored function with provided arguments.
@@ -119,11 +126,12 @@ def __get_func(self) -> Callable:
119126
return self.__func
120127

121128

129+
#TODO: Implement JavaFunction.
122130
class JavaFunction(Function):
123131
"""Descriptor for a java callable function."""
124132

125133
def __call__(self, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any:
126134
"""Execute the stored function with provided arguments."""
127135

128-
def check_signature(self, checker: Callable) -> None:
136+
def check_signature(self, *args: Tuple[Any, ...]) -> None:
129137
"""Check function signature is legal or not."""

python/flink_agents/plan/tests/test_action.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,30 @@
1717
#################################################################################
1818
import pytest
1919

20-
from flink_agents.api.event import InputEvent, OutputEvent
20+
from flink_agents.api.event import InputEvent
2121
from flink_agents.plan.action import Action
2222
from flink_agents.plan.function import PythonFunction
2323

2424

25-
def increment(event: InputEvent) -> OutputEvent: # noqa: D103
26-
value = event.input
27-
value += 1
28-
return OutputEvent(output=value)
25+
def legal_signature(event: InputEvent) -> None: # noqa: D103
26+
pass
2927

30-
def decrement(value: int) -> OutputEvent: # noqa: D103
31-
value -= 1
32-
return OutputEvent(output=value)
28+
def illegal_signature(value: int) -> None: # noqa: D103
29+
pass
3330

34-
def test_action_signature() -> None: # noqa: D103
31+
def test_action_signature_legal() -> None: # noqa: D103
3532
Action(
36-
name="increment",
37-
exec=PythonFunction.from_callable(increment),
33+
name="legal",
34+
exec=PythonFunction.from_callable(legal_signature),
3835
listen_event_types=[InputEvent],
3936
)
4037

38+
def test_action_signature_illegal() -> None: # noqa: D103
4139
with pytest.raises(TypeError):
4240
Action(
43-
name="decrement",
44-
exec=PythonFunction.from_callable(decrement),
41+
name="illegal",
42+
exec=PythonFunction.from_callable(illegal_signature),
4543
listen_event_types=[InputEvent],
4644
)
4745

46+

python/flink_agents/plan/tests/test_workflow_plan.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

0 commit comments

Comments
 (0)