Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions linux_voice_assistant/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
from .mpv_player import MpvMediaPlayer
from .util import call_all, get_mac, is_arm

from .event_bus import EventBus
from .event_led import LedEvent

_LOGGER = logging.getLogger(__name__)
_MODULE_DIR = Path(__file__).parent
_REPO_DIR = _MODULE_DIR.parent
Expand Down Expand Up @@ -78,6 +81,8 @@ class ServerState:
tts_player: MpvMediaPlayer
wakeup_sound: str
timer_finished_sound: str
loop: asyncio.AbstractEventLoop
event_bus: EventBus
media_player_entity: Optional[MediaPlayerEntity] = None
satellite: "Optional[VoiceSatelliteProtocol]" = None

Expand Down Expand Up @@ -110,11 +115,16 @@ def __init__(self, state: ServerState) -> None:
self._continue_conversation = False
self._timer_finished = False

self.state.event_bus.publish('ready', {})
_LOGGER.info('System is ready!')

def handle_voice_event(
self, event_type: VoiceAssistantEventType, data: Dict[str, str]
) -> None:
_LOGGER.debug("Voice event: type=%s, data=%s", event_type.name, data)

self.state.event_bus.publish(f'voice_{event_type.name}', data)

if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START:
self._tts_url = data.get("url")
self._tts_played = False
Expand Down Expand Up @@ -157,6 +167,7 @@ def handle_timer_event(
self._play_timer_finished()

def handle_message(self, msg: message.Message) -> Iterable[message.Message]:
_LOGGER.debug(f'message {msg.__name__}')
if isinstance(msg, VoiceAssistantEventResponse):
# Pipeline event
data: Dict[str, str] = {}
Expand Down Expand Up @@ -197,14 +208,11 @@ def handle_message(self, msg: message.Message) -> Iterable[message.Message]:
| VoiceAssistantFeature.TIMERS
),
)
elif isinstance(
msg,
(
ListEntitiesRequest,
SubscribeHomeAssistantStatesRequest,
MediaPlayerCommandRequest,
),
):
elif isinstance(msg, (
ListEntitiesRequest,
SubscribeHomeAssistantStatesRequest,
MediaPlayerCommandRequest,
),):
for entity in self.state.entities:
yield from entity.handle_message(msg)

Expand Down Expand Up @@ -245,13 +253,13 @@ def handle_message(self, msg: message.Message) -> Iterable[message.Message]:
break

def handle_audio(self, audio_chunk: bytes) -> None:

if not self._is_streaming_audio:
return

self.send_messages([VoiceAssistantAudio(data=audio_chunk)])

def wakeup(self) -> None:
# Why are we stopping the timer? Wouldn't it be better to delay it?
if self._timer_finished:
# Stop timer instead
self._timer_finished = False
Expand All @@ -264,6 +272,10 @@ def wakeup(self) -> None:
self.send_messages(
[VoiceAssistantRequest(start=True, wake_word_phrase=wake_word_phrase)]
)

self.state.event_bus.publish('voice_wakeword', {'wake_word_phrase': wake_word_phrase})


self.duck()
self._is_streaming_audio = True
self.state.tts_player.play(self.state.wakeup_sound)
Expand All @@ -286,6 +298,8 @@ def play_tts(self) -> None:
self._tts_played = True
_LOGGER.debug("Playing TTS response: %s", self._tts_url)

self.state.event_bus.publish('voice_play_tts', {})

self.state.stop_word.is_active = True
self.state.tts_player.play(self._tts_url, done_callback=self._tts_finished)

Expand All @@ -301,6 +315,9 @@ def _tts_finished(self) -> None:
self.state.stop_word.is_active = False
self.send_messages([VoiceAssistantAnnounceFinished()])

# Actual time the TTS stops speaking
self.state.event_bus.publish('voice__tts_finished', {})

if self._continue_conversation:
self.send_messages([VoiceAssistantRequest(start=True)])
self._is_streaming_audio = True
Expand Down Expand Up @@ -432,6 +449,8 @@ async def main() -> None:
stop_config_path = wake_word_dir / f"{args.stop_model}.json"
_LOGGER.debug("Loading stop model: %s", stop_config_path)
stop_model = MicroWakeWord.from_config(stop_config_path, libtensorflowlite_c_path)

loop = asyncio.get_running_loop()

