Source code for locopy.snowflake

# SPDX-Copyright: Copyright (c) Capital One Services, LLC
# SPDX-License-Identifier: Apache-2.0
# Copyright 2018 Capital One Services, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Snowflake Module
Module to wrap a database adapter into a Snowflake class which can be used to connect
to Snowflake, and run arbitrary code.
"""
import os
from pathlib import PurePath

from .database import Database
from .errors import DBError, S3CredentialsError
from .logger import INFO, get_logger
from .s3 import S3
from .utility import find_column_type

logger = get_logger(__name__, INFO)

COPY_FORMAT_OPTIONS = {
    "csv": {
        "compression",
        "record_delimiter",
        "field_delimiter",
        "skip_header",
        "date_format",
        "time_format",
        "timestamp_format",
        "binary_format",
        "escape",
        "escape_unenclosed_field",
        "trim_space",
        "field_optionally_enclosed_by",
        "null_if",
        "error_on_column_count_mismatch",
        "validate_utf8",
        "empty_field_as_null",
        "skip_byte_order_mark",
        "encoding",
    },
    "json": {
        "compression",
        "file_extension",
        "enable_octal",
        "allow_duplicate",
        "strip_outer_array",
        "strip_null_values",
        "ignore_utf8_errors",
        "skip_byte_order_mark",
    },
    "parquet": {"binary_as_text"},
}


UNLOAD_FORMAT_OPTIONS = {
    "csv": {
        "compression",
        "record_delimiter",
        "field_delimiter",
        "file_extension",
        "date_format",
        "time_format",
        "timestamp_format",
        "binary_format",
        "escape",
        "escape_unenclosed_field",
        "field_optionally_enclosed_by",
        "null_if",
    },
    "json": {"compression", "file_extension"},
    "parquet": {"snappy_compression"},
}


[docs]def combine_options(options=None): """Returns the ``copy_options`` or ``format_options`` attribute with spaces in between and as a string. If options is ``None`` then return an empty string. Parameters ---------- options : list, optional list of strings which is to be converted into a single string with spaces inbetween. Defaults to ``None`` Returns ------- str: ``options`` attribute with spaces in between """ return " ".join(options) if options is not None else ""
[docs]class Snowflake(S3, Database): """Locopy class which manages connections to Snowflake. Inherits ``Database`` and implements the specific ``COPY INTO`` functionality. Parameters ---------- profile : str, optional The name of the AWS profile to use which is typically stored in the ``credentials`` file. You can also set environment variable ``AWS_DEFAULT_PROFILE`` which would be used instead. kms_key : str, optional The KMS key to use for encryption If kms_key Defaults to ``None`` then the AES256 ServerSideEncryption will be used. dbapi : DBAPI 2 module, optional A database adapter which is Python DB API 2.0 compliant (``snowflake.connector``) config_yaml : str, optional String representing the YAML file location of the database connection keyword arguments. It is worth noting that this should only contain valid arguments for the database connector you plan on using. It will throw an exception if something is passed through which isn't valid. **kwargs Database connection keyword arguments. Attributes ---------- profile : str String representing the AWS profile for authentication kms_key : str String representing the s3 kms key session : boto3.Session Hold the AWS session credentials / info s3 : botocore.client.S3 Hold the S3 client object which is used to upload/delete files to S3 dbapi : DBAPI 2 module database adapter which is Python DBAPI 2.0 compliant (snowflake.connector) connection : dict Dictionary of database connection items conn : dbapi.connection DBAPI connection instance cursor : dbapi.cursor DBAPI cursor instance Raises ------ CredentialsError Database credentials are not provided or valid S3Error Error initializing AWS Session (ex: invalid profile) S3CredentialsError Issue with AWS credentials S3InitializationError Issue initializing S3 session """ def __init__( self, profile=None, kms_key=None, dbapi=None, config_yaml=None, **kwargs ): try: S3.__init__(self, profile, kms_key) except S3CredentialsError: logger.warning( "S3 credentials were not found. S3 functionality is disabled" ) logger.warning("Only internal stages are available") Database.__init__(self, dbapi, config_yaml, **kwargs)
[docs] def connect(self): """Creates a connection to the Snowflake cluster by setting the values of the ``conn`` and ``cursor`` attributes. Raises ------ DBError If there is a problem establishing a connection to Snowflake. """ super(Snowflake, self).connect() if self.connection.get("warehouse") is not None: self.execute("USE WAREHOUSE {0}".format(self.connection["warehouse"])) if self.connection.get("database") is not None: self.execute("USE DATABASE {0}".format(self.connection["database"])) if self.connection.get("schema") is not None: self.execute("USE SCHEMA {0}".format(self.connection["schema"]))
[docs] def upload_to_internal( self, local, stage, parallel=4, auto_compress=True, overwrite=True ): """ Upload file(s) to a internal stage via the ``PUT`` command. Parameters ---------- local : str The local directory path to the file to upload. Wildcard characters (``*``, ``?``) are supported to enable uploading multiple files in a directory. Otherwise it must be the absolute path. stage : str Internal stage location to load the file. parallel : int, optional Specifies the number of threads to use for uploading files. auto_compress : bool, optional Specifies if Snowflake uses gzip to compress files during upload. If ``True``, the files are compressed (if they are not already compressed). if ``False``, the files are uploaded as-is. overwrite : bool, optional Specifies whether Snowflake overwrites an existing file with the same name during upload. If ``True``, existing file with the same name is overwritten. if ``False``, existing file with the same name is not overwritten. """ local_uri = PurePath(local).as_posix() self.execute( "PUT 'file://{0}' {1} PARALLEL={2} AUTO_COMPRESS={3} OVERWRITE={4}".format( local_uri, stage, parallel, auto_compress, overwrite ) )
[docs] def download_from_internal(self, stage, local=None, parallel=10): """ Download file(s) from a internal stage via the ``GET`` command. Parameters ---------- stage : str Internal stage location to load the file. local : str, optional The local directory path where files will be downloaded to. Defualts to the current working directory (``os.getcwd()``). Otherwise it must be the absolute path. parallel : int, optional Specifies the number of threads to use for downloading files. """ if local is None: local = os.getcwd() local_uri = PurePath(local).as_posix() self.execute( "GET {0} 'file://{1}' PARALLEL={2}".format(stage, local_uri, parallel) )
[docs] def copy( self, table_name, stage, file_type="csv", format_options=None, copy_options=None ): """Executes the ``COPY INTO <table>`` command to load CSV files from a stage into a Snowflake table. If ``file_type == csv`` and ``format_options == None``, ``format_options`` will default to: ``["FIELD_DELIMITER='|'", "SKIP_HEADER=0"]`` Parameters ---------- table_name : str The Snowflake table name which is being loaded. Must be fully qualified: `<namespace>.<table_name>` stage : str Stage location of the load file. This can be a internal or external stage file_type : str The file type. One of ``csv``, ``json``, or ``parquet`` format_options : list List of strings of format options to provide to the ``COPY INTO`` command. The options will typically be in the format of ``["a=b", "c=d"]`` copy_options : list List of strings of copy options to provide to the ``COPY INTO`` command. Raises ------ DBError If there is a problem executing the COPY command, or a connection has not been initalized. """ if not self._is_connected(): raise DBError("No Snowflake connection object is present.") if file_type not in COPY_FORMAT_OPTIONS: raise ValueError( "Invalid file_type. Must be one of {0}".format( list(COPY_FORMAT_OPTIONS.keys()) ) ) if format_options is None and file_type == "csv": format_options = ["FIELD_DELIMITER='|'", "SKIP_HEADER=0"] format_options_text = combine_options(format_options) copy_options_text = combine_options(copy_options) base_copy_string = ( "COPY INTO {0} FROM '{1}' " "FILE_FORMAT = (TYPE='{2}' {3}) {4}" ) try: sql = base_copy_string.format( table_name, stage, file_type, format_options_text, copy_options_text ) self.execute(sql, commit=True) except Exception as e: logger.error("Error running COPY on Snowflake. err: %s", e) raise DBError("Error running COPY on Snowflake.")
[docs] def unload( self, stage, table_name, file_type="csv", format_options=None, header=False, copy_options=None, ): """Executes the ``COPY INTO <location>`` command to export a query/table from Snowflake to a stage. If ``file_type == csv`` and ``format_options == None``, ``format_options`` will default to: ``["FIELD_DELIMITER='|'"]`` Parameters ---------- stage : str Stage location (internal or external) where the data files are unloaded table_name : str The Snowflake table name which is being unloaded. Must be fully qualified: ``<namespace>.<table_name>`` file_type : str The file type. One of ``csv``, ``json``, or ``parquet`` format_options : list List of strings of format options to provide to the ``COPY INTO`` command. header : bool, optional Boolean flag if header is included in the file(s) copy_options : list List of strings of copy options to provide to the ``COPY INTO`` command. Raises ------ DBError If there is a problem executing the UNLOAD command, or a connection has not been initalized. """ if not self._is_connected(): raise DBError("No Snowflake connection object is present") if file_type not in COPY_FORMAT_OPTIONS: raise ValueError( "Invalid file_type. Must be one of {0}".format( list(UNLOAD_FORMAT_OPTIONS.keys()) ) ) if format_options is None and file_type == "csv": format_options = ["FIELD_DELIMITER='|'"] format_options_text = combine_options(format_options) copy_options_text = combine_options(copy_options) base_unload_string = ( "COPY INTO {0} FROM {1} " "FILE_FORMAT = (TYPE='{2}' {3}) HEADER={4} {5}" ) try: sql = base_unload_string.format( stage, table_name, file_type, format_options_text, header, copy_options_text, ) self.execute(sql, commit=True) except Exception as e: logger.error("Error running UNLOAD on Snowflake. err: %s", e) raise DBError("Error running UNLOAD on Snowflake.")
[docs] def insert_dataframe_to_table( self, dataframe, table_name, columns=None, create=False, metadata=None ): """ Insert a Pandas dataframe to an existing table or a new table. In newer versions of the python snowflake connector (v2.1.2+) users can call the ``write_pandas`` method from the cursor directly, ``insert_dataframe_to_table`` is a custom implementation and does not use ``write_pandas``. Instead of using ``COPY INTO`` the method builds a list of tuples to insert directly into the table. There are also options to create the table if it doesn't exist and use your own metadata. If your data is significantly large then using ``COPY INTO <table>`` is more appropriate. Parameters ---------- dataframe: Pandas Dataframe The pandas dataframe which needs to be inserted. table_name: str The name of the Snowflake table which is being inserted. columns: list, optional The list of columns which will be uploaded. create: bool, optional Boolean flag if a new table need to be created and insert to. metadata: dictionary, optional If metadata==None, it will be generated based on data """ import pandas as pd if columns: dataframe = dataframe[columns] all_columns = columns or list(dataframe.columns) column_sql = "(" + ",".join(all_columns) + ")" string_join = "(" + ",".join(["%s"] * len(all_columns)) + ")" # create a list of tuples for insert to_insert = [] for row in dataframe.itertuples(index=False): none_row = tuple([None if pd.isnull(val) else str(val) for val in row]) to_insert.append(none_row) if not create and metadata: logger.warning("Metadata will not be used because create is set to False.") if create: if not metadata: logger.info("Metadata is missing. Generating metadata ...") metadata = find_column_type(dataframe, "snowflake") logger.info("Metadata is complete. Creating new table ...") create_join = ( "(" + ",".join( [ list(metadata.keys())[i] + " " + list(metadata.values())[i] for i in range(len(metadata)) ] ) + ")" ) column_sql = "(" + ",".join(list(metadata.keys())) + ")" create_query = "CREATE TABLE {table_name} {create_join}".format( table_name=table_name, create_join=create_join ) self.execute(create_query) logger.info("New table has been created") insert_query = """INSERT INTO {table_name} {columns} VALUES {values}""".format( table_name=table_name, columns=column_sql, values=string_join ) logger.info("Inserting records...") self.execute(insert_query, params=to_insert, many=True) logger.info("Table insertion has completed")
[docs] def to_dataframe(self, size=None): """Return a dataframe of the last query results. This is just a convenience method. This method overrides the base classes implementation in favour for the snowflake connectors built-in ``fetch_pandas_all`` when ``size==None``. If ``size != None`` then we will continue to use the existing functionality where we iterate through the cursor and build the dataframe. Parameters ---------- size : int, optional Chunk size to fetch. Defaults to None. Returns ------- pandas.DataFrame Dataframe with lowercase column names. Returns None if no fetched result. """ if size is None and self.cursor._query_result_format == "arrow": return self.cursor.fetch_pandas_all() else: return super(Snowflake, self).to_dataframe(size)