1
0
mirror of https://github.com/natekspencer/hacs-oasis_mini.git synced 2025-12-06 18:44:14 -05:00

Switch to using mqtt

This commit is contained in:
Nathan Spencer
2025-11-22 04:40:58 +00:00
parent 171a608314
commit 886d7598f3
34 changed files with 2036 additions and 1057 deletions

View File

@@ -0,0 +1,7 @@
"""Oasis control."""
from .clients import OasisCloudClient, OasisMqttClient
from .device import OasisDevice
from .exceptions import UnauthenticatedError
__all__ = ["OasisDevice", "OasisCloudClient", "OasisMqttClient", "UnauthenticatedError"]

View File

@@ -0,0 +1,7 @@
"""Oasis control clients."""
from .cloud_client import OasisCloudClient
from .http_client import OasisHttpClient
from .mqtt_client import OasisMqttClient
__all__ = ["OasisCloudClient", "OasisHttpClient", "OasisMqttClient"]

View File

@@ -0,0 +1,191 @@
"""Oasis cloud client."""
from __future__ import annotations
from datetime import timedelta
import logging
from typing import Any
from urllib.parse import urljoin
from aiohttp import ClientResponseError, ClientSession
from ..exceptions import UnauthenticatedError
from ..utils import now
_LOGGER = logging.getLogger(__name__)
BASE_URL = "https://app.grounded.so"
PLAYLISTS_REFRESH_LIMITER = timedelta(minutes=5)
class OasisCloudClient:
"""Cloud client for Oasis.
Responsibilities:
- Manage aiohttp session (optionally owned)
- Manage access token
- Provide async_* helpers for:
* login/logout
* user info
* devices
* tracks/playlists
* latest software metadata
"""
_session: ClientSession | None
_owns_session: bool
_access_token: str | None
# these are "cache" fields for tracks/playlists
_playlists_next_refresh: float
playlists: list[dict[str, Any]]
_playlist_details: dict[int, dict[str, str]]
def __init__(
self,
*,
session: ClientSession | None = None,
access_token: str | None = None,
) -> None:
self._session = session
self._owns_session = session is None
self._access_token = access_token
# simple in-memory caches
self._playlists_next_refresh = 0.0
self.playlists = []
self._playlist_details = {}
@property
def session(self) -> ClientSession:
"""Return (or lazily create) the aiohttp ClientSession."""
if self._session is None or self._session.closed:
self._session = ClientSession()
self._owns_session = True
return self._session
async def async_close(self) -> None:
"""Close owned session (call from HA unload / cleanup)."""
if self._session and not self._session.closed and self._owns_session:
await self._session.close()
@property
def access_token(self) -> str | None:
return self._access_token
@access_token.setter
def access_token(self, value: str | None) -> None:
self._access_token = value
async def async_login(self, email: str, password: str) -> None:
"""Login via the cloud and store the access token."""
response = await self._async_request(
"POST",
urljoin(BASE_URL, "api/auth/login"),
json={"email": email, "password": password},
)
token = response.get("access_token") if isinstance(response, dict) else None
self.access_token = token
_LOGGER.debug("Cloud login succeeded, token set: %s", bool(token))
async def async_logout(self) -> None:
"""Logout from the cloud."""
await self._async_auth_request("GET", "api/auth/logout")
self.access_token = None
async def async_get_user(self) -> dict:
"""Get current user info."""
return await self._async_auth_request("GET", "api/auth/user")
async def async_get_devices(self) -> list[dict[str, Any]]:
"""Get user devices (raw JSON from API)."""
return await self._async_auth_request("GET", "api/user/devices")
async def async_get_playlists(
self, personal_only: bool = False
) -> list[dict[str, Any]]:
"""Get playlists from the cloud (cached by PLAYLISTS_REFRESH_LIMITER)."""
if self._playlists_next_refresh <= now():
params = {"my_playlists": str(personal_only).lower()}
playlists = await self._async_auth_request(
"GET", "api/playlist", params=params
)
if playlists:
self.playlists = playlists
self._playlists_next_refresh = now() + PLAYLISTS_REFRESH_LIMITER
return self.playlists
async def async_get_track_info(self, track_id: int) -> dict[str, Any] | None:
"""Get single track info from the cloud."""
try:
return await self._async_auth_request("GET", f"api/track/{track_id}")
except ClientResponseError as err:
if err.status == 404:
return {"id": track_id, "name": f"Unknown Title (#{track_id})"}
except Exception as ex: # noqa: BLE001
_LOGGER.exception("Error fetching track %s: %s", track_id, ex)
return None
async def async_get_tracks(
self, tracks: list[int] | None = None
) -> list[dict[str, Any]]:
"""Get multiple tracks info from the cloud (handles pagination)."""
response = await self._async_auth_request(
"GET",
"api/track",
params={"ids[]": tracks or []},
)
if not response:
return []
track_details = response.get("data", [])
while next_page_url := response.get("next_page_url"):
response = await self._async_auth_request("GET", next_page_url)
track_details += response.get("data", [])
return track_details
async def async_get_latest_software_details(self) -> dict[str, int | str]:
"""Get latest software metadata from cloud."""
return await self._async_auth_request("GET", "api/software/last-version")
async def _async_auth_request(self, method: str, url: str, **kwargs: Any) -> Any:
"""Perform an authenticated cloud request."""
if not self.access_token:
raise UnauthenticatedError("Unauthenticated")
headers = kwargs.pop("headers", {}) or {}
headers["Authorization"] = f"Bearer {self.access_token}"
return await self._async_request(
method,
url if url.startswith("http") else urljoin(BASE_URL, url),
headers=headers,
**kwargs,
)
async def _async_request(self, method: str, url: str, **kwargs: Any) -> Any:
"""Low-level HTTP helper for both cloud and (if desired) device HTTP."""
session = self.session
_LOGGER.debug(
"%s %s",
method,
session._build_url(url).update_query( # pylint: disable=protected-access
kwargs.get("params"),
),
)
response = await session.request(method, url, **kwargs)
if response.status == 200:
if response.content_type == "application/json":
return await response.json()
if response.content_type == "text/plain":
return await response.text()
if response.content_type == "text/html" and BASE_URL in url:
text = await response.text()
if "login-page" in text:
raise UnauthenticatedError("Unauthenticated")
return None
if response.status == 401:
raise UnauthenticatedError("Unauthenticated")
response.raise_for_status()