state = ServerState(
name=args.name,
Expand All @@ -441,12 +460,16 @@ async def main() -> None:
available_wake_words=available_wake_words,
wake_word=wake_model,
stop_word=stop_model,
event_bus=EventBus(),
loop=loop,
music_player=MpvMediaPlayer(device=args.audio_output_device),
tts_player=MpvMediaPlayer(device=args.audio_output_device),
wakeup_sound=args.wakeup_sound,
timer_finished_sound=args.timer_finished_sound,
)

LedEvent(state)

process_audio_thread = threading.Thread(
target=process_audio, args=(state,), daemon=True
)
Expand All @@ -455,7 +478,6 @@ async def main() -> None:
def sd_callback(indata, _frames, _time, _status):
state.audio_queue.put_nowait(bytes(indata))

loop = asyncio.get_running_loop()
server = await loop.create_server(
lambda: VoiceSatelliteProtocol(state), host=args.host, port=args.port
)
Expand Down
67 changes: 67 additions & 0 deletions linux_voice_assistant/event_bus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
from typing import Any, Callable, Dict, List, Optional

_LOGGER = logging.getLogger(__name__)


class EventBus:
"""A simple synchronous publish/subscribe event bus."""

def __init__(self):
# A dictionary to hold listeners for specific string topics
self.topics: Dict[str, List[Callable[[Any], None]]] = {}

def subscribe(self, topic: str, listener: Callable[[Any], None]) -> None:
"""
Subscribes a listener to a topic.
"""

# _LOGGER.debug(f'EventBus subscribe {topic}')

if topic not in self.topics:
self.topics[topic] = []
self.topics[topic].append(listener)

def publish(self, topic: str, data: [dict, None]) -> None:
"""
Publishes an event to all subscribed listeners.
"""

# _LOGGER.debug(f'EventBus publish {topic}')

data['__topic'] = topic

listeners = self.topics.get(topic, [])
for listener in listeners:
listener(data)
Comment on lines +25 to +36

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Suggestion] Based on below suggestion:

Suggested change
def publish(self, topic: str, data: [dict, None]) -> None:
"""
Publishes an event to all subscribed listeners.
"""
# _LOGGER.debug(f'EventBus publish {topic}')
data['__topic'] = topic
listeners = self.topics.get(topic, [])
for listener in listeners:
listener(data)
async def publish(self, topic: str, data: [dict, None]) -> None:
"""
Publishes an event to all subscribed listeners.
"""
# _LOGGER.debug(f'EventBus publish {topic}')
data['__topic'] = topic
listeners = self.topics.get(topic, [])
for listener in sorted(listeners, key=lambda l: l._event_bus_priority):
if inspect.iscoroutine(listener):
await listener(data)
else:
listener(data)

When publishing events, use asyncio.create_task to publish the event.


# Client helpers for subscriptions

# The decorator to mark methods for subscription.
def subscribe(func: Callable) -> Callable:
"""Decorator to mark a method for event bus subscription."""
func._event_bus_subscribe = True
return func
Comment on lines +41 to +44

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Suggestion] Configurable handler metadata:

Suggested change
def subscribe(func: Callable) -> Callable:
"""Decorator to mark a method for event bus subscription."""
func._event_bus_subscribe = True
return func
@dataclass
class EventBusMetaData:
priority: int = 0
# Any other info related to the handler
def subscribe(priority: int=0, **kwargs):
"""Decorator to mark a method for event bus subscription."""
def wrapper(func: Callable) -> Callable:
"""Function wrapper."""
func._event_bus_subscribe = True
func._event_bus_meta = EventBusMetaData(
priority=priority,
...
)
return func
return wrapper

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nitpick] An event bus is based off events - not topics, it would scale far more easily to have event handlers react to VoiceAssistantEvents


class EventHandler:
"""
A base class for components that subscribe to events.

Subclasses should define event handlers as methods decorated with `@subscribe`.
The method name will automatically be used as the event topic.
"""

def __init__(self, state: Any):
self.state = state
self._subscribe_all_methods()
_LOGGER.debug(f"EventHandler {self.__class__.__name__} has subscribed to all decorated methods.")

def _subscribe_all_methods(self):
"""Finds and subscribes all methods decorated with @subscribe."""
for method_name in dir(self):
method = getattr(self, method_name)

if hasattr(method, '_event_bus_subscribe'):
# The topic is the name of the method itself.
self.state.event_bus.subscribe(method_name, method)
_LOGGER.debug(f"Subscribed method '{method_name}' to topic '{method_name}'")
Loading