Skip to content

Commit fe1dcf3

Browse files
cheuktviambot
andauthored
RSDK-807: Extra params for Input controller (#180)
Co-authored-by: viambot <viambot@users.noreply.github.com>
1 parent cb9f2bb commit fe1dcf3

File tree

6 files changed

+115
-35
lines changed

6 files changed

+115
-35
lines changed

examples/server/v1/components.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def _execute_callback(self, event):
395395
if callback:
396396
callback(event)
397397

398-
async def get_controls(self, **kwargs) -> List[Control]:
398+
async def get_controls(self, extra: Optional[Dict[str, Any]] = None, **kwargs) -> List[Control]:
399399
return [
400400
Control.ABSOLUTE_X,
401401
Control.ABSOLUTE_Y,
@@ -418,7 +418,7 @@ async def get_controls(self, **kwargs) -> List[Control]:
418418
Control.BUTTON_MENU,
419419
]
420420

421-
async def get_events(self, **kwargs) -> Dict[Control, Event]:
421+
async def get_events(self, extra: Optional[Dict[str, Any]] = None, **kwargs) -> Dict[Control, Event]:
422422
with self.lock:
423423
return {key: value for (key, value) in self.last_events.items()}
424424

src/viam/components/input/client.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from time import time
44
from typing import Any, Dict, List, Optional
55
from grpclib import GRPCError, Status
6+
from google.protobuf.struct_pb2 import Struct
67

78
from grpclib.client import Channel
89
import viam
@@ -19,6 +20,7 @@
1920
StreamEventsResponse,
2021
TriggerEventRequest,
2122
)
23+
from viam.utils import dict_to_struct
2224

2325
from .input import Control, ControlFunction, Controller, Event, EventType
2426

@@ -36,19 +38,29 @@ def __init__(self, name: str, channel: Channel):
3638
self._stream_lock = Lock()
3739
self._is_streaming = False
3840
self._is_stream_ready = False
41+
self._callback_extra: Struct = dict_to_struct({})
3942
super().__init__(name)
4043

41-
async def get_controls(self, *, timeout: Optional[float] = None) -> List[Control]:
42-
request = GetControlsRequest(controller=self.name)
44+
async def get_controls(self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None) -> List[Control]:
45+
if extra is None:
46+
extra = {}
47+
request = GetControlsRequest(controller=self.name, extra=dict_to_struct(extra))
4348
response: GetControlsResponse = await self.client.GetControls(request, timeout=timeout)
4449
return [Control(control) for control in response.controls]
4550

46-
async def get_events(self, *, timeout: Optional[float] = None) -> Dict[Control, Event]:
47-
request = GetEventsRequest(controller=self.name)
51+
async def get_events(self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None) -> Dict[Control, Event]:
52+
if extra is None:
53+
extra = {}
54+
request = GetEventsRequest(controller=self.name, extra=dict_to_struct(extra))
4855
response: GetEventsResponse = await self.client.GetEvents(request, timeout=timeout)
4956
return {Control(event.control): Event.from_proto(event) for (event) in response.events}
5057

51-
def register_control_callback(self, control: Control, triggers: List[EventType], function: Optional[ControlFunction]):
58+
def register_control_callback(
59+
self, control: Control, triggers: List[EventType], function: Optional[ControlFunction], extra: Optional[Dict[str, Any]] = None
60+
):
61+
if extra is None:
62+
extra = {}
63+
self._callback_extra = dict_to_struct(extra)
5264
with self._lock:
5365
callbacks = self.callbacks.get(control, {})
5466
for trigger in triggers:
@@ -71,8 +83,10 @@ def handle_task_result(task: asyncio.Task):
7183
task = asyncio.create_task(self._stream_events(), name=f"{viam._TASK_PREFIX}-input_stream_events")
7284
task.add_done_callback(handle_task_result)
7385

74-
async def trigger_event(self, event: Event, *, timeout: Optional[float] = None):
75-
request = TriggerEventRequest(controller=self.name, event=event.proto)
86+
async def trigger_event(self, event: Event, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None):
87+
if extra is None:
88+
extra = {}
89+
request = TriggerEventRequest(controller=self.name, event=event.proto, extra=dict_to_struct(extra))
7690
try:
7791
await self.client.TriggerEvent(request, timeout=timeout)
7892
except GRPCError as e:
@@ -88,7 +102,7 @@ async def _stream_events(self):
88102
if not self.callbacks:
89103
return
90104