View File

@@ -0,0 +1,215 @@
"""Oasis HTTP client (per-device)."""
from __future__ import annotations
import logging
from typing import Any
from aiohttp import ClientSession
from ..const import AUTOPLAY_MAP
from ..device import OasisDevice
from ..utils import _bit_to_bool, _parse_int
from .transport import OasisClientProtocol
_LOGGER = logging.getLogger(__name__)
class OasisHttpClient(OasisClientProtocol):
"""HTTP-based Oasis transport.
This client is typically used per-device (per host/IP).
It implements the OasisClientProtocol so OasisDevice can delegate
all commands through it.
"""
def __init__(self, host: str, session: ClientSession | None = None) -> None:
self._host = host
self._session: ClientSession | None = session
self._owns_session: bool = session is None
@property
def session(self) -> ClientSession:
if self._session is None or self._session.closed:
self._session = ClientSession()
self._owns_session = True
return self._session
async def async_close(self) -> None:
"""Close owned session."""
if self._session and not self._session.closed and self._owns_session:
await self._session.close()
@property
def url(self) -> str:
# These devices are plain HTTP, no TLS
return f"http://{self._host}/"
async def _async_request(self, method: str, url: str, **kwargs: Any) -> Any:
"""Low-level HTTP helper."""
session = self.session
_LOGGER.debug(
"%s %s",
method,
session._build_url(url).update_query( # pylint: disable=protected-access
kwargs.get("params"),
),
)
resp = await session.request(method, url, **kwargs)
if resp.status == 200:
if resp.content_type == "text/plain":
return await resp.text()
if resp.content_type == "application/json":
return await resp.json()
return None
resp.raise_for_status()
async def _async_get(self, **kwargs: Any) -> str | None:
return await self._async_request("GET", self.url, **kwargs)
async def _async_command(self, **kwargs: Any) -> str | None:
result = await self._async_get(**kwargs)
_LOGGER.debug("Result: %s", result)
return result
async def async_get_mac_address(self, device: OasisDevice) -> str | None:
"""Fetch MAC address via HTTP GETMAC."""
try:
mac = await self._async_get(params={"GETMAC": ""})
if isinstance(mac, str):
return mac.strip()
except Exception: # noqa: BLE001
_LOGGER.exception(
"Failed to get MAC address via HTTP for %s", device.serial_number
)
return None
async def async_send_ball_speed_command(
self,
device: OasisDevice,
speed: int,
) -> None:
await self._async_command(params={"WRIOASISSPEED": speed})
async def async_send_led_command(
self,
device: OasisDevice,
led_effect: str,
color: str,
led_speed: int,
brightness: int,
) -> None:
payload = f"{led_effect};0;{color};{led_speed};{brightness}"
await self._async_command(params={"WRILED": payload})
async def async_send_sleep_command(self, device: OasisDevice) -> None:
await self._async_command(params={"CMDSLEEP": ""})
async def async_send_move_job_command(
self,
device: OasisDevice,
from_index: int,
to_index: int,
) -> None:
await self._async_command(params={"MOVEJOB": f"{from_index};{to_index}"})
async def async_send_change_track_command(
self,
device: OasisDevice,
index: int,
) -> None:
await self._async_command(params={"CMDCHANGETRACK": index})
async def async_send_add_joblist_command(
self,
device: OasisDevice,
tracks: list[int],
) -> None:
# The old code passed the list directly; if the device expects CSV:
await self._async_command(params={"ADDJOBLIST": ",".join(map(str, tracks))})
async def async_send_set_playlist_command(
self,
device: OasisDevice,
playlist: list[int],
) -> None:
await self._async_command(params={"WRIJOBLIST": ",".join(map(str, playlist))})
# optional: optimistic state update
device.update_from_status_dict({"playlist": playlist})
async def async_send_set_repeat_playlist_command(
self,
device: OasisDevice,
repeat: bool,
) -> None:
await self._async_command(params={"WRIREPEATJOB": 1 if repeat else 0})
async def async_send_set_autoplay_command(
self,
device: OasisDevice,
option: str,
) -> None:
await self._async_command(params={"WRIWAITAFTER": option})
async def async_send_upgrade_command(
self,
device: OasisDevice,
beta: bool,
) -> None:
await self._async_command(params={"CMDUPGRADE": 1 if beta else 0})
async def async_send_play_command(self, device: OasisDevice) -> None:
await self._async_command(params={"CMDPLAY": ""})
async def async_send_pause_command(self, device: OasisDevice) -> None:
await self._async_command(params={"CMDPAUSE": ""})
async def async_send_stop_command(self, device: OasisDevice) -> None:
await self._async_command(params={"CMDSTOP": ""})
async def async_send_reboot_command(self, device: OasisDevice) -> None:
await self._async_command(params={"CMDBOOT": ""})
async def async_get_status(self, device: OasisDevice) -> None:
"""Fetch status via GETSTATUS and update the device."""
raw_status = await self._async_get(params={"GETSTATUS": ""})
if raw_status is None:
return
_LOGGER.debug("Status for %s: %s", device.serial_number, raw_status)
values = raw_status.split(";")
if len(values) < 7:
_LOGGER.warning(
"Unexpected status format for %s: %s", device.serial_number, values
)
return
playlist = [_parse_int(track) for track in values[3].split(",") if track]
shift = len(values) - 18 if len(values) > 17 else 0
try:
status: dict[str, Any] = {
"status_code": _parse_int(values[0]),
"error": _parse_int(values[1]),
"ball_speed": _parse_int(values[2]),
"playlist": playlist,
"playlist_index": min(_parse_int(values[4]), len(playlist)),
"progress": _parse_int(values[5]),
"led_effect": values[6],
"led_speed": _parse_int(values[8]),
"brightness": _parse_int(values[9]),
"color": values[10] if "#" in values[10] else None,
"busy": _bit_to_bool(values[11 + shift]),
"download_progress": _parse_int(values[12 + shift]),
"max_brightness": _parse_int(values[13 + shift]),
"repeat_playlist": _bit_to_bool(values[15 + shift]),
"autoplay": AUTOPLAY_MAP.get(value := values[16 + shift], value),
}
except Exception: # noqa: BLE001
_LOGGER.exception("Error parsing HTTP status for %s", device.serial_number)
return
device.update_from_status_dict(status)

