93 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			93 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """API for traversing the AST nodes. Implemented by the compiler and
 | |
| meta introspection.
 | |
| """
 | |
| 
 | |
| import typing as t
 | |
| 
 | |
| from .nodes import Node
 | |
| 
 | |
| if t.TYPE_CHECKING:
 | |
|     import typing_extensions as te
 | |
| 
 | |
|     class VisitCallable(te.Protocol):
 | |
|         def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any: ...
 | |
| 
 | |
| 
 | |
| class NodeVisitor:
 | |
|     """Walks the abstract syntax tree and call visitor functions for every
 | |
|     node found.  The visitor functions may return values which will be
 | |
|     forwarded by the `visit` method.
 | |
| 
 | |
|     Per default the visitor functions for the nodes are ``'visit_'`` +
 | |
|     class name of the node.  So a `TryFinally` node visit function would
 | |
|     be `visit_TryFinally`.  This behavior can be changed by overriding
 | |
|     the `get_visitor` function.  If no visitor function exists for a node
 | |
|     (return value `None`) the `generic_visit` visitor is used instead.
 | |
|     """
 | |
| 
 | |
|     def get_visitor(self, node: Node) -> "t.Optional[VisitCallable]":
 | |
|         """Return the visitor function for this node or `None` if no visitor
 | |
|         exists for this node.  In that case the generic visit function is
 | |
|         used instead.
 | |
|         """
 | |
|         return getattr(self, f"visit_{type(node).__name__}", None)
 | |
| 
 | |
|     def visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
 | |
|         """Visit a node."""
 | |
|         f = self.get_visitor(node)
 | |
| 
 | |
|         if f is not None:
 | |
|             return f(node, *args, **kwargs)
 | |
| 
 | |
|         return self.generic_visit(node, *args, **kwargs)
 | |
| 
 | |
|     def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
 | |
|         """Called if no explicit visitor function exists for a node."""
 | |
|         for child_node in node.iter_child_nodes():
 | |
|             self.visit(child_node, *args, **kwargs)
 | |
| 
 | |
| 
 | |
| class NodeTransformer(NodeVisitor):
 | |
|     """Walks the abstract syntax tree and allows modifications of nodes.
 | |
| 
 | |
|     The `NodeTransformer` will walk the AST and use the return value of the
 | |
|     visitor functions to replace or remove the old node.  If the return
 | |
|     value of the visitor function is `None` the node will be removed
 | |
|     from the previous location otherwise it's replaced with the return
 | |
|     value.  The return value may be the original node in which case no
 | |
|     replacement takes place.
 | |
|     """
 | |
| 
 | |
|     def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> Node:
 | |
|         for field, old_value in node.iter_fields():
 | |
|             if isinstance(old_value, list):
 | |
|                 new_values = []
 | |
|                 for value in old_value:
 | |
|                     if isinstance(value, Node):
 | |
|                         value = self.visit(value, *args, **kwargs)
 | |
|                         if value is None:
 | |
|                             continue
 | |
|                         elif not isinstance(value, Node):
 | |
|                             new_values.extend(value)
 | |
|                             continue
 | |
|                     new_values.append(value)
 | |
|                 old_value[:] = new_values
 | |
|             elif isinstance(old_value, Node):
 | |
|                 new_node = self.visit(old_value, *args, **kwargs)
 | |
|                 if new_node is None:
 | |
|                     delattr(node, field)
 | |
|                 else:
 | |
|                     setattr(node, field, new_node)
 | |
|         return node
 | |
| 
 | |
|     def visit_list(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.List[Node]:
 | |
|         """As transformers may return lists in some places this method
 | |
|         can be used to enforce a list as return value.
 | |
|         """
 | |
|         rv = self.visit(node, *args, **kwargs)
 | |
| 
 | |
|         if not isinstance(rv, list):
 | |
|             return [rv]
 | |
| 
 | |
|         return rv
 |