91-
request = StreamEventsRequest(controller=self.name, events=[])
105+
request = StreamEventsRequest(controller=self.name, events=[], extra=self._callback_extra)
92106
with self._lock:
93107
for (control, callbacks) in self.callbacks.items():
94108
event = StreamEventsRequest.Events(

src/viam/components/input/input.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33
from datetime import datetime
44
from enum import Enum
5-
from typing import Callable, Dict, List, Optional
5+
from typing import Any, Callable, Dict, List, Optional
66

77
from google.protobuf.timestamp_pb2 import Timestamp
88
from typing_extensions import Self
@@ -136,7 +136,7 @@ class Controller(ComponentBase):
136136
"""
137137

138138
@abc.abstractmethod
139-
async def get_controls(self, *, timeout: Optional[float] = None, **kwargs) -> List[Control]:
139+
async def get_controls(self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs) -> List[Control]:
140140
"""
141141
Returns a list of Controls provided by the Controller
142142
@@ -146,7 +146,9 @@ async def get_controls(self, *, timeout: Optional[float] = None, **kwargs) -> Li
146146
...
147147

148148
@abc.abstractmethod
149-
async def get_events(self, *, timeout: Optional[float] = None, **kwargs) -> Dict[Control, Event]:
149+
async def get_events(
150+
self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs
151+
) -> Dict[Control, Event]:
150152
"""
151153
Returns the most recent Event for each input
152154
(which should be the current state)
@@ -157,7 +159,15 @@ async def get_events(self, *, timeout: Optional[float] = None, **kwargs) -> Dict
157159
...
158160

159161
@abc.abstractmethod
160-
def register_control_callback(self, control: Control, triggers: List[EventType], function: Optional[ControlFunction], **kwargs):
162+
def register_control_callback(
163+
self,
164+
control: Control,
165+
triggers: List[EventType],
166+
function: Optional[ControlFunction],
167+
*,
168+
extra: Optional[Dict[str, Any]] = None,
169+
**kwargs,
170+
):
161171
"""
162172
Register a function that will fire on given EventTypes for a given
163173
Control
@@ -171,7 +181,7 @@ def register_control_callback(self, control: Control, triggers: List[EventType],
171181
"""
172182
...
173183

174-
async def trigger_event(self, event: Event, *, timeout: Optional[float] = None, **kwargs):
184+
async def trigger_event(self, event: Event, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs):
175185
"""Directly send an Event (such as a button press) from external code
176186
177187
Args:

src/viam/components/input/service.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
TriggerEventRequest,
1919
TriggerEventResponse,
2020
)
21+
from viam.utils import struct_to_dict
2122

22-
from .input import Control, Controller, Event, EventType
2323

24+
from .input import Control, Controller, Event, EventType
2425

2526
LOGGER = viam.logging.getLogger(__name__)
2627

@@ -41,7 +42,7 @@ async def GetControls(self, stream: Stream[GetControlsRequest, GetControlsRespon
4142
except ComponentNotFoundError as e:
4243
raise e.grpc_error
4344
timeout = stream.deadline.time_remaining() if stream.deadline else None
44-
controls = await controller.get_controls(timeout=timeout)
45+
controls = await controller.get_controls(extra=struct_to_dict(request.extra), timeout=timeout)
4546
response = GetControlsResponse(controls=[c.value for c in controls])
4647
await stream.send_message(response)
4748

@@ -54,7 +55,7 @@ async def GetEvents(self, stream: Stream[GetEventsRequest, GetEventsResponse]) -
5455
except ComponentNotFoundError as e:
5556
raise e.grpc_error
5657
timeout = stream.deadline.time_remaining() if stream.deadline else None
57-
events = await controller.get_events(timeout=timeout)
58+
events = await controller.get_events(extra=struct_to_dict(request.extra), timeout=timeout)
5859
pb_events = [e.proto for e in events.values()]
5960
response = GetEventsResponse(events=pb_events)
6061
await stream.send_message(response)
@@ -85,11 +86,21 @@ def ctrlFunc(event: Event):
8586
for event in request.events:
8687
triggers = [EventType(et) for et in event.events]
8788
if len(triggers):
88-
controller.register_control_callback(Control(event.control), triggers, ctrlFunc)
89+
controller.register_control_callback(
90+
Control(event.control),
91+
triggers,
92+
ctrlFunc,
93+
extra=struct_to_dict(request.extra),
94+
)
8995

9096
cancelled_triggers = [EventType(et) for et in event.cancelled_events]
9197
if len(cancelled_triggers):
92-
controller.register_control_callback(Control(event.control), cancelled_triggers, None)
98+
controller.register_control_callback(
99+
Control(event.control),
100+
cancelled_triggers,
101+
None,
102+
extra=struct_to_dict(request.extra),
103+
)
93104

94105
# Asynchronously wait for messages to come over the read pipe and run the READ function whenever the pipe is ready.
95106
def read():
@@ -130,7 +141,12 @@ def unregister_pipe_callbacks():
130141
for event in request.events:
131142
triggers = [EventType(et) for et in event.events]
132143
if len(triggers):
133-
controller.register_control_callback(Control(event.control), triggers, None)
144+
controller.register_control_callback(
145+
Control(event.control),
146+
triggers,
147+
None,
148+
extra=struct_to_dict(request.extra),
149+
)
134150

135151
async def TriggerEvent(self, stream: Stream[TriggerEventRequest, TriggerEventResponse]) -> None:
136152
request = await stream.recv_message()
@@ -141,7 +157,7 @@ async def TriggerEvent(self, stream: Stream[TriggerEventRequest, TriggerEventRes
141157
controller = self.get_component(name)
142158
pb_event = request.event
143159
event = Event.from_proto(pb_event)
144-
await controller.trigger_event(event, timeout=timeout)
160+
await controller.trigger_event(event, extra=struct_to_dict(request.extra), timeout=timeout)
145161
except ComponentNotFoundError as e:
146162
raise e.grpc_error
147163
except NotSupportedError as e:

tests/mocks/components.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,11 @@ def __init__(self, name: str):
473473
self.events: Dict[Control, Event] = {}
474474
self.callbacks: Dict[Control, Dict[EventType, Optional[ControlFunction]]] = {}
475475
self.timeout: Optional[float] = None
476+
self.extra = None
477+
self.reg_extra = None
476478

477-
async def get_controls(self, *, timeout: Optional[float] = None, **kwargs) -> List[Control]:
479+
async def get_controls(self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs) -> List[Control]:
480+
self.extra = extra
478481
self.timeout = timeout
479482
return [
480483
Control.ABSOLUTE_X,
@@ -500,15 +503,27 @@ async def get_controls(self, *, timeout: Optional[float] = None, **kwargs) -> Li
500503
Control.BUTTON_E_STOP,
501504
]
502505

503-
async def get_events(self, *, timeout: Optional[float] = None, **kwargs) -> Dict[Control, Event]:
506+
async def get_events(
507+
self, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs
508+
) -> Dict[Control, Event]:
509+
self.extra = extra
504510
self.timeout = timeout
505511
return self.events
506512

507-
def register_control_callback(self, control: Control, triggers: List[EventType], function: Optional[ControlFunction], **kwargs):
513+
def register_control_callback(
514+
self,
515+
control: Control,
516+
triggers: List[EventType],
517+
function: Optional[ControlFunction],
518+
extra: Optional[Dict[str, Any]] = None,
519+
**kwargs,
520+
):
508521
self.callbacks[control] = {trigger: function for trigger in triggers}
522+
self.reg_extra = extra
509523

510-
async def trigger_event(self, event: Event, *, timeout: Optional[float] = None, **kwargs):
524+
async def trigger_event(self, event: Event, *, extra: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None, **kwargs):
511525
self.events[event.control] = event
526+
self.extra = extra
512527
self.timeout = timeout
513528
callback = self.callbacks.get(event.control, {}).get(event.event)
514529
if callback:

0 commit comments

Comments
 (0)