diff --git a/dbt/adapters/sqlserver/sqlserver_column.py b/dbt/adapters/sqlserver/sqlserver_column.py index 68ef98e3..a207fefd 100644 --- a/dbt/adapters/sqlserver/sqlserver_column.py +++ b/dbt/adapters/sqlserver/sqlserver_column.py @@ -1,4 +1,5 @@ from dbt.adapters.fabric import FabricColumn +from dbt_common.exceptions import DbtRuntimeError class SQLServerColumn(FabricColumn): @@ -20,3 +21,31 @@ def is_integer(self) -> bool: "serial8", "int", ] + + def is_string(self) -> bool: + return self.dtype.lower() in [ + "text", + "character varying", + "character", + "varchar", + "nvarchar", + ] + + def string_size(self) -> int: + if not self.is_string(): + raise DbtRuntimeError("Called string_size() on non-string field!") + + if self.dtype == "text" or self.char_size is None: + # char_size should never be None. Handle it reasonably just in case + return 256 + elif self.dtype.lower() == "nvarchar": + # char_size is doubled for nvarchar + return int(self.char_size // 2) + else: + return int(self.char_size) + + def string_type(self, size: int) -> str: + if self.dtype: + return f"{self.dtype}({size if size > 0 else '8000'})" + else: + return f"varchar({size if size > 0 else '8000'})" diff --git a/dbt/include/sqlserver/macros/adapter/columns.sql b/dbt/include/sqlserver/macros/adapter/columns.sql index a98750e7..59d19803 100644 --- a/dbt/include/sqlserver/macros/adapter/columns.sql +++ b/dbt/include/sqlserver/macros/adapter/columns.sql @@ -48,3 +48,36 @@ {% do run_query(rename_column) %} {% endmacro %} + +{% macro sqlserver__get_columns_in_relation(relation) -%} + {% set query_label = apply_label() %} + {% call statement('get_columns_in_relation', fetch_result=True) %} + {{ get_use_database_sql(relation.database) }} + with mapping as ( + select + row_number() over (partition by object_name(c.object_id) order by c.column_id) as ordinal_position, + c.name collate database_default as column_name, + t.name as data_type, + c.max_length as character_maximum_length, + c.precision as numeric_precision, + c.scale as numeric_scale + from sys.columns c {{ information_schema_hints() }} + inner join sys.types t {{ information_schema_hints() }} + on c.user_type_id = t.user_type_id + where c.object_id = object_id('{{ 'tempdb..' ~ relation.include(database=false, schema=false) if '#' in relation.identifier else relation }}') + ) + + select + column_name, + data_type, + character_maximum_length, + numeric_precision, + numeric_scale + from mapping + order by ordinal_position + {{ query_label }} + + {% endcall %} + {% set table = load_result('get_columns_in_relation').table %} + {{ return(sql_convert_columns_in_relation(table)) }} +{% endmacro %}