View File

@@ -0,0 +1,517 @@
"""Oasis MQTT client (multi-device)."""
from __future__ import annotations
import asyncio
import base64
from datetime import UTC, datetime
import logging
import ssl
from typing import Any, Final
import aiomqtt
from ..const import AUTOPLAY_MAP
from ..device import OasisDevice
from ..utils import _bit_to_bool
from .transport import OasisClientProtocol
_LOGGER = logging.getLogger(__name__)
# mqtt connection parameters
HOST: Final = "mqtt.grounded.so"
PORT: Final = 8084
PATH: Final = "mqtt"
USERNAME: Final = "YXBw"
PASSWORD: Final = "RWdETFlKMDczfi4t"
RECONNECT_INTERVAL: Final = 4
class OasisMqttClient(OasisClientProtocol):
"""MQTT-based Oasis transport using WSS.
Responsibilities:
- Maintain a single MQTT connection to:
wss://mqtt.grounded.so:8084/mqtt
- Subscribe only to <serial>/STATUS/# for devices it knows about.
- Publish commands to <serial>/COMMAND/CMD
- Map MQTT payloads to OasisDevice.update_from_status_dict()
"""
def __init__(self) -> None:
# MQTT connection state
self._client: aiomqtt.Client | None = None
self._loop_task: asyncio.Task | None = None
self._connected_at: datetime | None = None
self._connected_event: asyncio.Event = asyncio.Event()
self._stop_event: asyncio.Event = asyncio.Event()
# Known devices by serial
self._devices: dict[str, OasisDevice] = {}
# Per-device events
self._first_status_events: dict[str, asyncio.Event] = {}
self._mac_events: dict[str, asyncio.Event] = {}
# Subscription bookkeeping
self._subscribed_serials: set[str] = set()
self._subscription_lock = asyncio.Lock()
def register_device(self, device: OasisDevice) -> None:
"""Register a device so MQTT messages can be routed to it."""
if not device.serial_number:
raise ValueError("Device must have serial_number set before registration")
serial = device.serial_number
self._devices[serial] = device
# Ensure we have per-device events
self._first_status_events.setdefault(serial, asyncio.Event())
self._mac_events.setdefault(serial, asyncio.Event())
# If we're already connected, subscribe to this device's topics
if self._client is not None:
try:
loop = asyncio.get_running_loop()
loop.create_task(self._subscribe_serial(serial))
except RuntimeError:
# No running loop (unlikely in HA), so just log
_LOGGER.debug(
"Could not schedule subscription for %s (no running loop)", serial
)
if not device.client:
device.attach_client(self)
def unregister_device(self, device: OasisDevice) -> None:
serial = device.serial_number
if not serial:
return
self._devices.pop(serial, None)
self._first_status_events.pop(serial, None)
self._mac_events.pop(serial, None)
# If connected and we were subscribed, unsubscribe
if self._client is not None and serial in self._subscribed_serials:
try:
loop = asyncio.get_running_loop()
loop.create_task(self._unsubscribe_serial(serial))
except RuntimeError:
_LOGGER.debug(
"Could not schedule unsubscription for %s (no running loop)",
serial,
)
async def _subscribe_serial(self, serial: str) -> None:
"""Subscribe to STATUS topics for a single device."""
if not self._client:
return
async with self._subscription_lock:
if not self._client or serial in self._subscribed_serials:
return
topic = f"{serial}/STATUS/#"
await self._client.subscribe([(topic, 1)])
self._subscribed_serials.add(serial)
_LOGGER.info("Subscribed to %s", topic)
async def _unsubscribe_serial(self, serial: str) -> None:
"""Unsubscribe from STATUS topics for a single device."""
if not self._client:
return
async with self._subscription_lock:
if not self._client or serial not in self._subscribed_serials:
return
topic = f"{serial}/STATUS/#"
await self._client.unsubscribe(topic)
self._subscribed_serials.discard(serial)
_LOGGER.info("Unsubscribed from %s", topic)
async def _resubscribe_all(self) -> None:
"""Resubscribe to all known devices after (re)connect."""
self._subscribed_serials.clear()
for serial in list(self._devices):
await self._subscribe_serial(serial)
def start(self) -> None:
"""Start MQTT connection loop."""
if self._loop_task is None or self._loop_task.done():
self._stop_event.clear()
loop = asyncio.get_running_loop()
self._loop_task = loop.create_task(self._mqtt_loop())
async def async_close(self) -> None:
"""Close connection loop and MQTT client."""
await self.stop()
async def stop(self) -> None:
"""Stop MQTT connection loop."""
self._stop_event.set()
if self._loop_task:
self._loop_task.cancel()
try:
await self._loop_task
except asyncio.CancelledError:
pass
if self._client:
try:
await self._client.disconnect()
except Exception:
_LOGGER.exception("Error disconnecting MQTT client")
finally:
self._client = None
async def wait_until_ready(
self, device: OasisDevice, timeout: float = 10.0, request_status: bool = True
) -> bool:
"""
Wait until:
1. MQTT client is connected
2. Device sends at least one STATUS message
If request_status=True, a request status command is sent *after* connection.
"""
serial = device.serial_number
if not serial:
raise RuntimeError("Device has no serial_number set")
first_status_event = self._first_status_events.setdefault(
serial, asyncio.Event()
)
# Wait for MQTT connection
try:
await asyncio.wait_for(self._connected_event.wait(), timeout=timeout)
except asyncio.TimeoutError:
_LOGGER.debug(
"Timeout (%.1fs) waiting for MQTT connection (device %s)",
timeout,
serial,
)
return False
# Optionally request a status refresh
if request_status:
try:
first_status_event.clear()
await self.async_get_status(device)
except Exception:
_LOGGER.debug(
"Could not request status for %s (not fully connected yet?)",
serial,
)
# Wait for first status
try:
await asyncio.wait_for(first_status_event.wait(), timeout=timeout)
return True
except asyncio.TimeoutError:
_LOGGER.debug(
"Timeout (%.1fs) waiting for first STATUS message from %s",
timeout,
serial,
)
return False
async def async_get_mac_address(self, device: OasisDevice) -> str | None:
"""For MQTT, GETSTATUS causes MAC_ADDRESS to be published."""
# If already known on the device, return it
if device.mac_address:
return device.mac_address
serial = device.serial_number
if not serial:
raise RuntimeError("Device has no serial_number set")
mac_event = self._mac_events.setdefault(serial, asyncio.Event())
mac_event.clear()
# Ask device to refresh status (including MAC_ADDRESS)
await self.async_get_status(device)
try:
await asyncio.wait_for(mac_event.wait(), timeout=3.0)
except asyncio.TimeoutError:
_LOGGER.debug("Timed out waiting for MAC_ADDRESS for %s", serial)
return device.mac_address
async def async_send_ball_speed_command(
self,
device: OasisDevice,
speed: int,
) -> None:
payload = f"WRIOASISSPEED={speed}"
await self._publish_command(device, payload)
async def async_send_led_command(
self,
device: OasisDevice,
led_effect: str,
color: str,
led_speed: int,
brightness: int,
) -> None:
payload = f"WRILED={led_effect};0;{color};{led_speed};{brightness}"
await self._publish_command(device, payload)
async def async_send_sleep_command(self, device: OasisDevice) -> None:
await self._publish_command(device, "CMDSLEEP")
async def async_send_move_job_command(
self,
device: OasisDevice,
from_index: int,
to_index: int,
) -> None:
payload = f"MOVEJOB={from_index};{to_index}"
await self._publish_command(device, payload)
async def async_send_change_track_command(
self,
device: OasisDevice,
index: int,
) -> None:
payload = f"CMDCHANGETRACK={index}"
await self._publish_command(device, payload)
async def async_send_add_joblist_command(
self,
device: OasisDevice,
tracks: list[int],
) -> None:
track_str = ",".join(map(str, tracks))
payload = f"ADDJOBLIST={track_str}"
await self._publish_command(device, payload)
async def async_send_set_playlist_command(
self,
device: OasisDevice,
playlist: list[int],
) -> None:
track_str = ",".join(map(str, playlist))
payload = f"WRIJOBLIST={track_str}"
await self._publish_command(device, payload)
# local state optimistic update
device.update_from_status_dict({"playlist": playlist})
async def async_send_set_repeat_playlist_command(
self,
device: OasisDevice,
repeat: bool,
) -> None:
payload = f"WRIREPEATJOB={1 if repeat else 0}"
await self._publish_command(device, payload)
async def async_send_set_autoplay_command(
self,
device: OasisDevice,
option: str,
) -> None:
payload = f"WRIWAITAFTER={option}"
await self._publish_command(device, payload)
async def async_send_upgrade_command(
self,
device: OasisDevice,
beta: bool,
) -> None:
payload = f"CMDUPGRADE={1 if beta else 0}"
await self._publish_command(device, payload)
async def async_send_play_command(self, device: OasisDevice) -> None:
await self._publish_command(device, "CMDPLAY")
async def async_send_pause_command(self, device: OasisDevice) -> None:
await self._publish_command(device, "CMDPAUSE")
async def async_send_stop_command(self, device: OasisDevice) -> None:
await self._publish_command(device, "CMDSTOP")
async def async_send_reboot_command(self, device: OasisDevice) -> None:
await self._publish_command(device, "CMDBOOT")
async def async_get_status(self, device: OasisDevice) -> None:
"""Ask device to publish STATUS topics."""
await self._publish_command(device, "GETSTATUS")
async def _publish_command(self, device: OasisDevice, payload: str) -> None:
if not self._client:
raise RuntimeError("MQTT client not connected yet")
serial = device.serial_number
if not serial:
raise RuntimeError("Device has no serial_number set")
topic = f"{serial}/COMMAND/CMD"
_LOGGER.debug("MQTT publish %s => %s", topic, payload)
await self._client.publish(topic, payload.encode(), qos=1)
async def _mqtt_loop(self) -> None:
loop = asyncio.get_running_loop()
tls_context = await loop.run_in_executor(None, ssl.create_default_context)
while not self._stop_event.is_set():
try:
_LOGGER.debug(
"Connecting MQTT WSS to wss://%s:%s/%s",
HOST,
PORT,
PATH,
)
async with aiomqtt.Client(
hostname=HOST,
port=PORT,
transport="websockets",
tls_context=tls_context,
username=base64.b64decode(USERNAME).decode(),
password=base64.b64decode(PASSWORD).decode(),
keepalive=30,
websocket_path=f"/{PATH}",
) as client:
self._client = client
self._connected_event.set()
self._connected_at = datetime.now(UTC)
_LOGGER.info("Connected to MQTT broker")
# Subscribe only to STATUS topics for known devices
await self._resubscribe_all()
async for msg in client.messages:
if self._stop_event.is_set():
break
await self._handle_status_message(msg)
except asyncio.CancelledError:
break
except Exception:
_LOGGER.debug("MQTT connection error")
finally:
if self._connected_event.is_set():
self._connected_event.clear()
if self._connected_at:
_LOGGER.debug(
"MQTT was connected for %s",
datetime.now(UTC) - self._connected_at,
)
self._connected_at = None
self._client = None
self._subscribed_serials.clear()
if not self._stop_event.is_set():
_LOGGER.debug(
"Disconnected from broker, retrying in %.1fs", RECONNECT_INTERVAL
)
await asyncio.sleep(RECONNECT_INTERVAL)
async def _handle_status_message(self, msg: aiomqtt.Message) -> None:
"""Map MQTT STATUS topics → OasisDevice.update_from_status_dict payloads."""
topic_str = str(msg.topic) if msg.topic is not None else ""
payload = msg.payload.decode(errors="replace")
parts = topic_str.split("/")
# Expect: "<serial>/STATUS/<STATUS_NAME>"
if len(parts) < 3:
return
serial, _, status_name = parts[:3]
device = self._devices.get(serial)
if not device:
# Ignore devices we don't know about
_LOGGER.debug("Received MQTT for unknown device %s: %s", serial, topic_str)
return
data: dict[str, Any] = {}
try:
if status_name == "OASIS_STATUS":
data["status_code"] = int(payload)
elif status_name == "OASIS_ERROR":
data["error"] = int(payload)
elif status_name == "OASIS_SPEEED":
data["ball_speed"] = int(payload)
elif status_name == "JOBLIST":
data["playlist"] = [int(x) for x in payload.split(",") if x]
elif status_name == "CURRENTJOB":
data["playlist_index"] = int(payload)
elif status_name == "CURRENTLINE":
data["progress"] = int(payload)
elif status_name == "LED_EFFECT":
data["led_effect"] = payload
elif status_name == "LED_EFFECT_COLOR":
data["led_effect_color"] = payload
elif status_name == "LED_SPEED":
data["led_speed"] = int(payload)
elif status_name == "LED_BRIGHTNESS":
data["brightness"] = int(payload)
elif status_name == "LED_MAX":
data["max_brightness"] = int(payload)
elif status_name == "LED_EFFECT_PARAM":
data["color"] = payload if payload.startswith("#") else None
elif status_name == "SYSTEM_BUSY":
data["busy"] = payload in ("1", "true", "True")
elif status_name == "DOWNLOAD_PROGRESS":
data["download_progress"] = int(payload)
elif status_name == "REPEAT_JOB":
data["repeat_playlist"] = payload in ("1", "true", "True")
elif status_name == "WAIT_AFTER_JOB":
data["autoplay"] = AUTOPLAY_MAP.get(payload, payload)
elif status_name == "AUTO_CLEAN":
data["auto_clean"] = payload in ("1", "true", "True")
elif status_name == "SOFTWARE_VER":
data["software_version"] = payload
elif status_name == "MAC_ADDRESS":
data["mac_address"] = payload
mac_event = self._mac_events.setdefault(serial, asyncio.Event())
mac_event.set()
elif status_name == "WIFI_SSID":
data["wifi_ssid"] = payload
elif status_name == "WIFI_IP":
data["wifi_ip"] = payload
elif status_name == "WIFI_PDNS":
data["wifi_pdns"] = payload
elif status_name == "WIFI_SDNS":
data["wifi_sdns"] = payload
elif status_name == "WIFI_GATE":
data["wifi_gate"] = payload
elif status_name == "WIFI_SUB":
data["wifi_sub"] = payload
elif status_name == "WIFI_STATUS":
data["wifi_connected"] = _bit_to_bool(payload)
elif status_name == "SCHEDULE":
data["schedule"] = payload
elif status_name == "ENVIRONMENT":
data["environment"] = payload
else:
_LOGGER.warning(
"Unknown status received for %s: %s=%s",
serial,
status_name,
payload,
)
except Exception: # noqa: BLE001
_LOGGER.exception(
"Error parsing MQTT payload for %s %s: %r", serial, status_name, payload
)
return
if data:
device.update_from_status_dict(data)
first_status_event = self._first_status_events.setdefault(
serial, asyncio.Event()
)
if not first_status_event.is_set():
first_status_event.set()

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
from typing import Protocol, runtime_checkable
from ..device import OasisDevice
@runtime_checkable
class OasisClientProtocol(Protocol):
"""Transport/client interface for an Oasis device.
Concrete implementations:
- MQTT client (remote connection)
- HTTP client (direct LAN)
"""
async def async_get_mac_address(self, device: OasisDevice) -> str | None: ...
async def async_send_ball_speed_command(
self,
device: OasisDevice,
speed: int,
) -> None: ...
async def async_send_led_command(
self,
device: OasisDevice,
led_effect: str,
color: str,
led_speed: int,
brightness: int,
) -> None: ...
async def async_send_sleep_command(self, device: OasisDevice) -> None: ...
async def async_send_move_job_command(
self,
device: OasisDevice,
from_index: int,
to_index: int,
) -> None: ...
async def async_send_change_track_command(
self,
device: OasisDevice,
index: int,
) -> None: ...
async def async_send_add_joblist_command(
self,
device: OasisDevice,
tracks: list[int],
) -> None: ...
async def async_send_set_playlist_command(
self,
device: OasisDevice,
playlist: list[int],
) -> None: ...
async def async_send_set_repeat_playlist_command(
self,
device: OasisDevice,
repeat: bool,
) -> None: ...
async def async_send_set_autoplay_command(
self,
device: OasisDevice,
option: str,
) -> None: ...
async def async_send_upgrade_command(
self,
device: OasisDevice,
beta: bool,
) -> None: ...
async def async_send_play_command(self, device: OasisDevice) -> None: ...
async def async_send_pause_command(self, device: OasisDevice) -> None: ...
async def async_send_stop_command(self, device: OasisDevice) -> None: ...
async def async_send_reboot_command(self, device: OasisDevice) -> None: ...

