This repository has been archived on 2026-03-20. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
datavaultgenerator-1.1.5/DataVaultGenerator/Entities/SubDag.py

85 lines
2.7 KiB
Python

from DataVaultGenerator.Dag import DagNode
from DataVaultGenerator.Components import ErrorCollection, GeneratorEntity
class SubDag(GeneratorEntity):
def __init__(self, model, filename, definition: dict = None):
GeneratorEntity.__init__(self, model, filename, definition)
self.entrypoints = definition.get('entrypoints',[])
self.key = definition.get('key',definition.get('name'))
self.excludes = definition.get('excludes',[])
self.tree = []
def validate(self):
errors = ErrorCollection()
# Validating entity references:
for ep in self.entrypoints:
if self.model.get_entity(ep) is None:
errors.add("VALIDATION ERROR",
(self.filename,"SubDag", "<" + self.name + ">"),
f'Entrypoint <{ep}> not found')
for ex in self.excludes:
if self.model.get_entity(ex) is None:
errors.add("VALIDATION ERROR",
(self.filename,"SubDag", "<" + self.name + ">"),
f'Exclude <{ex}> not found')
return errors
def get_entrypoints_nodes(self):
if self.entrypoints:
return [self.model.dag.get_node(n) for n in self.entrypoints]
else:
return [n for n in self.model.dag.get_roots()]
def get_tree(self):
return self.get_nodes()
def get_nodes(self):
self.model.dag.reset()
if self.subtype == 'forward':
r = []
for en in self.get_entrypoints_nodes():
r.extend(self.model.dag.get_forward_tree(en,excludes=self.excludes))
self.tree = self.dedup_tree(r)
return self.tree
if self.subtype == 'backward':
r = []
for en in self.get_entrypoints_nodes():
r.extend(self.model.dag.get_backward_tree(en))
r = self.model.dag.reverse_level(r)
self.tree = self.dedup_tree(r)
return self.tree
return []
def dedup_tree(self, tree: list):
dedup = {}
for e in tree:
if e.name not in dedup:
dedup[e.name] = e
elif dedup[e.name].level < e.level: # Replace if existing elements level is lower than current elements level
dedup[e.name] = e
return [e for e in dedup.values()]
def get_leveldict(self, nodes: list) -> dict:
# returns dict. Each key represents one level. Each level contains a list of nodes.
ld = dict()
for n in nodes:
if n.level not in ld:
ld[n.level] = []
ld[n.level].append(n)
return ld