Skip to content

Commit

Permalink
Closes #927, #928: Schema refresh improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
Marina Samuel committed Apr 4, 2019
1 parent 99a49b1 commit fd87610
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 92 deletions.
13 changes: 6 additions & 7 deletions redash/handlers/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from redash.handlers.base import BaseResource, get_object_or_404, require_fields
from redash.permissions import (require_access, require_admin,
require_permission, view_only)
from redash.tasks.queries import refresh_schemas
from redash.tasks.queries import refresh_schema
from redash.query_runner import (get_configuration_schema_for_query_runner_type,
query_runners, NotSupported)
from redash.utils import filter_none
Expand Down Expand Up @@ -54,7 +54,7 @@ def post(self, data_source_id):
models.db.session.add(data_source)

# Refresh the stored schemas when a data source is updated
refresh_schemas.apply_async(queue=settings.SCHEMAS_REFRESH_QUEUE)
refresh_schema.apply_async(args=(data_source.id,), queue=settings.SCHEMAS_REFRESH_QUEUE)

try:
models.db.session.commit()
Expand Down Expand Up @@ -133,7 +133,7 @@ def post(self):
models.db.session.commit()

# Refresh the stored schemas when a new data source is added to the list
refresh_schemas.apply_async(queue=settings.SCHEMAS_REFRESH_QUEUE)
refresh_schema.apply_async(args=(datasource.id,), queue=settings.SCHEMAS_REFRESH_QUEUE)
except IntegrityError as e:
models.db.session.rollback()
if req['name'] in e.message:
Expand All @@ -158,10 +158,9 @@ def get(self, data_source_id):

