85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
from DataVaultGenerator.Dag import DagNode
|
|
from DataVaultGenerator.Components import ErrorCollection, GeneratorEntity
|
|
|
|
|
|
class SubDag(GeneratorEntity):
|
|
def __init__(self, model, filename, definition: dict = None):
|
|
GeneratorEntity.__init__(self, model, filename, definition)
|
|
|
|
self.entrypoints = definition.get('entrypoints',[])
|
|
self.key = definition.get('key',definition.get('name'))
|
|
|
|
self.excludes = definition.get('excludes',[])
|
|
|
|
self.tree = []
|
|
|
|
def validate(self):
|
|
|
|
errors = ErrorCollection()
|
|
|
|
# Validating entity references:
|
|
|
|
for ep in self.entrypoints:
|
|
if self.model.get_entity(ep) is None:
|
|
errors.add("VALIDATION ERROR",
|
|
(self.filename,"SubDag", "<" + self.name + ">"),
|
|
f'Entrypoint <{ep}> not found')
|
|
|
|
for ex in self.excludes:
|
|
if self.model.get_entity(ex) is None:
|
|
errors.add("VALIDATION ERROR",
|
|
(self.filename,"SubDag", "<" + self.name + ">"),
|
|
f'Exclude <{ex}> not found')
|
|
|
|
return errors
|
|
|
|
def get_entrypoints_nodes(self):
|
|
if self.entrypoints:
|
|
return [self.model.dag.get_node(n) for n in self.entrypoints]
|
|
else:
|
|
return [n for n in self.model.dag.get_roots()]
|
|
|
|
def get_tree(self):
|
|
return self.get_nodes()
|
|
|
|
def get_nodes(self):
|
|
self.model.dag.reset()
|
|
|
|
if self.subtype == 'forward':
|
|
r = []
|
|
for en in self.get_entrypoints_nodes():
|
|
r.extend(self.model.dag.get_forward_tree(en,excludes=self.excludes))
|
|
self.tree = self.dedup_tree(r)
|
|
return self.tree
|
|
|
|
if self.subtype == 'backward':
|
|
r = []
|
|
for en in self.get_entrypoints_nodes():
|
|
r.extend(self.model.dag.get_backward_tree(en))
|
|
|
|
r = self.model.dag.reverse_level(r)
|
|
self.tree = self.dedup_tree(r)
|
|
return self.tree
|
|
|
|
return []
|
|
|
|
def dedup_tree(self, tree: list):
|
|
dedup = {}
|
|
for e in tree:
|
|
if e.name not in dedup:
|
|
dedup[e.name] = e
|
|
elif dedup[e.name].level < e.level: # Replace if existing elements level is lower than current elements level
|
|
dedup[e.name] = e
|
|
return [e for e in dedup.values()]
|
|
|
|
|
|
|
|
def get_leveldict(self, nodes: list) -> dict:
|
|
# returns dict. Each key represents one level. Each level contains a list of nodes.
|
|
ld = dict()
|
|
for n in nodes:
|
|
if n.level not in ld:
|
|
ld[n.level] = []
|
|
ld[n.level].append(n)
|
|
return ld
|