sqlglot.schema
1from __future__ import annotations 2 3import abc 4import typing as t 5 6import sqlglot 7from sqlglot import expressions as exp 8from sqlglot.errors import ParseError, SchemaError 9from sqlglot.helper import dict_depth 10from sqlglot.trie import in_trie, new_trie 11 12if t.TYPE_CHECKING: 13 from sqlglot.dataframe.sql.types import StructType 14 from sqlglot.dialects.dialect import DialectType 15 16 ColumnMapping = t.Union[t.Dict, str, StructType, t.List] 17 18TABLE_ARGS = ("this", "db", "catalog") 19 20T = t.TypeVar("T") 21 22 23class Schema(abc.ABC): 24 """Abstract base class for database schemas""" 25 26 @abc.abstractmethod 27 def add_table( 28 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 29 ) -> None: 30 """ 31 Register or update a table. Some implementing classes may require column information to also be provided. 32 33 Args: 34 table: table expression instance or string representing the table. 35 column_mapping: a column mapping that describes the structure of the table. 36 """ 37 38 @abc.abstractmethod 39 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 40 """ 41 Get the column names for a table. 42 43 Args: 44 table: the `Table` expression instance. 45 only_visible: whether to include invisible columns. 46 47 Returns: 48 The list of column names. 49 """ 50 51 @abc.abstractmethod 52 def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: 53 """ 54 Get the :class:`sqlglot.exp.DataType` type of a column in the schema. 55 56 Args: 57 table: the source table. 58 column: the target column. 59 60 Returns: 61 The resulting column type. 62 """ 63 64 @property 65 def supported_table_args(self) -> t.Tuple[str, ...]: 66 """ 67 Table arguments this schema support, e.g. `("this", "db", "catalog")` 68 """ 69 raise NotImplementedError 70 71 @property 72 def empty(self) -> bool: 73 """Returns whether or not the schema is empty.""" 74 return True 75 76 77class AbstractMappingSchema(t.Generic[T]): 78 def __init__( 79 self, 80 mapping: dict | None = None, 81 ) -> None: 82 self.mapping = mapping or {} 83 self.mapping_trie = new_trie( 84 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 85 ) 86 self._supported_table_args: t.Tuple[str, ...] = tuple() 87 88 @property 89 def empty(self) -> bool: 90 return not self.mapping 91 92 def _depth(self) -> int: 93 return dict_depth(self.mapping) 94 95 @property 96 def supported_table_args(self) -> t.Tuple[str, ...]: 97 if not self._supported_table_args and self.mapping: 98 depth = self._depth() 99 100 if not depth: # None 101 self._supported_table_args = tuple() 102 elif 1 <= depth <= 3: 103 self._supported_table_args = TABLE_ARGS[:depth] 104 else: 105 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 106 107 return self._supported_table_args 108 109 def table_parts(self, table: exp.Table) -> t.List[str]: 110 if isinstance(table.this, exp.ReadCSV): 111 return [table.this.name] 112 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 113 114 def find( 115 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 116 ) -> t.Optional[T]: 117 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 118 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 119 120 if value == 0: 121 return None 122 elif value == 1: 123 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 124 if len(possibilities) == 1: 125 parts.extend(possibilities[0]) 126 else: 127 message = ", ".join(".".join(parts) for parts in possibilities) 128 if raise_on_missing: 129 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 130 return None 131 return self.nested_get(parts, raise_on_missing=raise_on_missing) 132 133 def nested_get( 134 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 135 ) -> t.Optional[t.Any]: 136 return nested_get( 137 d or self.mapping, 138 *zip(self.supported_table_args, reversed(parts)), 139 raise_on_missing=raise_on_missing, 140 ) 141 142 143class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 144 """ 145 Schema based on a nested mapping. 146 147 Args: 148 schema (dict): Mapping in one of the following forms: 149 1. {table: {col: type}} 150 2. {db: {table: {col: type}}} 151 3. {catalog: {db: {table: {col: type}}}} 152 4. None - Tables will be added later 153 visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns 154 are assumed to be visible. The nesting should mirror that of the schema: 155 1. {table: set(*cols)}} 156 2. {db: {table: set(*cols)}}} 157 3. {catalog: {db: {table: set(*cols)}}}} 158 dialect (str): The dialect to be used for custom type mappings. 159 """ 160 161 def __init__( 162 self, 163 schema: t.Optional[t.Dict] = None, 164 visible: t.Optional[t.Dict] = None, 165 dialect: DialectType = None, 166 ) -> None: 167 self.dialect = dialect 168 self.visible = visible or {} 169 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 170 super().__init__(self._normalize(schema or {})) 171 172 @classmethod 173 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 174 return MappingSchema( 175 schema=mapping_schema.mapping, 176 visible=mapping_schema.visible, 177 dialect=mapping_schema.dialect, 178 ) 179 180 def copy(self, **kwargs) -> MappingSchema: 181 return MappingSchema( 182 **{ # type: ignore 183 "schema": self.mapping.copy(), 184 "visible": self.visible.copy(), 185 "dialect": self.dialect, 186 **kwargs, 187 } 188 ) 189 190 def add_table( 191 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 192 ) -> None: 193 """ 194 Register or update a table. Updates are only performed if a new column mapping is provided. 195 196 Args: 197 table: the `Table` expression instance or string representing the table. 198 column_mapping: a column mapping that describes the structure of the table. 199 """ 200 normalized_table = self._normalize_table(self._ensure_table(table)) 201 normalized_column_mapping = { 202 self._normalize_name(key): value 203 for key, value in ensure_column_mapping(column_mapping).items() 204 } 205 206 schema = self.find(normalized_table, raise_on_missing=False) 207 if schema and not normalized_column_mapping: 208 return 209 210 parts = self.table_parts(normalized_table) 211 212 nested_set( 213 self.mapping, 214 tuple(reversed(parts)), 215 normalized_column_mapping, 216 ) 217 new_trie([parts], self.mapping_trie) 218 219 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 220 table_ = self._normalize_table(self._ensure_table(table)) 221 schema = self.find(table_) 222 223 if schema is None: 224 return [] 225 226 if not only_visible or not self.visible: 227 return list(schema) 228 229 visible = self.nested_get(self.table_parts(table_), self.visible) 230 return [col for col in schema if col in visible] # type: ignore 231 232 def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: 233 column_name = self._normalize_name(column if isinstance(column, str) else column.this) 234 table_ = self._normalize_table(self._ensure_table(table)) 235 236 table_schema = self.find(table_, raise_on_missing=False) 237 if table_schema: 238 column_type = table_schema.get(column_name) 239 240 if isinstance(column_type, exp.DataType): 241 return column_type 242 elif isinstance(column_type, str): 243 return self._to_data_type(column_type.upper()) 244 raise SchemaError(f"Unknown column type '{column_type}'") 245 246 return exp.DataType.build("unknown") 247 248 def _normalize(self, schema: t.Dict) -> t.Dict: 249 """ 250 Converts all identifiers in the schema into lowercase, unless they're quoted. 251 252 Args: 253 schema: the schema to normalize. 254 255 Returns: 256 The normalized schema mapping. 257 """ 258 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 259 260 normalized_mapping: t.Dict = {} 261 for keys in flattened_schema: 262 columns = nested_get(schema, *zip(keys, keys)) 263 assert columns is not None 264 265 normalized_keys = [self._normalize_name(key) for key in keys] 266 for column_name, column_type in columns.items(): 267 nested_set( 268 normalized_mapping, 269 normalized_keys + [self._normalize_name(column_name)], 270 column_type, 271 ) 272 273 return normalized_mapping 274 275 def _normalize_table(self, table: exp.Table) -> exp.Table: 276 normalized_table = table.copy() 277 for arg in TABLE_ARGS: 278 value = normalized_table.args.get(arg) 279 if isinstance(value, (str, exp.Identifier)): 280 normalized_table.set(arg, self._normalize_name(value)) 281 282 return normalized_table 283 284 def _normalize_name(self, name: str | exp.Identifier) -> str: 285 try: 286 identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier) 287 except ParseError: 288 return name if isinstance(name, str) else name.name 289 290 return identifier.name if identifier.quoted else identifier.name.lower() 291 292 def _depth(self) -> int: 293 # The columns themselves are a mapping, but we don't want to include those 294 return super()._depth() - 1 295 296 def _ensure_table(self, table: exp.Table | str) -> exp.Table: 297 if isinstance(table, exp.Table): 298 return table 299 300 table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table) 301 if not table_: 302 raise SchemaError(f"Not a valid table '{table}'") 303 304 return table_ 305 306 def _to_data_type(self, schema_type: str) -> exp.DataType: 307 """ 308 Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. 309 310 Args: 311 schema_type: the type we want to convert. 312 313 Returns: 314 The resulting expression type. 315 """ 316 if schema_type not in self._type_mapping_cache: 317 try: 318 expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) 319 if expression is None: 320 raise ValueError(f"Could not parse {schema_type}") 321 self._type_mapping_cache[schema_type] = expression # type: ignore 322 except AttributeError: 323 raise SchemaError(f"Failed to convert type {schema_type}") 324 325 return self._type_mapping_cache[schema_type] 326 327 328def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema: 329 if isinstance(schema, Schema): 330 return schema 331 332 return MappingSchema(schema, dialect=dialect) 333 334 335def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 336 if isinstance(mapping, dict): 337 return mapping 338 elif isinstance(mapping, str): 339 col_name_type_strs = [x.strip() for x in mapping.split(",")] 340 return { 341 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 342 for name_type_str in col_name_type_strs 343 } 344 # Check if mapping looks like a DataFrame StructType 345 elif hasattr(mapping, "simpleString"): 346 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore 347 elif isinstance(mapping, list): 348 return {x.strip(): None for x in mapping} 349 elif mapping is None: 350 return {} 351 raise ValueError(f"Invalid mapping provided: {type(mapping)}") 352 353 354def flatten_schema( 355 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 356) -> t.List[t.List[str]]: 357 tables = [] 358 keys = keys or [] 359 360 for k, v in schema.items(): 361 if depth >= 2: 362 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 363 elif depth == 1: 364 tables.append(keys + [k]) 365 return tables 366 367 368def nested_get( 369 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 370) -> t.Optional[t.Any]: 371 """ 372 Get a value for a nested dictionary. 373 374 Args: 375 d: the dictionary to search. 376 *path: tuples of (name, key), where: 377 `key` is the key in the dictionary to get. 378 `name` is a string to use in the error if `key` isn't found. 379 380 Returns: 381 The value or None if it doesn't exist. 382 """ 383 for name, key in path: 384 d = d.get(key) # type: ignore 385 if d is None: 386 if raise_on_missing: 387 name = "table" if name == "this" else name 388 raise ValueError(f"Unknown {name}: {key}") 389 return None 390 return d 391 392 393def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 394 """ 395 In-place set a value for a nested dictionary 396 397 Example: 398 >>> nested_set({}, ["top_key", "second_key"], "value") 399 {'top_key': {'second_key': 'value'}} 400 401 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 402 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 403 404 Args: 405 d: dictionary to update. 406 keys: the keys that makeup the path to `value`. 407 value: the value to set in the dictionary for the given key path. 408 409 Returns: 410 The (possibly) updated dictionary. 411 """ 412 if not keys: 413 return d 414 415 if len(keys) == 1: 416 d[keys[0]] = value 417 return d 418 419 subd = d 420 for key in keys[:-1]: 421 if key not in subd: 422 subd = subd.setdefault(key, {}) 423 else: 424 subd = subd[key] 425 426 subd[keys[-1]] = value 427 return d
24class Schema(abc.ABC): 25 """Abstract base class for database schemas""" 26 27 @abc.abstractmethod 28 def add_table( 29 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 30 ) -> None: 31 """ 32 Register or update a table. Some implementing classes may require column information to also be provided. 33 34 Args: 35 table: table expression instance or string representing the table. 36 column_mapping: a column mapping that describes the structure of the table. 37 """ 38 39 @abc.abstractmethod 40 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 41 """ 42 Get the column names for a table. 43 44 Args: 45 table: the `Table` expression instance. 46 only_visible: whether to include invisible columns. 47 48 Returns: 49 The list of column names. 50 """ 51 52 @abc.abstractmethod 53 def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: 54 """ 55 Get the :class:`sqlglot.exp.DataType` type of a column in the schema. 56 57 Args: 58 table: the source table. 59 column: the target column. 60 61 Returns: 62 The resulting column type. 63 """ 64 65 @property 66 def supported_table_args(self) -> t.Tuple[str, ...]: 67 """ 68 Table arguments this schema support, e.g. `("this", "db", "catalog")` 69 """ 70 raise NotImplementedError 71 72 @property 73 def empty(self) -> bool: 74 """Returns whether or not the schema is empty.""" 75 return True
Abstract base class for database schemas
27 @abc.abstractmethod 28 def add_table( 29 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 30 ) -> None: 31 """ 32 Register or update a table. Some implementing classes may require column information to also be provided. 33 34 Args: 35 table: table expression instance or string representing the table. 36 column_mapping: a column mapping that describes the structure of the table. 37 """
Register or update a table. Some implementing classes may require column information to also be provided.
Arguments:
- table: table expression instance or string representing the table.
- column_mapping: a column mapping that describes the structure of the table.
39 @abc.abstractmethod 40 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 41 """ 42 Get the column names for a table. 43 44 Args: 45 table: the `Table` expression instance. 46 only_visible: whether to include invisible columns. 47 48 Returns: 49 The list of column names. 50 """
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
Returns:
The list of column names.
52 @abc.abstractmethod 53 def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: 54 """ 55 Get the :class:`sqlglot.exp.DataType` type of a column in the schema. 56 57 Args: 58 table: the source table. 59 column: the target column. 60 61 Returns: 62 The resulting column type. 63 """
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
Returns:
The resulting column type.
78class AbstractMappingSchema(t.Generic[T]): 79 def __init__( 80 self, 81 mapping: dict | None = None, 82 ) -> None: 83 self.mapping = mapping or {} 84 self.mapping_trie = new_trie( 85 tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) 86 ) 87 self._supported_table_args: t.Tuple[str, ...] = tuple() 88 89 @property 90 def empty(self) -> bool: 91 return not self.mapping 92 93 def _depth(self) -> int: 94 return dict_depth(self.mapping) 95 96 @property 97 def supported_table_args(self) -> t.Tuple[str, ...]: 98 if not self._supported_table_args and self.mapping: 99 depth = self._depth() 100 101 if not depth: # None 102 self._supported_table_args = tuple() 103 elif 1 <= depth <= 3: 104 self._supported_table_args = TABLE_ARGS[:depth] 105 else: 106 raise SchemaError(f"Invalid mapping shape. Depth: {depth}") 107 108 return self._supported_table_args 109 110 def table_parts(self, table: exp.Table) -> t.List[str]: 111 if isinstance(table.this, exp.ReadCSV): 112 return [table.this.name] 113 return [table.text(part) for part in TABLE_ARGS if table.text(part)] 114 115 def find( 116 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 117 ) -> t.Optional[T]: 118 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 119 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 120 121 if value == 0: 122 return None 123 elif value == 1: 124 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 125 if len(possibilities) == 1: 126 parts.extend(possibilities[0]) 127 else: 128 message = ", ".join(".".join(parts) for parts in possibilities) 129 if raise_on_missing: 130 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 131 return None 132 return self.nested_get(parts, raise_on_missing=raise_on_missing) 133 134 def nested_get( 135 self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True 136 ) -> t.Optional[t.Any]: 137 return nested_get( 138 d or self.mapping, 139 *zip(self.supported_table_args, reversed(parts)), 140 raise_on_missing=raise_on_missing, 141 )
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
115 def find( 116 self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True 117 ) -> t.Optional[T]: 118 parts = self.table_parts(table)[0 : len(self.supported_table_args)] 119 value, trie = in_trie(self.mapping_trie if trie is None else trie, parts) 120 121 if value == 0: 122 return None 123 elif value == 1: 124 possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) 125 if len(possibilities) == 1: 126 parts.extend(possibilities[0]) 127 else: 128 message = ", ".join(".".join(parts) for parts in possibilities) 129 if raise_on_missing: 130 raise SchemaError(f"Ambiguous mapping for {table}: {message}.") 131 return None 132 return self.nested_get(parts, raise_on_missing=raise_on_missing)
144class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 145 """ 146 Schema based on a nested mapping. 147 148 Args: 149 schema (dict): Mapping in one of the following forms: 150 1. {table: {col: type}} 151 2. {db: {table: {col: type}}} 152 3. {catalog: {db: {table: {col: type}}}} 153 4. None - Tables will be added later 154 visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns 155 are assumed to be visible. The nesting should mirror that of the schema: 156 1. {table: set(*cols)}} 157 2. {db: {table: set(*cols)}}} 158 3. {catalog: {db: {table: set(*cols)}}}} 159 dialect (str): The dialect to be used for custom type mappings. 160 """ 161 162 def __init__( 163 self, 164 schema: t.Optional[t.Dict] = None, 165 visible: t.Optional[t.Dict] = None, 166 dialect: DialectType = None, 167 ) -> None: 168 self.dialect = dialect 169 self.visible = visible or {} 170 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 171 super().__init__(self._normalize(schema or {})) 172 173 @classmethod 174 def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: 175 return MappingSchema( 176 schema=mapping_schema.mapping, 177 visible=mapping_schema.visible, 178 dialect=mapping_schema.dialect, 179 ) 180 181 def copy(self, **kwargs) -> MappingSchema: 182 return MappingSchema( 183 **{ # type: ignore 184 "schema": self.mapping.copy(), 185 "visible": self.visible.copy(), 186 "dialect": self.dialect, 187 **kwargs, 188 } 189 ) 190 191 def add_table( 192 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 193 ) -> None: 194 """ 195 Register or update a table. Updates are only performed if a new column mapping is provided. 196 197 Args: 198 table: the `Table` expression instance or string representing the table. 199 column_mapping: a column mapping that describes the structure of the table. 200 """ 201 normalized_table = self._normalize_table(self._ensure_table(table)) 202 normalized_column_mapping = { 203 self._normalize_name(key): value 204 for key, value in ensure_column_mapping(column_mapping).items() 205 } 206 207 schema = self.find(normalized_table, raise_on_missing=False) 208 if schema and not normalized_column_mapping: 209 return 210 211 parts = self.table_parts(normalized_table) 212 213 nested_set( 214 self.mapping, 215 tuple(reversed(parts)), 216 normalized_column_mapping, 217 ) 218 new_trie([parts], self.mapping_trie) 219 220 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 221 table_ = self._normalize_table(self._ensure_table(table)) 222 schema = self.find(table_) 223 224 if schema is None: 225 return [] 226 227 if not only_visible or not self.visible: 228 return list(schema) 229 230 visible = self.nested_get(self.table_parts(table_), self.visible) 231 return [col for col in schema if col in visible] # type: ignore 232 233 def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: 234 column_name = self._normalize_name(column if isinstance(column, str) else column.this) 235 table_ = self._normalize_table(self._ensure_table(table)) 236 237 table_schema = self.find(table_, raise_on_missing=False) 238 if table_schema: 239 column_type = table_schema.get(column_name) 240 241 if isinstance(column_type, exp.DataType): 242 return column_type 243 elif isinstance(column_type, str): 244 return self._to_data_type(column_type.upper()) 245 raise SchemaError(f"Unknown column type '{column_type}'") 246 247 return exp.DataType.build("unknown") 248 249 def _normalize(self, schema: t.Dict) -> t.Dict: 250 """ 251 Converts all identifiers in the schema into lowercase, unless they're quoted. 252 253 Args: 254 schema: the schema to normalize. 255 256 Returns: 257 The normalized schema mapping. 258 """ 259 flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) 260 261 normalized_mapping: t.Dict = {} 262 for keys in flattened_schema: 263 columns = nested_get(schema, *zip(keys, keys)) 264 assert columns is not None 265 266 normalized_keys = [self._normalize_name(key) for key in keys] 267 for column_name, column_type in columns.items(): 268 nested_set( 269 normalized_mapping, 270 normalized_keys + [self._normalize_name(column_name)], 271 column_type, 272 ) 273 274 return normalized_mapping 275 276 def _normalize_table(self, table: exp.Table) -> exp.Table: 277 normalized_table = table.copy() 278 for arg in TABLE_ARGS: 279 value = normalized_table.args.get(arg) 280 if isinstance(value, (str, exp.Identifier)): 281 normalized_table.set(arg, self._normalize_name(value)) 282 283 return normalized_table 284 285 def _normalize_name(self, name: str | exp.Identifier) -> str: 286 try: 287 identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier) 288 except ParseError: 289 return name if isinstance(name, str) else name.name 290 291 return identifier.name if identifier.quoted else identifier.name.lower() 292 293 def _depth(self) -> int: 294 # The columns themselves are a mapping, but we don't want to include those 295 return super()._depth() - 1 296 297 def _ensure_table(self, table: exp.Table | str) -> exp.Table: 298 if isinstance(table, exp.Table): 299 return table 300 301 table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table) 302 if not table_: 303 raise SchemaError(f"Not a valid table '{table}'") 304 305 return table_ 306 307 def _to_data_type(self, schema_type: str) -> exp.DataType: 308 """ 309 Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. 310 311 Args: 312 schema_type: the type we want to convert. 313 314 Returns: 315 The resulting expression type. 316 """ 317 if schema_type not in self._type_mapping_cache: 318 try: 319 expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) 320 if expression is None: 321 raise ValueError(f"Could not parse {schema_type}") 322 self._type_mapping_cache[schema_type] = expression # type: ignore 323 except AttributeError: 324 raise SchemaError(f"Failed to convert type {schema_type}") 325 326 return self._type_mapping_cache[schema_type]
Schema based on a nested mapping.
Arguments:
- schema (dict): Mapping in one of the following forms:
- {table: {col: type}}
- {db: {table: {col: type}}}
- {catalog: {db: {table: {col: type}}}}
- None - Tables will be added later
- visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns
are assumed to be visible. The nesting should mirror that of the schema:
- {table: set(cols)}}
- {db: {table: set(cols)}}}
- {catalog: {db: {table: set(*cols)}}}}
- dialect (str): The dialect to be used for custom type mappings.
162 def __init__( 163 self, 164 schema: t.Optional[t.Dict] = None, 165 visible: t.Optional[t.Dict] = None, 166 dialect: DialectType = None, 167 ) -> None: 168 self.dialect = dialect 169 self.visible = visible or {} 170 self._type_mapping_cache: t.Dict[str, exp.DataType] = {} 171 super().__init__(self._normalize(schema or {}))
191 def add_table( 192 self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None 193 ) -> None: 194 """ 195 Register or update a table. Updates are only performed if a new column mapping is provided. 196 197 Args: 198 table: the `Table` expression instance or string representing the table. 199 column_mapping: a column mapping that describes the structure of the table. 200 """ 201 normalized_table = self._normalize_table(self._ensure_table(table)) 202 normalized_column_mapping = { 203 self._normalize_name(key): value 204 for key, value in ensure_column_mapping(column_mapping).items() 205 } 206 207 schema = self.find(normalized_table, raise_on_missing=False) 208 if schema and not normalized_column_mapping: 209 return 210 211 parts = self.table_parts(normalized_table) 212 213 nested_set( 214 self.mapping, 215 tuple(reversed(parts)), 216 normalized_column_mapping, 217 ) 218 new_trie([parts], self.mapping_trie)
Register or update a table. Updates are only performed if a new column mapping is provided.
Arguments:
- table: the
Table
expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table.
220 def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: 221 table_ = self._normalize_table(self._ensure_table(table)) 222 schema = self.find(table_) 223 224 if schema is None: 225 return [] 226 227 if not only_visible or not self.visible: 228 return list(schema) 229 230 visible = self.nested_get(self.table_parts(table_), self.visible) 231 return [col for col in schema if col in visible] # type: ignore
Get the column names for a table.
Arguments:
- table: the
Table
expression instance. - only_visible: whether to include invisible columns.
Returns:
The list of column names.
233 def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: 234 column_name = self._normalize_name(column if isinstance(column, str) else column.this) 235 table_ = self._normalize_table(self._ensure_table(table)) 236 237 table_schema = self.find(table_, raise_on_missing=False) 238 if table_schema: 239 column_type = table_schema.get(column_name) 240 241 if isinstance(column_type, exp.DataType): 242 return column_type 243 elif isinstance(column_type, str): 244 return self._to_data_type(column_type.upper()) 245 raise SchemaError(f"Unknown column type '{column_type}'") 246 247 return exp.DataType.build("unknown")
Get the sqlglot.exp.DataType
type of a column in the schema.
Arguments:
- table: the source table.
- column: the target column.
Returns:
The resulting column type.
Inherited Members
336def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: 337 if isinstance(mapping, dict): 338 return mapping 339 elif isinstance(mapping, str): 340 col_name_type_strs = [x.strip() for x in mapping.split(",")] 341 return { 342 name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() 343 for name_type_str in col_name_type_strs 344 } 345 # Check if mapping looks like a DataFrame StructType 346 elif hasattr(mapping, "simpleString"): 347 return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore 348 elif isinstance(mapping, list): 349 return {x.strip(): None for x in mapping} 350 elif mapping is None: 351 return {} 352 raise ValueError(f"Invalid mapping provided: {type(mapping)}")
355def flatten_schema( 356 schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None 357) -> t.List[t.List[str]]: 358 tables = [] 359 keys = keys or [] 360 361 for k, v in schema.items(): 362 if depth >= 2: 363 tables.extend(flatten_schema(v, depth - 1, keys + [k])) 364 elif depth == 1: 365 tables.append(keys + [k]) 366 return tables
369def nested_get( 370 d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True 371) -> t.Optional[t.Any]: 372 """ 373 Get a value for a nested dictionary. 374 375 Args: 376 d: the dictionary to search. 377 *path: tuples of (name, key), where: 378 `key` is the key in the dictionary to get. 379 `name` is a string to use in the error if `key` isn't found. 380 381 Returns: 382 The value or None if it doesn't exist. 383 """ 384 for name, key in path: 385 d = d.get(key) # type: ignore 386 if d is None: 387 if raise_on_missing: 388 name = "table" if name == "this" else name 389 raise ValueError(f"Unknown {name}: {key}") 390 return None 391 return d
Get a value for a nested dictionary.
Arguments:
- d: the dictionary to search.
- *path: tuples of (name, key), where:
key
is the key in the dictionary to get.name
is a string to use in the error ifkey
isn't found.
Returns:
The value or None if it doesn't exist.
394def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: 395 """ 396 In-place set a value for a nested dictionary 397 398 Example: 399 >>> nested_set({}, ["top_key", "second_key"], "value") 400 {'top_key': {'second_key': 'value'}} 401 402 >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") 403 {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} 404 405 Args: 406 d: dictionary to update. 407 keys: the keys that makeup the path to `value`. 408 value: the value to set in the dictionary for the given key path. 409 410 Returns: 411 The (possibly) updated dictionary. 412 """ 413 if not keys: 414 return d 415 416 if len(keys) == 1: 417 d[keys[0]] = value 418 return d 419 420 subd = d 421 for key in keys[:-1]: 422 if key not in subd: 423 subd = subd.setdefault(key, {}) 424 else: 425 subd = subd[key] 426 427 subd[keys[-1]] = value 428 return d
In-place set a value for a nested dictionary
Example:
>>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}}
>>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
Arguments:
- d: dictionary to update.
- keys: the keys that makeup the path to
value
. - value: the value to set in the dictionary for the given key path.
Returns:
The (possibly) updated dictionary.