Skip to content

Commit

Permalink
Address second iteration of comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
tneymanov committed Jan 20, 2020
1 parent 5f38e93 commit 059ef86
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 28 deletions.
3 changes: 2 additions & 1 deletion gcp_variant_transforms/options/variant_transform_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def add_arguments(self, parser):
parser.add_argument(
'--num_bigquery_write_shards',
type=int, default=1,
help=('This flag is deprecated and may be removed in future releases.'))
help=('This flag is deprecated and will be removed in future '
'releases.'))
parser.add_argument(
'--null_numeric_value_replacement',
type=int,
Expand Down
16 changes: 13 additions & 3 deletions gcp_variant_transforms/pipeline_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def parse_args(argv, command_line_options):
known_args, pipeline_args = parser.parse_known_args(argv)
for transform_options in options:
transform_options.validate(known_args)
_raise_error_on_invalid_flags(pipeline_args)
_raise_error_on_invalid_flags(
pipeline_args,
known_args.output_table if hasattr(known_args, 'output_table') else None)
if hasattr(known_args, 'input_pattern') or hasattr(known_args, 'input_file'):
known_args.all_patterns = _get_all_patterns(
known_args.input_pattern, known_args.input_file)
Expand Down Expand Up @@ -301,8 +303,8 @@ def write_headers(merged_header, file_path):
vcf_header_io.WriteVcfHeaders(file_path))


def _raise_error_on_invalid_flags(pipeline_args):
# type: (List[str]) -> None
def _raise_error_on_invalid_flags(pipeline_args, output_table):
# type: (List[str], Any) -> None
"""Raises an error if there are unrecognized flags."""
parser = argparse.ArgumentParser()
for cls in pipeline_options.PipelineOptions.__subclasses__():
Expand All @@ -315,6 +317,14 @@ def _raise_error_on_invalid_flags(pipeline_args):
not known_pipeline_args.setup_file):
raise ValueError('The --setup_file flag is required for DataflowRunner. '
'Please provide a path to the setup.py file.')
if output_table:
if (not hasattr(known_pipeline_args, 'temp_location') or
not known_pipeline_args.temp_location):
raise ValueError('--temp_location is required for BigQuery imports.')
if not known_pipeline_args.temp_location.startswith('gs://'):
raise ValueError(
'--temp_location must be valid GCS location for BigQuery imports')



def is_pipeline_direct_runner(pipeline):
Expand Down
18 changes: 14 additions & 4 deletions gcp_variant_transforms/pipeline_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,31 @@ def test_fail_on_invalid_flags(self):
'gcp-variant-transforms-test',
'--staging_location',
'gs://integration_test_runs/staging']
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)

# Add Dataflow runner (requires --setup_file).
pipeline_args.extend(['--runner', 'DataflowRunner'])
with self.assertRaisesRegexp(ValueError, 'setup_file'):
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)

# Add setup.py (required for Variant Transforms run). This is now valid.
pipeline_args.extend(['--setup_file', 'setup.py'])
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
pipeline_common._raise_error_on_invalid_flags(pipeline_args, None)

with self.assertRaisesRegexp(ValueError, '--temp_location is required*'):
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')

pipeline_args.extend(['--temp_location', 'wrong_gcs'])
with self.assertRaisesRegexp(ValueError, '--temp_location must be valid*'):
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')

pipeline_args = pipeline_args[:-1] + ['gs://valid_bucket/temp']
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')

# Add an unknown flag.
pipeline_args.extend(['--unknown_flag', 'somevalue'])
with self.assertRaisesRegexp(ValueError, 'Unrecognized.*unknown_flag'):
pipeline_common._raise_error_on_invalid_flags(pipeline_args)
pipeline_common._raise_error_on_invalid_flags(pipeline_args, 'output')