View File

@@ -0,0 +1,106 @@
"""Constants."""
from __future__ import annotations
import json
import os
from typing import Any, Final
__TRACKS_FILE = os.path.join(os.path.dirname(__file__), "tracks.json")
try:
with open(__TRACKS_FILE, "r", encoding="utf8") as file:
TRACKS: Final[dict[int, dict[str, Any]]] = {
int(k): v for k, v in json.load(file).items()
}
except Exception: # ignore: broad-except
TRACKS = {}
AUTOPLAY_MAP: Final[dict[str, str]] = {
"0": "on",
"1": "off",
"2": "5 minutes",
"3": "10 minutes",
"4": "30 minutes",
"5": "24 hours",
}
ERROR_CODE_MAP: Final[dict[int, str]] = {
0: "None",
1: "Error has occurred while reading the flash memory",
2: "Error while starting the Wifi",
3: "Error when starting DNS settings for your machine",
4: "Failed to open the file to write",
5: "Not enough memory to perform the upgrade",
6: "Error while trying to upgrade your system",
7: "Error while trying to download the new version of the software",
8: "Error while reading the upgrading file",
9: "Failed to start downloading the upgrade file",
10: "Error while starting downloading the job file",
11: "Error while opening the file folder",
12: "Failed to delete a file",
13: "Error while opening the job file",
14: "You have wrong power adapter",
15: "Failed to update the device IP on Oasis Server",
16: "Your device failed centering itself",
17: "There appears to be an issue with your Oasis Device",
18: "Error while downloading the job file",
}
LED_EFFECTS: Final[dict[str, str]] = {
"0": "Solid",
"1": "Rainbow",
"2": "Glitter",
"3": "Confetti",
"4": "Sinelon",
"5": "BPM",
"6": "Juggle",
"7": "Theater",
"8": "Color Wipe",
"9": "Sparkle",
"10": "Comet",
"11": "Follow Ball",
"12": "Follow Rainbow",
"13": "Chasing Comet",
"14": "Gradient Follow",
"15": "Cumulative Fill",
"16": "Multi Comets A",
"17": "Rainbow Chaser",
"18": "Twinkle Lights",
"19": "Tennis Game",
"20": "Breathing Exercise 4-7-8",
"21": "Cylon Scanner",
"22": "Palette Mode",
"23": "Aurora Flow",
"24": "Colorful Drops",
"25": "Color Snake",
"26": "Flickering Candles",
"27": "Digital Rain",
"28": "Center Explosion",
"29": "Rainbow Plasma",
"30": "Comet Race",
"31": "Color Waves",
"32": "Meteor Storm",
"33": "Firefly Flicker",
"34": "Ripple",
"35": "Jelly Bean",
"36": "Forest Rain",
"37": "Multi Comets",
"38": "Multi Comets with Background",
"39": "Rainbow Fill",
"40": "White Red Comet",
"41": "Color Comets",
}
STATUS_CODE_MAP: Final[dict[int, str]] = {
0: "booting",
2: "stopped",
3: "centering",
4: "playing",
5: "paused",
6: "sleeping",
9: "error",
11: "updating",
13: "downloading",
14: "busy",
15: "live",
}

