Coverage for /opt/hostedtoolcache/Python/3.10.17/x64/lib/python3.10/site-packages/vfx_seqtools/decorators.py: 94%

32 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-05-30 00:30 +0000

1"""Permit options to be shared between multiple Typer scripts.""" 

2 

3# type: ignore 

4 

5# https://github.com/fastapi/typer/discussions/742 

6import functools 

7import inspect 

8import typing 

9from typing import Any, Callable, TypeVar 

10 

11from typing_extensions import Concatenate, ParamSpec 

12 

13RHook = TypeVar("RHook") 

14ParamsHook = ParamSpec("ParamsHook") 

15RSource = TypeVar("RSource") 

16ParamsSource = ParamSpec("ParamsSource") 

17 

18 

19class BadHookError(TypeError): ... 

20 

21 

22def attach_hook( 

23 hook_func: Callable[ParamsHook, RHook], hook_output_kwarg: str = "" 

24) -> Callable[..., Any]: 

25 """ 

26 Decorates a source function to be executed with a pre-execution hook function. The hook function's 

27 output is passed to the source function as a specified keyword argument. This decorator 

28 filters keyword arguments for the hook function according to its signature, and the rest of the arguments 

29 are passed to the source function. It updates the wrapper function's signature to include the combined 

30 list of arguments, excluding the internally managed hook_output_kwarg. 

31 

32 The motivation for this utility is to allow combining groups of shared options for Typer cli scripts. 

33 Typer infers the command line arguments from a functions type annotations, and to share common groups of arguments 

34 between multiple scripts, there is a necessity to merge parameter lists of function. 

35 

36 Usage Examples: 

37 

38 common.py 

39 ``` 

40 def logging_options( 

41 log_level: Annotated[int, typer.Option(help="Log level. Must be between 0 and 9."), 

42 log_to_file: Annotated[Optional[pathlib.Path], typer.Option(help="A file to stream logs to.") = None 

43 ): 

44 if log_level < 0 or log_level > 9: 

45 raise ValueError("log_level must be between 0 and 9.") 

46 ... 

47 # create logger 

48 ... 

49 return logger 

50 ``` 

51 

52 main1.py 

53 ``` 

54 @attach_hook(common.logging_options, hook_output_kwarg="logger") 

55 def foo(size: int, logger: Logger): 

56 .... 

57 

58 if __name__=="__main__": 

59 typer.run(foo) 

60 ``` 

61 

62 main2.py 

63 ``` 

64 @attach_hook(common.logging_options, hook_output_kwarg="logger") 

65 def bar(color: str, logger: Logger): 

66 .... 

67 

68 if __name__=="__main__": 

69 typer.run(bar) 

70 ``` 

71 

72 in the example above both main1 and main2 cli's enable to specify shared logging arguments from the command line, 

73 in addition to the specific argument of each script. 

74 

75 Args: 

76 hook_func: The hook function to execute before the source function. All required argumenets must 

77 be allowed to be passed as keyword arguments. 

78 hook_output_kwarg: The keyword argument name for the hook's output passed to the source function. 

79 If None, defaults to the hook function's name. 

80 

81 Raises: 

82 BadHookError: If the hook function has an argument with no default value that collides with source. 

83 

84 Returns: 

85 A decorator that chains the hook function with the source function, excluding the hook_output_kwarg 

86 from the wrapper's external signature. 

87 """ 

88 if hook_output_kwarg is None or hook_output_kwarg == "": 

89 hook_output_kwarg = hook_func.__name__ 

90 

91 def decorator( 

92 source_func: Callable[Concatenate[RHook, ParamsSource], RSource], # type: ignore 

93 ) -> Callable[Concatenate[ParamsSource, ParamsHook], RSource]: # type: ignore 

94 source_params = inspect.signature(source_func).parameters 

95 

96 # Raise BadHookError if the hook has non-default argument that collides with the `source_func`. 

97 dup_params = [ 

98 k 

99 for k, v in inspect.signature(hook_func).parameters.items() 

100 if k in source_params and v.default == inspect.Parameter.empty 

101 ] 

102 if dup_params: 

103 raise BadHookError( 

104 f"The following non-default arguments of the hook function (`{hook_func.__name__}`) collide with the source func (`{source_func.__name__}`): {dup_params}" 

105 ) 

106 hook_params = { 

107 k: v.replace(kind=inspect.Parameter.KEYWORD_ONLY) 

108 for k, v in inspect.signature(hook_func).parameters.items() 

109 if k not in source_params 

110 } 

111 

112 @functools.wraps(source_func) 

113 def wrapper(*args: list, **kwargs: dict) -> RSource: 

114 # Filter kwargs for those accepted by the hook function 

115 hook_kwargs = {k: v for k, v in kwargs.items() if k in hook_params} 

116 

117 # Execute hook function with its specific kwargs 

118 hook_result = hook_func(**hook_kwargs) # type: ignore 

119 

120 # Filter in the remaining kwargs for the source function. 

121 source_kwargs = {k: v for k, v in kwargs.items() if k not in hook_kwargs} 

122 

123 # Execute the source function with original args and pass the hook's output to the source function as 

124 # the specified keyword argument 

125 # mypy bug: https://github.com/python/mypy/issues/18481 

126 return source_func( # type: ignore 

127 *args, # type: ignore 

128 **source_kwargs, # type: ignore 

129 **{hook_output_kwarg: hook_result}, # type: ignore 

130 ) # type: ignore 

131 

132 # Combine signatures, but remove the hook_output_kwarg 

133 combined_params = [ 

134 param for param in source_params.values() if param.name != hook_output_kwarg 

135 ] + list(hook_params.values()) 

136 # mypy bug: https://github.com/python/mypy/issues/12472 

137 wrapper.__signature__ = inspect.signature(source_func).replace( # type: ignore 

138 parameters=combined_params 

139 ) 

140 

141 # Combine annotations, but remove the hook_output_kwarg 

142 wrapper.__annotations__ = { 

143 **typing.get_type_hints(source_func), 

144 **typing.get_type_hints(hook_func), 

145 } 

146 if hook_output_kwarg: 

147 wrapper.__annotations__.pop(hook_output_kwarg, None) 

148 

149 return wrapper # type: ignore 

150 

151 return decorator