Skip to content

Instantly share code, notes, and snippets.

@mystcb
Created December 7, 2021 21:55
Show Gist options
  • Save mystcb/e86e7a56e8d4ffde6d283797acf032f3 to your computer and use it in GitHub Desktop.
Save mystcb/e86e7a56e8d4ffde6d283797acf032f3 to your computer and use it in GitHub Desktop.
Home Assistant - TP-Link Plugin - KLAP Auth Hacks
"""Component to embed TP-Link smart home devices."""
from __future__ import annotations
import asyncio
from datetime import timedelta
from typing import Any
from kasa import SmartDevice, SmartDeviceException
from kasa.discover import Discover
from kasa.auth import Auth
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components import network
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
from homeassistant.config_entries import ConfigEntry, ConfigEntryNotReady
from homeassistant.const import (
CONF_HOST,
CONF_MAC,
CONF_NAME,
EVENT_HOMEASSISTANT_STARTED,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType
from .const import (
CONF_DIMMER,
CONF_DISCOVERY,
CONF_LIGHT,
CONF_STRIP,
CONF_SWITCH,
DOMAIN,
PLATFORMS,
)
from .coordinator import TPLinkDataUpdateCoordinator
from .migration import (
async_migrate_entities_devices,
async_migrate_legacy_entries,
async_migrate_yaml_entries,
)
DISCOVERY_INTERVAL = timedelta(minutes=15)
TPLINK_HOST_SCHEMA = vol.Schema({vol.Required(CONF_HOST): cv.string})
CONFIG_SCHEMA = vol.Schema(
vol.All(
cv.deprecated(DOMAIN),
{
DOMAIN: vol.Schema(
{
vol.Optional(CONF_LIGHT, default=[]): vol.All(
cv.ensure_list, [TPLINK_HOST_SCHEMA]
),
vol.Optional(CONF_SWITCH, default=[]): vol.All(
cv.ensure_list, [TPLINK_HOST_SCHEMA]
),
vol.Optional(CONF_STRIP, default=[]): vol.All(
cv.ensure_list, [TPLINK_HOST_SCHEMA]
),
vol.Optional(CONF_DIMMER, default=[]): vol.All(
cv.ensure_list, [TPLINK_HOST_SCHEMA]
),
vol.Optional(CONF_DISCOVERY, default=True): cv.boolean,
}
)
},
),
extra=vol.ALLOW_EXTRA,
)
authentication = Auth(user="",password="")
@callback
def async_trigger_discovery(
hass: HomeAssistant,
discovered_devices: dict[str, SmartDevice],
) -> None:
"""Trigger config flows for discovered devices."""
for formatted_mac, device in discovered_devices.items():
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_DISCOVERY},
data={
CONF_NAME: device.alias,
CONF_HOST: device.host,
CONF_MAC: formatted_mac,
},
)
)
async def async_discover_devices(hass: HomeAssistant) -> dict[str, SmartDevice]:
"""Discover TPLink devices on configured network interfaces."""
broadcast_addresses = await network.async_get_ipv4_broadcast_addresses(hass)
tasks = [Discover.discover(target=str(address),authentication=authentication) for address in broadcast_addresses]
discovered_devices: dict[str, SmartDevice] = {}
for device_list in await asyncio.gather(*tasks):
for device in device_list.values():
discovered_devices[dr.format_mac(device.mac)] = device
return discovered_devices
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the TP-Link component."""
conf = config.get(DOMAIN)
hass.data[DOMAIN] = {}
legacy_entry = None
config_entries_by_mac = {}
for entry in hass.config_entries.async_entries(DOMAIN):
if async_entry_is_legacy(entry):
legacy_entry = entry
elif entry.unique_id:
config_entries_by_mac[entry.unique_id] = entry
discovered_devices = await async_discover_devices(hass)
hosts_by_mac = {mac: device.host for mac, device in discovered_devices.items()}
if legacy_entry:
async_migrate_legacy_entries(
hass, hosts_by_mac, config_entries_by_mac, legacy_entry
)
# Migrate the yaml entry that was previously imported
async_migrate_yaml_entries(hass, legacy_entry.data)
if conf is not None:
async_migrate_yaml_entries(hass, conf)
if discovered_devices:
async_trigger_discovery(hass, discovered_devices)
async def _async_discovery(*_: Any) -> None:
if discovered := await async_discover_devices(hass):
async_trigger_discovery(hass, discovered)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, _async_discovery)
async_track_time_interval(hass, _async_discovery, DISCOVERY_INTERVAL)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up TPLink from a config entry."""
if async_entry_is_legacy(entry):
return True
legacy_entry: ConfigEntry | None = None
for config_entry in hass.config_entries.async_entries(DOMAIN):
if async_entry_is_legacy(config_entry):
legacy_entry = config_entry
break
if legacy_entry is not None:
await async_migrate_entities_devices(hass, legacy_entry.entry_id, entry)
try:
device: SmartDevice = await Discover.discover_single(host=entry.data[CONF_HOST],authentication=authentication)
except SmartDeviceException as ex:
raise ConfigEntryNotReady from ex
if device.is_dimmer:
async_fix_dimmer_unique_id(hass, entry, device)
hass.data[DOMAIN][entry.entry_id] = TPLinkDataUpdateCoordinator(hass, device)
hass.config_entries.async_setup_platforms(entry, PLATFORMS)
return True
@callback
def async_fix_dimmer_unique_id(
hass: HomeAssistant, entry: ConfigEntry, device: SmartDevice
) -> None:
"""Migrate the unique id of dimmers back to the legacy one.
Dimmers used to use the switch format since pyHS100 treated them as SmartPlug but
the old code created them as lights
https://github.com/home-assistant/core/blob/2021.9.7/homeassistant/components/tplink/common.py#L86
"""
# This is the unique id before 2021.0/2021.1
original_unique_id = legacy_device_id(device)
# This is the unique id that was used in 2021.0/2021.1 rollout
rollout_unique_id = device.mac.replace(":", "").upper()
entity_registry = er.async_get(hass)
rollout_entity_id = entity_registry.async_get_entity_id(
LIGHT_DOMAIN, DOMAIN, rollout_unique_id
)
original_entry_id = entity_registry.async_get_entity_id(
LIGHT_DOMAIN, DOMAIN, original_unique_id
)
# If they are now using the 2021.0/2021.1 rollout entity id
# and have deleted the original entity id, we want to update that entity id
# so they don't end up with another _2 entity, but only if they deleted
# the original
if rollout_entity_id and not original_entry_id:
entity_registry.async_update_entity(
rollout_entity_id, new_unique_id=original_unique_id
)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
hass_data: dict[str, Any] = hass.data[DOMAIN]
if entry.entry_id not in hass_data:
return True
device: SmartDevice = hass.data[DOMAIN][entry.entry_id].device
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
hass_data.pop(entry.entry_id)
await device.protocol.close()
return unload_ok
@callback
def async_entry_is_legacy(entry: ConfigEntry) -> bool:
"""Check if a config entry is the legacy shared one."""
return entry.unique_id is None or entry.unique_id == DOMAIN
def legacy_device_id(device: SmartDevice) -> str:
"""Convert the device id so it matches what was used in the original version."""
device_id: str = device.device_id
# Plugs are prefixed with the mac in python-kasa but not
# in pyHS100 so we need to strip off the mac
if "_" not in device_id:
return device_id
return device_id.split("_")[1]
"""Config flow for TP-Link."""
from __future__ import annotations
import logging
from typing import Any
from kasa import SmartDevice, SmartDeviceException
from kasa.discover import Discover
from kasa.auth import Auth
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components.dhcp import IP_ADDRESS, MAC_ADDRESS
from homeassistant.const import CONF_DEVICE, CONF_HOST, CONF_MAC, CONF_NAME
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.typing import DiscoveryInfoType
from . import async_discover_devices, async_entry_is_legacy
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
authentication = Auth(user="", password="")
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for tplink."""
VERSION = 1
def __init__(self) -> None:
"""Initialize the config flow."""
self._discovered_devices: dict[str, SmartDevice] = {}
self._discovered_device: SmartDevice | None = None
async def async_step_dhcp(self, discovery_info: DiscoveryInfoType) -> FlowResult:
"""Handle discovery via dhcp."""
return await self._async_handle_discovery(
discovery_info[IP_ADDRESS], discovery_info[MAC_ADDRESS]
)
async def async_step_discovery(
self, discovery_info: DiscoveryInfoType
) -> FlowResult:
"""Handle discovery."""
return await self._async_handle_discovery(
discovery_info[CONF_HOST], discovery_info[CONF_MAC]
)
async def _async_handle_discovery(self, host: str, mac: str) -> FlowResult:
"""Handle any discovery."""
await self.async_set_unique_id(dr.format_mac(mac))
self._abort_if_unique_id_configured(updates={CONF_HOST: host})
self._async_abort_entries_match({CONF_HOST: host})
self.context[CONF_HOST] = host
for progress in self._async_in_progress():
if progress.get("context", {}).get(CONF_HOST) == host:
return self.async_abort(reason="already_in_progress")
try:
self._discovered_device = await self._async_try_connect(
host, raise_on_progress=True
)
except SmartDeviceException:
return self.async_abort(reason="cannot_connect")
return await self.async_step_discovery_confirm()
async def async_step_discovery_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Confirm discovery."""
assert self._discovered_device is not None
if user_input is not None:
return self._async_create_entry_from_device(self._discovered_device)
self._set_confirm_only()
placeholders = {
"name": self._discovered_device.alias,
"model": self._discovered_device.model,
"host": self._discovered_device.host,
}
self.context["title_placeholders"] = placeholders
return self.async_show_form(
step_id="discovery_confirm", description_placeholders=placeholders
)
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the initial step."""
errors = {}
if user_input is not None:
if not (host := user_input[CONF_HOST]):
return await self.async_step_pick_device()
try:
device = await self._async_try_connect(host, raise_on_progress=False)
except SmartDeviceException:
errors["base"] = "cannot_connect"
else:
return self._async_create_entry_from_device(device)
return self.async_show_form(
step_id="user",
data_schema=vol.Schema({vol.Optional(CONF_HOST, default=""): str}),
errors=errors,
)
async def async_step_pick_device(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the step to pick discovered device."""
if user_input is not None:
mac = user_input[CONF_DEVICE]
await self.async_set_unique_id(mac, raise_on_progress=False)
return self._async_create_entry_from_device(self._discovered_devices[mac])
configured_devices = {
entry.unique_id
for entry in self._async_current_entries()
if not async_entry_is_legacy(entry)
}
self._discovered_devices = await async_discover_devices(self.hass)
devices_name = {
formatted_mac: f"{device.alias} {device.model} ({device.host}) {formatted_mac}"
for formatted_mac, device in self._discovered_devices.items()
if formatted_mac not in configured_devices
}
# Check if there is at least one device
if not devices_name:
return self.async_abort(reason="no_devices_found")
return self.async_show_form(
step_id="pick_device",
data_schema=vol.Schema({vol.Required(CONF_DEVICE): vol.In(devices_name)}),
)
async def async_step_migration(self, migration_input: dict[str, Any]) -> FlowResult:
"""Handle migration from legacy config entry to per device config entry."""
mac = migration_input[CONF_MAC]
await self.async_set_unique_id(dr.format_mac(mac), raise_on_progress=False)
self._abort_if_unique_id_configured()
return self.async_create_entry(
title=migration_input[CONF_NAME],
data={
CONF_HOST: migration_input[CONF_HOST],
},
)
@callback
def _async_create_entry_from_device(self, device: SmartDevice) -> FlowResult:
"""Create a config entry from a smart device."""
self._abort_if_unique_id_configured(updates={CONF_HOST: device.host})
return self.async_create_entry(
title=f"{device.alias} {device.model}",
data={
CONF_HOST: device.host,
},
)
async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult:
"""Handle import step."""
host = user_input[CONF_HOST]
try:
device = await self._async_try_connect(host, raise_on_progress=False)
except SmartDeviceException:
_LOGGER.error("Failed to import %s: cannot connect", host)
return self.async_abort(reason="cannot_connect")
return self._async_create_entry_from_device(device)
async def _async_try_connect(
self, host: str, raise_on_progress: bool = True
) -> SmartDevice:
"""Try to connect."""
self._async_abort_entries_match({CONF_HOST: host})
device: SmartDevice = await Discover.discover_single(host=host, authentication=authentication)
await self.async_set_unique_id(
dr.format_mac(device.mac), raise_on_progress=raise_on_progress
)
return device
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment