Skip to content

Commit

Permalink
Make ChunkAsyncStreamIterator an aiohttp helper (#134843)
Browse files Browse the repository at this point in the history
make ChunkAsyncStreamIterator a generic aiohttp helper
  • Loading branch information
mib1185 authored Jan 6, 2025
1 parent bc22e34 commit acd9597
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 28 deletions.
30 changes: 3 additions & 27 deletions homeassistant/components/cloud/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from collections.abc import AsyncIterator, Callable, Coroutine, Mapping
import hashlib
import logging
from typing import Any, Self
from typing import Any

from aiohttp import ClientError, ClientTimeout, StreamReader
from aiohttp import ClientError, ClientTimeout
from hass_nabucasa import Cloud, CloudError
from hass_nabucasa.cloud_api import (
async_files_delete_file,
Expand All @@ -19,6 +19,7 @@

from homeassistant.components.backup import AgentBackup, BackupAgent, BackupAgentError
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.aiohttp_client import ChunkAsyncStreamIterator
from homeassistant.helpers.dispatcher import async_dispatcher_connect

from .client import CloudClient
Expand Down Expand Up @@ -73,31 +74,6 @@ def handle_event(data: Mapping[str, Any]) -> None:
return unsub


class ChunkAsyncStreamIterator:
"""Async iterator for chunked streams.
Based on aiohttp.streams.ChunkTupleAsyncStreamIterator, but yields
bytes instead of tuple[bytes, bool].
"""

__slots__ = ("_stream",)

def __init__(self, stream: StreamReader) -> None:
"""Initialize."""
self._stream = stream

def __aiter__(self) -> Self:
"""Iterate."""
return self

async def __anext__(self) -> bytes:
"""Yield next chunk."""
rv = await self._stream.readchunk()
if rv == (b"", False):
raise StopAsyncIteration
return rv[0]


class CloudBackupAgent(BackupAgent):
"""Cloud backup agent."""

Expand Down
27 changes: 26 additions & 1 deletion homeassistant/helpers/aiohttp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ssl import SSLContext
import sys
from types import MappingProxyType
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Self

import aiohttp
from aiohttp import web
Expand Down Expand Up @@ -82,6 +82,31 @@ async def json(
return await super().json(*args, loads=loads, **kwargs)


class ChunkAsyncStreamIterator:
"""Async iterator for chunked streams.
Based on aiohttp.streams.ChunkTupleAsyncStreamIterator, but yields
bytes instead of tuple[bytes, bool].
"""

__slots__ = ("_stream",)

def __init__(self, stream: aiohttp.StreamReader) -> None:
"""Initialize."""
self._stream = stream

def __aiter__(self) -> Self:
"""Iterate."""
return self

async def __anext__(self) -> bytes:
"""Yield next chunk."""
rv = await self._stream.readchunk()
if rv == (b"", False):
raise StopAsyncIteration
return rv[0]


@callback
@bind_hass
def async_get_clientsession(
Expand Down

0 comments on commit acd9597

Please sign in to comment.