From 059531e89dc32a29b4401c901c95cce50d92360b Mon Sep 17 00:00:00 2001 From: drizk1 Date: Wed, 24 Jul 2024 11:52:02 -0400 Subject: [PATCH 1/6] adds 7 pkg extensions --- Project.toml | 28 +++- README.md | 2 +- docs/examples/UserGuide/getting_started.jl | 21 +-- docs/src/index.md | 2 +- ext/AWSExt.jl | 115 ++++++++++++++++ ext/CHExt.jl | 75 +++++++++++ ext/GBQExt.jl | 121 +++++++++++++++++ ext/LibPQExt.jl | 67 ++++++++++ ext/MySQLExt.jl | 75 +++++++++++ ext/ODBCExt.jl | 74 +++++++++++ ext/SQLiteExt.jl | 67 ++++++++++ src/TBD_macros.jl | 60 +++------ src/TidierDB.jl | 148 +++------------------ src/parsing_athena.jl | 69 ---------- src/parsing_gbq.jl | 100 -------------- 15 files changed, 663 insertions(+), 361 deletions(-) create mode 100644 ext/AWSExt.jl create mode 100644 ext/CHExt.jl create mode 100644 ext/GBQExt.jl create mode 100644 ext/LibPQExt.jl create mode 100644 ext/MySQLExt.jl create mode 100644 ext/ODBCExt.jl create mode 100644 ext/SQLiteExt.jl diff --git a/Project.toml b/Project.toml index 2498edd..327b94c 100644 --- a/Project.toml +++ b/Project.toml @@ -4,22 +4,33 @@ authors = ["Daniel Rizk and contributors"] version = "0.2.4" [deps] -AWS = "fbe9abb3-538b-5e4e-ba9e-bc94f4f92ebc" Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc" -ClickHouse = "82f2e89e-b495-11e9-1d9d-fb40d7cf2130" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DuckDB = "d2f5444f-75bc-4fdf-ac35-56f514c445e1" GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -GoogleCloud = "55e21f81-8b0a-565e-b5ad-6816892a5ee7" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -LibPQ = "194296ae-ab2e-5f79-8cd4-7183a0a5a0d1" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -MySQL = "39abe10b-433b-5dbd-92d4-e302a9df00cd" ODBC = "be6f12e9-ca4f-5eb2-a339-a4f995cc0291" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" + +[weakdeps] SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9" +LibPQ = "194296ae-ab2e-5f79-8cd4-7183a0a5a0d1" +GoogleCloud = "55e21f81-8b0a-565e-b5ad-6816892a5ee7" +AWS = "fbe9abb3-538b-5e4e-ba9e-bc94f4f92ebc" +MySQL = "39abe10b-433b-5dbd-92d4-e302a9df00cd" +ClickHouse = "82f2e89e-b495-11e9-1d9d-fb40d7cf2130" + + +[extensions] +SQLiteExt = "SQLite" +LibPQExt = "LibPQ" +GBQExt = "GoogleCloud" +AWSExt = "AWS" +MySQLExt = "MySQL" +CHExt = "ClickHouse" [compat] AWS = "1.9" @@ -44,6 +55,13 @@ julia = "1.9" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9" +LibPQ = "194296ae-ab2e-5f79-8cd4-7183a0a5a0d1" +GoogleCloud = "55e21f81-8b0a-565e-b5ad-6816892a5ee7" +AWS = "fbe9abb3-538b-5e4e-ba9e-bc94f4f92ebc" +MySQL = "39abe10b-433b-5dbd-92d4-e302a9df00cd" +ClickHouse = "82f2e89e-b495-11e9-1d9d-fb40d7cf2130" + [targets] test = ["Documenter", "Test"] diff --git a/README.md b/README.md index c5b53fa..e8d91c0 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ The main goal of TidierDB.jl is to bring the syntax of Tidier.jl to multiple SQL - Snowflake `set_sql_mode(:snowflake)` - Google Big Query `set_sql_mode(:gbq)` - Oracle `set_sql_mode(:oracle)` -- Databricks +- Databricks `set_sql_mode(:databricks)` The style of SQL that is generated can be modified using `set_sql_mode()`. diff --git a/docs/examples/UserGuide/getting_started.jl b/docs/examples/UserGuide/getting_started.jl index d933afb..37a3fd8 100644 --- a/docs/examples/UserGuide/getting_started.jl +++ b/docs/examples/UserGuide/getting_started.jl @@ -9,7 +9,7 @@ # Alternatively, `using Tidier` will import TidierDB in the above manner for you, where TidierDB functions and macros will be available as `DB.@mutate()` and so on, and the TidierData equivalent would be `@mutate()`. -# There are two ways to connect to the database. You can use `connect` without any need to load any additional packages. However, Oracle and Athena do not support this method yet and will require you to load in ODBC.jl or AWS.jl respectively. +# To connect to a database, you can uset the `connect` function as shown below, or establish your own connection through the respecitve libraries. # For example # Connecting to MySQL @@ -21,14 +21,15 @@ # conn = connect(:duckdb) # ``` -# Alternatively, you can use the packages outlined below to establish a connection through their respective methods. +# ## Package Extensions +# The following backends utilize package extensions. To use one of backends listed below, you will need to write `using Library` -# - ClickHouse: ClickHouse.jl -# - MySQL and MariaDB: MySQL.jl -# - MSSQL: ODBC.jl -# - Postgres: LibPQ.jl -# - SQLite: SQLite.jl -# - Athena: AWS.jl -# - Oracle: ODBC.jl +# - ClickHouse: `using ClickHouse` +# - MySQL and MariaDB: `using MySQL` +# - MSSQL: `using ODBC` +# - Postgres: `using LibPQ`` +# - SQLite: `using SQLite` +# - Athena: `using AWS` +# - Oracle: `using ODBC` +# - Google BigQuery: `using GoogleCloud` -# For DuckDB, SQLite, and MySQL, `copy_to()` lets you copy data to the database and query there. ClickHouse, MSSQL, and Postgres support for `copy_to()` has not been added yet. diff --git a/docs/src/index.md b/docs/src/index.md index f6fb99c..38080dc 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -18,7 +18,7 @@ The main goal of TidierDB.jl is to bring the syntax of Tidier.jl to multiple SQL - Snowflake `set_sql_mode(:snowflake)` - Google Big Query `set_sql_mode(:gbq)` - Oracle `set_sql_mode(:oracle)` -- Databricks +- Databricks `set_sql_mode(:databricks)` The style of SQL that is generated can be modified using `set_sql_mode()`. diff --git a/ext/AWSExt.jl b/ext/AWSExt.jl new file mode 100644 index 0000000..6bcd846 --- /dev/null +++ b/ext/AWSExt.jl @@ -0,0 +1,115 @@ +module AWSExt + +using TidierDB +using DataFrames +using AWS, HTTP, JSON3 +__init__() = println("Extension was loaded!") + + + +function collect_athena(result) + # Extract column names and types from the result set metadata + column_names = [col["Label"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]] + column_types = [col["Type"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]] + + # Process data rows, starting from the second row to skip header information + data_rows = result["ResultSet"]["Rows"] + filtered_column_names = filter(c -> !isempty(c), column_names) + num_columns = length(filtered_column_names) + + data_for_df = [ + [get(col, "VarCharValue", missing) for col in row["Data"]] for row in data_rows[2:end] + ] + + # Ensure each row has the correct number of elements + adjusted_data_for_df = [ + length(row) == num_columns ? row : resize!(copy(row), num_columns) for row in data_for_df + ] + + # Pad rows with missing values if they are shorter than expected + for row in adjusted_data_for_df + if length(row) < num_columns + append!(row, fill(missing, num_columns - length(row))) + end + end + + # Transpose the data to match DataFrame constructor requirements + data_transposed = permutedims(hcat(adjusted_data_for_df...)) + + # Create the DataFrame + df = DataFrame(data_transposed, Symbol.(filtered_column_names)) + TidierDB.parse_athena_df(df, column_types) + # Return the DataFrame + return df +end + +@service Athena + +function TidierDB.get_table_metadata(AWS_GLOBAL_CONFIG, table_name::String, athena_params) + schema, table = split(table_name, '.') # Ensure this correctly parses your input + query = """SELECT * FROM $schema.$table limit 0;""" + # println(query) + # try + exe_query = Athena.start_query_execution(query, athena_params; aws_config = AWS_GLOBAL_CONFIG) + + # Poll Athena to check if the query has completed + status = "RUNNING" + while status in ["RUNNING", "QUEUED"] + sleep(1) # Wait for 1 second before checking the status again to avoid flooding the API + query_status = Athena.get_query_execution(exe_query["QueryExecutionId"], athena_params; aws_config = AWS_GLOBAL_CONFIG) + status = query_status["QueryExecution"]["Status"]["State"] + if status == "FAILED" + error("Query failed: ", query_status["QueryExecution"]["Status"]["StateChangeReason"]) + elseif status == "CANCELLED" + error("Query was cancelled.") + end + end + + # Fetch the results once the query completes + result = Athena.get_query_results(exe_query["QueryExecutionId"], athena_params; aws_config = AWS_GLOBAL_CONFIG) + + column_names = [col["Label"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]] + column_types = [col["Type"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]] + df = DataFrame(name = column_names, type = column_types) + df[!, :current_selxn] .= 1 + df[!, :table_name] .= table_name + + return select(df, 1 => :name, 2 => :type, :current_selxn, :table_name) +end + + +function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) + if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :athena + final_query = TidierDB.finalize_query(sqlquery) + exe_query = Athena.start_query_execution(final_query, sqlquery.athena_params; aws_config = sqlquery.db) + status = "RUNNING" + while status in ["RUNNING", "QUEUED"] + sleep(1) # Wait for 1 second before checking the status again to avoid flooding the API + query_status = Athena.get_query_execution(exe_query["QueryExecutionId"], sqlquery.athena_params; aws_config = sqlquery.db) + status = query_status["QueryExecution"]["Status"]["State"] + if status == "FAILED" + error("Query failed: ", query_status["QueryExecution"]["Status"]["StateChangeReason"]) + elseif status == "CANCELLED" + error("Query was cancelled.") + end + end + result = Athena.get_query_results(exe_query["QueryExecutionId"], sqlquery.athena_params; aws_config = sqlquery.db) + return collect_athena(result) + end + +end + + +end + + + + + + + + diff --git a/ext/CHExt.jl b/ext/CHExt.jl new file mode 100644 index 0000000..dfd4d43 --- /dev/null +++ b/ext/CHExt.jl @@ -0,0 +1,75 @@ +module CHExt + +using TidierDB +using DataFrames +import ClickHouse +__init__() = println("Extension was loaded!") + +function TidierDB.connect(backend::Symbol; kwargs...) + if backend == :Clickhouse || backend == :clickhouse + set_sql_mode(:clickhouse) + if haskey(kwargs, :host) && haskey(kwargs, :port) + return ClickHouse.connect(kwargs[:host], kwargs[:port]; (k => v for (k, v) in kwargs if k ∉ [:host, :port])...) + else + throw(ArgumentError("Missing required positional arguments 'host' and 'port' for ClickHouse.")) + end + + elseif backend == :DuckDB || backend == :duckdb + set_sql_mode(:duckdb) + db = DBInterface.connect(DuckDB.DB, ":memory:") + DBInterface.execute(db, "SET autoinstall_known_extensions=1;") + DBInterface.execute(db, "SET autoload_known_extensions=1;") + + # Install and load the httpfs extension + DBInterface.execute(db, "INSTALL httpfs;") + DBInterface.execute(db, "LOAD httpfs;") + return db + else + throw(ArgumentError("Unsupported backend: $backend")) + end +end + + + # ClickHouse + function TidierDB.get_table_metadata(conn::ClickHouse.ClickHouseSock, table_name::String) + # Query to get column names and types from INFORMATION_SCHEMA + query = """ + SELECT + name AS column_name, + type AS data_type + FROM system.columns + WHERE table = '$table_name' AND database = 'default' + """ + result = ClickHouse.select_df(conn,query) + + result[!, :current_selxn] .= 1 + result[!, :table_name] .= table_name + # Adjust the select statement to include the new table_name column + return select(result, 1 => :name, 2 => :type, :current_selxn, :table_name) +end + + + +function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) + if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :clickhouse + final_query = TidierDB.finalize_query(sqlquery) + df_result = ClickHouse.select_df(sqlquery.db, final_query) + selected_columns_order = sqlquery.metadata[sqlquery.metadata.current_selxn .== 1, :name] + df_result = df_result[:, selected_columns_order] + return df_result + elseif TidierDB.current_sql_mode[] == :snowflake + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_snowflake(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :databricks + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_databricks(sqlquery.db, final_query) + return DataFrame(result) + end +end + +end diff --git a/ext/GBQExt.jl b/ext/GBQExt.jl new file mode 100644 index 0000000..83c7f6c --- /dev/null +++ b/ext/GBQExt.jl @@ -0,0 +1,121 @@ +module GBQExt + +using TidierDB +using DataFrames +using GoogleCloud, HTTP, JSON3 +__init__() = println("Extension was loaded!") + +mutable struct GBQ + projectname::String + session::GoogleSession + bigquery_resource + bigquery_method +end + +function TidierDB.connect(type::Symbol, json_key_path::String, project_id::String) + # Expand the user's path to the JSON key + creds_path = expanduser(json_key_path) + set_sql_mode(:gbq) + # Create credentials and session for Google Cloud + creds = JSONCredentials(creds_path) + session = GoogleSession(creds, ["https://www.googleapis.com/auth/bigquery"]) + + # Define the API method for BigQuery + bigquery_method = GoogleCloud.api.APIMethod( + :POST, + "https://bigquery.googleapis.com/bigquery/v2/projects/$(project_id)/queries", + "Run query", + Dict{Symbol, Any}(); + transform=(x, t) -> x + ) + + # Define the API resource for BigQuery + bigquery_resource = GoogleCloud.api.APIResource( + "https://bigquery.googleapis.com/bigquery/v2", + ;query=bigquery_method # Pass the method as a named argument + ) + + # Store all data in a global GBQ instance + global gbq_instance = GBQ(project_id, session, bigquery_resource, bigquery_method) + + # Return only the session + return session +end + +function collect_gbq(conn, query) + query_data = Dict( + "query" => query, + "useLegacySql" => false, + "location" => "US") + + response = GoogleCloud.api.execute( + conn, + gbq_instance.bigquery_resource, # Use the resource from GBQ + gbq_instance.bigquery_method, + data=query_data + ) + response_string = String(response) + response_data = JSON3.read(response_string) + rows = get(response_data, "rows", []) + + # Convert rows to DataFrame + # First, extract column names from the schema + column_names = [field["name"] for field in response_data["schema"]["fields"]] + column_types = [field["type"] for field in response_data["schema"]["fields"]] + # Then, convert each row's data (currently nested inside dicts with key "v") into arrays of dicts + if !isempty(rows) + # Return an empty DataFrame with the correct columns but 0 rows + data = [get(row["f"][i], "v", missing) for row in rows, i in 1:length(column_names)] + df = DataFrame(data, Symbol.(column_names)) + df = TidierDB.parse_gbq_df(df, column_types) + return df + else + # Convert each row's data (nested inside dicts with key "v") into arrays of dicts + df =DataFrame([Vector{Union{Missing, Any}}(undef, 0) for _ in column_names], Symbol.(column_names)) + df = TidierDB.parse_gbq_df(df, column_types) + return df + end + + return df +end +function TidierDB.get_table_metadata(conn::GoogleSession{JSONCredentials}, table_name::String) + query = " SELECT * FROM + $table_name LIMIT 0 + ;" + query_data = Dict( + "query" => query, + "useLegacySql" => false, + "location" => "US") + # Define the API resource + + response = GoogleCloud.api.execute( + conn, + gbq_instance.bigquery_resource, + gbq_instance.bigquery_method, + data=query_data + ) + response_string = String(response) + response_data = JSON3.read(response_string) + column_names = [field["name"] for field in response_data["schema"]["fields"]] + column_types = [field["type"] for field in response_data["schema"]["fields"]] + result = DataFrame(name = column_names, type = column_types) + result[!, :current_selxn] .= 1 + result[!, :table_name] .= table_name + + return select(result, 1 => :name, 2 => :type, :current_selxn, :table_name) +end + + +function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) + if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :gbq + final_query = TidierDB.finalize_query(sqlquery) + return collect_gbq(sqlquery.db, final_query) + + end +end + +end diff --git a/ext/LibPQExt.jl b/ext/LibPQExt.jl new file mode 100644 index 0000000..1851920 --- /dev/null +++ b/ext/LibPQExt.jl @@ -0,0 +1,67 @@ +module LibPQExt + +using TidierDB +using DataFrames +using LibPQ +__init__() = println("Extension was loaded!") + +function TidierDB.connect(backend::Symbol; kwargs...) + if backend == :Postgres || backend == :postgres + set_sql_mode(:postgres) + # Construct a connection string from kwargs for LibPQ + conn_str = join(["$(k)=$(v)" for (k, v) in kwargs], " ") + return LibPQ.Connection(conn_str) + elseif backend == :DuckDB || backend == :duckdb + set_sql_mode(:duckdb) + db = DBInterface.connect(DuckDB.DB, ":memory:") + DBInterface.execute(db, "SET autoinstall_known_extensions=1;") + DBInterface.execute(db, "SET autoload_known_extensions=1;") + + # Install and load the httpfs extension + DBInterface.execute(db, "INSTALL httpfs;") + DBInterface.execute(db, "LOAD httpfs;") + return db + else + throw(ArgumentError("Unsupported backend: $backend")) + end +end + + +function TidierDB.get_table_metadata(conn::LibPQ.Connection, table_name::String) + query = """ + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_name = '$table_name' + ORDER BY ordinal_position; + """ + result = LibPQ.execute(conn, query) |> DataFrame + result[!, :current_selxn] .= 1 + result[!, :table_name] .= table_name + # Adjust the select statement to include the new table_name column + return select(result, 1 => :name, 2 => :type, :current_selxn, :table_name) +end + + +# In SQLiteExt.jl +function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) + if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql || TidierDB.current_sql_mode[] == :mssql || TidierDB.current_sql_mode[] == :mariadb + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :snowflake + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_snowflake(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :databricks + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_databricks(sqlquery.db, final_query) + return DataFrame(result) + end +end + +# In DuckDBExt.jl + + + + +end diff --git a/ext/MySQLExt.jl b/ext/MySQLExt.jl new file mode 100644 index 0000000..2699260 --- /dev/null +++ b/ext/MySQLExt.jl @@ -0,0 +1,75 @@ +module MySQLExt + +using TidierDB +using DataFrames +using MySQL +__init__() = println("Extension was loaded!") + +function TidierDB.connect(backend::Symbol; kwargs...) + if backend == :MySQL || backend == :mysql + set_sql_mode(:mysql) + + # Required parameters by MySQL.jl: host and user + host = get(kwargs, :host, "localhost") + user = get(kwargs, :user, "") + password = get(kwargs, :password, "") + # Extract other optional parameters + db = get(kwargs, :db, nothing) + port = get(kwargs, :port, nothing) + return DBInterface.connect(MySQL.Connection, host, user, password; db=db, port=port) + + elseif backend == :DuckDB || backend == :duckdb + set_sql_mode(:duckdb) + db = DBInterface.connect(DuckDB.DB, ":memory:") + DBInterface.execute(db, "SET autoinstall_known_extensions=1;") + DBInterface.execute(db, "SET autoload_known_extensions=1;") + + # Install and load the httpfs extension + DBInterface.execute(db, "INSTALL httpfs;") + DBInterface.execute(db, "LOAD httpfs;") + return db + else + throw(ArgumentError("Unsupported backend: $backend")) + end +end + + +# MySQL +function TidierDB.get_table_metadata(conn::MySQL.Connection, table_name::String) + # Query to get column names and types from INFORMATION_SCHEMA + query = """ + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_name = '$table_name' + AND TABLE_SCHEMA = '$(conn.db)' + ORDER BY ordinal_position; + """ + + result = DBInterface.execute(conn, query) |> DataFrame + result[!, 2] = map(x -> String(x), result[!, 2]) + result[!, :current_selxn] .= 1 + result[!, :table_name] .= table_name + # Adjust the select statement to include the new table_name column + return DataFrames.select(result, :1 => :name, 2 => :type, :current_selxn, :table_name) +end + + +function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) + if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql || TidierDB.current_sql_mode[] == :mssql || TidierDB.current_sql_mode[] == :mariadb + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :snowflake + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_snowflake(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :databricks + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_databricks(sqlquery.db, final_query) + return DataFrame(result) + end +end + + + +end diff --git a/ext/ODBCExt.jl b/ext/ODBCExt.jl new file mode 100644 index 0000000..ca3ed53 --- /dev/null +++ b/ext/ODBCExt.jl @@ -0,0 +1,74 @@ +module ODBCExt + +using TidierDB +using DataFrames +using ODBC +__init__() = println("Extension was loaded!") + +function TidierDB.connect(backend::Symbol; kwargs...) + if backend == :SQLite || backend == :lite + db_path = get(kwargs, :db, ":memory:") + set_sql_mode(:lite) + return SQLite.DB(db_path) + elseif backend == :DuckDB || backend == :duckdb + set_sql_mode(:duckdb) + db = DBInterface.connect(DuckDB.DB, ":memory:") + DBInterface.execute(db, "SET autoinstall_known_extensions=1;") + DBInterface.execute(db, "SET autoload_known_extensions=1;") + + # Install and load the httpfs extension + DBInterface.execute(db, "INSTALL httpfs;") + DBInterface.execute(db, "LOAD httpfs;") + return db + else + throw(ArgumentError("Unsupported backend: $backend")) + end +end + + + +# MSSQL +function TidierDB.get_table_metadata(conn::ODBC.Connection, table_name::String) + if current_sql_mode[] == :oracle + table_name = uppercase(table_name) + query = """ + SELECT column_name, data_type + FROM all_tab_columns + WHERE table_name = '$table_name' + ORDER BY column_id + """ + else + query = """ + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_name = '$table_name' + ORDER BY ordinal_position; + """ + end + + result = DBInterface.execute(conn, query) |> DataFrame + result[!, :current_selxn] .= 1 + result[!, :table_name] .= table_name + # Adjust the select statement to include the new table_name column + return select(result, :column_name => :name, :data_type => :type, :current_selxn, :table_name) +end + + +function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) + if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql || TidierDB.current_sql_mode[] == :mssql || TidierDB.current_sql_mode[] == :oracle + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :snowflake + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_snowflake(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :databricks + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_databricks(sqlquery.db, final_query) + return DataFrame(result) + end +end + + +end diff --git a/ext/SQLiteExt.jl b/ext/SQLiteExt.jl new file mode 100644 index 0000000..2486947 --- /dev/null +++ b/ext/SQLiteExt.jl @@ -0,0 +1,67 @@ +module SQLiteExt + +using TidierDB +using DataFrames +using SQLite +__init__() = println("Extension was loaded!") + +function TidierDB.connect(backend::Symbol; kwargs...) + if backend == :SQLite || backend == :lite + db_path = get(kwargs, :db, ":memory:") + set_sql_mode(:lite) + return SQLite.DB(db_path) + elseif backend == :DuckDB || backend == :duckdb + set_sql_mode(:duckdb) + db = DBInterface.connect(DuckDB.DB, ":memory:") + DBInterface.execute(db, "SET autoinstall_known_extensions=1;") + DBInterface.execute(db, "SET autoload_known_extensions=1;") + + # Install and load the httpfs extension + DBInterface.execute(db, "INSTALL httpfs;") + DBInterface.execute(db, "LOAD httpfs;") + return db + else + throw(ArgumentError("Unsupported backend: $backend")) + end +end + + + +function TidierDB.get_table_metadata(db::SQLite.DB, table_name::String) + query = "PRAGMA table_info($table_name);" + result = SQLite.DBInterface.execute(db, query) |> DataFrame + result[!, :current_selxn] .= 1 + resize!(result.current_selxn, nrow(result)) + result[!, :table_name] .= table_name + # Adjust the select statement to include the new table_name column + return DataFrames.select(result, 2 => :name, 3 => :type, :current_selxn, :table_name) +end + +function TidierDB.copy_to(conn::SQLite.DB, df::DataFrame, name::String) + SQLite.load!(df, conn, name) +end + + +# In SQLiteExt.jl +function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) + if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql || TidierDB.current_sql_mode[] == :mssql + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :snowflake + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_snowflake(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :databricks + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_databricks(sqlquery.db, final_query) + return DataFrame(result) + end +end + +# In DuckDBExt.jl + + + + +end diff --git a/src/TBD_macros.jl b/src/TBD_macros.jl index 7099da0..b191895 100644 --- a/src/TBD_macros.jl +++ b/src/TBD_macros.jl @@ -653,49 +653,25 @@ macro show_query(sqlquery) end end -macro collect(sqlquery) - return quote - # Extract the database connection from the SQLQuery object - db = $(esc(sqlquery)).db - sq = $(esc(sqlquery)) - # Finalize the query to get the SQL string - final_query = finalize_query($(esc(sqlquery))) - df_result = DataFrame() - # Determine the type of db and execute the query accordingly - if db isa DatabricksConnection - df_result = execute_databricks(db, final_query) - elseif db isa SQLite.DB || db isa LibPQ.Connection || db isa DuckDB.DB || db isa MySQL.Connection || db isa ODBC.Connection - result = DBInterface.execute(db, final_query) - df_result = DataFrame(result) - elseif current_sql_mode[] == :clickhouse - df_result = ClickHouse.select_df(db, final_query) - selected_columns_order = sq.metadata[sq.metadata.current_selxn .== 1, :name] - df_result = df_result[:, selected_columns_order] - elseif db isa GoogleSession{JSONCredentials} - df_result = collect_gbq(sq.db, final_query) +function final_collect(sqlquery::TidierDB.SQLQuery) + if current_sql_mode[] ==:duckdb + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) elseif current_sql_mode[] == :snowflake - df_result = execute_snowflake(db, final_query) - elseif current_sql_mode[] == :athena - exe_query = Athena.start_query_execution(final_query, sq.athena_params; aws_config = db) - status = "RUNNING" - while status in ["RUNNING", "QUEUED"] - sleep(1) # Wait for 1 second before checking the status again to avoid flooding the API - query_status = Athena.get_query_execution(exe_query["QueryExecutionId"], sq.athena_params; aws_config = db) - status = query_status["QueryExecution"]["Status"]["State"] - if status == "FAILED" - error("Query failed: ", query_status["QueryExecution"]["Status"]["StateChangeReason"]) - elseif status == "CANCELLED" - error("Query was cancelled.") - end + final_query = TidierDB.finalize_query(sqlquery) + result = execute_snowflake(sqlquery.db, final_query) + return DataFrame(result) + elseif current_sql_mode[] == :databricks + final_query = TidierDB.finalize_query(sqlquery) + result = execute_databricks(sqlquery.db, final_query) + return DataFrame(result) end - - # Fetch the results once the query completes - result = Athena.get_query_results(exe_query["QueryExecutionId"], sq.athena_params; aws_config = db) - df_result = collect_athena(result) - else - error("Unsupported database type: $(typeof(db))") - end - df_result - end end + +macro collect(sqlquery) + return quote + final_collect($(esc(sqlquery))) + end +end \ No newline at end of file diff --git a/src/TidierDB.jl b/src/TidierDB.jl index 82f4f18..7a5a160 100644 --- a/src/TidierDB.jl +++ b/src/TidierDB.jl @@ -1,19 +1,18 @@ module TidierDB -using LibPQ +#using LibPQ using DataFrames using MacroTools using Chain -using SQLite +#using SQLite using Reexport using DuckDB -using MySQL -using ODBC -import ClickHouse +#using MySQL +#using ODBC +#import ClickHouse using Arrow -using AWS -using JSON3 -using GoogleCloud +#using AWS +#using GoogleCloud using HTTP using JSON3 using GZip @@ -21,8 +20,7 @@ using GZip @reexport using DataFrames: DataFrame @reexport using Chain @reexport using DuckDB -import DuckDB: open as duckdb_open -import DuckDB: connect as duckdb_connect + #using TidierDB export db_table, set_sql_mode, @arrange, @group_by, @filter, @select, @mutate, @summarize, @summarise, @@ -149,29 +147,9 @@ function finalize_query(sqlquery::SQLQuery) end -function get_table_metadata(db::SQLite.DB, table_name::String) - query = "PRAGMA table_info($table_name);" - result = SQLite.DBInterface.execute(db, query) |> DataFrame - result[!, :current_selxn] .= 1 - resize!(result.current_selxn, nrow(result)) - result[!, :table_name] .= table_name - # Adjust the select statement to include the new table_name column - return DataFrames.select(result, 2 => :name, 3 => :type, :current_selxn, :table_name) -end -function get_table_metadata(conn::LibPQ.Connection, table_name::String) - query = """ - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_name = '$table_name' - ORDER BY ordinal_position; - """ - result = LibPQ.execute(conn, query) |> DataFrame - result[!, :current_selxn] .= 1 - result[!, :table_name] .= table_name - # Adjust the select statement to include the new table_name column - return select(result, 1 => :name, 2 => :type, :current_selxn, :table_name) -end + + # DuckDB @@ -193,68 +171,10 @@ function get_table_metadata(conn::DuckDB.DB, table_name::String) return select(result, 1 => :name, 2 => :type, :current_selxn, :table_name) end -# MySQL -function get_table_metadata(conn::MySQL.Connection, table_name::String) - # Query to get column names and types from INFORMATION_SCHEMA - query = """ - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_name = '$table_name' - AND TABLE_SCHEMA = '$(conn.db)' - ORDER BY ordinal_position; - """ - - result = DBInterface.execute(conn, query) |> DataFrame - result[!, 2] = map(x -> String(x), result[!, 2]) - result[!, :current_selxn] .= 1 - result[!, :table_name] .= table_name - # Adjust the select statement to include the new table_name column - return DataFrames.select(result, :1 => :name, 2 => :type, :current_selxn, :table_name) -end -# MSSQL -function get_table_metadata(conn::ODBC.Connection, table_name::String) - if current_sql_mode[] == :oracle - table_name = uppercase(table_name) - query = """ - SELECT column_name, data_type - FROM all_tab_columns - WHERE table_name = '$table_name' - ORDER BY column_id - """ - else - query = """ - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_name = '$table_name' - ORDER BY ordinal_position; - """ - end - result = DBInterface.execute(conn, query) |> DataFrame - result[!, :current_selxn] .= 1 - result[!, :table_name] .= table_name - # Adjust the select statement to include the new table_name column - return select(result, :column_name => :name, :data_type => :type, :current_selxn, :table_name) -end - # ClickHouse -function get_table_metadata(conn::ClickHouse.ClickHouseSock, table_name::String) - # Query to get column names and types from INFORMATION_SCHEMA - query = """ - SELECT - name AS column_name, - type AS data_type - FROM system.columns - WHERE table = '$table_name' AND database = 'default' - """ - result = ClickHouse.select_df(conn,query) - result[!, :current_selxn] .= 1 - result[!, :table_name] .= table_name - # Adjust the select statement to include the new table_name column - return select(result, 1 => :name, 2 => :type, :current_selxn, :table_name) -end """ $docstring_db_table @@ -283,8 +203,8 @@ function db_table(db, table, athena_params::Any=nothing; iceberg::Bool=false, de metadata = get_table_metadata(db, table_name) end elseif current_sql_mode[] == :athena - metadata = get_table_metadata_athena(db, table_name, athena_params) - elseif current_sql_mode[] == :snowflake + metadata = get_table_metadata(db, table_name, athena_params) + elseif current_sql_mode[] == :snowflake || current_sql_mode[] == :databricks metadata = get_table_metadata(db, table_name) else error("Unsupported SQL mode: $(current_sql_mode[])") @@ -292,7 +212,7 @@ function db_table(db, table, athena_params::Any=nothing; iceberg::Bool=false, de formatted_table_name = if current_sql_mode[] == :snowflake "$(db.database).$(db.schema).$table_name" - elseif db isa DatabricksConnection + elseif db isa DatabricksConnection || current_sql_mode[] == :databricks "$(db.database).$(db.schema).$table_name" elseif iceberg "iceberg_scan('$table_name', allow_moved_paths = true)" @@ -315,12 +235,6 @@ function copy_to(conn, df_or_path::Union{DataFrame, AbstractString}, name::Strin if isa(df_or_path, DataFrame) if current_sql_mode[] == :duckdb DuckDB.register_data_frame(conn, df_or_path, name) - elseif current_sql_mode[] == :lite - SQLite.load!(df_or_path, conn, name) - elseif current_sql_mode[] == :mysql - MySQL.load(df_or_path, conn, name) - else - error("Unsupported SQL mode: $(current_sql_mode[])") end # If the input is not a DataFrame, treat it as a file path elseif isa(df_or_path, AbstractString) @@ -364,39 +278,7 @@ end $docstring_connect """ function connect(backend::Symbol; kwargs...) - if backend == :MySQL || backend == :mysql - set_sql_mode(:mysql) - - # Required parameters by MySQL.jl: host and user - host = get(kwargs, :host, "localhost") - user = get(kwargs, :user, "") - password = get(kwargs, :password, "") - # Extract other optional parameters - db = get(kwargs, :db, nothing) - port = get(kwargs, :port, nothing) - return DBInterface.connect(MySQL.Connection, host, user, password; db=db, port=port) - elseif backend == :Postgres || backend == :postgres - set_sql_mode(:postgres) - # Construct a connection string from kwargs for LibPQ - conn_str = join(["$(k)=$(v)" for (k, v) in kwargs], " ") - return LibPQ.Connection(conn_str) - elseif backend == :MsSQL || backend == :mssql - set_sql_mode(:mssql) - # Construct a connection string for ODBC if required for MsSQL - conn_str = join(["$(k)=$(v)" for (k, v) in kwargs], ";") - return ODBC.Connection(conn_str) - elseif backend == :Clickhouse || backend == :clickhouse - set_sql_mode(:clickhouse) - if haskey(kwargs, :host) && haskey(kwargs, :port) - return ClickHouse.connect(kwargs[:host], kwargs[:port]; (k => v for (k, v) in kwargs if k ∉ [:host, :port])...) - else - throw(ArgumentError("Missing required positional arguments 'host' and 'port' for ClickHouse.")) - end - elseif backend == :SQLite || backend == :lite - db_path = get(kwargs, :db, ":memory:") - set_sql_mode(:lite) - return SQLite.DB(db_path) - elseif backend == :DuckDB || backend == :duckdb + if backend == :DuckDB || backend == :duckdb set_sql_mode(:duckdb) db = DBInterface.connect(DuckDB.DB, ":memory:") DBInterface.execute(db, "SET autoinstall_known_extensions=1;") @@ -459,7 +341,7 @@ function connect(backend_type::Symbol, db_type::Symbol; access_key::String="", s end function connect(symbol, token::String) - if token == "md:" + if token == "md:" return DBInterface.connect(DuckDB.DB, "md:") else return DBInterface.connect(DuckDB.DB, "md:$token") diff --git a/src/parsing_athena.jl b/src/parsing_athena.jl index 54d5c89..0ea4e97 100644 --- a/src/parsing_athena.jl +++ b/src/parsing_athena.jl @@ -177,72 +177,3 @@ function parse_athena_df(df, column_types) end -function collect_athena(result) - # Extract column names and types from the result set metadata - column_names = [col["Label"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]] - column_types = [col["Type"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]] - - # Process data rows, starting from the second row to skip header information - data_rows = result["ResultSet"]["Rows"] - filtered_column_names = filter(c -> !isempty(c), column_names) - num_columns = length(filtered_column_names) - - data_for_df = [ - [get(col, "VarCharValue", missing) for col in row["Data"]] for row in data_rows[2:end] - ] - - # Ensure each row has the correct number of elements - adjusted_data_for_df = [ - length(row) == num_columns ? row : resize!(copy(row), num_columns) for row in data_for_df - ] - - # Pad rows with missing values if they are shorter than expected - for row in adjusted_data_for_df - if length(row) < num_columns - append!(row, fill(missing, num_columns - length(row))) - end - end - - # Transpose the data to match DataFrame constructor requirements - data_transposed = permutedims(hcat(adjusted_data_for_df...)) - - # Create the DataFrame - df = DataFrame(data_transposed, Symbol.(filtered_column_names)) - parse_athena_df(df, column_types) - # Return the DataFrame - return df -end - -@service Athena - -function get_table_metadata_athena(AWS_GLOBAL_CONFIG, table_name::String, athena_params) - schema, table = split(table_name, '.') # Ensure this correctly parses your input - query = """SELECT * FROM $schema.$table limit 0;""" - # println(query) - # try - exe_query = Athena.start_query_execution(query, athena_params; aws_config = AWS_GLOBAL_CONFIG) - - # Poll Athena to check if the query has completed - status = "RUNNING" - while status in ["RUNNING", "QUEUED"] - sleep(1) # Wait for 1 second before checking the status again to avoid flooding the API - query_status = Athena.get_query_execution(exe_query["QueryExecutionId"], athena_params; aws_config = AWS_GLOBAL_CONFIG) - status = query_status["QueryExecution"]["Status"]["State"] - if status == "FAILED" - error("Query failed: ", query_status["QueryExecution"]["Status"]["StateChangeReason"]) - elseif status == "CANCELLED" - error("Query was cancelled.") - end - end - - # Fetch the results once the query completes - result = Athena.get_query_results(exe_query["QueryExecutionId"], athena_params; aws_config = AWS_GLOBAL_CONFIG) - - column_names = [col["Label"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]] - column_types = [col["Type"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]] - df = DataFrame(name = column_names, type = column_types) - df[!, :current_selxn] .= 1 - df[!, :table_name] .= table_name - - return select(df, 1 => :name, 2 => :type, :current_selxn, :table_name) -end diff --git a/src/parsing_gbq.jl b/src/parsing_gbq.jl index fd64a4c..f570f3b 100644 --- a/src/parsing_gbq.jl +++ b/src/parsing_gbq.jl @@ -1,78 +1,4 @@ -mutable struct GBQ - projectname::String - session::GoogleSession - bigquery_resource - bigquery_method -end - -function connect(type::Symbol, json_key_path::String, project_id::String) - # Expand the user's path to the JSON key - creds_path = expanduser(json_key_path) - set_sql_mode(:gbq) - # Create credentials and session for Google Cloud - creds = JSONCredentials(creds_path) - session = GoogleSession(creds, ["https://www.googleapis.com/auth/bigquery"]) - - # Define the API method for BigQuery - bigquery_method = GoogleCloud.api.APIMethod( - :POST, - "https://bigquery.googleapis.com/bigquery/v2/projects/$(project_id)/queries", - "Run query", - Dict{Symbol, Any}(); - transform=(x, t) -> x - ) - - # Define the API resource for BigQuery - bigquery_resource = GoogleCloud.api.APIResource( - "https://bigquery.googleapis.com/bigquery/v2", - ;query=bigquery_method # Pass the method as a named argument - ) - - # Store all data in a global GBQ instance - global gbq_instance = GBQ(project_id, session, bigquery_resource, bigquery_method) - - # Return only the session - return session -end - -function collect_gbq(conn, query) - query_data = Dict( - "query" => query, - "useLegacySql" => false, - "location" => "US") - - response = GoogleCloud.api.execute( - conn, - gbq_instance.bigquery_resource, # Use the resource from GBQ - gbq_instance.bigquery_method, - data=query_data - ) - response_string = String(response) - response_data = JSON3.read(response_string) - rows = get(response_data, "rows", []) - - # Convert rows to DataFrame - # First, extract column names from the schema - column_names = [field["name"] for field in response_data["schema"]["fields"]] - column_types = [field["type"] for field in response_data["schema"]["fields"]] - # Then, convert each row's data (currently nested inside dicts with key "v") into arrays of dicts - if !isempty(rows) - # Return an empty DataFrame with the correct columns but 0 rows - data = [get(row["f"][i], "v", missing) for row in rows, i in 1:length(column_names)] - df = DataFrame(data, Symbol.(column_names)) - df = parse_gbq_df(df, column_types) - return df - else - # Convert each row's data (nested inside dicts with key "v") into arrays of dicts - df =DataFrame([Vector{Union{Missing, Any}}(undef, 0) for _ in column_names], Symbol.(column_names)) - df = parse_gbq_df(df, column_types) - return df - end - - return df -end - function apply_type_conversion_gbq(df, col_index, col_type) if col_type == "FLOAT" @@ -100,32 +26,6 @@ function parse_gbq_df(df, column_types) return df end -function get_table_metadata(conn::GoogleSession{JSONCredentials}, table_name::String) - query = " SELECT * FROM - $table_name LIMIT 0 - ;" - query_data = Dict( - "query" => query, - "useLegacySql" => false, - "location" => "US") - # Define the API resource - - response = GoogleCloud.api.execute( - conn, - gbq_instance.bigquery_resource, - gbq_instance.bigquery_method, - data=query_data - ) - response_string = String(response) - response_data = JSON3.read(response_string) - column_names = [field["name"] for field in response_data["schema"]["fields"]] - column_types = [field["type"] for field in response_data["schema"]["fields"]] - result = DataFrame(name = column_names, type = column_types) - result[!, :current_selxn] .= 1 - result[!, :table_name] .= table_name - - return select(result, 1 => :name, 2 => :type, :current_selxn, :table_name) -end function expr_to_sql_gbq(expr, sq; from_summarize::Bool) expr = parse_char_matching(expr) From ac1571d49e5bc717af4ae14c5c5ae518cb0b630c Mon Sep 17 00:00:00 2001 From: drizk1 Date: Wed, 24 Jul 2024 15:51:09 -0400 Subject: [PATCH 2/6] fix a couple small bugs, clickhouse connect --- NEWS.md | 5 +++++ ext/AWSExt.jl | 8 ++++++++ ext/CHExt.jl | 10 +++++++++- ext/GBQExt.jl | 9 ++++++++- ext/LibPQExt.jl | 5 ----- src/TidierDB.jl | 2 ++ 6 files changed, 32 insertions(+), 7 deletions(-) diff --git a/NEWS.md b/NEWS.md index ec1cdb4..0059c81 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,10 @@ # TidierDB.jl updates +## v0.3. - 2024-07-25 +- Introduces package extensions for + - Postgres, ClickHouse, MySQL, MsSQL, SQLite, Oracle, Athena, and Google BigQuery + - (Documentation)[https://tidierorg.github.io/TidierDB.jl/latest/examples/generated/UserGuide/getting_started/] updated for using these backends. + ## v0.2.4 - 2024-07-12 - Switches to DuckDB to 1.0 version - Adds support for `iceberg` tables via DuckDB to read iceberg paths in `db_table` when `iceberg = true` diff --git a/ext/AWSExt.jl b/ext/AWSExt.jl index 6bcd846..4b30486 100644 --- a/ext/AWSExt.jl +++ b/ext/AWSExt.jl @@ -99,6 +99,14 @@ function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) end result = Athena.get_query_results(exe_query["QueryExecutionId"], sqlquery.athena_params; aws_config = sqlquery.db) return collect_athena(result) + elseif TidierDB.current_sql_mode[] == :snowflake + final_query = TidierDB.finalize_query(sqlquery) + result = execute_snowflake(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :databricks + final_query = TidierDB.finalize_query(sqlquery) + result = execute_databricks(sqlquery.db, final_query) + return DataFrame(result) end end diff --git a/ext/CHExt.jl b/ext/CHExt.jl index dfd4d43..09f3f6a 100644 --- a/ext/CHExt.jl +++ b/ext/CHExt.jl @@ -9,7 +9,15 @@ function TidierDB.connect(backend::Symbol; kwargs...) if backend == :Clickhouse || backend == :clickhouse set_sql_mode(:clickhouse) if haskey(kwargs, :host) && haskey(kwargs, :port) - return ClickHouse.connect(kwargs[:host], kwargs[:port]; (k => v for (k, v) in kwargs if k ∉ [:host, :port])...) + kwargs_filtered = Dict{Symbol, Any}() + for (k, v) in kwargs + if k == :user + kwargs_filtered[:username] = v + elseif k ∉ [:host, :port] + kwargs_filtered[k] = v + end + end + return ClickHouse.connect(kwargs[:host], kwargs[:port]; kwargs_filtered...) else throw(ArgumentError("Missing required positional arguments 'host' and 'port' for ClickHouse.")) end diff --git a/ext/GBQExt.jl b/ext/GBQExt.jl index 83c7f6c..92697b7 100644 --- a/ext/GBQExt.jl +++ b/ext/GBQExt.jl @@ -114,7 +114,14 @@ function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) elseif TidierDB.current_sql_mode[] == :gbq final_query = TidierDB.finalize_query(sqlquery) return collect_gbq(sqlquery.db, final_query) - + elseif TidierDB.current_sql_mode[] == :snowflake + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_snowflake(sqlquery.db, final_query) + return DataFrame(result) + elseif TidierDB.current_sql_mode[] == :databricks + final_query = TidierDB.finalize_query(sqlquery) + result = TidierDB.execute_databricks(sqlquery.db, final_query) + return DataFrame(result) end end diff --git a/ext/LibPQExt.jl b/ext/LibPQExt.jl index 1851920..4ea033a 100644 --- a/ext/LibPQExt.jl +++ b/ext/LibPQExt.jl @@ -42,7 +42,6 @@ function TidierDB.get_table_metadata(conn::LibPQ.Connection, table_name::String) end -# In SQLiteExt.jl function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql || TidierDB.current_sql_mode[] == :mssql || TidierDB.current_sql_mode[] == :mariadb final_query = TidierDB.finalize_query(sqlquery) @@ -59,9 +58,5 @@ function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) end end -# In DuckDBExt.jl - - - end diff --git a/src/TidierDB.jl b/src/TidierDB.jl index 7a5a160..d658586 100644 --- a/src/TidierDB.jl +++ b/src/TidierDB.jl @@ -76,6 +76,8 @@ function expr_to_sql(expr, sq; from_summarize::Bool = false) return expr_to_sql_oracle(expr, sq; from_summarize=from_summarize) elseif current_sql_mode[] == :snowflake return expr_to_sql_snowflake(expr, sq; from_summarize=from_summarize) + elseif current_sql_mode[] == :databricks + return expr_to_sql_duckdb(expr, sq; from_summarize=from_summarize) else error("Unsupported SQL mode: $(current_sql_mode[])") end From a0c7d3e885b1812aade00f91a08aec1aa933008f Mon Sep 17 00:00:00 2001 From: drizk1 Date: Wed, 24 Jul 2024 17:20:27 -0400 Subject: [PATCH 3/6] couple other small fixes --- ext/AWSExt.jl | 6 +++--- ext/GBQExt.jl | 3 +-- src/parsing_snowflake.jl | 1 + 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ext/AWSExt.jl b/ext/AWSExt.jl index 4b30486..4bcf370 100644 --- a/ext/AWSExt.jl +++ b/ext/AWSExt.jl @@ -72,7 +72,7 @@ function TidierDB.get_table_metadata(AWS_GLOBAL_CONFIG, table_name::String, athe column_types = [col["Type"] for col in result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]] df = DataFrame(name = column_names, type = column_types) df[!, :current_selxn] .= 1 - df[!, :table_name] .= table_name + df[!, :table_name] .= split(table_name, ".")[2] return select(df, 1 => :name, 2 => :type, :current_selxn, :table_name) end @@ -101,11 +101,11 @@ function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) return collect_athena(result) elseif TidierDB.current_sql_mode[] == :snowflake final_query = TidierDB.finalize_query(sqlquery) - result = execute_snowflake(sqlquery.db, final_query) + result = TidierDB.execute_snowflake(sqlquery.db, final_query) return DataFrame(result) elseif TidierDB.current_sql_mode[] == :databricks final_query = TidierDB.finalize_query(sqlquery) - result = execute_databricks(sqlquery.db, final_query) + result = TidierDB.execute_databricks(sqlquery.db, final_query) return DataFrame(result) end diff --git a/ext/GBQExt.jl b/ext/GBQExt.jl index 92697b7..f3788fd 100644 --- a/ext/GBQExt.jl +++ b/ext/GBQExt.jl @@ -100,8 +100,7 @@ function TidierDB.get_table_metadata(conn::GoogleSession{JSONCredentials}, table column_types = [field["type"] for field in response_data["schema"]["fields"]] result = DataFrame(name = column_names, type = column_types) result[!, :current_selxn] .= 1 - result[!, :table_name] .= table_name - + result[!, :table_name] .= split(table_name, ".")[2] return select(result, 1 => :name, 2 => :type, :current_selxn, :table_name) end diff --git a/src/parsing_snowflake.jl b/src/parsing_snowflake.jl index 4e6485b..87419fe 100644 --- a/src/parsing_snowflake.jl +++ b/src/parsing_snowflake.jl @@ -238,6 +238,7 @@ function execute_snowflake(conn::SnowflakeConnection, sql_query::String) end function get_table_metadata(conn::SnowflakeConnection, table_name::String) + table_name = uppercase(table_name) query = """ SELECT COLUMN_NAME, DATA_TYPE FROM $(conn.database).INFORMATION_SCHEMA.COLUMNS From 10391de79e5d90716a3e1429bf51ef79d1dabb01 Mon Sep 17 00:00:00 2001 From: drizk1 Date: Thu, 25 Jul 2024 09:58:11 -0400 Subject: [PATCH 4/6] change api to types from symbols --- README.md | 36 +++---- docs/examples/UserGuide/Snowflake.jl | 1 + docs/examples/UserGuide/athena.jl | 2 +- docs/examples/UserGuide/databricks.jl | 25 ++--- docs/examples/UserGuide/from_queryex.jl | 6 +- docs/examples/UserGuide/getting_started.jl | 4 +- docs/examples/UserGuide/key_differences.jl | 2 +- docs/examples/UserGuide/s3viaduckdb.jl | 4 +- docs/src/index.md | 36 +++---- ext/AWSExt.jl | 43 +++----- ext/CHExt.jl | 45 ++------ ext/GBQExt.jl | 25 ++--- ext/LibPQExt.jl | 37 +------ ext/MySQLExt.jl | 39 ++----- ext/ODBCExt.jl | 44 ++------ ext/SQLiteExt.jl | 40 ++----- src/TBD_macros.jl | 64 +++++++---- src/TidierDB.jl | 120 +++++++++++---------- src/db_parsing.jl | 8 +- src/docstrings.jl | 74 ++++++------- src/joins_sq.jl | 2 +- src/structs.jl | 6 -- 22 files changed, 263 insertions(+), 400 deletions(-) diff --git a/README.md b/README.md index e8d91c0..582066f 100644 --- a/README.md +++ b/README.md @@ -14,19 +14,19 @@ The main goal of TidierDB.jl is to bring the syntax of Tidier.jl to multiple SQL ## Currently supported backends include: -- DuckDB (the default) `set_sql_mode(:duckdb)` -- ClickHouse `set_sql_mode(:clickhouse)` -- SQLite `set_sql_mode(:lite)` -- MySQL and MariaDB `set_sql_mode(:mysql)` -- MSSQL `set_sql_mode(:mssql)` -- Postgres `set_sql_mode(:postgres)` -- Athena `set_sql_mode(:athena)` -- Snowflake `set_sql_mode(:snowflake)` -- Google Big Query `set_sql_mode(:gbq)` -- Oracle `set_sql_mode(:oracle)` -- Databricks `set_sql_mode(:databricks)` - -The style of SQL that is generated can be modified using `set_sql_mode()`. +- DuckDB (the default) `duckdb()` +- ClickHouse `clickhouse()` +- SQLite `sqlite()` +- MySQL and MariaDB `mysql()` +- MSSQL `mssql()` +- Postgres `postgres()` +- Athena `athena()` +- Snowflake `snowflake()` +- Google Big Query `gbq()` +- Oracle `oracle()` +- Databricks `databricks()` + +Change the backend using `set_sql_mode()` - for example - `set_sql_mode(databricks())` ## Installation @@ -96,9 +96,9 @@ using TidierData import TidierDB as DB db = DB.connect(:duckdb); -path = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv" +path_or_name = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv" -@chain DB.db_table(db, path) begin +@chain DB.db_table(db, path_or_name) begin DB.@filter(!starts_with(model, "M")) DB.@group_by(cyl) DB.@summarize(mpg = mean(mpg)) @@ -128,7 +128,7 @@ end We cannot do this using TidierDB. However, we can call `@pivot_longer()` from TidierData *after* the result of the query has been instantiated as a DataFrame, like this: ```julia -@chain DB.db_table(db, :mtcars) begin +@chain DB.db_table(db, path_or_name) begin DB.@filter(!starts_with(model, "M")) DB.@group_by(cyl) DB.@summarize(mpg = mean(mpg)) @@ -167,7 +167,7 @@ end We can replace `DB.collect()` with `DB.@show_query` to reveal the underlying SQL query being generated by TidierDB. To handle complex queries, TidierDB makes heavy use of Common Table Expressions (CTE), which are a useful tool to organize long queries. ```julia -@chain DB.db_table(db, :mtcars) begin +@chain DB.db_table(db, path_or_name) begin DB.@filter(!starts_with(model, "M")) DB.@group_by(cyl) DB.@summarize(mpg = mean(mpg)) @@ -207,7 +207,7 @@ SELECT * ## TidierDB is already quite fully-featured, supporting advanced TidierData functions like `across()` for multi-column selection. ```julia -@chain DB.db_table(db, :mtcars) begin +@chain DB.db_table(db, path_or_name) begin DB.@group_by(cyl) DB.@summarize(across((starts_with("a"), ends_with("s")), (mean, sum))) DB.@collect diff --git a/docs/examples/UserGuide/Snowflake.jl b/docs/examples/UserGuide/Snowflake.jl index 760846c..0b6a15a 100644 --- a/docs/examples/UserGuide/Snowflake.jl +++ b/docs/examples/UserGuide/Snowflake.jl @@ -15,6 +15,7 @@ # - Allow you to build a a SQL query and `@show_query` even if the OAuth_token has expired. To `@collect` you will have to reconnect and rerun db_table if your OAuth token has expired # ```julia +# set_sql_mode(snowflake()) # ac_id = "string_id" # token = "OAuth_token_string" # con = connect(:snowflake, ac_id, token, "DEMODB", "PUBLIC", "COMPUTE_WH") diff --git a/docs/examples/UserGuide/athena.jl b/docs/examples/UserGuide/athena.jl index daf6d8b..03c65a8 100644 --- a/docs/examples/UserGuide/athena.jl +++ b/docs/examples/UserGuide/athena.jl @@ -5,7 +5,7 @@ # ```julia # using TidierDB, AWS -# set_sql_mode(:athena) +# set_sql_mode(athena()) # # Replace your credentials as needed below # aws_access_key_id = get(ENV,"AWS_ACCESS_KEY_ID","key") # aws_secret_access_key = get(ENV, "AWS_SECRET_ACCESS_KEY","secret_key") diff --git a/docs/examples/UserGuide/databricks.jl b/docs/examples/UserGuide/databricks.jl index 0c7f8e7..be861db 100644 --- a/docs/examples/UserGuide/databricks.jl +++ b/docs/examples/UserGuide/databricks.jl @@ -12,6 +12,7 @@ # Since each time `db_table` runs, it runs a query to pull the metadata, you may choose to use run `db_table` and save the results, and use these results with `from_query()`. This will reduce the number of queries to your database and is illustrated below. # ```julia +# set_sql_mode(databricks()) # instance_id = "string_id" # token "string_token" # warehouse_id = "e673cd4f387f964a" @@ -26,18 +27,18 @@ # end # ``` # ``` -# 32×2 DataFrame +# 32×2 DataFrame # Row │ wt test -# │ Float64 Float64 +# │ Float64 Float64 # ─────┼────────────────── -# 1 │ 2.62 5.24 -# 2 │ 2.875 5.75 -# 3 │ 2.32 4.64 -# 4 │ 3.215 6.43 -# ⋮ │ ⋮ ⋮ -# 29 │ 3.17 6.34 -# 30 │ 2.77 5.54 -# 31 │ 3.57 7.14 -# 32 │ 2.78 5.56 -# 24 rows omitted +# 1 │ 2.62 5.24 +# 2 │ 2.875 5.75 +# 3 │ 2.32 4.64 +# 4 │ 3.215 6.43 +# ⋮ │ ⋮ ⋮ +# 29 │ 3.17 6.34 +# 30 │ 2.77 5.54 +# 31 │ 3.57 7.14 +# 32 │ 2.78 5.56 +# 24 rows omitted # ``` \ No newline at end of file diff --git a/docs/examples/UserGuide/from_queryex.jl b/docs/examples/UserGuide/from_queryex.jl index 7a68d89..785233a 100644 --- a/docs/examples/UserGuide/from_queryex.jl +++ b/docs/examples/UserGuide/from_queryex.jl @@ -2,13 +2,13 @@ # ```julia # import TidierDB as DB -# con = DB.connect(:duckdb) -# DB.copy_to(con, "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv", "mtcars2") +# con = DB.connect(duckdb()) +# mtcars_path = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv" # ``` # Start a query to analyze fuel efficiency by number of cylinders. However, to further build on this query later, end the chain without using `@show_query` or `@collect` # ```julia -# query = DB.@chain DB.db_table(con, :mtcars2) begin +# query = DB.@chain DB.db_table(con, mtcars_path) begin # DB.@group_by cyl # DB.@summarize begin # across(mpg, (mean, minimum, maximum)) diff --git a/docs/examples/UserGuide/getting_started.jl b/docs/examples/UserGuide/getting_started.jl index 37a3fd8..5397feb 100644 --- a/docs/examples/UserGuide/getting_started.jl +++ b/docs/examples/UserGuide/getting_started.jl @@ -14,11 +14,11 @@ # For example # Connecting to MySQL # ```julia -# conn = connect(:mysql; host="localhost", user="root", password="password", db="mydb") +# conn = connect(mysql(); host="localhost", user="root", password="password", db="mydb") # ``` # versus connecting to DuckDB # ```julia -# conn = connect(:duckdb) +# conn = connect(duckdb()) # ``` # ## Package Extensions diff --git a/docs/examples/UserGuide/key_differences.jl b/docs/examples/UserGuide/key_differences.jl index ece29ee..ae50ca3 100644 --- a/docs/examples/UserGuide/key_differences.jl +++ b/docs/examples/UserGuide/key_differences.jl @@ -11,7 +11,7 @@ df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -db = connect(:duckdb); +db = connect(duckdb()); copy_to(db, df, "df_mem"); # copying over the data frame to an in-memory database diff --git a/docs/examples/UserGuide/s3viaduckdb.jl b/docs/examples/UserGuide/s3viaduckdb.jl index 1c36309..85daed5 100644 --- a/docs/examples/UserGuide/s3viaduckdb.jl +++ b/docs/examples/UserGuide/s3viaduckdb.jl @@ -8,10 +8,10 @@ # Using TidierDB # # #Connect to Google Cloud via DuckDB -# #google_db = connect(:duckdb, :gbq, access_key="string", secret_key="string") +# #google_db = connect(duckdb(), :gbq, access_key="string", secret_key="string") # #Connect to AWS via DuckDB -# aws_db = connect(:duckdb, :aws, aws_access_key_id= "string", +# aws_db = connect(duckdb(), :aws, aws_access_key_id= "string", # aws_secret_access_key= "string", # aws_region="us-east-1") # s3_csv_path = "s3://path/to_data.csv" diff --git a/docs/src/index.md b/docs/src/index.md index 38080dc..54c4ad4 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,19 +8,19 @@ The main goal of TidierDB.jl is to bring the syntax of Tidier.jl to multiple SQL ## Currently supported backends include: -- DuckDB (the default) `set_sql_mode(:duckdb)` -- ClickHouse `set_sql_mode(:clickhouse)` -- SQLite `set_sql_mode(:lite)` -- MySQL and MariaDB `set_sql_mode(:mysql)` -- MSSQL `set_sql_mode(:mssql)` -- Postgres `set_sql_mode(:postgres)` -- Athena `set_sql_mode(:athena)` -- Snowflake `set_sql_mode(:snowflake)` -- Google Big Query `set_sql_mode(:gbq)` -- Oracle `set_sql_mode(:oracle)` -- Databricks `set_sql_mode(:databricks)` - -The style of SQL that is generated can be modified using `set_sql_mode()`. +- DuckDB (the default) `duckdb()` +- ClickHouse `clickhouse()` +- SQLite `sqlite()` +- MySQL and MariaDB `mysql()` +- MSSQL `mssql()` +- Postgres `postgres()` +- Athena `athena()` +- Snowflake `snowflake()` +- Google Big Query `gbq()` +- Oracle `oracle()` +- Databricks `databricks()` + +Change the backend using `set_sql_mode()` - for example - `set_sql_mode(databricks())` ## Installation @@ -90,9 +90,9 @@ using TidierData import TidierDB as DB db = DB.connect(:duckdb); -path = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv" +path_or_name = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv" -@chain DB.db_table(db, path) begin +@chain DB.db_table(db, path_or_name) begin DB.@filter(!starts_with(model, "M")) DB.@group_by(cyl) DB.@summarize(mpg = mean(mpg)) @@ -122,7 +122,7 @@ end We cannot do this using TidierDB. However, we can call `@pivot_longer()` from TidierData *after* the result of the query has been instantiated as a DataFrame, like this: ```julia -@chain DB.db_table(db, :mtcars) begin +@chain DB.db_table(db, path_or_name) begin DB.@filter(!starts_with(model, "M")) DB.@group_by(cyl) DB.@summarize(mpg = mean(mpg)) @@ -161,7 +161,7 @@ end We can replace `DB.collect()` with `DB.@show_query` to reveal the underlying SQL query being generated by TidierDB. To handle complex queries, TidierDB makes heavy use of Common Table Expressions (CTE), which are a useful tool to organize long queries. ```julia -@chain DB.db_table(db, :mtcars) begin +@chain DB.db_table(db, path_or_name) begin DB.@filter(!starts_with(model, "M")) DB.@group_by(cyl) DB.@summarize(mpg = mean(mpg)) @@ -201,7 +201,7 @@ SELECT * ## TidierDB is already quite fully-featured, supporting advanced TidierData functions like `across()` for multi-column selection. ```julia -@chain DB.db_table(db, :mtcars) begin +@chain DB.db_table(db, path_or_name) begin DB.@group_by(cyl) DB.@summarize(across((starts_with("a"), ends_with("s")), (mean, sum))) DB.@collect diff --git a/ext/AWSExt.jl b/ext/AWSExt.jl index 4bcf370..6947f2c 100644 --- a/ext/AWSExt.jl +++ b/ext/AWSExt.jl @@ -78,37 +78,22 @@ function TidierDB.get_table_metadata(AWS_GLOBAL_CONFIG, table_name::String, athe end -function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) - if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres - final_query = TidierDB.finalize_query(sqlquery) - result = DBInterface.execute(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :athena - final_query = TidierDB.finalize_query(sqlquery) - exe_query = Athena.start_query_execution(final_query, sqlquery.athena_params; aws_config = sqlquery.db) - status = "RUNNING" - while status in ["RUNNING", "QUEUED"] - sleep(1) # Wait for 1 second before checking the status again to avoid flooding the API - query_status = Athena.get_query_execution(exe_query["QueryExecutionId"], sqlquery.athena_params; aws_config = sqlquery.db) - status = query_status["QueryExecution"]["Status"]["State"] - if status == "FAILED" - error("Query failed: ", query_status["QueryExecution"]["Status"]["StateChangeReason"]) - elseif status == "CANCELLED" - error("Query was cancelled.") - end +function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:athena}) + final_query = TidierDB.finalize_query(sqlquery) + exe_query = Athena.start_query_execution(final_query, sqlquery.athena_params; aws_config = sqlquery.db) + status = "RUNNING" + while status in ["RUNNING", "QUEUED"] + sleep(1) # Wait for 1 second before checking the status again to avoid flooding the API + query_status = Athena.get_query_execution(exe_query["QueryExecutionId"], sqlquery.athena_params; aws_config = sqlquery.db) + status = query_status["QueryExecution"]["Status"]["State"] + if status == "FAILED" + error("Query failed: ", query_status["QueryExecution"]["Status"]["StateChangeReason"]) + elseif status == "CANCELLED" + error("Query was cancelled.") end - result = Athena.get_query_results(exe_query["QueryExecutionId"], sqlquery.athena_params; aws_config = sqlquery.db) - return collect_athena(result) - elseif TidierDB.current_sql_mode[] == :snowflake - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_snowflake(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :databricks - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_databricks(sqlquery.db, final_query) - return DataFrame(result) end - + result = Athena.get_query_results(exe_query["QueryExecutionId"], sqlquery.athena_params; aws_config = sqlquery.db) + return collect_athena(result) end diff --git a/ext/CHExt.jl b/ext/CHExt.jl index 09f3f6a..decdc5a 100644 --- a/ext/CHExt.jl +++ b/ext/CHExt.jl @@ -5,9 +5,8 @@ using DataFrames import ClickHouse __init__() = println("Extension was loaded!") -function TidierDB.connect(backend::Symbol; kwargs...) - if backend == :Clickhouse || backend == :clickhouse - set_sql_mode(:clickhouse) +function TidierDB.connect(::clickhouse; kwargs...) + set_sql_mode(clickhouse()) if haskey(kwargs, :host) && haskey(kwargs, :port) kwargs_filtered = Dict{Symbol, Any}() for (k, v) in kwargs @@ -21,20 +20,6 @@ function TidierDB.connect(backend::Symbol; kwargs...) else throw(ArgumentError("Missing required positional arguments 'host' and 'port' for ClickHouse.")) end - - elseif backend == :DuckDB || backend == :duckdb - set_sql_mode(:duckdb) - db = DBInterface.connect(DuckDB.DB, ":memory:") - DBInterface.execute(db, "SET autoinstall_known_extensions=1;") - DBInterface.execute(db, "SET autoload_known_extensions=1;") - - # Install and load the httpfs extension - DBInterface.execute(db, "INSTALL httpfs;") - DBInterface.execute(db, "LOAD httpfs;") - return db - else - throw(ArgumentError("Unsupported backend: $backend")) - end end @@ -58,26 +43,12 @@ end -function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) - if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql - final_query = TidierDB.finalize_query(sqlquery) - result = DBInterface.execute(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :clickhouse - final_query = TidierDB.finalize_query(sqlquery) - df_result = ClickHouse.select_df(sqlquery.db, final_query) - selected_columns_order = sqlquery.metadata[sqlquery.metadata.current_selxn .== 1, :name] - df_result = df_result[:, selected_columns_order] - return df_result - elseif TidierDB.current_sql_mode[] == :snowflake - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_snowflake(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :databricks - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_databricks(sqlquery.db, final_query) - return DataFrame(result) - end +function TidierDB.final_collect(sqlquery, ::Type{<:clickhouse}) + final_query = TidierDB.finalize_query(sqlquery) + df_result = ClickHouse.select_df(sqlquery.db, final_query) + selected_columns_order = sqlquery.metadata[sqlquery.metadata.current_selxn .== 1, :name] + df_result = df_result[:, selected_columns_order] + return df_result end end diff --git a/ext/GBQExt.jl b/ext/GBQExt.jl index f3788fd..56d79c9 100644 --- a/ext/GBQExt.jl +++ b/ext/GBQExt.jl @@ -12,10 +12,10 @@ mutable struct GBQ bigquery_method end -function TidierDB.connect(type::Symbol, json_key_path::String, project_id::String) +function TidierDB.connect(::gbq, json_key_path::String, project_id::String) # Expand the user's path to the JSON key creds_path = expanduser(json_key_path) - set_sql_mode(:gbq) + set_sql_mode(gbq()) # Create credentials and session for Google Cloud creds = JSONCredentials(creds_path) session = GoogleSession(creds, ["https://www.googleapis.com/auth/bigquery"]) @@ -105,23 +105,10 @@ function TidierDB.get_table_metadata(conn::GoogleSession{JSONCredentials}, table end -function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) - if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql - final_query = TidierDB.finalize_query(sqlquery) - result = DBInterface.execute(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :gbq - final_query = TidierDB.finalize_query(sqlquery) - return collect_gbq(sqlquery.db, final_query) - elseif TidierDB.current_sql_mode[] == :snowflake - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_snowflake(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :databricks - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_databricks(sqlquery.db, final_query) - return DataFrame(result) - end + +function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:gbq}) + final_query = TidierDB.finalize_query(sqlquery) + return collect_gbq(sqlquery.db, final_query) end end diff --git a/ext/LibPQExt.jl b/ext/LibPQExt.jl index 4ea033a..6c597d9 100644 --- a/ext/LibPQExt.jl +++ b/ext/LibPQExt.jl @@ -5,25 +5,9 @@ using DataFrames using LibPQ __init__() = println("Extension was loaded!") -function TidierDB.connect(backend::Symbol; kwargs...) - if backend == :Postgres || backend == :postgres - set_sql_mode(:postgres) - # Construct a connection string from kwargs for LibPQ +function TidierDB.connect(::postgres; kwargs...) conn_str = join(["$(k)=$(v)" for (k, v) in kwargs], " ") return LibPQ.Connection(conn_str) - elseif backend == :DuckDB || backend == :duckdb - set_sql_mode(:duckdb) - db = DBInterface.connect(DuckDB.DB, ":memory:") - DBInterface.execute(db, "SET autoinstall_known_extensions=1;") - DBInterface.execute(db, "SET autoload_known_extensions=1;") - - # Install and load the httpfs extension - DBInterface.execute(db, "INSTALL httpfs;") - DBInterface.execute(db, "LOAD httpfs;") - return db - else - throw(ArgumentError("Unsupported backend: $backend")) - end end @@ -42,21 +26,10 @@ function TidierDB.get_table_metadata(conn::LibPQ.Connection, table_name::String) end -function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) - if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql || TidierDB.current_sql_mode[] == :mssql || TidierDB.current_sql_mode[] == :mariadb - final_query = TidierDB.finalize_query(sqlquery) - result = DBInterface.execute(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :snowflake - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_snowflake(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :databricks - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_databricks(sqlquery.db, final_query) - return DataFrame(result) - end +function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:postgres}) + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) end - end diff --git a/ext/MySQLExt.jl b/ext/MySQLExt.jl index 2699260..53e899c 100644 --- a/ext/MySQLExt.jl +++ b/ext/MySQLExt.jl @@ -5,10 +5,8 @@ using DataFrames using MySQL __init__() = println("Extension was loaded!") -function TidierDB.connect(backend::Symbol; kwargs...) - if backend == :MySQL || backend == :mysql - set_sql_mode(:mysql) - +function TidierDB.connect(::mysql; kwargs...) + set_sql_mode(mysql()) # Required parameters by MySQL.jl: host and user host = get(kwargs, :host, "localhost") user = get(kwargs, :user, "") @@ -17,20 +15,6 @@ function TidierDB.connect(backend::Symbol; kwargs...) db = get(kwargs, :db, nothing) port = get(kwargs, :port, nothing) return DBInterface.connect(MySQL.Connection, host, user, password; db=db, port=port) - - elseif backend == :DuckDB || backend == :duckdb - set_sql_mode(:duckdb) - db = DBInterface.connect(DuckDB.DB, ":memory:") - DBInterface.execute(db, "SET autoinstall_known_extensions=1;") - DBInterface.execute(db, "SET autoload_known_extensions=1;") - - # Install and load the httpfs extension - DBInterface.execute(db, "INSTALL httpfs;") - DBInterface.execute(db, "LOAD httpfs;") - return db - else - throw(ArgumentError("Unsupported backend: $backend")) - end end @@ -54,22 +38,13 @@ function TidierDB.get_table_metadata(conn::MySQL.Connection, table_name::String) end -function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) - if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql || TidierDB.current_sql_mode[] == :mssql || TidierDB.current_sql_mode[] == :mariadb - final_query = TidierDB.finalize_query(sqlquery) - result = DBInterface.execute(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :snowflake - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_snowflake(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :databricks - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_databricks(sqlquery.db, final_query) - return DataFrame(result) - end +function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:mysql}) + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) end + end diff --git a/ext/ODBCExt.jl b/ext/ODBCExt.jl index ca3ed53..3477b91 100644 --- a/ext/ODBCExt.jl +++ b/ext/ODBCExt.jl @@ -5,27 +5,6 @@ using DataFrames using ODBC __init__() = println("Extension was loaded!") -function TidierDB.connect(backend::Symbol; kwargs...) - if backend == :SQLite || backend == :lite - db_path = get(kwargs, :db, ":memory:") - set_sql_mode(:lite) - return SQLite.DB(db_path) - elseif backend == :DuckDB || backend == :duckdb - set_sql_mode(:duckdb) - db = DBInterface.connect(DuckDB.DB, ":memory:") - DBInterface.execute(db, "SET autoinstall_known_extensions=1;") - DBInterface.execute(db, "SET autoload_known_extensions=1;") - - # Install and load the httpfs extension - DBInterface.execute(db, "INSTALL httpfs;") - DBInterface.execute(db, "LOAD httpfs;") - return db - else - throw(ArgumentError("Unsupported backend: $backend")) - end -end - - # MSSQL function TidierDB.get_table_metadata(conn::ODBC.Connection, table_name::String) @@ -54,21 +33,16 @@ function TidierDB.get_table_metadata(conn::ODBC.Connection, table_name::String) end -function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) - if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql || TidierDB.current_sql_mode[] == :mssql || TidierDB.current_sql_mode[] == :oracle - final_query = TidierDB.finalize_query(sqlquery) - result = DBInterface.execute(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :snowflake - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_snowflake(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :databricks - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_databricks(sqlquery.db, final_query) - return DataFrame(result) - end +function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:mssql}) + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) end +function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:oracle}) + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) +end end diff --git a/ext/SQLiteExt.jl b/ext/SQLiteExt.jl index 2486947..1bbe699 100644 --- a/ext/SQLiteExt.jl +++ b/ext/SQLiteExt.jl @@ -5,24 +5,10 @@ using DataFrames using SQLite __init__() = println("Extension was loaded!") -function TidierDB.connect(backend::Symbol; kwargs...) - if backend == :SQLite || backend == :lite +function TidierDB.connect(::sqlite; kwargs...) db_path = get(kwargs, :db, ":memory:") - set_sql_mode(:lite) + set_sql_mode(sqlite()) return SQLite.DB(db_path) - elseif backend == :DuckDB || backend == :duckdb - set_sql_mode(:duckdb) - db = DBInterface.connect(DuckDB.DB, ":memory:") - DBInterface.execute(db, "SET autoinstall_known_extensions=1;") - DBInterface.execute(db, "SET autoload_known_extensions=1;") - - # Install and load the httpfs extension - DBInterface.execute(db, "INSTALL httpfs;") - DBInterface.execute(db, "LOAD httpfs;") - return db - else - throw(ArgumentError("Unsupported backend: $backend")) - end end @@ -43,25 +29,11 @@ end # In SQLiteExt.jl -function TidierDB.final_collect(sqlquery::TidierDB.SQLQuery) - if TidierDB.current_sql_mode[] == :duckdb || TidierDB.current_sql_mode[] == :lite || TidierDB.current_sql_mode[] == :postgres || TidierDB.current_sql_mode[] == :mysql || TidierDB.current_sql_mode[] == :mssql - final_query = TidierDB.finalize_query(sqlquery) - result = DBInterface.execute(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :snowflake - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_snowflake(sqlquery.db, final_query) - return DataFrame(result) - elseif TidierDB.current_sql_mode[] == :databricks - final_query = TidierDB.finalize_query(sqlquery) - result = TidierDB.execute_databricks(sqlquery.db, final_query) - return DataFrame(result) - end +function TidierDB.final_collect(sqlquery::SQLQuery, ::Type{<:sqlite}) + final_query = TidierDB.finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) end -# In DuckDBExt.jl - - - end diff --git a/src/TBD_macros.jl b/src/TBD_macros.jl index b191895..df9cf2a 100644 --- a/src/TBD_macros.jl +++ b/src/TBD_macros.jl @@ -158,7 +158,7 @@ end function process_mutate_expression(expr, sq, select_expressions, cte_name) if isa(expr, Expr) && expr.head == :(=) && isa(expr.args[1], Symbol) col_name = string(expr.args[1]) - if current_sql_mode[] == :snowflake + if current_sql_mode[] == snowflake() col_name = uppercase(col_name) end col_expr = expr_to_sql(expr.args[2], sq) # Convert to SQL expression @@ -373,7 +373,7 @@ function process_summary_expression(expr, sq, summary_str) summary_operation = string(summary_operation) summary_column = expr_to_sql(expr.args[1], sq, from_summarize = true) summary_column = string(summary_column) - if current_sql_mode[] == :snowflake + if current_sql_mode[] == snowflake() summary_column = uppercase(summary_column) end push!(sq.metadata, Dict("name" => summary_column, "type" => "UNKNOWN", "current_selxn" => 1, "table_name" => sq.from)) @@ -653,25 +653,53 @@ macro show_query(sqlquery) end end -function final_collect(sqlquery::TidierDB.SQLQuery) - if current_sql_mode[] ==:duckdb - final_query = TidierDB.finalize_query(sqlquery) - result = DBInterface.execute(sqlquery.db, final_query) - return DataFrame(result) - elseif current_sql_mode[] == :snowflake - final_query = TidierDB.finalize_query(sqlquery) - result = execute_snowflake(sqlquery.db, final_query) - return DataFrame(result) - elseif current_sql_mode[] == :databricks - final_query = TidierDB.finalize_query(sqlquery) - result = execute_databricks(sqlquery.db, final_query) - return DataFrame(result) - end + + +function final_collect(sqlquery::SQLQuery, ::Type{<:duckdb}) + final_query = finalize_query(sqlquery) + result = DBInterface.execute(sqlquery.db, final_query) + return DataFrame(result) end +function final_collect(sqlquery::SQLQuery, ::Type{<:databricks}) + final_query = finalize_query(sqlquery) + result = execute_databricks(sqlquery.db, final_query) + return DataFrame(result) +end + +function final_collect(sqlquery::SQLQuery, ::Type{<:snowflake}) + final_query = finalize_query(sqlquery) + result = execute_snowflake(sqlquery.db, final_query) + return DataFrame(result) +end macro collect(sqlquery) return quote - final_collect($(esc(sqlquery))) + backend = current_sql_mode[] + if backend == duckdb() + final_collect($(esc(sqlquery)), duckdb) + elseif backend == clickhouse() + final_collect($(esc(sqlquery)), clickhouse) + elseif backend == sqlite() + final_collect($(esc(sqlquery)), sqlite) + elseif backend == mysql() + final_collect($(esc(sqlquery)), mysql) + elseif backend == mssql() + final_collect($(esc(sqlquery)), mssql) + elseif backend == postgres() + final_collect($(esc(sqlquery)), postgres) + elseif backend == athena() + final_collect($(esc(sqlquery)), athena) + elseif backend == snowflake() + final_collect($(esc(sqlquery)), snowflake) + elseif backend == gbq() + final_collect($(esc(sqlquery)), gbq) + elseif backend == oracle() + final_collect($(esc(sqlquery)), oracle) + elseif backend == databricks() + final_collect($(esc(sqlquery)), databricks) + else + throw(ArgumentError("Unsupported SQL mode: $backend")) + end end -end \ No newline at end of file +end diff --git a/src/TidierDB.jl b/src/TidierDB.jl index d658586..3c0fc1b 100644 --- a/src/TidierDB.jl +++ b/src/TidierDB.jl @@ -1,18 +1,11 @@ module TidierDB -#using LibPQ using DataFrames using MacroTools using Chain -#using SQLite using Reexport using DuckDB -#using MySQL -#using ODBC -#import ClickHouse using Arrow -#using AWS -#using GoogleCloud using HTTP using JSON3 using GZip @@ -21,12 +14,33 @@ using GZip @reexport using Chain @reexport using DuckDB -#using TidierDB export db_table, set_sql_mode, @arrange, @group_by, @filter, @select, @mutate, @summarize, @summarise, @distinct, @left_join, @right_join, @inner_join, @count, @window_order, @window_frame, @show_query, @collect, @slice_max, @slice_min, @slice_sample, @rename, copy_to, duckdb_open, duckdb_connect, @semi_join, @full_join, - @anti_join, connect, from_query, @interpolate, add_interp_parameter!, update_con + @anti_join, connect, from_query, @interpolate, add_interp_parameter!, update_con, + clickhouse, duckdb, sqlite, mysql, mssql, postgres, athena, snowflake, gbq, oracle, databricks, SQLQuery + + abstract type SQLBackend end + + struct clickhouse <: SQLBackend end + struct duckdb <: SQLBackend end + struct sqlite <: SQLBackend end + struct mysql <: SQLBackend end + struct mssql <: SQLBackend end + struct postgres <: SQLBackend end + struct athena <: SQLBackend end + struct snowflake <: SQLBackend end + struct gbq <: SQLBackend end + struct oracle <: SQLBackend end + struct databricks <: SQLBackend end + + current_sql_mode = Ref{SQLBackend}(duckdb()) + + function set_sql_mode(mode::SQLBackend) + current_sql_mode[] = mode + end + include("docstrings.jl") include("structs.jl") @@ -47,36 +61,31 @@ include("joins_sq.jl") include("slices_sq.jl") -current_sql_mode = Ref(:duckdb) -# Function to switch modes -function set_sql_mode(mode::Symbol) - current_sql_mode[] = mode -end # Unified expr_to_sql function to use right mode function expr_to_sql(expr, sq; from_summarize::Bool = false) - if current_sql_mode[] == :lite + if current_sql_mode[] == sqlite() return expr_to_sql_lite(expr, sq, from_summarize=from_summarize) - elseif current_sql_mode[] == :postgres + elseif current_sql_mode[] == postgres() return expr_to_sql_postgres(expr, sq; from_summarize=from_summarize) - elseif current_sql_mode[] == :duckdb + elseif current_sql_mode[] == duckdb() return expr_to_sql_duckdb(expr, sq; from_summarize=from_summarize) - elseif current_sql_mode[] == :mysql + elseif current_sql_mode[] == mysql() return expr_to_sql_mysql(expr, sq; from_summarize=from_summarize) - elseif current_sql_mode[] == :mssql + elseif current_sql_mode[] == mssql() return expr_to_sql_mssql(expr, sq; from_summarize=from_summarize) - elseif current_sql_mode[] == :clickhouse + elseif current_sql_mode[] == clickhouse() return expr_to_sql_clickhouse(expr, sq; from_summarize=from_summarize) - elseif current_sql_mode[] == :athena + elseif current_sql_mode[] == athena() return expr_to_sql_trino(expr, sq; from_summarize=from_summarize) - elseif current_sql_mode[] == :gbq + elseif current_sql_mode[] == gbq() return expr_to_sql_gbq(expr, sq; from_summarize=from_summarize) - elseif current_sql_mode[] == :oracle + elseif current_sql_mode[] == oracle() return expr_to_sql_oracle(expr, sq; from_summarize=from_summarize) - elseif current_sql_mode[] == :snowflake + elseif current_sql_mode[] == snowflake() return expr_to_sql_snowflake(expr, sq; from_summarize=from_summarize) - elseif current_sql_mode[] == :databricks + elseif current_sql_mode[] == databricks() return expr_to_sql_duckdb(expr, sq; from_summarize=from_summarize) else error("Unsupported SQL mode: $(current_sql_mode[])") @@ -141,7 +150,7 @@ function finalize_query(sqlquery::SQLQuery) "FROM )" => ")" , "SELECT SELECT " => "SELECT ", "SELECT SELECT " => "SELECT ", "DISTINCT SELECT " => "DISTINCT ", "SELECT SELECT SELECT " => "SELECT ", "PARTITION BY GROUP BY" => "PARTITION BY", "GROUP BY GROUP BY" => "GROUP BY", "HAVING HAVING" => "HAVING", ) - if current_sql_mode[] == :postgres || current_sql_mode[] == :duckdb || current_sql_mode[] == :mysql || current_sql_mode[] == :mssql || current_sql_mode[] == :clickhouse || current_sql_mode[] == :athena || current_sql_mode[] == :gbq || current_sql_mode[] == :oracle || current_sql_mode[] == :snowflake + if current_sql_mode[] == postgres() || current_sql_mode[] == duckdb() || current_sql_mode[] == mysql() || current_sql_mode[] == mssql() || current_sql_mode[] == clickhouse() || current_sql_mode[] == athena() || current_sql_mode[] == gbq() || current_sql_mode[] == oracle() || current_sql_mode[] == snowflake() || current_sql_mode[] == databricks() complete_query = replace(complete_query, "\"" => "'", "==" => "=") end @@ -184,9 +193,9 @@ $docstring_db_table function db_table(db, table, athena_params::Any=nothing; iceberg::Bool=false, delta::Bool=false) table_name = string(table) - if current_sql_mode[] == :lite + if current_sql_mode[] == sqlite() metadata = get_table_metadata(db, table_name) - elseif current_sql_mode[] == :postgres ||current_sql_mode[] == :duckdb || current_sql_mode[] == :mysql || current_sql_mode[] == :mssql || current_sql_mode[] == :clickhouse || current_sql_mode[] == :gbq ||current_sql_mode[] == :oracle + elseif current_sql_mode[] == postgres() ||current_sql_mode[] == duckdb() || current_sql_mode[] == mysql() || current_sql_mode[] == mssql() || current_sql_mode[] == clickhouse() || current_sql_mode[] == gbq() ||current_sql_mode[] == oracle() if iceberg DBInterface.execute(db, "INSTALL iceberg;") DBInterface.execute(db, "LOAD iceberg;") @@ -204,17 +213,17 @@ function db_table(db, table, athena_params::Any=nothing; iceberg::Bool=false, de else metadata = get_table_metadata(db, table_name) end - elseif current_sql_mode[] == :athena + elseif current_sql_mode[] == athena() metadata = get_table_metadata(db, table_name, athena_params) - elseif current_sql_mode[] == :snowflake || current_sql_mode[] == :databricks + elseif current_sql_mode[] == snowflake() || current_sql_mode[] == databricks() metadata = get_table_metadata(db, table_name) else error("Unsupported SQL mode: $(current_sql_mode[])") end - formatted_table_name = if current_sql_mode[] == :snowflake + formatted_table_name = if current_sql_mode[] == snowflake() "$(db.database).$(db.schema).$table_name" - elseif db isa DatabricksConnection || current_sql_mode[] == :databricks + elseif db isa DatabricksConnection || current_sql_mode[] == databricks() "$(db.database).$(db.schema).$table_name" elseif iceberg "iceberg_scan('$table_name', allow_moved_paths = true)" @@ -235,12 +244,12 @@ $docstring_copy_to function copy_to(conn, df_or_path::Union{DataFrame, AbstractString}, name::String) # Check if the input is a DataFrame if isa(df_or_path, DataFrame) - if current_sql_mode[] == :duckdb + if current_sql_mode[] == duckdb() DuckDB.register_data_frame(conn, df_or_path, name) end # If the input is not a DataFrame, treat it as a file path elseif isa(df_or_path, AbstractString) - if current_sql_mode[] != :duckdb + if current_sql_mode[] != duckdb() error("Direct file loading is only supported for DuckDB in this implementation.") end # Determine the file type based on the extension @@ -279,40 +288,33 @@ end """ $docstring_connect """ -function connect(backend::Symbol; kwargs...) - if backend == :DuckDB || backend == :duckdb - set_sql_mode(:duckdb) - db = DBInterface.connect(DuckDB.DB, ":memory:") - DBInterface.execute(db, "SET autoinstall_known_extensions=1;") - DBInterface.execute(db, "SET autoload_known_extensions=1;") +function connect(::duckdb; kwargs...) + set_sql_mode(duckdb()) + db = DBInterface.connect(DuckDB.DB, ":memory:") + DBInterface.execute(db, "SET autoinstall_known_extensions=1;") + DBInterface.execute(db, "SET autoload_known_extensions=1;") - # Install and load the httpfs extension - DBInterface.execute(db, "INSTALL httpfs;") - DBInterface.execute(db, "LOAD httpfs;") - return db - else - throw(ArgumentError("Unsupported backend: $backend")) - end + # Install and load the httpfs extension + DBInterface.execute(db, "INSTALL httpfs;") + DBInterface.execute(db, "LOAD httpfs;") + return db end -function connect(backend::Symbol, identifier::String, auth_token::String, database::String, schema::String, warehouse::String) - if backend == :snowflake - # Snowflake specific settings - set_sql_mode(:snowflake) + +function connect(::snowflake, identifier::String, auth_token::String, database::String, schema::String, warehouse::String) + set_sql_mode(snowflake()) api_url = "https://$identifier.snowflakecomputing.com/api/v2/statements" return SnowflakeConnection(identifier, auth_token, database, schema, warehouse, api_url) - elseif backend == :databricks - # Databricks specific settings - # Remove any leading slash from workspace_id +end + +function connect(::databricks, identifier::String, auth_token::String, database::String, schema::String, warehouse::String) + set_sql_mode(databricks()) identifier = lstrip(identifier, '/') api_url = "https://$(identifier).cloud.databricks.com/api/2.0/sql/statements" return DatabricksConnection(identifier, auth_token, database, schema, warehouse, api_url) - else - error("Unsupported backend type: $backend") - end end -function connect(backend_type::Symbol, db_type::Symbol; access_key::String="", secret_key::String="", aws_access_key_id::String="", aws_secret_access_key::String="", aws_region::String="") +function connect(::duckdb, db_type::Symbol; access_key::String="", secret_key::String="", aws_access_key_id::String="", aws_secret_access_key::String="", aws_region::String="") # Connect to the DuckDB database mem = DuckDB.open(":memory:") db = DuckDB.connect(mem) @@ -342,7 +344,7 @@ function connect(backend_type::Symbol, db_type::Symbol; access_key::String="", s return db end -function connect(symbol, token::String) +function connect(::duckdb, token::String) if token == "md:" return DBInterface.connect(DuckDB.DB, "md:") else diff --git a/src/db_parsing.jl b/src/db_parsing.jl index 0a7d6ed..758bb8b 100644 --- a/src/db_parsing.jl +++ b/src/db_parsing.jl @@ -40,11 +40,11 @@ function parse_tidy_db(exprs, metadata::DataFrame) if actual_expr.args[1] == :(:) # Handle range expression start_col = string(actual_expr.args[2]) - if current_sql_mode[] == :snowflake + if current_sql_mode[] == snowflake() start_col = uppercase(start_col) end end_col = string(actual_expr.args[3]) - if current_sql_mode[] == :snowflake + if current_sql_mode[] == snowflake() end_col = uppercase(end_col) end start_idx = findfirst(==(start_col), all_columns) @@ -61,7 +61,7 @@ function parse_tidy_db(exprs, metadata::DataFrame) elseif actual_expr.args[1] == :starts_with || actual_expr.args[1] == :ends_with || actual_expr.args[1] == :contains # Handle starts_with, ends_with, and contains substring = actual_expr.args[2] - if current_sql_mode[] == :snowflake + if current_sql_mode[] == snowflake() substring = uppercase(substring) end match_columns = filter(col -> @@ -85,7 +85,7 @@ function parse_tidy_db(exprs, metadata::DataFrame) end col_name = isa(actual_expr, Symbol) ? string(actual_expr) : actual_expr - if current_sql_mode[] == :snowflake + if current_sql_mode[] == snowflake() col_name = uppercase(col_name) end if is_excluded diff --git a/src/docstrings.jl b/src/docstrings.jl index f143d92..2b9f6b7 100644 --- a/src/docstrings.jl +++ b/src/docstrings.jl @@ -16,7 +16,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -81,7 +81,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -135,7 +135,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -169,7 +169,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -210,7 +210,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -257,7 +257,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -307,7 +307,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -348,7 +348,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -387,7 +387,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -421,7 +421,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -463,7 +463,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -499,7 +499,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -566,7 +566,7 @@ julia> df2 = DataFrame(id2 = ["AA", "AC", "AE", "AG", "AI", "AK", "AM"], category = ["X", "Y", "X", "Y", "X", "Y", "X"], score = [88, 92, 77, 83, 95, 68, 74]); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -619,7 +619,7 @@ julia> df2 = DataFrame(id2 = ["AA", "AC", "AE", "AG", "AI", "AK", "AM"], category = ["X", "Y", "X", "Y", "X", "Y", "X"], score = [88, 92, 77, 83, 95, 68, 74]); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -669,7 +669,7 @@ julia> df2 = DataFrame(id2 = ["AA", "AC", "AE", "AG", "AI", "AK", "AM"], category = ["X", "Y", "X", "Y", "X", "Y", "X"], score = [88, 92, 77, 83, 95, 68, 74]); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -716,7 +716,7 @@ julia> df2 = DataFrame(id2 = ["AA", "AC", "AE", "AG", "AI", "AK", "AM"], category = ["X", "Y", "X", "Y", "X", "Y", "X"], score = [88, 92, 77, 83, 95, 68, 74]); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -771,7 +771,7 @@ julia> df2 = DataFrame(id2 = ["AA", "AC", "AE", "AG", "AI", "AK", "AM"], category = ["X", "Y", "X", "Y", "X", "Y", "X"], score = [88, 92, 77, 83, 95, 68, 74]); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -819,7 +819,7 @@ julia> df2 = DataFrame(id2 = ["AA", "AC", "AE", "AG", "AI", "AK", "AM"], category = ["X", "Y", "X", "Y", "X", "Y", "X"], score = [88, 92, 77, 83, 95, 68, 74]); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -859,7 +859,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -900,7 +900,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "test"); ``` @@ -922,7 +922,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); ``` @@ -946,7 +946,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); ``` @@ -954,13 +954,13 @@ julia> copy_to(db, df, "df_mem"); const docstring_connect = """ - connect(backend::Symbol; kwargs...) + connect(backend; kwargs...) This function establishes a database connection based on the specified backend and connection parameters and sets the SQL mode # Arguments -- `backend`: A symbol specifying the database backend to connect to. Supported backends are: - - `:duckdb`, `:lite`(SQLite), `:mssql`, `mysql`(for MariaDB and MySQL), `:clickhouse`, `:postgres` +- `backend`: type specifying the database backend to connect to. Supported backends are: + - `duckdb()`, `sqlite()`(SQLite), `mssql()`, `mysql()`(for MariaDB and MySQL), `clickhouse()`, `postgres()` - `kwargs`: Keyword arguments specifying the connection parameters for the selected backend. The required parameters vary depending on the backend: - MySQL: - `host`: The host name or IP address of the MySQL server. Default is "localhost". @@ -975,25 +975,25 @@ This function establishes a database connection based on the specified backend a # Examples ```julia # Connect to MySQL -# conn = connect(:mysql; host="localhost", user="root", password="password", db="mydb") +# conn = connect(mysql(); host="localhost", user="root", password="password", db="mydb") # Connect to PostgreSQL using LibPQ -# conn = connect(:postgres; host="localhost", dbname="mydb", user="postgres", password="password") +# conn = connect(postgres(); host="localhost", dbname="mydb", user="postgres", password="password") # Connect to ClickHouse -# conn = connect(:clickhouse; host="localhost", port=9000, database="mydb", user="default", password="") +# conn = connect(clickhouse(); host="localhost", port=9000, database="mydb", user="default", password="") # Connect to SQLite -# conn = connect(:lite) +# conn = connect(sqlite()) # Connect to Google Big Query -# conn = connect(:gbq, "json_user_key_path", "project_id") +# conn = connect(gbq(), "json_user_key_path", "project_id") # Connect to Snowflake -# conn = connect(:snowflake, "ac_id", "token", "Database_name", "Schema_name", "warehouse_name") +# conn = connect(snowflake(), "ac_id", "token", "Database_name", "Schema_name", "warehouse_name") # Connect to DuckDB # connect to Google Cloud via DuckDB -# google_db = connect(:duckdb, :gbq, access_key="string", secret_key="string") +# google_db = connect(duckdb(), :gbq, access_key="string", secret_key="string") # Connect to AWS via DuckDB -# aws_db = connect2(:duckdb, :aws, aws_access_key_id=get(ENV, "AWS_ACCESS_KEY_ID", "access_key"), aws_secret_access_key=get(ENV, "AWS_SECRET_ACCESS_KEY", "secret_access key"), aws_region=get(ENV, "AWS_DEFAULT_REGION", "us-east-1")) +# aws_db = connect2(duckdb(), :aws, aws_access_key_id=get(ENV, "AWS_ACCESS_KEY_ID", "access_key"), aws_secret_access_key=get(ENV, "AWS_SECRET_ACCESS_KEY", "secret_access key"), aws_region=get(ENV, "AWS_DEFAULT_REGION", "us-east-1")) # Connect to MotherDuck -# connect(:duckdb, "token") for first connection, vs connect(:duckdb, "md:") for reconnection -julia> db = connect(:duckdb) +# connect(duckdb(), "token") for first connection, vs connect(:duckdb, "md:") for reconnection +julia> db = connect(duckdb()) DuckDB.Connection(":memory:") ``` """ @@ -1012,7 +1012,7 @@ Interpolate parameters into expressions for database queries. # Example ```julia -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); @@ -1071,7 +1071,7 @@ julia> df = DataFrame(id = [string('A' + i ÷ 26, 'A' + i % 26) for i in 0:9], value = repeat(1:5, 2), percent = 0.1:0.1:1.0); -julia> db = connect(:duckdb); +julia> db = connect(duckdb()); julia> copy_to(db, df, "df_mem"); diff --git a/src/joins_sq.jl b/src/joins_sq.jl index 7ff2f2f..91829d3 100644 --- a/src/joins_sq.jl +++ b/src/joins_sq.jl @@ -1,7 +1,7 @@ function gbq_join_parse(input) input = string(input) parts = split(input, ".") - if current_sql_mode[] == :gbq && length(parts) >=2 + if current_sql_mode[] == gbq() && length(parts) >=2 return join(parts[2:end], ".") else return input diff --git a/src/structs.jl b/src/structs.jl index bb1b4a5..1cc0c50 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -5,12 +5,6 @@ mutable struct CTE where::String groupBy::String having::String - # Additional fields as necessary - - # Default constructor - #CTE() = new("", "", "", "", "", "") - - # Custom constructor accepting keyword arguments function CTE(;name::String="", select::String="", from::String="", where::String="", groupBy::String="", having::String="") new(name, select, from, where, groupBy, having) end From fb108ce17fc0a3192db0213c94f9bf31c36f8956 Mon Sep 17 00:00:00 2001 From: drizk1 Date: Thu, 25 Jul 2024 10:19:13 -0400 Subject: [PATCH 5/6] fix duckdb s3 connection, bump version, update news --- NEWS.md | 3 ++- Project.toml | 2 +- src/TidierDB.jl | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/NEWS.md b/NEWS.md index 0059c81..36d9784 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,9 +1,10 @@ # TidierDB.jl updates ## v0.3. - 2024-07-25 -- Introduces package extensions for +- Introduces package extensions for: - Postgres, ClickHouse, MySQL, MsSQL, SQLite, Oracle, Athena, and Google BigQuery - (Documentation)[https://tidierorg.github.io/TidierDB.jl/latest/examples/generated/UserGuide/getting_started/] updated for using these backends. +- Change `set_sql_mode()` to use types not symbols (ie `set_sql_mode(snowflake())` not `set_sql_mode(:snowflake)`) ## v0.2.4 - 2024-07-12 - Switches to DuckDB to 1.0 version diff --git a/Project.toml b/Project.toml index 327b94c..556feb0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TidierDB" uuid = "86993f9b-bbba-4084-97c5-ee15961ad48b" authors = ["Daniel Rizk and contributors"] -version = "0.2.4" +version = "0.3.0" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" diff --git a/src/TidierDB.jl b/src/TidierDB.jl index 3c0fc1b..7423f56 100644 --- a/src/TidierDB.jl +++ b/src/TidierDB.jl @@ -316,8 +316,7 @@ end function connect(::duckdb, db_type::Symbol; access_key::String="", secret_key::String="", aws_access_key_id::String="", aws_secret_access_key::String="", aws_region::String="") # Connect to the DuckDB database - mem = DuckDB.open(":memory:") - db = DuckDB.connect(mem) + db = DBInterface.connect(DuckDB.DB, ":memory:") # Enable auto-install and auto-load of known extensions DBInterface.execute(db, "SET autoinstall_known_extensions=1;") From ba209a2cfbcb812216b4949ab321238186a0b3f5 Mon Sep 17 00:00:00 2001 From: drizk1 Date: Thu, 25 Jul 2024 10:26:32 -0400 Subject: [PATCH 6/6] couple api changes missed in examples --- README.md | 2 +- docs/examples/UserGuide/getting_started.jl | 2 +- docs/src/index.md | 2 +- src/docstrings.jl | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 582066f..5ddba6c 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ Even though the code reads similarly to TidierData, note that no computational w using TidierData import TidierDB as DB -db = DB.connect(:duckdb); +db = DB.connect(duckdb()); path_or_name = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv" @chain DB.db_table(db, path_or_name) begin diff --git a/docs/examples/UserGuide/getting_started.jl b/docs/examples/UserGuide/getting_started.jl index 5397feb..0d2299a 100644 --- a/docs/examples/UserGuide/getting_started.jl +++ b/docs/examples/UserGuide/getting_started.jl @@ -27,7 +27,7 @@ # - ClickHouse: `using ClickHouse` # - MySQL and MariaDB: `using MySQL` # - MSSQL: `using ODBC` -# - Postgres: `using LibPQ`` +# - Postgres: `using LibPQ` # - SQLite: `using SQLite` # - Athena: `using AWS` # - Oracle: `using ODBC` diff --git a/docs/src/index.md b/docs/src/index.md index 54c4ad4..c97d151 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -89,7 +89,7 @@ Even though the code reads similarly to TidierData, note that no computational w using TidierData import TidierDB as DB -db = DB.connect(:duckdb); +db = DB.connect(duckdb()); path_or_name = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv" @chain DB.db_table(db, path_or_name) begin diff --git a/src/docstrings.jl b/src/docstrings.jl index 2b9f6b7..3f60071 100644 --- a/src/docstrings.jl +++ b/src/docstrings.jl @@ -992,7 +992,7 @@ This function establishes a database connection based on the specified backend a # Connect to AWS via DuckDB # aws_db = connect2(duckdb(), :aws, aws_access_key_id=get(ENV, "AWS_ACCESS_KEY_ID", "access_key"), aws_secret_access_key=get(ENV, "AWS_SECRET_ACCESS_KEY", "secret_access key"), aws_region=get(ENV, "AWS_DEFAULT_REGION", "us-east-1")) # Connect to MotherDuck -# connect(duckdb(), "token") for first connection, vs connect(:duckdb, "md:") for reconnection +# connect(duckdb(), "token") for first connection, vs connect(duckdb(), "md:") for reconnection julia> db = connect(duckdb()) DuckDB.Connection(":memory:") ```