Source code for web_poet._layouts

from __future__ import annotations

import inspect
import sys
import types
import typing
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from collections.abc import Iterable

from itemadapter import ItemAdapter

from web_poet.fields import (
    _FIELDS_INFO_ATTRIBUTE_READ,
    FieldInfo,
    field,
    get_fields_dict,
)
from web_poet.pages import ItemPage, get_item_cls
from web_poet.utils import cached_method, ensure_awaitable


[docs] def layout_switch( cls: type[ItemPage] | None = None, *, switch_method: str = "get_layout", layouts: typing.Iterable[type[ItemPage]] | None = None, ): """Decorate a page object class to expose fields from its selected layout. The decorated class must define a method named by *switch_method* (``"get_layout"`` by default). The method can be synchronous or asynchronous and must return an :class:`~web_poet.pages.ItemPage` instance. By default, forwarded fields are inferred from the output item type field names. This keeps forwarding aligned with the declared item schema. For output item types that do not expose field names (for example, plain :class:`dict`), pass ``layouts`` explicitly. In that case, forwarded fields are the union of fields defined across the provided layout page object classes. If the decorated class already defines a field with the same name, :func:`layout_switch` gives priority to the selected layout field and falls back to the decorated class field when the selected layout does not define that field. """ def wrap(page_cls: type[ItemPage]) -> type[ItemPage]: nonlocal layouts if layouts: fields_to_forward = _layout_fields(layouts) else: fields_to_forward = _item_fields_from_page_cls(page_cls) layouts = _discover_layout_classes(page_cls) helper_names = _install_cached_layout_helpers(page_cls, switch_method) page_field_names = set(get_fields_dict(page_cls).keys()) for field_name in fields_to_forward: fallback_attr_name = None fallback_attr = None if field_name in page_field_names: fallback_attr_name = ( "_layout_switch_fallback_field__" f"{switch_method.replace('.', '_')}__{field_name}" ) if not hasattr(page_cls, fallback_attr_name): sentinel = object() fallback_attr = inspect.getattr_static( page_cls, field_name, sentinel ) if fallback_attr is not sentinel: setattr(page_cls, fallback_attr_name, fallback_attr) fallback_attr = getattr(page_cls, fallback_attr_name, None) forward_async = _should_forward_async( page_cls, layouts, field_name, switch_method, fallback_attr=fallback_attr, ) descriptor = _build_forwarding_field( field_name, async_helper_name=helper_names["async"], sync_helper_name=helper_names["sync"], forward_async=forward_async, fallback_attr_name=fallback_attr_name, ) setattr(page_cls, field_name, descriptor) _register_field(page_cls, field_name) return page_cls if cls is None: return wrap return wrap(cls)
def _item_fields_from_page_cls(page_cls: type[ItemPage]) -> list[str]: item_cls = get_item_cls(page_cls, default=dict) item_field_names = ItemAdapter.get_field_names_from_class(item_cls) if item_field_names is None: raise ValueError( "The output item type does not expose field names. " "Pass explicit layout classes via layout_switch(layouts=[...])." ) return list(item_field_names) def _discover_layout_classes(page_cls: type[ItemPage]) -> Iterable[type[ItemPage]]: hints = _get_type_hints(page_cls) classes: list[type[ItemPage]] = [] for annotation in hints.values(): for layout_cls in _get_item_page_classes(annotation): if layout_cls not in classes: classes.append(layout_cls) return classes def _get_type_hints(page_cls: type) -> dict[str, Any]: module = sys.modules.get(page_cls.__module__) globalns = vars(module) if module else None try: return typing.get_type_hints(page_cls, globalns=globalns, include_extras=True) except Exception: return dict(getattr(page_cls, "__annotations__", {})) def _get_item_page_classes(annotation: Any) -> Iterable[type[ItemPage]]: if isinstance(annotation, str): return [] origin = typing.get_origin(annotation) if origin is typing.Annotated: return _get_item_page_classes(typing.get_args(annotation)[0]) if origin in {typing.Union, types.UnionType}: result: list[type[ItemPage]] = [] for arg in typing.get_args(annotation): result.extend(_get_item_page_classes(arg)) return result if isinstance(annotation, type) and issubclass(annotation, ItemPage): return [annotation] return [] def _layout_fields(layout_classes: Iterable[type[ItemPage]]) -> list[str]: names: list[str] = [] for layout_cls in layout_classes: for field_name in get_fields_dict(layout_cls): if field_name not in names: names.append(field_name) return names def _install_cached_layout_helpers( page_cls: type[ItemPage], switch_method: str ) -> dict[str, str]: method_tag = switch_method.replace(".", "_") sync_name = f"_layout_switch_cached_layout_sync__{method_tag}" async_name = f"_layout_switch_cached_layout_async__{method_tag}" if not hasattr(page_cls, sync_name): @cached_method def _cached_layout_sync(self): return getattr(self, switch_method)() setattr(page_cls, sync_name, _cached_layout_sync) if not hasattr(page_cls, async_name): @cached_method async def _cached_layout_async(self): return await ensure_awaitable(getattr(self, switch_method)()) setattr(page_cls, async_name, _cached_layout_async) return {"sync": sync_name, "async": async_name} def _should_forward_async( page_cls: type[ItemPage], layout_classes: Iterable[type[ItemPage]], field_name: str, switch_method: str, fallback_attr: Any, ) -> bool: if inspect.iscoroutinefunction(getattr(page_cls, switch_method)): return True for layout_cls in layout_classes: attr = getattr(layout_cls, field_name, None) original_method = getattr(attr, "original_method", None) if original_method is not None and inspect.iscoroutinefunction(original_method): return True fallback_original_method = getattr(fallback_attr, "original_method", None) return bool( fallback_original_method is not None and inspect.iscoroutinefunction(fallback_original_method) ) def _build_forwarding_field( field_name: str, *, async_helper_name: str, sync_helper_name: str, forward_async: bool, fallback_attr_name: str | None, ): def _resolve_field_target(page, layout): layout_fields = get_fields_dict(layout) if field_name in layout_fields: return layout, field_name if fallback_attr_name is not None: return page, fallback_attr_name raise AttributeError( f"Selected layout {layout.__class__.__name__} does not define field " f"'{field_name}' and no fallback field is defined in " f"{page.__class__.__name__}" ) if forward_async: async def forwarded(self): layout = await getattr(self, async_helper_name)() target, attr_name = _resolve_field_target(self, layout) return await ensure_awaitable(getattr(target, attr_name)) else: def forwarded(self): layout = getattr(self, sync_helper_name)() target, attr_name = _resolve_field_target(self, layout) return getattr(target, attr_name) forwarded.__name__ = field_name return field(forwarded) def _register_field(page_cls: type[ItemPage], field_name: str) -> None: fields = dict(get_fields_dict(page_cls)) if field_name in fields: return fields[field_name] = FieldInfo(name=field_name) setattr(page_cls, _FIELDS_INFO_ATTRIBUTE_READ, fields)