View File

@@ -0,0 +1,327 @@
"""Oasis device."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Final, Iterable
from .const import ERROR_CODE_MAP, LED_EFFECTS, STATUS_CODE_MAP, TRACKS
if TYPE_CHECKING: # avoid runtime circular imports
from .clients.transport import OasisClientProtocol
_LOGGER = logging.getLogger(__name__)
BALL_SPEED_MAX: Final = 400
BALL_SPEED_MIN: Final = 100
LED_SPEED_MAX: Final = 90
LED_SPEED_MIN: Final = -90
_STATE_FIELDS = (
"autoplay",
"ball_speed",
"brightness",
"busy",
"color",
"download_progress",
"error",
"led_effect",
"led_speed",
"mac_address",
"max_brightness",
"playlist",
"playlist_index",
"progress",
"repeat_playlist",
"serial_number",
"software_version",
"status_code",
)
class OasisDevice:
"""Oasis device model + behavior.
Transport-agnostic; all I/O is delegated to an attached
OasisClientProtocol (MQTT, HTTP, etc.) via `attach_client`.
"""
manufacturer: Final = "Kinetic Oasis"
def __init__(
self,
*,
model: str | None = None,
serial_number: str | None = None,
ssid: str | None = None,
ip_address: str | None = None,
client: OasisClientProtocol | None = None,
) -> None:
# Transport
self._client: OasisClientProtocol | None = client
self._listeners: list[Callable[[], None]] = []
# Details
self.model: str | None = model
self.serial_number: str | None = serial_number
self.ssid: str | None = ssid
self.ip_address: str | None = ip_address
# Status
self.auto_clean: bool = False
self.autoplay: str = "off"
self.ball_speed: int = BALL_SPEED_MIN
self.brightness: int = 0
self.busy: bool = False
self.color: str | None = None
self.download_progress: int = 0
self.error: int = 0
self.led_effect: str = "0"
self.led_speed: int = 0
self.mac_address: str | None = None
self.max_brightness: int = 200
self.playlist: list[int] = []
self.playlist_index: int = 0
self.progress: int = 0
self.repeat_playlist: bool = False
self.software_version: str | None = None
self.status_code: int = 0
self.wifi_connected: bool = False
self.wifi_ip: str | None = None
self.wifi_ssid: str | None = None
self.wifi_pdns: str | None = None
self.wifi_sdns: str | None = None
self.wifi_gate: str | None = None
self.wifi_sub: str | None = None
self.environment: str | None = None
self.schedule: Any | None = None
# Track metadata cache (used if you hydrate from cloud)
self._track: dict | None = None
def attach_client(self, client: OasisClientProtocol) -> None:
"""Attach a transport client (MQTT, HTTP, etc.) to this device."""
self._client = client
@property
def client(self) -> OasisClientProtocol | None:
"""Return the current transport client, if any."""
return self._client
def _require_client(self) -> OasisClientProtocol:
"""Return the attached client or raise if missing."""
if self._client is None:
raise RuntimeError(
f"No client/transport attached for device {self.serial_number!r}"
)
return self._client
def _update_field(self, name: str, value: Any) -> bool:
old = getattr(self, name, None)
if old != value:
_LOGGER.debug(
"%s changed: '%s' -> '%s'",
name.replace("_", " ").capitalize(),
old,
value,
)
setattr(self, name, value)
return True
return False
def update_from_status_dict(self, data: dict[str, Any]) -> None:
"""Update device fields from a status payload (from any transport)."""
changed = False
for key, value in data.items():
if hasattr(self, key):
if self._update_field(key, value):
changed = True
else:
_LOGGER.warning("Unknown field: %s=%s", key, value)
if changed:
self._notify_listeners()
def as_dict(self) -> dict[str, Any]:
"""Return core state as a dict."""
return {field: getattr(self, field) for field in _STATE_FIELDS}
@property
def error_message(self) -> str | None:
"""Return the error message, if any."""
if self.status_code == 9:
return ERROR_CODE_MAP.get(self.error, f"Unknown ({self.error})")
return None
@property
def status(self) -> str:
"""Return human-readable status from status_code."""
return STATUS_CODE_MAP.get(self.status_code, f"Unknown ({self.status_code})")
@property
def track_id(self) -> int | None:
if not self.playlist:
return None
i = self.playlist_index
return self.playlist[0] if i >= len(self.playlist) else self.playlist[i]
@property
def track(self) -> dict | None:
"""Return cached track info if it matches the current `track_id`."""
if self._track and self._track.get("id") == self.track_id:
return self._track
if track := TRACKS.get(self.track_id):
self._track = track
return self._track
return None
@property
def drawing_progress(self) -> float | None:
"""Return drawing progress percentage for the current track."""
# if not (self.track and (svg_content := self.track.get("svg_content"))):
# return None
# svg_content = decrypt_svg_content(svg_content)
# paths = svg_content.split("L")
total = self.track.get("reduced_svg_content_new", 0) # or len(paths)
percent = (100 * self.progress) / total
return percent
@property
def playlist_details(self) -> dict[int, dict[str, str]]:
"""Basic playlist details using built-in TRACKS metadata."""
return {
track_id: TRACKS.get(
track_id,
{"name": f"Unknown Title (#{track_id})"},
)
for track_id in self.playlist
}
def add_update_listener(self, listener: Callable[[], None]) -> Callable[[], None]:
"""Register a callback for state changes.
Returns an unsubscribe function.
"""
self._listeners.append(listener)
def _unsub() -> None:
try:
self._listeners.remove(listener)
except ValueError:
pass
return _unsub
def _notify_listeners(self) -> None:
"""Call all registered listeners."""
for listener in list(self._listeners):
try:
listener()
except Exception: # noqa: BLE001
_LOGGER.exception("Error in update listener")
async def async_get_mac_address(self) -> str | None:
"""Return the device MAC address, refreshing via transport if needed."""
if self.mac_address:
return self.mac_address
client = self._require_client()
mac = await client.async_get_mac_address(self)
if mac:
self._update_field("mac_address", mac)
return mac
async def async_set_ball_speed(self, speed: int) -> None:
if not BALL_SPEED_MIN <= speed <= BALL_SPEED_MAX:
raise ValueError("Invalid speed specified")
client = self._require_client()
await client.async_send_ball_speed_command(self, speed)
async def async_set_led(
self,
*,
led_effect: str | None = None,
color: str | None = None,
led_speed: int | None = None,
brightness: int | None = None,
) -> None:
"""Set the Oasis Mini LED (shared validation & attribute updates)."""
if led_effect is None:
led_effect = self.led_effect
if color is None:
color = self.color or "#ffffff"
if led_speed is None:
led_speed = self.led_speed
if brightness is None:
brightness = self.brightness
if led_effect not in LED_EFFECTS:
raise ValueError("Invalid led effect specified")
if not LED_SPEED_MIN <= led_speed <= LED_SPEED_MAX:
raise ValueError("Invalid led speed specified")
if not 0 <= brightness <= self.max_brightness:
raise ValueError("Invalid brightness specified")
client = self._require_client()
await client.async_send_led_command(
self, led_effect, color, led_speed, brightness
)
async def async_sleep(self) -> None:
client = self._require_client()
await client.async_send_sleep_command(self)
async def async_move_track(self, from_index: int, to_index: int) -> None:
client = self._require_client()
await client.async_send_move_job_command(self, from_index, to_index)
async def async_change_track(self, index: int) -> None:
client = self._require_client()
await client.async_send_change_track_command(self, index)
async def async_add_track_to_playlist(self, track: int | Iterable[int]) -> None:
if isinstance(track, int):
tracks = [track]
else:
tracks = list(track)
client = self._require_client()
await client.async_send_add_joblist_command(self, tracks)
async def async_set_playlist(self, playlist: int | Iterable[int]) -> None:
if isinstance(playlist, int):
playlist_list = [playlist]
else:
playlist_list = list(playlist)
client = self._require_client()
await client.async_send_set_playlist_command(self, playlist_list)
async def async_set_repeat_playlist(self, repeat: bool) -> None:
client = self._require_client()
await client.async_send_set_repeat_playlist_command(self, repeat)
async def async_set_autoplay(self, option: bool | int | str) -> None:
"""Set autoplay / wait-after behavior."""
if isinstance(option, bool):
option = 0 if option else 1
client = self._require_client()
await client.async_send_set_autoplay_command(self, str(option))
async def async_upgrade(self, beta: bool = False) -> None:
client = self._require_client()
await client.async_send_upgrade_command(self, beta)
async def async_play(self) -> None:
client = self._require_client()
await client.async_send_play_command(self)
async def async_pause(self) -> None:
client = self._require_client()
await client.async_send_pause_command(self)
async def async_stop(self) -> None:
client = self._require_client()
await client.async_send_stop_command(self)
async def async_reboot(self) -> None:
client = self._require_client()
await client.async_send_reboot_command(self)

View File

@@ -0,0 +1,5 @@
"""Exceptions."""
class UnauthenticatedError(Exception):
"""Unauthenticated."""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,184 @@
"""Oasis Mini utils."""
from __future__ import annotations
import base64
from datetime import UTC, datetime
import logging
import math
from xml.etree.ElementTree import Element, SubElement, tostring
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
_LOGGER = logging.getLogger(__name__)
APP_KEY = "5joW8W4Usk4xUXu5bIIgGiHloQmzMZUMgz6NWQnNI04="
BACKGROUND_FILL = ("#CCC9C4", "#28292E")
COLOR_DARK = ("#28292E", "#F4F5F8")
COLOR_LIGHT = ("#FFFFFF", "#222428")
COLOR_LIGHT_SHADE = ("#FFFFFF", "#86888F")
COLOR_MEDIUM_SHADE = ("#E5E2DE", "#86888F")
COLOR_MEDIUM_TINT = ("#B8B8B8", "#FFFFFF")
def _bit_to_bool(val: str) -> bool:
"""Convert a bit string to bool."""
return val == "1"
def _parse_int(val: str) -> int:
"""Convert an int string to int."""
try:
return int(val)
except Exception:
return 0
def draw_svg(track: dict, progress: int, model_id: str) -> str | None:
"""Draw SVG."""
if track and (svg_content := track.get("svg_content")):
try:
if progress is not None:
svg_content = decrypt_svg_content(svg_content)
paths = svg_content.split("L")
total = track.get("reduced_svg_content_new", 0) or len(paths)
percent = min((100 * progress) / total, 100)
progress = math.floor((percent / 100) * (len(paths) - 1))
svg = Element(
"svg",
{
"title": "OasisStatus",
"version": "1.1",
"viewBox": "-25 -25 250 250",
"xmlns": "http://www.w3.org/2000/svg",
"class": "svg-status",
},
)
style = SubElement(svg, "style")
style.text = f"""
circle.background {{ fill: {BACKGROUND_FILL[0]}; }}
circle.ball {{ stroke: {COLOR_DARK[0]}; fill: {COLOR_LIGHT[0]}; }}
path.progress_arc {{ stroke: {COLOR_MEDIUM_SHADE[0]}; }}
path.progress_arc_complete {{ stroke: {COLOR_DARK[0]}; }}
path.track {{ stroke: {COLOR_LIGHT_SHADE[0]}; }}
path.track_complete {{ stroke: {COLOR_MEDIUM_TINT[0]}; }}
@media (prefers-color-scheme: dark) {{
circle.background {{ fill: {BACKGROUND_FILL[1]}; }}
circle.ball {{ stroke: {COLOR_DARK[1]}; fill: {COLOR_LIGHT[1]}; }}
path.progress_arc {{ stroke: {COLOR_MEDIUM_SHADE[1]}; }}
path.progress_arc_complete {{ stroke: {COLOR_DARK[1]}; }}
path.track {{ stroke: {COLOR_LIGHT_SHADE[1]}; }}
path.track_complete {{ stroke: {COLOR_MEDIUM_TINT[1]}; }}
}}""".replace("\n", " ").strip()
group = SubElement(
svg,
"g",
{"stroke-linecap": "round", "fill": "none", "fill-rule": "evenodd"},
)
progress_arc = "M37.85,203.55L32.85,200.38L28.00,196.97L23.32,193.32L18.84,189.45L14.54,185.36L10.45,181.06L6.58,176.58L2.93,171.90L-0.48,167.05L-3.65,162.05L-6.57,156.89L-9.24,151.59L-11.64,146.17L-13.77,140.64L-15.63,135.01L-17.22,129.30L-18.51,123.51L-19.53,117.67L-20.25,111.79L-20.69,105.88L-20.84,99.95L-20.69,94.02L-20.25,88.11L-19.53,82.23L-18.51,76.39L-17.22,70.60L-15.63,64.89L-13.77,59.26L-11.64,53.73L-9.24,48.31L-6.57,43.01L-3.65,37.85L-0.48,32.85L2.93,28.00L6.58,23.32L10.45,18.84L14.54,14.54L18.84,10.45L23.32,6.58L28.00,2.93L32.85,-0.48L37.85,-3.65L43.01,-6.57L48.31,-9.24L53.73,-11.64L59.26,-13.77L64.89,-15.63L70.60,-17.22L76.39,-18.51L82.23,-19.53L88.11,-20.25L94.02,-20.69L99.95,-20.84L105.88,-20.69L111.79,-20.25L117.67,-19.53L123.51,-18.51L129.30,-17.22L135.01,-15.63L140.64,-13.77L146.17,-11.64L151.59,-9.24L156.89,-6.57L162.05,-3.65L167.05,-0.48L171.90,2.93L176.58,6.58L181.06,10.45L185.36,14.54L189.45,18.84L193.32,23.32L196.97,28.00L200.38,32.85L203.55,37.85L206.47,43.01L209.14,48.31L211.54,53.73L213.67,59.26L215.53,64.89L217.12,70.60L218.41,76.39L219.43,82.23L220.15,88.11L220.59,94.02L220.73,99.95L220.59,105.88L220.15,111.79L219.43,117.67L218.41,123.51L217.12,129.30L215.53,135.01L213.67,140.64L211.54,146.17L209.14,151.59L206.47,156.89L203.55,162.05L200.38,167.05L196.97,171.90L193.32,176.58L189.45,181.06L185.36,185.36L181.06,189.45L176.58,193.32L171.90,196.97L167.05,200.38"
SubElement(
group,
"path",
{
"class": "progress_arc",
"stroke-width": "2",
"d": progress_arc,
},
)
progress_arc_paths = progress_arc.split("L")
paths_to_draw = math.floor((percent * len(progress_arc_paths)) / 100)
SubElement(
group,
"path",
{
"class": "progress_arc_complete",
"stroke-width": "4",
"d": "L".join(progress_arc_paths[:paths_to_draw]),
},
)
SubElement(
group,
"circle",
{
"class": "background",
"r": "100",
"cx": "100",
"cy": "100",
"opacity": "0.3",
},
)
SubElement(
group,
"path",
{
"class": "track",
"stroke-width": "1.4",
"d": svg_content,
},
)
SubElement(
group,
"path",
{
"class": "track_complete",
"stroke-width": "1.8",
"d": "L".join(paths[:progress]),
},
)
_cx, _cy = map(float, paths[progress].replace("M", "").split(","))
SubElement(
group,
"circle",
{
"class": "ball",
"stroke-width": "1",
"cx": f"{_cx:.2f}",
"cy": f"{_cy:.2f}",
"r": "5",
},
)
return tostring(svg).decode()
except Exception as e:
_LOGGER.exception(e)
return None
def decrypt_svg_content(svg_content: dict[str, str]):
"""Decrypt SVG content using AES CBC mode."""
if decrypted := svg_content.get("decrypted"):
return decrypted
# decode base64-encoded data
key = base64.b64decode(APP_KEY)
iv = base64.b64decode(svg_content["iv"])
ciphertext = base64.b64decode(svg_content["content"])
# create the cipher and decrypt the ciphertext
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
decryptor = cipher.decryptor()
decrypted = decryptor.update(ciphertext) + decryptor.finalize()
# remove PKCS7 padding
pad_len = decrypted[-1]
decrypted = decrypted[:-pad_len].decode("utf-8")
# save decrypted data so we don't have to do this each time
svg_content["decrypted"] = decrypted
return decrypted
def now() -> datetime:
return datetime.now(UTC)