Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import typing as t
  5from dataclasses import dataclass, field
  6
  7from sqlglot import Schema, exp, maybe_parse
  8from sqlglot.optimizer import Scope, build_scope, optimize
  9from sqlglot.optimizer.lower_identities import lower_identities
 10from sqlglot.optimizer.qualify_columns import qualify_columns
 11from sqlglot.optimizer.qualify_tables import qualify_tables
 12
 13if t.TYPE_CHECKING:
 14    from sqlglot.dialects.dialect import DialectType
 15
 16
 17@dataclass(frozen=True)
 18class Node:
 19    name: str
 20    expression: exp.Expression
 21    source: exp.Expression
 22    downstream: t.List[Node] = field(default_factory=list)
 23    alias: str = ""
 24
 25    def walk(self) -> t.Iterator[Node]:
 26        yield self
 27
 28        for d in self.downstream:
 29            if isinstance(d, Node):
 30                yield from d.walk()
 31            else:
 32                yield d
 33
 34    def to_html(self, **opts) -> LineageHTML:
 35        return LineageHTML(self, **opts)
 36
 37
 38def lineage(
 39    column: str | exp.Column,
 40    sql: str | exp.Expression,
 41    schema: t.Optional[t.Dict | Schema] = None,
 42    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 43    rules: t.Sequence[t.Callable] = (lower_identities, qualify_tables, qualify_columns),
 44    dialect: DialectType = None,
 45) -> Node:
 46    """Build the lineage graph for a column of a SQL query.
 47
 48    Args:
 49        column: The column to build the lineage for.
 50        sql: The SQL string or expression.
 51        schema: The schema of tables.
 52        sources: A mapping of queries which will be used to continue building lineage.
 53        rules: Optimizer rules to apply, by default only qualifying tables and columns.
 54        dialect: The dialect of input SQL.
 55
 56    Returns:
 57        A lineage node.
 58    """
 59
 60    expression = maybe_parse(sql, dialect=dialect)
 61
 62    if sources:
 63        expression = exp.expand(
 64            expression,
 65            {
 66                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 67                for k, v in sources.items()
 68            },
 69        )
 70
 71    optimized = optimize(expression, schema=schema, rules=rules)
 72    scope = build_scope(optimized)
 73
 74    def to_node(
 75        column_name: str,
 76        scope: Scope,
 77        scope_name: t.Optional[str] = None,
 78        upstream: t.Optional[Node] = None,
 79        alias: t.Optional[str] = None,
 80    ) -> Node:
 81        aliases = {
 82            dt.alias: dt.comments[0].split()[1]
 83            for dt in scope.derived_tables
 84            if dt.comments and dt.comments[0].startswith("source: ")
 85        }
 86        if isinstance(scope.expression, exp.Union):
 87            for scope in scope.union_scopes:
 88                node = to_node(
 89                    column_name,
 90                    scope=scope,
 91                    scope_name=scope_name,
 92                    upstream=upstream,
 93                    alias=aliases.get(scope_name),
 94                )
 95            return node
 96
 97        # Find the specific select clause that is the source of the column we want.
 98        # This can either be a specific, named select or a generic `*` clause.
 99        select = next(
100            (select for select in scope.selects if select.alias_or_name == column_name),
101            exp.Star() if scope.expression.is_star else None,
102        )
103
104        if not select:
105            raise ValueError(f"Could not find {column_name} in {scope.expression}")
106
107        if isinstance(scope.expression, exp.Select):
108            # For better ergonomics in our node labels, replace the full select with
109            # a version that has only the column we care about.
110            #   "x", SELECT x, y FROM foo
111            #     => "x", SELECT x FROM foo
112            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
113        else:
114            source = scope.expression
115
116        # Create the node for this step in the lineage chain, and attach it to the previous one.
117        node = Node(
118            name=f"{scope_name}.{column_name}" if scope_name else column_name,
119            source=source,
120            expression=select,
121            alias=alias or "",
122        )
123        if upstream:
124            upstream.downstream.append(node)
125
126        # Find all columns that went into creating this one to list their lineage nodes.
127        for c in set(select.find_all(exp.Column)):
128            table = c.table
129            source = scope.sources.get(table)
130
131            if isinstance(source, Scope):
132                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
133                to_node(
134                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
135                )
136            else:
137                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
138                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
139                # is not passed into the `sources` map.
140                source = source or exp.Placeholder()
141                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
142
143        return node
144
145    return to_node(column if isinstance(column, str) else column.name, scope)
146
147
148class LineageHTML:
149    """Node to HTML generator using vis.js.
150
151    https://visjs.github.io/vis-network/docs/network/
152    """
153
154    def __init__(
155        self,
156        node: Node,
157        dialect: DialectType = None,
158        imports: bool = True,
159        **opts: t.Any,
160    ):
161        self.node = node
162        self.imports = imports
163
164        self.options = {
165            "height": "500px",
166            "width": "100%",
167            "layout": {
168                "hierarchical": {
169                    "enabled": True,
170                    "nodeSpacing": 200,
171                    "sortMethod": "directed",
172                },
173            },
174            "interaction": {
175                "dragNodes": False,
176                "selectable": False,
177            },
178            "physics": {
179                "enabled": False,
180            },
181            "edges": {
182                "arrows": "to",
183            },
184            "nodes": {
185                "font": "20px monaco",
186                "shape": "box",
187                "widthConstraint": {
188                    "maximum": 300,
189                },
190            },
191            **opts,
192        }
193
194        self.nodes = {}
195        self.edges = []
196
197        for node in node.walk():
198            if isinstance(node.expression, exp.Table):
199                label = f"FROM {node.expression.this}"
200                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
201                group = 1
202            else:
203                label = node.expression.sql(pretty=True, dialect=dialect)
204                source = node.source.transform(
205                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
206                    if n is node.expression
207                    else n,
208                    copy=False,
209                ).sql(pretty=True, dialect=dialect)
210                title = f"<pre>{source}</pre>"
211                group = 0
212
213            node_id = id(node)
214
215            self.nodes[node_id] = {
216                "id": node_id,
217                "label": label,
218                "title": title,
219                "group": group,
220            }
221
222            for d in node.downstream:
223                self.edges.append({"from": node_id, "to": id(d)})
224
225    def __str__(self):
226        nodes = json.dumps(list(self.nodes.values()))
227        edges = json.dumps(self.edges)
228        options = json.dumps(self.options)
229        imports = (
230            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
231  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
232  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
233            if self.imports
234            else ""
235        )
236
237        return f"""<div>
238  <div id="sqlglot-lineage"></div>
239  {imports}
240  <script type="text/javascript">
241    var nodes = new vis.DataSet({nodes})
242    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
243
244    new vis.Network(
245        document.getElementById("sqlglot-lineage"),
246        {{
247            nodes: nodes,
248            edges: new vis.DataSet({edges})
249        }},
250        {options},
251    )
252  </script>
253</div>"""
254
255    def _repr_html_(self) -> str:
256        return self.__str__()
@dataclass(frozen=True)
class Node:
18@dataclass(frozen=True)
19class Node:
20    name: str
21    expression: exp.Expression
22    source: exp.Expression
23    downstream: t.List[Node] = field(default_factory=list)
24    alias: str = ""
25
26    def walk(self) -> t.Iterator[Node]:
27        yield self
28
29        for d in self.downstream:
30            if isinstance(d, Node):
31                yield from d.walk()
32            else:
33                yield d
34
35    def to_html(self, **opts) -> LineageHTML:
36        return LineageHTML(self, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[sqlglot.lineage.Node] = <factory>, alias: str = '')
def walk(self) -> Iterator[sqlglot.lineage.Node]:
26    def walk(self) -> t.Iterator[Node]:
27        yield self
28
29        for d in self.downstream:
30            if isinstance(d, Node):
31                yield from d.walk()
32            else:
33                yield d
def to_html(self, **opts) -> sqlglot.lineage.LineageHTML:
35    def to_html(self, **opts) -> LineageHTML:
36        return LineageHTML(self, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Subqueryable]] = None, rules: Sequence[Callable] = (<function lower_identities at 0x7fc62238b760>, <function qualify_tables at 0x7fc622395fc0>, <function qualify_columns at 0x7fc622394dc0>), dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.lineage.Node:
 39def lineage(
 40    column: str | exp.Column,
 41    sql: str | exp.Expression,
 42    schema: t.Optional[t.Dict | Schema] = None,
 43    sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
 44    rules: t.Sequence[t.Callable] = (lower_identities, qualify_tables, qualify_columns),
 45    dialect: DialectType = None,
 46) -> Node:
 47    """Build the lineage graph for a column of a SQL query.
 48
 49    Args:
 50        column: The column to build the lineage for.
 51        sql: The SQL string or expression.
 52        schema: The schema of tables.
 53        sources: A mapping of queries which will be used to continue building lineage.
 54        rules: Optimizer rules to apply, by default only qualifying tables and columns.
 55        dialect: The dialect of input SQL.
 56
 57    Returns:
 58        A lineage node.
 59    """
 60
 61    expression = maybe_parse(sql, dialect=dialect)
 62
 63    if sources:
 64        expression = exp.expand(
 65            expression,
 66            {
 67                k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect))
 68                for k, v in sources.items()
 69            },
 70        )
 71
 72    optimized = optimize(expression, schema=schema, rules=rules)
 73    scope = build_scope(optimized)
 74
 75    def to_node(
 76        column_name: str,
 77        scope: Scope,
 78        scope_name: t.Optional[str] = None,
 79        upstream: t.Optional[Node] = None,
 80        alias: t.Optional[str] = None,
 81    ) -> Node:
 82        aliases = {
 83            dt.alias: dt.comments[0].split()[1]
 84            for dt in scope.derived_tables
 85            if dt.comments and dt.comments[0].startswith("source: ")
 86        }
 87        if isinstance(scope.expression, exp.Union):
 88            for scope in scope.union_scopes:
 89                node = to_node(
 90                    column_name,
 91                    scope=scope,
 92                    scope_name=scope_name,
 93                    upstream=upstream,
 94                    alias=aliases.get(scope_name),
 95                )
 96            return node
 97
 98        # Find the specific select clause that is the source of the column we want.
 99        # This can either be a specific, named select or a generic `*` clause.
