Skip to content

Commit

Permalink
Add ops of ExportToFile and ImportFromFile without full volume copying
Browse files Browse the repository at this point in the history
  • Loading branch information
Lifann committed Jun 10, 2022
1 parent 3592668 commit dfc5667
Show file tree
Hide file tree
Showing 8 changed files with 757 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,18 @@ class CuckooHashTableOfTensors final : public LookupInterface {
return table_->export_values(ctx, value_dim);
}

Status SaveToFile(OpKernelContext* ctx, const string filepath,
const size_t buffer_size) {
int64 value_dim = value_shape_.dim_size(0);
return table_->save_to_file(ctx, value_dim, filepath, buffer_size);
}

Status LoadFromFile(OpKernelContext* ctx, const string filepath,
const size_t buffer_size) {
int64 value_dim = value_shape_.dim_size(0);
return table_->load_from_file(ctx, value_dim, filepath, buffer_size);
}

DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }

DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
Expand Down Expand Up @@ -607,6 +619,36 @@ class HashTableExportOp : public HashTableOpKernel {
}
};

// Op that export all keys and values to file.
template <class K, class V>
class HashTableExportToFileOp : public HashTableOpKernel {
public:
explicit HashTableExportToFileOp(OpKernelConstruction* ctx)
: HashTableOpKernel(ctx) {
int64 signed_buffer_size = 0;
ctx->GetAttr("buffer_size", &signed_buffer_size);
buffer_size_ = static_cast<size_t>(signed_buffer_size);
}

void Compute(OpKernelContext* ctx) override {
LookupInterface* table;
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
core::ScopedUnref unref_me(table);

const Tensor& ftensor = ctx->input(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
errors::InvalidArgument("filepath must be scalar."));
string filepath = string(ftensor.scalar<tstring>()().data());

lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
(lookup::CuckooHashTableOfTensors<K, V>*)table;
OP_REQUIRES_OK(ctx, table_cuckoo->SaveToFile(ctx, filepath, buffer_size_));
}

private:
size_t buffer_size_;
};

// Clear the table and insert data.
class HashTableImportOp : public HashTableOpKernel {
public:
Expand Down Expand Up @@ -637,6 +679,37 @@ class HashTableImportOp : public HashTableOpKernel {
}
};

// Op that export all keys and values to file.
template <class K, class V>
class HashTableImportFromFileOp : public HashTableOpKernel {
public:
explicit HashTableImportFromFileOp(OpKernelConstruction* ctx)
: HashTableOpKernel(ctx) {
int64 signed_buffer_size = 0;
ctx->GetAttr("buffer_size", &signed_buffer_size);
buffer_size_ = static_cast<size_t>(signed_buffer_size);
}

void Compute(OpKernelContext* ctx) override {
LookupInterface* table;
OP_REQUIRES_OK(ctx, GetTable(ctx, &table));
core::ScopedUnref unref_me(table);

const Tensor& ftensor = ctx->input(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
errors::InvalidArgument("filepath must be scalar."));
string filepath = string(ftensor.scalar<tstring>()().data());

lookup::CuckooHashTableOfTensors<K, V>* table_cuckoo =
(lookup::CuckooHashTableOfTensors<K, V>*)table;
OP_REQUIRES_OK(ctx,
table_cuckoo->LoadFromFile(ctx, filepath, buffer_size_));
}

private:
size_t buffer_size_;
};

REGISTER_KERNEL_BUILDER(
Name(PREFIX_OP_NAME(CuckooHashTableFind)).Device(DEVICE_CPU),
HashTableFindOp);
Expand Down Expand Up @@ -679,7 +752,17 @@ REGISTER_KERNEL_BUILDER(
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("Tin") \
.TypeConstraint<value_dtype>("Tout"), \
HashTableFindWithExistsOp<key_dtype, value_dtype>);
HashTableFindWithExistsOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableExportToFileOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableImportFromFileOp<key_dtype, value_dtype>);

REGISTER_KERNEL(int32, double);
REGISTER_KERNEL(int32, float);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
if (cur_size > 0) {
CUDA_CHECK(cudaMallocManaged((void**)&d_dump_counter, sizeof(size_t)));
CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * cur_size));
CUDA_CHECK(cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * cur_size));
CUDA_CHECK(cudaMallocManaged((void**)&d_values,
sizeof(V) * runtime_dim_ * cur_size));
table_->dump(d_keys, (gpu::ValueArrayBase<V>*)d_values, 0, capacity,
d_dump_counter, stream);
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t), cudaMemcpyDeviceToHost, stream);
d_dump_counter, stream);
cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t),
cudaMemcpyDeviceToHost, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
}

Expand All @@ -226,8 +228,9 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
CreateTable(new_max_size, &table_);

