Coverage for /opt/hostedtoolcache/Python/3.10.17/x64/lib/python3.10/site-packages/vfx_seqtools/decorators.py: 94%
32 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-30 00:30 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-30 00:30 +0000
1"""Permit options to be shared between multiple Typer scripts."""
3# type: ignore
5# https://github.com/fastapi/typer/discussions/742
6import functools
7import inspect
8import typing
9from typing import Any, Callable, TypeVar
11from typing_extensions import Concatenate, ParamSpec
13RHook = TypeVar("RHook")
14ParamsHook = ParamSpec("ParamsHook")
15RSource = TypeVar("RSource")
16ParamsSource = ParamSpec("ParamsSource")
19class BadHookError(TypeError): ...
22def attach_hook(
23 hook_func: Callable[ParamsHook, RHook], hook_output_kwarg: str = ""
24) -> Callable[..., Any]:
25 """
26 Decorates a source function to be executed with a pre-execution hook function. The hook function's
27 output is passed to the source function as a specified keyword argument. This decorator
28 filters keyword arguments for the hook function according to its signature, and the rest of the arguments
29 are passed to the source function. It updates the wrapper function's signature to include the combined
30 list of arguments, excluding the internally managed hook_output_kwarg.
32 The motivation for this utility is to allow combining groups of shared options for Typer cli scripts.
33 Typer infers the command line arguments from a functions type annotations, and to share common groups of arguments
34 between multiple scripts, there is a necessity to merge parameter lists of function.
36 Usage Examples:
38 common.py
39 ```
40 def logging_options(
41 log_level: Annotated[int, typer.Option(help="Log level. Must be between 0 and 9."),
42 log_to_file: Annotated[Optional[pathlib.Path], typer.Option(help="A file to stream logs to.") = None
43 ):
44 if log_level < 0 or log_level > 9:
45 raise ValueError("log_level must be between 0 and 9.")
46 ...
47 # create logger
48 ...
49 return logger
50 ```
52 main1.py
53 ```
54 @attach_hook(common.logging_options, hook_output_kwarg="logger")
55 def foo(size: int, logger: Logger):
56 ....
58 if __name__=="__main__":
59 typer.run(foo)
60 ```
62 main2.py
63 ```
64 @attach_hook(common.logging_options, hook_output_kwarg="logger")
65 def bar(color: str, logger: Logger):
66 ....
68 if __name__=="__main__":
69 typer.run(bar)
70 ```
72 in the example above both main1 and main2 cli's enable to specify shared logging arguments from the command line,
73 in addition to the specific argument of each script.
75 Args:
76 hook_func: The hook function to execute before the source function. All required argumenets must
77 be allowed to be passed as keyword arguments.
78 hook_output_kwarg: The keyword argument name for the hook's output passed to the source function.
79 If None, defaults to the hook function's name.
81 Raises:
82 BadHookError: If the hook function has an argument with no default value that collides with source.
84 Returns:
85 A decorator that chains the hook function with the source function, excluding the hook_output_kwarg
86 from the wrapper's external signature.
87 """
88 if hook_output_kwarg is None or hook_output_kwarg == "":
89 hook_output_kwarg = hook_func.__name__
91 def decorator(
92 source_func: Callable[Concatenate[RHook, ParamsSource], RSource], # type: ignore
93 ) -> Callable[Concatenate[ParamsSource, ParamsHook], RSource]: # type: ignore
94 source_params = inspect.signature(source_func).parameters
96 # Raise BadHookError if the hook has non-default argument that collides with the `source_func`.
97 dup_params = [
98 k
99 for k, v in inspect.signature(hook_func).parameters.items()
100 if k in source_params and v.default == inspect.Parameter.empty
101 ]
102 if dup_params:
103 raise BadHookError(
104 f"The following non-default arguments of the hook function (`{hook_func.__name__}`) collide with the source func (`{source_func.__name__}`): {dup_params}"
105 )
106 hook_params = {
107 k: v.replace(kind=inspect.Parameter.KEYWORD_ONLY)
108 for k, v in inspect.signature(hook_func).parameters.items()
109 if k not in source_params
110 }
112 @functools.wraps(source_func)
113 def wrapper(*args: list, **kwargs: dict) -> RSource:
114 # Filter kwargs for those accepted by the hook function
115 hook_kwargs = {k: v for k, v in kwargs.items() if k in hook_params}
117 # Execute hook function with its specific kwargs
118 hook_result = hook_func(**hook_kwargs) # type: ignore
120 # Filter in the remaining kwargs for the source function.
121 source_kwargs = {k: v for k, v in kwargs.items() if k not in hook_kwargs}
123 # Execute the source function with original args and pass the hook's output to the source function as
124 # the specified keyword argument
125 # mypy bug: https://github.com/python/mypy/issues/18481
126 return source_func( # type: ignore
127 *args, # type: ignore
128 **source_kwargs, # type: ignore
129 **{hook_output_kwarg: hook_result}, # type: ignore
130 ) # type: ignore
132 # Combine signatures, but remove the hook_output_kwarg
133 combined_params = [
134 param for param in source_params.values() if param.name != hook_output_kwarg
135 ] + list(hook_params.values())
136 # mypy bug: https://github.com/python/mypy/issues/12472
137 wrapper.__signature__ = inspect.signature(source_func).replace( # type: ignore
138 parameters=combined_params
139 )
141 # Combine annotations, but remove the hook_output_kwarg
142 wrapper.__annotations__ = {
143 **typing.get_type_hints(source_func),
144 **typing.get_type_hints(hook_func),
145 }
146 if hook_output_kwarg:
147 wrapper.__annotations__.pop(hook_output_kwarg, None)
149 return wrapper # type: ignore
151 return decorator