100        select = next(
101            (select for select in scope.selects if select.alias_or_name == column_name),
102            exp.Star() if scope.expression.is_star else None,
103        )
104
105        if not select:
106            raise ValueError(f"Could not find {column_name} in {scope.expression}")
107
108        if isinstance(scope.expression, exp.Select):
109            # For better ergonomics in our node labels, replace the full select with
110            # a version that has only the column we care about.
111            #   "x", SELECT x, y FROM foo
112            #     => "x", SELECT x FROM foo
113            source = t.cast(exp.Expression, scope.expression.select(select, append=False))
114        else:
115            source = scope.expression
116
117        # Create the node for this step in the lineage chain, and attach it to the previous one.
118        node = Node(
119            name=f"{scope_name}.{column_name}" if scope_name else column_name,
120            source=source,
121            expression=select,
122            alias=alias or "",
123        )
124        if upstream:
125            upstream.downstream.append(node)
126
127        # Find all columns that went into creating this one to list their lineage nodes.
128        for c in set(select.find_all(exp.Column)):
129            table = c.table
130            source = scope.sources.get(table)
131
132            if isinstance(source, Scope):
133                # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
134                to_node(
135                    c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table)
136                )
137            else:
138                # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
139                # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
140                # is not passed into the `sources` map.
141                source = source or exp.Placeholder()
142                node.downstream.append(Node(name=c.sql(), source=source, expression=source))
143
144        return node
145
146    return to_node(column if isinstance(column, str) else column.name, scope)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • rules: Optimizer rules to apply, by default only qualifying tables and columns.
  • dialect: The dialect of input SQL.