def test_get_compression_type(self):
vcf_metadata_list = [filesystem.FileMetadata(path, size) for
Expand Down
6 changes: 2 additions & 4 deletions gcp_variant_transforms/transforms/sample_info_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def process(self, vcf_header):
class SampleInfoToBigQuery(beam.PTransform):
"""Writes sample info to BigQuery."""

def __init__(self, output_table_prefix, temp_location, append=False,
def __init__(self, output_table_prefix, append=False,
samples_span_multiple_files=False):
# type: (str, Dict[str, str], bool, bool) -> None
"""Initializes the transform.
Expand All @@ -67,7 +67,6 @@ def __init__(self, output_table_prefix, temp_location, append=False,
self._append = append
self._samples_span_multiple_files = samples_span_multiple_files
self._schema = sample_info_table_schema_generator.generate_schema()
self._temp_location = temp_location

def expand(self, pcoll):
return (pcoll
Expand All @@ -82,5 +81,4 @@ def expand(self, pcoll):
beam.io.BigQueryDisposition.WRITE_APPEND
if self._append
else beam.io.BigQueryDisposition.WRITE_TRUNCATE),
method=beam.io.WriteToBigQuery.Method.FILE_LOADS,
custom_gcs_temp_location=self._temp_location))
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_convert_sample_info_to_row(self):
| transforms.Create([vcf_header_1, vcf_header_2])
| 'ConvertToRow'
>> transforms.ParDo(sample_info_to_bigquery.ConvertSampleInfoToRow(
), False))
False), ))

assert_that(bigquery_rows, equal_to(expected_rows))
pipeline.run()
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_convert_sample_info_to_row_without_file_in_hash(self):
| transforms.Create([vcf_header_1, vcf_header_2])
| 'ConvertToRow'
>> transforms.ParDo(sample_info_to_bigquery.ConvertSampleInfoToRow(
), True))
True), ))

assert_that(bigquery_rows, equal_to(expected_rows))
pipeline.run()
5 changes: 1 addition & 4 deletions gcp_variant_transforms/transforms/variant_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(
self,
output_table, # type: str
header_fields, # type: vcf_header_io.VcfHeader
temp_location, # type: str
variant_merger=None, # type: variant_merge_strategy.VariantMergeStrategy
proc_var_factory=None, # type: processed_variant.ProcessedVariantFactory
# TODO(bashir2): proc_var_factory is a required argument and if `None` is
Expand Down Expand Up @@ -99,7 +98,6 @@ def __init__(
"""
self._output_table = output_table
self._header_fields = header_fields
self._temp_location = temp_location
self._variant_merger = variant_merger
self._proc_var_factory = proc_var_factory
self._append = append
Expand Down Expand Up @@ -137,5 +135,4 @@ def expand(self, pcoll):
beam.io.BigQueryDisposition.WRITE_APPEND
if self._append
else beam.io.BigQueryDisposition.WRITE_TRUNCATE),
method=beam.io.WriteToBigQuery.Method.FILE_LOADS,
custom_gcs_temp_location=self._temp_location))
method=beam.io.WriteToBigQuery.Method.FILE_LOADS))
11 changes: 1 addition & 10 deletions gcp_variant_transforms/vcf_to_bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ def _run_annotation_pipeline(known_args, pipeline_args):
def _create_sample_info_table(pipeline, # type: beam.Pipeline
pipeline_mode, # type: PipelineModes
known_args, # type: argparse.Namespace,
temp_directory, # str
):
# type: (...) -> None
headers = pipeline_common.read_headers(
Expand All @@ -395,7 +394,6 @@ def _create_sample_info_table(pipeline, # type: beam.Pipeline
_ = (headers | 'SampleInfoToBigQuery' >>
sample_info_to_bigquery.SampleInfoToBigQuery(
known_args.output_table,
temp_directory,
known_args.append,
known_args.samples_span_multiple_files))

Expand All @@ -406,8 +404,6 @@ def run(argv=None):
logging.info('Command: %s', ' '.join(argv or sys.argv))
known_args, pipeline_args = pipeline_common.parse_args(argv,
_COMMAND_LINE_OPTIONS)
if known_args.output_table and '--temp_location' not in pipeline_args:
raise ValueError('--temp_location is required for BigQuery imports.')
if known_args.auto_flags_experiment:
_get_input_dimensions(known_args, pipeline_args)

Expand Down Expand Up @@ -483,10 +479,6 @@ def run(argv=None):
num_partitions = 1

if known_args.output_table:
temp_directory = pipeline_options.PipelineOptions(pipeline_args).view_as(
pipeline_options.GoogleCloudOptions).temp_location
if not temp_directory:
raise ValueError('--temp_location must be set when writing to BigQuery.')
for i in range(num_partitions):
table_suffix = ''
if partitioner and partitioner.get_partition_name(i):
Expand All @@ -496,7 +488,6 @@ def run(argv=None):
variant_to_bigquery.VariantToBigQuery(
table_name,
header_fields,
temp_directory,
variant_merger,
processed_variant_factory,
append=known_args.append,
Expand All @@ -507,7 +498,7 @@ def run(argv=None):
known_args.null_numeric_value_replacement)))
if known_args.generate_sample_info_table:
_create_sample_info_table(
pipeline, pipeline_mode, known_args, temp_directory)
pipeline, pipeline_mode, known_args)

if known_args.output_avro_path:
# TODO(bashir2): Add an integration test that outputs to Avro files and
Expand Down

0 comments on commit 059ef86

Please sign in to comment.