if (cur_size > 0) {
table_->upsert((const K*)d_keys, (const gpu::ValueArrayBase<V>*)d_values,
h_dump_counter, stream);
table_->upsert((const K*)d_keys,
(const gpu::ValueArrayBase<V>*)d_values, h_dump_counter,
stream);
cudaStreamSynchronize(stream);
cudaFree(d_keys);
cudaFree(d_values);
Expand Down Expand Up @@ -387,6 +390,54 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface {
return Status::OK();
}

Status ExportValuesToFile(OpKernelContext* ctx, const string filepath,
const size_t buffer_size) {
cudaStream_t _stream;
CUDA_CHECK(cudaStreamCreate(&_stream));

{
tf_shared_lock l(mu_);
table_->dump_to_file(ctx, filepath, runtime_dim_, _stream, buffer_size);
CUDA_CHECK(cudaStreamSynchronize(_stream));
}
CUDA_CHECK(cudaStreamDestroy(_stream));
return Status::OK();
}

Status ImportValuesFromFile(OpKernelContext* ctx, const string filepath,
const size_t buffer_size) {
cudaStream_t _stream;
CUDA_CHECK(cudaStreamCreate(&_stream));

{
tf_shared_lock l(mu_);

string keyfile = filepath + ".keys";
FILE* tmpfd = fopen(keyfile.c_str(), "rb");
if (tmpfd == nullptr) {
return errors::NotFound("Failed to read key file", keyfile);
}
fseek(tmpfd, 0, SEEK_END);
long int filesize = ftell(tmpfd);
if (filesize <= 0) {
fclose(tmpfd);
return errors::NotFound("Empty key file.");
}
size_t size = static_cast<size_t>(filesize) / sizeof(K);
fseek(tmpfd, 0, SEEK_SET);
fclose(tmpfd);

table_->clear(_stream);
CUDA_CHECK(cudaStreamSynchronize(_stream));
RehashIfNeeded(_stream, size);
table_->load_from_file(ctx, filepath, size, runtime_dim_, _stream,
buffer_size);
CUDA_CHECK(cudaStreamSynchronize(_stream));
}
CUDA_CHECK(cudaStreamDestroy(_stream));
return Status::OK();
}

DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
TensorShape key_shape() const final { return TensorShape(); }
Expand Down Expand Up @@ -625,6 +676,36 @@ REGISTER_KERNEL_BUILDER(
Name(PREFIX_OP_NAME(CuckooHashTableExport)).Device(DEVICE_GPU),
HashTableExportGpuOp);

// Op that export all keys and values to file.
template <class K, class V>
class HashTableExportToFileGpuOp : public OpKernel {
public:
explicit HashTableExportToFileGpuOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
int64 signed_buffer_size = 0;
ctx->GetAttr("buffer_size", &signed_buffer_size);
buffer_size_ = static_cast<size_t>(signed_buffer_size);
}

void Compute(OpKernelContext* ctx) override {
lookup::LookupInterface* table;
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);

const Tensor& ftensor = ctx->input(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
errors::InvalidArgument("filepath must be scalar."));
string filepath = string(ftensor.scalar<tstring>()().data());
lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
(lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
OP_REQUIRES_OK(
ctx, table_cuckoo->ExportValuesToFile(ctx, filepath, buffer_size_));
}

private:
size_t buffer_size_;
};

// Clear the table and insert data.
class HashTableImportGpuOp : public OpKernel {
public:
Expand All @@ -651,33 +732,76 @@ REGISTER_KERNEL_BUILDER(
Name(PREFIX_OP_NAME(CuckooHashTableImport)).Device(DEVICE_GPU),
HashTableImportGpuOp);

// Op that import from file.
template <class K, class V>
class HashTableImportFromFileGpuOp : public OpKernel {
public:
explicit HashTableImportFromFileGpuOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
int64 signed_buffer_size = 0;
ctx->GetAttr("buffer_size", &signed_buffer_size);
buffer_size_ = static_cast<size_t>(signed_buffer_size);
}

void Compute(OpKernelContext* ctx) override {
lookup::LookupInterface* table;
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);

const Tensor& ftensor = ctx->input(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()),
errors::InvalidArgument("filepath must be scalar."));
string filepath = string(ftensor.scalar<tstring>()().data());
lookup::CuckooHashTableOfTensorsGpu<K, V>* table_cuckoo =
(lookup::CuckooHashTableOfTensorsGpu<K, V>*)table;
OP_REQUIRES_OK(
ctx, table_cuckoo->ImportValuesFromFile(ctx, filepath, buffer_size_));
}

private:
size_t buffer_size_;
};

// Register the CuckooHashTableOfTensors op.

#define REGISTER_KERNEL(key_dtype, value_dtype) \
REGISTER_KERNEL_BUILDER( \
Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
.Device(DEVICE_GPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableGpuOp< \
lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
.Device(DEVICE_GPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableClearGpuOp<key_dtype, value_dtype>) \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
.Device(DEVICE_GPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableAccumGpuOp<key_dtype, value_dtype>) \
REGISTER_KERNEL_BUILDER( \
Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
.Device(DEVICE_GPU) \
.TypeConstraint<key_dtype>("Tin") \
.TypeConstraint<value_dtype>("Tout"), \
HashTableFindWithExistsGpuOp<key_dtype, value_dtype>)
#define REGISTER_KERNEL(key_dtype, value_dtype) \
REGISTER_KERNEL_BUILDER( \
Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \
.Device(DEVICE_GPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableGpuOp< \
lookup::CuckooHashTableOfTensorsGpu<key_dtype, value_dtype>, \
key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \
.Device(DEVICE_GPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableClearGpuOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \
.Device(DEVICE_GPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableAccumGpuOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \
.Device(DEVICE_GPU) \
.HostMemory("filepath") \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableExportToFileGpuOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER( \
Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \
.Device(DEVICE_GPU) \
.HostMemory("filepath") \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
HashTableImportFromFileGpuOp<key_dtype, value_dtype>); \
REGISTER_KERNEL_BUILDER( \
Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \
.Device(DEVICE_GPU) \
.TypeConstraint<key_dtype>("Tin") \
.TypeConstraint<value_dtype>("Tout"), \
HashTableFindWithExistsGpuOp<key_dtype, value_dtype>);

REGISTER_KERNEL(int64, float);
REGISTER_KERNEL(int64, Eigen::half);
Expand Down
Loading

0 comments on commit dfc5667

Please sign in to comment.