הראה קוד מקור ל http_async

from __future__ import annotations
import os
import asyncio
from typing import Optional, Any, Mapping
import time
import logging

from resilience import (
    DEFAULT_CIRCUIT_POLICY,
    DEFAULT_RETRY_POLICY,
    CircuitBreakerPolicy,
    RetryPolicy,
    compute_backoff_delay,
    get_circuit_breaker,
    note_request_duration,
    note_retry,
    resolve_labels,
)

try:
    import aiohttp  # type: ignore
except Exception:  # pragma: no cover
    aiohttp = None  # type: ignore

try:
    from observability import emit_event as _emit_event  # type: ignore
except Exception:  # pragma: no cover
    def _emit_event(_event: str, **_fields):  # type: ignore
        return None

logger = logging.getLogger(__name__)

_session: Optional["aiohttp.ClientSession"] = None  # type: ignore[name-defined]


[תיעוד] class CircuitOpenError(RuntimeError):
[תיעוד] def __init__(self, service: str, endpoint: str): super().__init__(f"Circuit open for {service}:{endpoint}") self.service = service self.endpoint = endpoint
_RETRYABLE_STATUS_EXTRA = {408, 425, 429} def _should_retry_status(status_code: int) -> bool: try: code = int(status_code) except Exception: return False if code >= 500: return True return code in _RETRYABLE_STATUS_EXTRA def _is_retryable_exception(exc: Exception) -> bool: retryable = (asyncio.TimeoutError,) if aiohttp is not None: retryable = retryable + (aiohttp.ClientError,) # type: ignore[attr-defined] return isinstance(exc, retryable) async def _async_sleep_with_backoff(attempt: int, policy: RetryPolicy) -> None: delay = compute_backoff_delay(attempt, policy) if delay <= 0: return try: await asyncio.sleep(delay) except asyncio.CancelledError: raise except Exception: return def _int_env(name: str, default: int) -> int: try: env_val = os.getenv(name) if env_val not in (None, ""): return int(env_val) except Exception: pass # Fallback to config if available try: from config import config # type: ignore return int(getattr(config, name, default)) except Exception: return default def _float_env(name: str, default: float) -> float: try: env_val = os.getenv(name) if env_val not in (None, ""): return float(env_val) except Exception: pass return default def _build_session_kwargs() -> dict[str, Any]: kwargs: dict[str, Any] = {} if aiohttp is None: # pragma: no cover return kwargs total = _int_env("AIOHTTP_TIMEOUT_TOTAL", 10) limit = _int_env("AIOHTTP_POOL_LIMIT", 50) limit_per_host = _int_env("AIOHTTP_LIMIT_PER_HOST", 0) try: timeout = aiohttp.ClientTimeout(total=total) # type: ignore[attr-defined] kwargs["timeout"] = timeout except Exception: pass try: connector = aiohttp.TCPConnector( limit=limit, limit_per_host=(None if limit_per_host <= 0 else limit_per_host), use_dns_cache=True, ttl_dns_cache=300, enable_cleanup_closed=True, force_close=False, ) kwargs["connector"] = connector except Exception: pass return kwargs def _normalize_headers(headers: Any) -> Mapping[str, str]: if headers is None: return {} if isinstance(headers, Mapping): items = headers.items() else: try: items = dict(headers).items() # type: ignore[arg-type] except Exception: return {} normalized: dict[str, str] = {} for key, value in items: if key is None or value is None: continue try: skey = str(key) svalue = str(value) except Exception: continue normalized[skey] = svalue return normalized def _prepare_headers(headers: Any) -> Any: base = _normalize_headers(headers) try: from observability import prepare_outgoing_headers # type: ignore merged = prepare_outgoing_headers(base or None) except Exception: merged = None if merged is None or merged == {}: if headers is None: return None if base: return base return headers try: if aiohttp is not None: from aiohttp import CIMultiDict # type: ignore[attr-defined] return CIMultiDict(merged.items()) except Exception: pass return merged def _instrument_session(session: "aiohttp.ClientSession") -> None: # type: ignore[name-defined] if session is None: return if getattr(session, "_codebot_ctx_headers", False): return original_request = getattr(session, "_request", None) # type: ignore[attr-defined] if original_request is None or not callable(original_request): # אין לנו מה לעטוף אם המופע לא תומך ב-request פנימי (במוקים מסוימים) return async def _request(method: str, url: str, **kwargs): # type: ignore[override] try: prepared = _prepare_headers(kwargs.get("headers")) if prepared is not None: kwargs["headers"] = prepared except Exception: pass return await original_request(method, url, **kwargs) session._request = _request # type: ignore[assignment] setattr(session, "_codebot_ctx_headers", True) class _ResilientRequestContext: def __init__( self, method: str, url: str, *, session: Optional["aiohttp.ClientSession"], # type: ignore[name-defined] service: Optional[str], endpoint: Optional[str], retry_policy: Optional[RetryPolicy], circuit_policy: Optional[CircuitBreakerPolicy], max_attempts_override: Optional[int], backoff_base_override: Optional[float], backoff_cap_override: Optional[float], jitter_override: Optional[float], **kwargs, ) -> None: self.method = str(method).upper() self.url = str(url) self._session = session self._request_kwargs = dict(kwargs) self._service_hint = service self._endpoint_hint = endpoint self._slow_ms = _float_env("HTTP_SLOW_MS", 0.0) base_policy = retry_policy or DEFAULT_RETRY_POLICY def _maybe_int(value, fallback): try: return max(1, int(value)) except Exception: return fallback def _maybe_float(value, fallback): try: return float(value) except Exception: return fallback if ( max_attempts_override is not None or backoff_base_override is not None or backoff_cap_override is not None or jitter_override is not None ): base_policy = RetryPolicy( max_attempts=_maybe_int(max_attempts_override, base_policy.max_attempts), backoff_base=_maybe_float(backoff_base_override, base_policy.backoff_base), backoff_cap=_maybe_float(backoff_cap_override, base_policy.backoff_cap), jitter=_maybe_float(jitter_override, base_policy.jitter), ) self._retry_policy = base_policy self._circuit_policy = circuit_policy or DEFAULT_CIRCUIT_POLICY ( self._service_label, self._endpoint_label, self._display_service, self._display_endpoint, ) = resolve_labels(self.url, self._service_hint, self._endpoint_hint) self._breaker = get_circuit_breaker( self._service_label, self._endpoint_label, display_service=self._display_service, display_endpoint=self._display_endpoint, policy=self._circuit_policy, ) self._response: Optional["aiohttp.ClientResponse"] = None # type: ignore[name-defined] self._error: Optional[Exception] = None self._retries = 0 self._last_error_signature: Optional[str] = None self._last_duration_seconds: float = 0.0 async def __aenter__(self): if aiohttp is None: # pragma: no cover raise RuntimeError("aiohttp is not available in this environment") session = self._session or get_session() if not self._breaker.allow_request(): self._breaker.record_skip() note_request_duration(self._service_label, self._endpoint_label, "circuit_open", 0.0) try: _emit_event( "circuit_open_block", severity="warning", service=self._display_service, endpoint=self._display_endpoint, ) except Exception: pass raise CircuitOpenError(self._display_service, self._display_endpoint) max_attempts = max(1, self._retry_policy.max_attempts) for attempt in range(1, max_attempts + 1): start = time.perf_counter() try: response = await session.request(self.method, self.url, **self._request_kwargs) duration = time.perf_counter() - start self._last_duration_seconds = duration status_code = int(getattr(response, "status", 0) or 0) if self._slow_ms and (duration * 1000.0) > self._slow_ms: try: logger.warning( "slow_http_async", extra={ "method": self.method, "url": self.url, "status": status_code, "ms": round(duration * 1000.0, 1), }, ) except Exception: pass if _should_retry_status(status_code): self._breaker.record_failure() note_request_duration(self._service_label, self._endpoint_label, "http_error", duration) if attempt >= max_attempts: self._response = response self._last_error_signature = "HTTPStatus" try: _emit_event( "external_request_failure", severity="error", service=self._display_service, endpoint=self._display_endpoint, error_signature="HTTPStatus", retries=self._retries, ) except Exception: pass break try: await response.release() except Exception: pass self._retries += 1 note_retry(self._service_label, self._endpoint_label) await _async_sleep_with_backoff(attempt, self._retry_policy) continue self._breaker.record_success() note_request_duration(self._service_label, self._endpoint_label, "success", duration) self._retries = attempt - 1 self._response = response return response except Exception as exc: duration = time.perf_counter() - start self._last_duration_seconds = duration self._error = exc self._last_error_signature = type(exc).__name__ note_request_duration(self._service_label, self._endpoint_label, "exception", duration) self._breaker.record_failure() if not _is_retryable_exception(exc) or attempt >= max_attempts: break self._retries += 1 note_retry(self._service_label, self._endpoint_label) await _async_sleep_with_backoff(attempt, self._retry_policy) self._error = None continue if self._response is not None: return self._response if self._error is not None: try: _emit_event( "external_request_failure", severity="error", service=self._display_service, endpoint=self._display_endpoint, error_signature=self._last_error_signature or type(self._error).__name__, retries=self._retries, ) except Exception: pass raise self._error try: _emit_event( "external_request_failure", severity="error", service=self._display_service, endpoint=self._display_endpoint, error_signature="CircuitOpen", retries=self._retries, ) except Exception: pass raise CircuitOpenError(self._display_service, self._display_endpoint) async def __aexit__(self, exc_type, exc, tb) -> bool: if self._response is not None: try: await self._response.release() except Exception: pass return False
[תיעוד] def get_session() -> "aiohttp.ClientSession": # type: ignore[name-defined] global _session if aiohttp is None: # pragma: no cover raise RuntimeError("aiohttp is not available in this environment") # אם יש סשן קיים אבל הוא שייך ללולאה אחרת/סגורה – נבנה סשן חדש try: current_loop = asyncio.get_event_loop() except Exception: current_loop = None # type: ignore[assignment] if _session is not None and not getattr(_session, "closed", False): # נסה לחלץ את הלולאה המקורית של הסשן/קונקטור (מאפיין פנימי אך יציב יחסית) session_loop = getattr(_session, "_loop", None) if session_loop is None: connector = getattr(_session, "_connector", None) session_loop = getattr(connector, "_loop", None) try: loop_is_closed = bool(getattr(session_loop, "is_closed", lambda: False)()) if session_loop else False except Exception: loop_is_closed = False if session_loop is not None and (loop_is_closed or (current_loop is not None and session_loop is not current_loop)): # נסה לסגור את הסשן הישן בלולאה שלו (best-effort), ואז ניצור חדש try: is_running = bool(getattr(session_loop, "is_running", lambda: False)()) except Exception: is_running = False try: if is_running: session_loop.create_task(_session.close()) # type: ignore[call-arg] else: session_loop.run_until_complete(_session.close()) # type: ignore[call-arg] except Exception: # אל תמנע בנייה מחדש גם אם הסגירה נכשלה pass finally: _session = None if _session is None or getattr(_session, "closed", False): kwargs = _build_session_kwargs() _session = aiohttp.ClientSession(**kwargs) _instrument_session(_session) return _session
[תיעוד] def request(method: str, url: str, **kwargs) -> _ResilientRequestContext: session = kwargs.pop("session", None) retry_policy = kwargs.pop("retry_policy", None) circuit_policy = kwargs.pop("circuit_policy", None) service = kwargs.pop("service", None) endpoint = kwargs.pop("endpoint", None) max_attempts_override = kwargs.pop("max_attempts", None) backoff_base_override = kwargs.pop("backoff_base", None) backoff_cap_override = kwargs.pop("backoff_cap", None) jitter_override = kwargs.pop("jitter", None) return _ResilientRequestContext( method, url, session=session, service=service, endpoint=endpoint, retry_policy=retry_policy, circuit_policy=circuit_policy, max_attempts_override=max_attempts_override, backoff_base_override=backoff_base_override, backoff_cap_override=backoff_cap_override, jitter_override=jitter_override, **kwargs, )
[תיעוד] async def close_session() -> None: global _session try: if _session is not None and not getattr(_session, "closed", False): await _session.close() finally: _session = None