Returns:

A lineage node.

class LineageHTML:
149class LineageHTML:
150    """Node to HTML generator using vis.js.
151
152    https://visjs.github.io/vis-network/docs/network/
153    """
154
155    def __init__(
156        self,
157        node: Node,
158        dialect: DialectType = None,
159        imports: bool = True,
160        **opts: t.Any,
161    ):
162        self.node = node
163        self.imports = imports
164
165        self.options = {
166            "height": "500px",
167            "width": "100%",
168            "layout": {
169                "hierarchical": {
170                    "enabled": True,
171                    "nodeSpacing": 200,
172                    "sortMethod": "directed",
173                },
174            },
175            "interaction": {
176                "dragNodes": False,
177                "selectable": False,
178            },
179            "physics": {
180                "enabled": False,
181            },
182            "edges": {
183                "arrows": "to",
184            },
185            "nodes": {
186                "font": "20px monaco",
187                "shape": "box",
188                "widthConstraint": {
189                    "maximum": 300,
190                },
191            },
192            **opts,
193        }
194
195        self.nodes = {}
196        self.edges = []
197
198        for node in node.walk():
199            if isinstance(node.expression, exp.Table):
200                label = f"FROM {node.expression.this}"
201                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
202                group = 1
203            else:
204                label = node.expression.sql(pretty=True, dialect=dialect)
205                source = node.source.transform(
206                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
207                    if n is node.expression
208                    else n,
209                    copy=False,
210                ).sql(pretty=True, dialect=dialect)
211                title = f"<pre>{source}</pre>"
212                group = 0
213
214            node_id = id(node)
215
216            self.nodes[node_id] = {
217                "id": node_id,
218                "label": label,
219                "title": title,
220                "group": group,
221            }
222
223            for d in node.downstream:
224                self.edges.append({"from": node_id, "to": id(d)})
225
226    def __str__(self):
227        nodes = json.dumps(list(self.nodes.values()))
228        edges = json.dumps(self.edges)
229        options = json.dumps(self.options)
230        imports = (
231            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
232  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
233  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
234            if self.imports
235            else ""
236        )
237
238        return f"""<div>
239  <div id="sqlglot-lineage"></div>
240  {imports}
241  <script type="text/javascript">
242    var nodes = new vis.DataSet({nodes})
243    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
244
245    new vis.Network(
246        document.getElementById("sqlglot-lineage"),
247        {{
248            nodes: nodes,
249            edges: new vis.DataSet({edges})
250        }},
251        {options},
252    )
253  </script>
254</div>"""
255
256    def _repr_html_(self) -> str:
257        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