response = {}
try:
current_schema = data_source.get_schema()
if refresh or len(current_schema) == 0:
refresh_schemas.apply(queue=settings.SCHEMAS_REFRESH_QUEUE)
response['schema'] = current_schema
if refresh:
refresh_schema.apply_async(args=(data_source.id,), queue=settings.SCHEMAS_REFRESH_QUEUE)
response['schema'] = data_source.get_schema()
except NotSupported:
response['error'] = {
'code': 1,
Expand Down
2 changes: 1 addition & 1 deletion redash/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class TableMetadata(TimestampMixin, db.Model):
__tablename__ = 'table_metadata'

def __str__(self):
return text_type(self.table_name)
return text_type(self.name)

def to_dict(self):
return {
Expand Down
3 changes: 0 additions & 3 deletions redash/query_runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,6 @@ def _run_query_internal(self, query):
return json_loads(results)['rows']

def get_table_sample(self, table_name):
if not self.configuration.get('samples', False):
return {}

if self.data_sample_query is None:
raise NotImplementedError()

Expand Down
2 changes: 1 addition & 1 deletion redash/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def all_settings():
QUERY_RESULTS_CLEANUP_COUNT = int(os.environ.get("REDASH_QUERY_RESULTS_CLEANUP_COUNT", "100"))
QUERY_RESULTS_CLEANUP_MAX_AGE = int(os.environ.get("REDASH_QUERY_RESULTS_CLEANUP_MAX_AGE", "7"))

SCHEMAS_REFRESH_SCHEDULE = int(os.environ.get("REDASH_SCHEMAS_REFRESH_SCHEDULE", 30))
SCHEMAS_REFRESH_SCHEDULE = int(os.environ.get("REDASH_SCHEMAS_REFRESH_SCHEDULE", 360))
SCHEMAS_REFRESH_QUEUE = os.environ.get("REDASH_SCHEMAS_REFRESH_QUEUE", "celery")

AUTH_TYPE = os.environ.get("REDASH_AUTH_TYPE", "api_key")
Expand Down
210 changes: 138 additions & 72 deletions redash/tasks/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,31 +232,42 @@ def cleanup_query_results():
models.db.session.commit()
logger.info("Deleted %d unused query results.", deleted_count)

def truncate_long_string(original_str, max_length):
new_str = original_str
if original_str and len(original_str) > max_length:
new_str = u'{}...'.format(original_str[:max_length])
return new_str

@celery.task(name="redash.tasks.get_table_sample_data")
def get_table_sample_data(data_source_id, table, table_id):
def get_table_sample_data(existing_columns, data_source_id, table_name, table_id):
ds = models.DataSource.get_by_id(data_source_id)
sample = ds.query_runner.get_table_sample(table['name'])
sample = ds.query_runner.get_table_sample(table_name)
if not sample:
return

# If a column exists, add a sample to it.
for i, column in enumerate(table['columns']):
persisted_column = ColumnMetadata.query.filter(
ColumnMetadata.name == column,
ColumnMetadata.table_id == table_id,
).options(load_only('id')).first()

if persisted_column:
column_example = str(sample.get(column, None))
if column_example and len(column_example) > 4000:
column_example = u'{}...'.format(column_example[:4000])

ColumnMetadata.query.filter(
ColumnMetadata.id == persisted_column.id,
).update({
'example': column_example,
})
persisted_columns = ColumnMetadata.query.filter(
ColumnMetadata.name.in_(existing_columns),
ColumnMetadata.table_id == table_id,
).options(load_only('id')).all()

# If a column exists, add a sample to it.
column_examples = []
for persisted_column in persisted_columns:
column_example = sample.get(persisted_column.name, None)
column_example = column_example if isinstance(column_example, unicode) else (
str(column_example).decode("utf-8", errors="replace").strip()
)
column_example = truncate_long_string(column_example, 4000)

column_examples.append({
"id": persisted_column.id,
"example": column_example
})

models.db.session.bulk_update_mappings(
ColumnMetadata,
column_examples
)
models.db.session.commit()

def cleanup_data_in_table(table_model):
Expand All @@ -280,90 +291,145 @@ def cleanup_schema_metadata():
cleanup_data_in_table(TableMetadata)
cleanup_data_in_table(ColumnMetadata)

@celery.task(name="redash.tasks.refresh_schema", time_limit=90, soft_time_limit=60)
def insert_or_update_table_metadata(ds, existing_tables_set, table_data):
# Update all persisted tables that exist to reflect this.
persisted_tables = TableMetadata.query.filter(
TableMetadata.name.in_(tuple(existing_tables_set)),
TableMetadata.data_source_id == ds.id,
)
persisted_tables.update({"exists": True}, synchronize_session='fetch')


# Find the tables that need to be created by subtracting the sets:
# existing_table_set - persisted table_set
persisted_table_set = set([
persisted_table.name for persisted_table in persisted_tables.all()
])

tables_to_create = existing_tables_set.difference(persisted_table_set)
table_metadata = [table_data[table_name] for table_name in list(tables_to_create)]

models.db.session.bulk_insert_mappings(
TableMetadata,
table_metadata
)

def insert_or_update_column_metadata(table, existing_columns_set, column_data):
persisted_columns = ColumnMetadata.query.filter(
ColumnMetadata.name.in_(tuple(existing_columns_set)),
ColumnMetadata.table_id == table.id,
).all()

persisted_column_data = []
for persisted_column in persisted_columns:
# Add id's to persisted column data so it can be used for updates.
column_data[persisted_column.name]['id'] = persisted_column.id
persisted_column_data.append(column_data[persisted_column.name])

models.db.session.bulk_update_mappings(
ColumnMetadata,
persisted_column_data
)
persisted_column_set = set([col_data['name'] for col_data in persisted_column_data])
columns_to_create = existing_columns_set.difference(persisted_column_set)

column_metadata = [column_data[col_name] for col_name in list(columns_to_create)]

models.db.session.bulk_insert_mappings(
ColumnMetadata,
column_metadata
)

@celery.task(name="redash.tasks.refresh_schema", time_limit=600, soft_time_limit=300)
def refresh_schema(data_source_id):
ds = models.DataSource.get_by_id(data_source_id)
logger.info(u"task=refresh_schema state=start ds_id=%s", ds.id)
start_time = time.time()

MAX_TYPE_STRING_LENGTH = 250
try:
existing_tables = set()
schema = ds.query_runner.get_schema(get_stats=True)

# Stores data from the updated schema that tells us which
# columns and which tables currently exist
existing_tables_set = set()
existing_columns_set = set()

# Stores data that will be inserted into postgres
table_data = {}
column_data = {}

new_column_names = {}
new_column_metadata = {}
for table in schema:
table_name = table['name']
existing_tables.add(table_name)

# Assume that there will only exist 1 table with a given name for a given data source so we use first()
persisted_table = TableMetadata.query.filter(
TableMetadata.name == table_name,
TableMetadata.data_source_id == ds.id,
).first()

if persisted_table:
TableMetadata.query.filter(
TableMetadata.id == persisted_table.id,
).update({"exists": True})
else:
metadata = 'metadata' in table
persisted_table = TableMetadata(
org_id=ds.org_id,
name=table_name,
data_source_id=ds.id,
column_metadata=metadata
)
models.db.session.add(persisted_table)
models.db.session.flush()
existing_tables_set.add(table_name)

metadata = 'metadata' in table
table_data[table_name] = {
"org_id": ds.org_id,
"name": table_name,
"data_source_id": ds.id,
"column_metadata": "metadata" in table
}
new_column_names[table_name] = table['columns']
new_column_metadata[table_name] = table['metadata']

insert_or_update_table_metadata(ds, existing_tables_set, table_data)
models.db.session.flush()

all_existing_persisted_tables = TableMetadata.query.filter(
TableMetadata.exists == True,
TableMetadata.data_source_id == ds.id,
).all()

existing_columns = set()
for i, column in enumerate(table['columns']):
existing_columns.add(column)
column_metadata = {
for j, table in enumerate(all_existing_persisted_tables):
for i, column in enumerate(new_column_names.get(table.name, [])):
existing_columns_set.add(column)
column_data[column] = {
'org_id': ds.org_id,
'table_id': persisted_table.id,
'table_id': table.id,
'name': column,
'type': None,
'example': None,
'exists': True
}
if 'metadata' in table:
column_metadata['type'] = table['metadata'][i]['type']

# If the column exists, update it, otherwise create a new one.
persisted_column = ColumnMetadata.query.filter(
ColumnMetadata.name == column,
ColumnMetadata.table_id == persisted_table.id,
).options(load_only('id')).first()
if persisted_column:
ColumnMetadata.query.filter(
ColumnMetadata.id == persisted_column.id,
).update(column_metadata)
else:
models.db.session.add(ColumnMetadata(**column_metadata))

if table.column_metadata:
column_type = new_column_metadata[table.name][i]['type']
column_type = truncate_long_string(column_type, MAX_TYPE_STRING_LENGTH)
column_data[column]['type'] = column_type

insert_or_update_column_metadata(table, existing_columns_set, column_data)
models.db.session.commit()

get_table_sample_data.apply_async(
args=(data_source_id, table, persisted_table.id),
queue=settings.SCHEMAS_REFRESH_QUEUE
)
if ds.query_runner.configuration.get('samples', False):
get_table_sample_data.apply_async(
args=(tuple(existing_columns_set), ds.id, table.name, table.id),
queue=settings.SCHEMAS_REFRESH_QUEUE
)

# If a column did not exist, set the 'column_exists' flag to false.
existing_columns_list = tuple(existing_columns)
existing_columns_list = tuple(existing_columns_set)
ColumnMetadata.query.filter(
ColumnMetadata.exists == True,
ColumnMetadata.table_id == persisted_table.id,
ColumnMetadata.table_id == table.id,
~ColumnMetadata.name.in_(existing_columns_list),
).update({
"exists": False,
"updated_at": db.func.now()
}, synchronize_session='fetch')

existing_columns_set = set()


# If a table did not exist in the get_schema() response above, set the 'exists' flag to false.
existing_tables_list = tuple(existing_tables)
tables_to_update = TableMetadata.query.filter(
existing_tables_list = tuple(existing_tables_set)
TableMetadata.query.filter(
TableMetadata.exists == True,
TableMetadata.data_source_id == ds.id,
~TableMetadata.name.in_(existing_tables_list)
).update({
~TableMetadata.name.in_(existing_tables_list)
).update({
"exists": False,
"updated_at": db.func.now()
}, synchronize_session='fetch')
Expand Down
16 changes: 8 additions & 8 deletions tests/tasks/test_refresh_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def test_refresh_schema_creates_tables(self):

refresh_schema(self.factory.data_source.id)
get_table_sample_data(
self.factory.data_source.id, {
"name": 'table',
"columns": [self.COLUMN_NAME]
}, 1
[self.COLUMN_NAME],
self.factory.data_source.id,
'table',
1
)
table_metadata = TableMetadata.query.all()
column_metadata = ColumnMetadata.query.all()
Expand Down Expand Up @@ -144,10 +144,10 @@ def test_refresh_schema_update_column(self):

refresh_schema(self.factory.data_source.id)
get_table_sample_data(
self.factory.data_source.id, {
"name": 'table',
"columns": [self.COLUMN_NAME]
}, 1
[self.COLUMN_NAME],
self.factory.data_source.id,
'table',
1
)
column_metadata = ColumnMetadata.query.all()
self.assertEqual(column_metadata[0].to_dict(), self.EXPECTED_COLUMN_METADATA)
Expand Down

0 comments on commit fd87610

Please sign in to comment.