Skip to content

Commit

Permalink
Add support for overriding urls and refactor configuration (#403)
Browse files Browse the repository at this point in the history
  • Loading branch information
edenhaus authored Jan 27, 2024
1 parent 95bb092 commit 3cb18de
Show file tree
Hide file tree
Showing 12 changed files with 324 additions and 196 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ from deebot_client.authentication import Authenticator
from deebot_client.commands import *
from deebot_client.commands.clean import CleanAction
from deebot_client.events import BatteryEvent
from deebot_client.models import Configuration
from deebot_client.mqtt_client import MqttClient, MqttConfiguration
from deebot_client.configuration import Configuration
from deebot_client.mqtt_client import MqttClient
from deebot_client.util import md5
from deebot_client.device import Device

Expand All @@ -42,17 +42,16 @@ country = "de"
async def main():
async with aiohttp.ClientSession() as session:
logging.basicConfig(level=logging.DEBUG)
config = Configuration(session, device_id=device_id, country=country)
config = create_config(session, device_id=device_id, country=country)

authenticator = Authenticator(config, account_id, password_hash)
authenticator = Authenticator(config.rest, account_id, password_hash)
api_client = ApiClient(authenticator)

devices_ = await api_client.get_devices()

bot = Device(devices_[0], authenticator)

mqtt_config = MqttConfiguration(config=config)
mqtt = MqttClient(mqtt_config, authenticator)
mqtt = MqttClient(config.mqtt, authenticator)
await bot.initialize(mqtt)

async def on_battery(event: BatteryEvent):
Expand Down
55 changes: 21 additions & 34 deletions deebot_client/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,26 @@

from aiohttp import ClientResponseError, hdrs

from .const import REALM
from .const import COUNTRY_CHINA, REALM
from .exceptions import ApiError, AuthenticationError, InvalidAuthenticationError
from .logging_filter import get_logger
from .models import Configuration, Credentials
from .models import Credentials
from .util import cancel, create_task, md5

if TYPE_CHECKING:
from collections.abc import Callable, Coroutine, Mapping

from .configuration import RestConfiguration


_LOGGER = get_logger(__name__)

_CLIENT_KEY = "1520391301804"
_CLIENT_SECRET = "6c319b2a5cd3e66e39159c2e28f2fce9" # noqa: S105
_AUTH_CLIENT_KEY = "1520391491841"
_AUTH_CLIENT_SECRET = "77ef58ce3afbe337da74aa8c5ab963a9" # noqa: S105
_USER_LOGIN_URL_FORMAT = (
"https://gl-{country}-api.ecovacs.{tld}/v1/private/{country}/{lang}/{deviceId}/{appCode}/"
"{appVersion}/{channel}/{deviceType}/user/login"
)
_GLOBAL_AUTHCODE_URL_FORMAT = (
"https://gl-{country}-openapi.ecovacs.{tld}/v1/global/auth/getAuthCode"
)
_USER_LOGIN_PATH_FORMAT = "/v1/private/{country}/{lang}/{deviceId}/{appCode}/{appVersion}/{channel}/{deviceType}/user/login"
_GLOBAL_AUTHCODE_PATH = "/v1/global/auth/getAuthCode"
_PATH_USERS_USER = "users/user.do"
_META = {
"lang": "EN",
Expand All @@ -42,28 +40,22 @@
MAX_RETRIES = 3


def _get_portal_url(config: Configuration, path: str) -> str:
subdomain = f"portal-{config.continent}" if config.country != "cn" else "portal"
return urljoin(f"https://{subdomain}.ecouser.net/api/", path)


class _AuthClient:
"""Ecovacs auth client."""

def __init__(
self,
config: Configuration,
config: RestConfiguration,
account_id: str,
password_hash: str,
) -> None:
self._config = config
self._account_id = account_id
self._password_hash = password_hash
self._tld = "com" if self._config.country != "cn" else "cn"

self._meta: dict[str, str] = {
**_META,
"country": self._config.country,
"country": self._config.country.lower(),
"deviceId": self._config.device_id,
}

Expand Down Expand Up @@ -102,9 +94,7 @@ async def login(self) -> Credentials:
async def __do_auth_response(
self, url: str, params: dict[str, Any]
) -> dict[str, Any]:
async with self._config.session.get(
url, params=params, timeout=60, ssl=self._config.verify_ssl
) as res:
async with self._config.session.get(url, params=params, timeout=60) as res:
res.raise_for_status()

# ecovacs returns a json but content_type header is set to text
Expand Down Expand Up @@ -134,9 +124,11 @@ async def __call_login_api(
"authTimeZone": "GMT-8",
}

url = _USER_LOGIN_URL_FORMAT.format(**self._meta, tld=self._tld)
url = urljoin(
self._config.login_url, _USER_LOGIN_PATH_FORMAT.format(**self._meta)
)

if self._config.country == "cn":
if self._config.country == COUNTRY_CHINA:
url += "CheckMobile"

return await self.__do_auth_response(
Expand Down Expand Up @@ -170,7 +162,7 @@ async def __call_auth_api(self, access_token: str, user_id: str) -> str:
"authTimespan": int(time.time() * 1000),
}

url = _GLOBAL_AUTHCODE_URL_FORMAT.format(**self._meta, tld=self._tld)
url = urljoin(self._config.auth_code_url, _GLOBAL_AUTHCODE_PATH)

res = await self.__do_auth_response(
url,
Expand All @@ -189,10 +181,10 @@ async def __call_login_by_it_token(
"token": auth_code,
"realm": REALM,
"resource": self._config.device_id,
"org": "ECOWW" if self._config.country != "cn" else "ECOCN",
"org": "ECOWW" if self._config.country != COUNTRY_CHINA else "ECOCN",
"last": "",
"country": self._config.country.upper()
if self._config.country != "cn"
"country": self._config.country
if self._config.country != COUNTRY_CHINA
else "Chinese",
"todo": "loginByItToken",
}
Expand Down Expand Up @@ -224,7 +216,7 @@ async def post(
credentials: Credentials | None = None,
) -> dict[str, Any]:
"""Perform a post request."""
url = _get_portal_url(self._config, path)
url = urljoin(self._config.portal_url, "api/" + path)
logger_requst_params = f"url={url}, params={query_params}, json={json}"

if credentials is not None:
Expand All @@ -250,12 +242,7 @@ async def post(

try:
async with self._config.session.post(
url,
json=json,
params=query_params,
headers=headers,
timeout=60,
ssl=self._config.verify_ssl,
url, json=json, params=query_params, headers=headers, timeout=60
) as res:
if res.status == HTTPStatus.OK:
response_data: dict[str, Any] = await res.json()
Expand Down Expand Up @@ -303,7 +290,7 @@ class Authenticator:

def __init__(
self,
config: Configuration,
config: RestConfiguration,
account_id: str,
password_hash: str,
) -> None:
Expand Down
106 changes: 106 additions & 0 deletions deebot_client/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Deebot configuration."""
from __future__ import annotations

from dataclasses import dataclass
import ssl
from typing import TYPE_CHECKING
from urllib.parse import urlparse

from deebot_client.const import COUNTRY_CHINA
from deebot_client.exceptions import DeebotError
from deebot_client.util.continents import get_continent_url_postfix

if TYPE_CHECKING:
from aiohttp import ClientSession


@dataclass(frozen=True, kw_only=True)
class MqttConfiguration:
"""Mqtt configuration."""

hostname: str
port: int
ssl_context: ssl.SSLContext | None
device_id: str


@dataclass(frozen=True, kw_only=True)
class RestConfiguration:
"""Rest configuration."""

session: ClientSession
device_id: str
country: str
portal_url: str
login_url: str
auth_code_url: str


@dataclass(frozen=True)
class Configuration:
"""Configuration representation."""

rest: RestConfiguration
mqtt: MqttConfiguration


def create_config(
session: ClientSession,
device_id: str,
country: str,
*,
override_mqtt_url: str | None = None,
override_rest_url: str | None = None,
) -> Configuration:
"""Create configuration."""
continent_postfix = get_continent_url_postfix(country)
if override_rest_url:
portal_url = login_url = auth_code_url = override_rest_url
else:
portal_url = f"https://portal{continent_postfix}.ecouser.net"
tld = "com" if country != COUNTRY_CHINA else "cn"
country = country.lower()
login_url = f"https://gl-{country}-api.ecovacs.{tld}"
auth_code_url = f"https://gl-{country}-openapi.ecovacs.{tld}"

rest_config = RestConfiguration(
session=session,
device_id=device_id,
country=country,
portal_url=portal_url,
login_url=login_url,
auth_code_url=auth_code_url,
)

if override_mqtt_url:
url = urlparse(override_mqtt_url)
match url.scheme:
case "mqtt":
default_port = 1883
ssl_ctx = None
case "mqtts":
default_port = 8883
ssl_ctx = ssl.create_default_context()
case _:
raise DeebotError("Invalid scheme. Expecting mqtt or mqtts")

if not url.hostname:
raise DeebotError("Hostame is required")

hostname = url.hostname
port = url.port or default_port
else:
hostname = f"mq{continent_postfix}.ecouser.net"
port = 443
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE

mqtt_config = MqttConfiguration(
hostname=hostname,
port=port,
ssl_context=ssl_ctx,
device_id=device_id,
)

return Configuration(rest_config, mqtt_config)
1 change: 1 addition & 0 deletions deebot_client/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
REQUEST_HEADERS = {
"User-Agent": "Dalvik/2.1.0 (Linux; U; Android 5.1.1; A5010 Build/LMY48Z)",
}
COUNTRY_CHINA = "CN"


class DataType(StrEnum):
Expand Down
48 changes: 0 additions & 48 deletions deebot_client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Required, TypedDict

from deebot_client.util.continents import get_continent

if TYPE_CHECKING:
from aiohttp import ClientSession

from deebot_client.capabilities import Capabilities
from deebot_client.const import DataType

Expand Down Expand Up @@ -164,47 +160,3 @@ def _str_to_bool_or_cert(value: bool | str) -> bool | str:

msg = f'Cannot convert "{value}" to a bool or certificate path'
raise ValueError(msg)


class Configuration:
"""Configuration representation."""

def __init__(
self,
session: ClientSession,
*,
device_id: str,
country: str,
continent: str | None = None,
verify_ssl: bool | str = True,
) -> None:
self._session = session
self._device_id = device_id
self._country = country.lower()
self._continent = (continent or get_continent(country)).lower()
self._verify_ssl = _str_to_bool_or_cert(verify_ssl)

@property
def session(self) -> ClientSession:
"""Client session."""
return self._session

@property
def device_id(self) -> str:
"""Device id."""
return self._device_id

@property
def country(self) -> str:
"""Country code."""
return self._country

@property
def continent(self) -> str:
"""Continent code."""
return self._continent

@property
def verify_ssl(self) -> bool | str:
"""Return bool or path to cert."""
return self._verify_ssl
Loading

0 comments on commit 3cb18de

Please sign in to comment.