Source code for neofoam.framework.dependency_resolver

# SPDX-License-Identifier: GPL-3.0-or-later
# SPDX-FileCopyrightText: 2025 NeoFOAM authors

"""
Runtime dependency resolution for Depends() markers.

Handles dependency injection during solver/model execution.
"""

import inspect
from typing import Annotated, Any, Callable, Optional, get_args, get_origin

from .context import Context
from .initialization.depends import Depends


[docs] class DependencyResolver: """Resolve Depends() markers and context-backed arguments.""" def __init__(self) -> None: self._cache: dict[str, dict[str, Any]] = { "time_step": {}, "iteration": {}, "operation": {}, }
[docs] def resolve_arguments( self, func: Callable[..., Any], ctx: Optional[Context] = None, **provided_kwargs: Any, ) -> dict[str, Any]: """Resolve function arguments from Depends markers and Context.""" sig = inspect.signature(func) kwargs = provided_kwargs.copy() for param_name, param in sig.parameters.items(): if param_name in kwargs: continue if param_name in ("self", "cls"): continue depends = self._extract_depends(param.annotation) if depends: value = self._resolve_dependency(depends, ctx) if value is None and not depends.optional: raise ValueError(f"Required dependency '{param_name}' not found") kwargs[param_name] = value continue if param.annotation != inspect.Parameter.empty: if param.annotation is Context: kwargs[param_name] = ctx continue if get_origin(param.annotation) is Annotated: args = get_args(param.annotation) if len(args) > 1 and isinstance(args[1], str): marker = args[1] if marker == "models" and ctx: kwargs[param_name] = ctx.models.get(param_name) continue if marker == "fields" and ctx: kwargs[param_name] = ctx.fields.get(param_name) continue if ctx: kwargs[param_name] = getattr(ctx, marker, {}).get( param_name ) continue if ctx and param_name in ctx.fields: kwargs[param_name] = ctx.fields[param_name] return kwargs
def _extract_depends(self, annotation: Any) -> Optional[Depends]: if get_origin(annotation) is Annotated: for arg in get_args(annotation)[1:]: if isinstance(arg, Depends): return arg return None def _resolve_dependency(self, depends: Depends, ctx: Optional[Context]) -> Any: cache_key = str(depends.dependency) scope = getattr(depends, "scope", "time_step") use_cache = getattr(depends, "cache", True) if use_cache and cache_key in self._cache[scope]: return self._cache[scope][cache_key] if isinstance(depends.dependency, str): value = self._resolve_path(depends.dependency, ctx) elif callable(depends.dependency): value = self._resolve_callable(depends.dependency, ctx) else: raise ValueError(f"Invalid dependency type: {type(depends.dependency)}") if use_cache: self._cache[scope][cache_key] = value return value def _resolve_path(self, path: str, ctx: Optional[Context]) -> Any: if ctx is None: return None parts = path.split(".") if parts[0] == "fields": return ctx.fields.get(parts[1]) if len(parts) > 1 else None if parts[0] == "models": return ctx.models.get(parts[1]) if len(parts) > 1 else None return getattr(ctx, path, None) def _resolve_callable( self, provider: Callable[..., Any], ctx: Optional[Context] ) -> Any: kwargs = self.resolve_arguments(provider, ctx) return provider(**kwargs) def clear_scope(self, scope: str) -> None: if scope in self._cache: self._cache[scope].clear() def clear_all(self) -> None: for scope in self._cache: self._cache[scope].clear()
[docs] def wrap_with_dependency_resolution( func: Callable[..., Any], instance: Any, dependency_resolver: DependencyResolver, ) -> Callable[[Context], Any]: """Wrap *func* so it can be called with just a Context. Dependencies are resolved via *dependency_resolver*. If the function signature includes a ``self`` parameter it is bound to *instance*. If the return value is a ``FieldUpdates`` the context is updated automatically. This is the canonical implementation shared by ``SolverSpec`` and ``ModelSpec`` — avoids duplicating the same wrapper in every factory. """ from functools import wraps @wraps(func) def wrapper(ctx: Context) -> Any: kwargs = dependency_resolver.resolve_arguments(func, ctx) sig = inspect.signature(func) if "self" in sig.parameters and "self" not in kwargs: kwargs["self"] = instance result = func(**kwargs) from .context import FieldUpdates if isinstance(result, FieldUpdates): ctx.fields.update(result) return None return result return wrapper