LineageHTML( node: sqlglot.lineage.Node, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, imports: bool = True, **opts: Any)
155    def __init__(
156        self,
157        node: Node,
158        dialect: DialectType = None,
159        imports: bool = True,
160        **opts: t.Any,
161    ):
162        self.node = node
163        self.imports = imports
164
165        self.options = {
166            "height": "500px",
167            "width": "100%",
168            "layout": {
169                "hierarchical": {
170                    "enabled": True,
171                    "nodeSpacing": 200,
172                    "sortMethod": "directed",
173                },
174            },
175            "interaction": {
176                "dragNodes": False,
177                "selectable": False,
178            },
179            "physics": {
180                "enabled": False,
181            },
182            "edges": {
183                "arrows": "to",
184            },
185            "nodes": {
186                "font": "20px monaco",
187                "shape": "box",
188                "widthConstraint": {
189                    "maximum": 300,
190                },
191            },
192            **opts,
193        }
194
195        self.nodes = {}
196        self.edges = []
197
198        for node in node.walk():
199            if isinstance(node.expression, exp.Table):
200                label = f"FROM {node.expression.this}"
201                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
202                group = 1
203            else:
204                label = node.expression.sql(pretty=True, dialect=dialect)
205                source = node.source.transform(
206                    lambda n: exp.Tag(this=n, prefix="<b>", postfix="</b>")
207                    if n is node.expression
208                    else n,
209                    copy=False,
210                ).sql(pretty=True, dialect=dialect)
211                title = f"<pre>{source}</pre>"
212                group = 0
213
214            node_id = id(node)
215
216            self.nodes[node_id] = {
217                "id": node_id,
218                "label": label,
219                "title": title,
220                "group": group,
221            }
222
223            for d in node.downstream:
224                self.edges.append({"from": node_id, "to": id(d)})