from DataVaultGenerator.Dag import DagNode from DataVaultGenerator.Components import ErrorCollection, GeneratorEntity class SubDag(GeneratorEntity): def __init__(self, model, filename, definition: dict = None): GeneratorEntity.__init__(self, model, filename, definition) self.entrypoints = definition.get('entrypoints',[]) self.key = definition.get('key',definition.get('name')) self.excludes = definition.get('excludes',[]) self.tree = [] def validate(self): errors = ErrorCollection() # Validating entity references: for ep in self.entrypoints: if self.model.get_entity(ep) is None: errors.add("VALIDATION ERROR", (self.filename,"SubDag", "<" + self.name + ">"), f'Entrypoint <{ep}> not found') for ex in self.excludes: if self.model.get_entity(ex) is None: errors.add("VALIDATION ERROR", (self.filename,"SubDag", "<" + self.name + ">"), f'Exclude <{ex}> not found') return errors def get_entrypoints_nodes(self): if self.entrypoints: return [self.model.dag.get_node(n) for n in self.entrypoints] else: return [n for n in self.model.dag.get_roots()] def get_tree(self): return self.get_nodes() def get_nodes(self): self.model.dag.reset() if self.subtype == 'forward': r = [] for en in self.get_entrypoints_nodes(): r.extend(self.model.dag.get_forward_tree(en,excludes=self.excludes)) self.tree = self.dedup_tree(r) return self.tree if self.subtype == 'backward': r = [] for en in self.get_entrypoints_nodes(): r.extend(self.model.dag.get_backward_tree(en)) r = self.model.dag.reverse_level(r) self.tree = self.dedup_tree(r) return self.tree return [] def dedup_tree(self, tree: list): dedup = {} for e in tree: if e.name not in dedup: dedup[e.name] = e elif dedup[e.name].level < e.level: # Replace if existing elements level is lower than current elements level dedup[e.name] = e return [e for e in dedup.values()] def get_leveldict(self, nodes: list) -> dict: # returns dict. Each key represents one level. Each level contains a list of nodes. ld = dict() for n in nodes: if n.level not in ld: ld[n.level] = [] ld[n.level].append(n) return ld