From dfc56671a713006ec321fa44752ff3d90cc83e40 Mon Sep 17 00:00:00 2001 From: Lifann Date: Thu, 12 May 2022 21:58:35 +0800 Subject: [PATCH] Add ops of ExportToFile and ImportFromFile without full volume copying --- .../core/kernels/cuckoo_hashtable_op.cc | 85 ++++++- .../kernels/cuckoo_hashtable_op_gpu.cu.cc | 184 +++++++++++--- .../kernels/lookup_impl/lookup_table_op_cpu.h | 86 +++++++ .../kernels/lookup_impl/lookup_table_op_gpu.h | 100 ++++++++ .../core/ops/cuckoo_hashtable_ops.cc | 14 ++ .../dynamic_embedding/core/utils/filebuffer.h | 225 ++++++++++++++++++ .../kernel_tests/cuckoo_hashtable_ops_test.py | 47 ++++ .../python/ops/cuckoo_hashtable_ops.py | 47 ++++ 8 files changed, 757 insertions(+), 31 deletions(-) create mode 100644 tensorflow_recommenders_addons/dynamic_embedding/core/utils/filebuffer.h diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc index 6d4bfd79b..686a0bc53 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op.cc @@ -304,6 +304,18 @@ class CuckooHashTableOfTensors final : public LookupInterface { return table_->export_values(ctx, value_dim); } + Status SaveToFile(OpKernelContext* ctx, const string filepath, + const size_t buffer_size) { + int64 value_dim = value_shape_.dim_size(0); + return table_->save_to_file(ctx, value_dim, filepath, buffer_size); + } + + Status LoadFromFile(OpKernelContext* ctx, const string filepath, + const size_t buffer_size) { + int64 value_dim = value_shape_.dim_size(0); + return table_->load_from_file(ctx, value_dim, filepath, buffer_size); + } + DataType key_dtype() const override { return DataTypeToEnum::v(); } DataType value_dtype() const override { return DataTypeToEnum::v(); } @@ -607,6 +619,36 @@ class HashTableExportOp : public HashTableOpKernel { } }; +// Op that export all keys and values to file. +template +class HashTableExportToFileOp : public HashTableOpKernel { + public: + explicit HashTableExportToFileOp(OpKernelConstruction* ctx) + : HashTableOpKernel(ctx) { + int64 signed_buffer_size = 0; + ctx->GetAttr("buffer_size", &signed_buffer_size); + buffer_size_ = static_cast(signed_buffer_size); + } + + void Compute(OpKernelContext* ctx) override { + LookupInterface* table; + OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); + core::ScopedUnref unref_me(table); + + const Tensor& ftensor = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()), + errors::InvalidArgument("filepath must be scalar.")); + string filepath = string(ftensor.scalar()().data()); + + lookup::CuckooHashTableOfTensors* table_cuckoo = + (lookup::CuckooHashTableOfTensors*)table; + OP_REQUIRES_OK(ctx, table_cuckoo->SaveToFile(ctx, filepath, buffer_size_)); + } + + private: + size_t buffer_size_; +}; + // Clear the table and insert data. class HashTableImportOp : public HashTableOpKernel { public: @@ -637,6 +679,37 @@ class HashTableImportOp : public HashTableOpKernel { } }; +// Op that export all keys and values to file. +template +class HashTableImportFromFileOp : public HashTableOpKernel { + public: + explicit HashTableImportFromFileOp(OpKernelConstruction* ctx) + : HashTableOpKernel(ctx) { + int64 signed_buffer_size = 0; + ctx->GetAttr("buffer_size", &signed_buffer_size); + buffer_size_ = static_cast(signed_buffer_size); + } + + void Compute(OpKernelContext* ctx) override { + LookupInterface* table; + OP_REQUIRES_OK(ctx, GetTable(ctx, &table)); + core::ScopedUnref unref_me(table); + + const Tensor& ftensor = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()), + errors::InvalidArgument("filepath must be scalar.")); + string filepath = string(ftensor.scalar()().data()); + + lookup::CuckooHashTableOfTensors* table_cuckoo = + (lookup::CuckooHashTableOfTensors*)table; + OP_REQUIRES_OK(ctx, + table_cuckoo->LoadFromFile(ctx, filepath, buffer_size_)); + } + + private: + size_t buffer_size_; +}; + REGISTER_KERNEL_BUILDER( Name(PREFIX_OP_NAME(CuckooHashTableFind)).Device(DEVICE_CPU), HashTableFindOp); @@ -679,7 +752,17 @@ REGISTER_KERNEL_BUILDER( .Device(DEVICE_CPU) \ .TypeConstraint("Tin") \ .TypeConstraint("Tout"), \ - HashTableFindWithExistsOp); + HashTableFindWithExistsOp); \ + REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableExportToFileOp); \ + REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \ + .Device(DEVICE_CPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableImportFromFileOp); REGISTER_KERNEL(int32, double); REGISTER_KERNEL(int32, float); diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc index c40f43ed2..d371ee859 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/cuckoo_hashtable_op_gpu.cu.cc @@ -214,10 +214,12 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface { if (cur_size > 0) { CUDA_CHECK(cudaMallocManaged((void**)&d_dump_counter, sizeof(size_t))); CUDA_CHECK(cudaMallocManaged((void**)&d_keys, sizeof(K) * cur_size)); - CUDA_CHECK(cudaMallocManaged((void**)&d_values, sizeof(V) * runtime_dim_ * cur_size)); + CUDA_CHECK(cudaMallocManaged((void**)&d_values, + sizeof(V) * runtime_dim_ * cur_size)); table_->dump(d_keys, (gpu::ValueArrayBase*)d_values, 0, capacity, - d_dump_counter, stream); - cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t), cudaMemcpyDeviceToHost, stream); + d_dump_counter, stream); + cudaMemcpyAsync(&h_dump_counter, d_dump_counter, sizeof(size_t), + cudaMemcpyDeviceToHost, stream); CUDA_CHECK(cudaStreamSynchronize(stream)); } @@ -226,8 +228,9 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface { CreateTable(new_max_size, &table_); if (cur_size > 0) { - table_->upsert((const K*)d_keys, (const gpu::ValueArrayBase*)d_values, - h_dump_counter, stream); + table_->upsert((const K*)d_keys, + (const gpu::ValueArrayBase*)d_values, h_dump_counter, + stream); cudaStreamSynchronize(stream); cudaFree(d_keys); cudaFree(d_values); @@ -387,6 +390,54 @@ class CuckooHashTableOfTensorsGpu final : public LookupInterface { return Status::OK(); } + Status ExportValuesToFile(OpKernelContext* ctx, const string filepath, + const size_t buffer_size) { + cudaStream_t _stream; + CUDA_CHECK(cudaStreamCreate(&_stream)); + + { + tf_shared_lock l(mu_); + table_->dump_to_file(ctx, filepath, runtime_dim_, _stream, buffer_size); + CUDA_CHECK(cudaStreamSynchronize(_stream)); + } + CUDA_CHECK(cudaStreamDestroy(_stream)); + return Status::OK(); + } + + Status ImportValuesFromFile(OpKernelContext* ctx, const string filepath, + const size_t buffer_size) { + cudaStream_t _stream; + CUDA_CHECK(cudaStreamCreate(&_stream)); + + { + tf_shared_lock l(mu_); + + string keyfile = filepath + ".keys"; + FILE* tmpfd = fopen(keyfile.c_str(), "rb"); + if (tmpfd == nullptr) { + return errors::NotFound("Failed to read key file", keyfile); + } + fseek(tmpfd, 0, SEEK_END); + long int filesize = ftell(tmpfd); + if (filesize <= 0) { + fclose(tmpfd); + return errors::NotFound("Empty key file."); + } + size_t size = static_cast(filesize) / sizeof(K); + fseek(tmpfd, 0, SEEK_SET); + fclose(tmpfd); + + table_->clear(_stream); + CUDA_CHECK(cudaStreamSynchronize(_stream)); + RehashIfNeeded(_stream, size); + table_->load_from_file(ctx, filepath, size, runtime_dim_, _stream, + buffer_size); + CUDA_CHECK(cudaStreamSynchronize(_stream)); + } + CUDA_CHECK(cudaStreamDestroy(_stream)); + return Status::OK(); + } + DataType key_dtype() const override { return DataTypeToEnum::v(); } DataType value_dtype() const override { return DataTypeToEnum::v(); } TensorShape key_shape() const final { return TensorShape(); } @@ -625,6 +676,36 @@ REGISTER_KERNEL_BUILDER( Name(PREFIX_OP_NAME(CuckooHashTableExport)).Device(DEVICE_GPU), HashTableExportGpuOp); +// Op that export all keys and values to file. +template +class HashTableExportToFileGpuOp : public OpKernel { + public: + explicit HashTableExportToFileGpuOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + int64 signed_buffer_size = 0; + ctx->GetAttr("buffer_size", &signed_buffer_size); + buffer_size_ = static_cast(signed_buffer_size); + } + + void Compute(OpKernelContext* ctx) override { + lookup::LookupInterface* table; + OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + const Tensor& ftensor = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()), + errors::InvalidArgument("filepath must be scalar.")); + string filepath = string(ftensor.scalar()().data()); + lookup::CuckooHashTableOfTensorsGpu* table_cuckoo = + (lookup::CuckooHashTableOfTensorsGpu*)table; + OP_REQUIRES_OK( + ctx, table_cuckoo->ExportValuesToFile(ctx, filepath, buffer_size_)); + } + + private: + size_t buffer_size_; +}; + // Clear the table and insert data. class HashTableImportGpuOp : public OpKernel { public: @@ -651,33 +732,76 @@ REGISTER_KERNEL_BUILDER( Name(PREFIX_OP_NAME(CuckooHashTableImport)).Device(DEVICE_GPU), HashTableImportGpuOp); +// Op that import from file. +template +class HashTableImportFromFileGpuOp : public OpKernel { + public: + explicit HashTableImportFromFileGpuOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + int64 signed_buffer_size = 0; + ctx->GetAttr("buffer_size", &signed_buffer_size); + buffer_size_ = static_cast(signed_buffer_size); + } + + void Compute(OpKernelContext* ctx) override { + lookup::LookupInterface* table; + OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table)); + core::ScopedUnref unref_me(table); + + const Tensor& ftensor = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ftensor.shape()), + errors::InvalidArgument("filepath must be scalar.")); + string filepath = string(ftensor.scalar()().data()); + lookup::CuckooHashTableOfTensorsGpu* table_cuckoo = + (lookup::CuckooHashTableOfTensorsGpu*)table; + OP_REQUIRES_OK( + ctx, table_cuckoo->ImportValuesFromFile(ctx, filepath, buffer_size_)); + } + + private: + size_t buffer_size_; +}; + // Register the CuckooHashTableOfTensors op. -#define REGISTER_KERNEL(key_dtype, value_dtype) \ - REGISTER_KERNEL_BUILDER( \ - Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \ - .Device(DEVICE_GPU) \ - .TypeConstraint("key_dtype") \ - .TypeConstraint("value_dtype"), \ - HashTableGpuOp< \ - lookup::CuckooHashTableOfTensorsGpu, \ - key_dtype, value_dtype>); \ - REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \ - .Device(DEVICE_GPU) \ - .TypeConstraint("key_dtype") \ - .TypeConstraint("value_dtype"), \ - HashTableClearGpuOp) \ - REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \ - .Device(DEVICE_GPU) \ - .TypeConstraint("key_dtype") \ - .TypeConstraint("value_dtype"), \ - HashTableAccumGpuOp) \ - REGISTER_KERNEL_BUILDER( \ - Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \ - .Device(DEVICE_GPU) \ - .TypeConstraint("Tin") \ - .TypeConstraint("Tout"), \ - HashTableFindWithExistsGpuOp) +#define REGISTER_KERNEL(key_dtype, value_dtype) \ + REGISTER_KERNEL_BUILDER( \ + Name(PREFIX_OP_NAME(CuckooHashTableOfTensors)) \ + .Device(DEVICE_GPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableGpuOp< \ + lookup::CuckooHashTableOfTensorsGpu, \ + key_dtype, value_dtype>); \ + REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableClear)) \ + .Device(DEVICE_GPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableClearGpuOp); \ + REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableAccum)) \ + .Device(DEVICE_GPU) \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableAccumGpuOp); \ + REGISTER_KERNEL_BUILDER(Name(PREFIX_OP_NAME(CuckooHashTableExportToFile)) \ + .Device(DEVICE_GPU) \ + .HostMemory("filepath") \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableExportToFileGpuOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) \ + .Device(DEVICE_GPU) \ + .HostMemory("filepath") \ + .TypeConstraint("key_dtype") \ + .TypeConstraint("value_dtype"), \ + HashTableImportFromFileGpuOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(PREFIX_OP_NAME(CuckooHashTableFindWithExists)) \ + .Device(DEVICE_GPU) \ + .TypeConstraint("Tin") \ + .TypeConstraint("Tout"), \ + HashTableFindWithExistsGpuOp); REGISTER_KERNEL(int64, float); REGISTER_KERNEL(int64, Eigen::half); diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h index 00bd4b395..b14c42b52 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_cpu.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow_recommenders_addons/dynamic_embedding/core/lib/cuckoo/cuckoohash_map.hh" +#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/filebuffer.h" #include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/types.h" namespace tensorflow { @@ -129,6 +130,12 @@ class TableWrapperBase { virtual void clear() {} virtual bool erase(const K& key) {} virtual Status export_values(OpKernelContext* ctx, int64 value_dim) {} + virtual Status save_to_file(OpKernelContext* ctx, int64 value_dim, + const string filepath, const size_t buffer_size) { + } + virtual Status load_from_file(OpKernelContext* ctx, int64 value_dim, + const string filepath, + const size_t buffer_size) {} }; template @@ -232,6 +239,85 @@ class TableWrapperOptimized final : public TableWrapperBase { return Status::OK(); } + Status save_to_file(OpKernelContext* ctx, int64 value_dim, + const string filepath, + const size_t buffer_size) override { + auto lt = table_->lock_table(); + int64 size = lt.size(); + + size_t key_buffer_size = buffer_size; + string key_tmpfile = filepath + ".keys.tmp"; + string key_file = filepath + ".keys"; + auto key_buffer = filebuffer::HostFileBuffer( + key_tmpfile, key_buffer_size, filebuffer::MODE::WRITE); + + size_t value_buffer_size = key_buffer_size * static_cast(value_dim); + string value_tmpfile = filepath + ".values.tmp"; + string value_file = filepath + ".values"; + auto value_buffer = filebuffer::HostFileBuffer( + value_tmpfile, value_buffer_size, filebuffer::MODE::WRITE); + + for (auto it = lt.begin(); it != lt.end(); ++it) { + key_buffer.Put(it->first); + value_buffer.BatchPut(it->second.data(), it->second.size()); + } + key_buffer.Flush(); + value_buffer.Flush(); + key_buffer.Close(); + value_buffer.Close(); + + if (rename(key_tmpfile.c_str(), key_file.c_str()) != 0) { + return errors::NotFound("key_tmpfile ", key_tmpfile, " is not found."); + } + if (rename(value_tmpfile.c_str(), value_file.c_str()) != 0) { + return errors::NotFound("value_tmpfile ", value_tmpfile, + " is not found."); + } + return Status::OK(); + } + + Status load_from_file(OpKernelContext* ctx, int64 value_dim, + const string filepath, + const size_t buffer_size) override { + size_t dim = static_cast(value_dim); + size_t key_buffer_size = buffer_size; + size_t value_buffer_size = key_buffer_size * dim; + string key_file = filepath + ".keys"; + string value_file = filepath + ".values"; + auto key_buffer = filebuffer::HostFileBuffer(key_file, key_buffer_size, + filebuffer::MODE::READ); + auto value_buffer = filebuffer::HostFileBuffer( + value_file, value_buffer_size, filebuffer::MODE::READ); + size_t nkeys = 1; + + size_t total_keys = 0; + size_t total_values = 0; + while (nkeys > 0) { + nkeys = key_buffer.Fill(); + value_buffer.Fill(); + total_keys += key_buffer.size(); + total_values += value_buffer.size(); + for (size_t i = 0; i < key_buffer.size(); i++) { + ValueType value_vec; + K key = key_buffer[i]; + for (size_t j = 0; j < dim; j++) { + V value = value_buffer[i * dim + j]; + value_vec[j] = value; + } + table_->insert_or_assign(key, value_vec); + } + key_buffer.Clear(); + value_buffer.Clear(); + } + if (total_keys * dim != total_values) { + LOG(ERROR) << "DataLoss: restore get " << total_keys << " and " + << total_values << " in file " << filepath << " with dim " + << dim; + exit(1); + } + return Status::OK(); + } + private: size_t init_size_; Table* table_; diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h index 8d389b098..4b29a2f07 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_gpu.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow_recommenders_addons/dynamic_embedding/core/lib/nvhash/nv_hashtable.cuh" +#include "tensorflow_recommenders_addons/dynamic_embedding/core/utils/filebuffer.h" namespace tensorflow { namespace recommenders_addons { @@ -60,6 +61,13 @@ class TableWrapperBase { virtual void dump(K* d_key, ValueType* d_val, const size_t offset, const size_t search_length, size_t* d_dump_counter, cudaStream_t stream) const {} + virtual void dump_to_file(OpKernelContext* ctx, const string filepath, + size_t dim, cudaStream_t stream, + const size_t buffer_size) const {} + virtual void load_from_file(OpKernelContext* ctx, const string filepath, + const size_t key_num, size_t dim, + cudaStream_t stream, + const size_t buffer_size) const {} virtual void get(const K* d_keys, ValueType* d_vals, bool* d_status, size_t len, ValueType* d_def_val, cudaStream_t stream, bool is_full_size_default) const {} @@ -98,6 +106,98 @@ class TableWrapper final : public TableWrapperBase { table_->dump(d_key, d_val, offset, search_length, d_dump_counter, stream); } + void dump_to_file(OpKernelContext* ctx, const string filepath, size_t dim, + cudaStream_t stream, + const size_t buffer_size) const override { + CUDA_CHECK(cudaStreamSynchronize(stream)); + K* keys = nullptr; + V* values = nullptr; + size_t offset = 0; + size_t* d_dump_counter; + size_t dump_counter; + size_t table_capacity = get_capacity(); + + CUDA_CHECK(cudaMalloc(&keys, sizeof(K) * buffer_size)); + CUDA_CHECK(cudaMalloc(&values, sizeof(V) * buffer_size * dim)); + CUDA_CHECK(cudaMalloc(&d_dump_counter, sizeof(size_t))); + + string key_file = filepath + ".keys"; + string value_file = filepath + ".values"; + string key_tmpfile = filepath + ".keys.tmp"; + string value_tmpfile = filepath + ".values.tmp"; + auto key_buffer = filebuffer::DeviceFileBuffer(key_tmpfile, buffer_size, + filebuffer::MODE::WRITE); + auto value_buffer = filebuffer::DeviceFileBuffer( + value_tmpfile, buffer_size * dim, filebuffer::MODE::WRITE); + size_t search_length = 0; + + size_t total_dumped = 0; + while (offset < table_capacity) { + if (offset + buffer_size >= table_capacity) { + search_length = table_capacity - offset; + } else { + search_length = buffer_size; + } + table_->dump(keys, (ValueType*)values, offset, search_length, + d_dump_counter, stream); + CUDA_CHECK(cudaMemcpyAsync(&dump_counter, d_dump_counter, sizeof(size_t), + cudaMemcpyDeviceToHost, stream)); + + key_buffer.BatchPut(keys, dump_counter, stream); + value_buffer.BatchPut(values, dump_counter * dim, stream); + cudaStreamSynchronize(stream); + offset += search_length; + total_dumped += dump_counter; + } + + LOG(INFO) << "Dump finish, offset=" << offset + << ", total_dumped=" << total_dumped; + + CUDA_CHECK(cudaFree(keys)); + CUDA_CHECK(cudaFree(values)); + CUDA_CHECK(cudaFree(d_dump_counter)); + + key_buffer.Close(); + value_buffer.Close(); + OP_REQUIRES( + ctx, rename(key_tmpfile.c_str(), key_file.c_str()) == 0, + errors::NotFound("key_tmpfile ", key_tmpfile, " is not found.")); + OP_REQUIRES( + ctx, rename(value_tmpfile.c_str(), value_file.c_str()) == 0, + errors::NotFound("value_tmpfile ", value_tmpfile, " is not found.")); + } + + void load_from_file(OpKernelContext* ctx, const string filepath, + const size_t key_num, size_t dim, cudaStream_t stream, + const size_t buffer_size) const override { + CUDA_CHECK(cudaStreamSynchronize(stream)); + string key_file = filepath + ".keys"; + string value_file = filepath + ".values"; + auto key_buffer = filebuffer::DeviceFileBuffer(key_file, buffer_size, + filebuffer::MODE::READ); + auto value_buffer = filebuffer::DeviceFileBuffer( + value_file, buffer_size * dim, filebuffer::MODE::READ); + + size_t nkeys = 1; + size_t total_keys = 0; + size_t total_values = 0; + while (nkeys > 0) { + nkeys = key_buffer.Fill(); + value_buffer.Fill(); + total_keys += key_buffer.size(); + total_values += value_buffer.size(); + + table_->upsert(key_buffer.data(), (ValueType*)value_buffer.data(), + nkeys, stream); + cudaStreamSynchronize(stream); + key_buffer.Clear(); + value_buffer.Clear(); + } + OP_REQUIRES(ctx, total_keys * dim == total_values, + errors::DataLoss("load from file get invalid ", total_keys, + " keys and", total_values, " values.")); + } + void get(const K* d_keys, ValueType* d_vals, bool* d_status, size_t len, ValueType* d_def_val, cudaStream_t stream, bool is_full_size_default) const override { diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/ops/cuckoo_hashtable_ops.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/cuckoo_hashtable_ops.cc index 718faccd5..c6e24f1fc 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/ops/cuckoo_hashtable_ops.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/ops/cuckoo_hashtable_ops.cc @@ -254,6 +254,13 @@ REGISTER_OP(PREFIX_OP_NAME(CuckooHashTableExport)) return Status::OK(); }); +REGISTER_OP(PREFIX_OP_NAME(CuckooHashTableExportToFile)) + .Input("table_handle: resource") + .Input("filepath: string") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("buffer_size: int >= 1"); + REGISTER_OP(PREFIX_OP_NAME(CuckooHashTableImport)) .Input("table_handle: resource") .Input("keys: Tin") @@ -270,6 +277,13 @@ REGISTER_OP(PREFIX_OP_NAME(CuckooHashTableImport)) return Status::OK(); }); +REGISTER_OP(PREFIX_OP_NAME(CuckooHashTableImportFromFile)) + .Input("table_handle: resource") + .Input("filepath: string") + .Attr("key_dtype: type") + .Attr("value_dtype: type") + .Attr("buffer_size: int >= 1"); + REGISTER_OP(PREFIX_OP_NAME(CuckooHashTableOfTensors)) .Output("table_handle: resource") .Attr("container: string = ''") diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/utils/filebuffer.h b/tensorflow_recommenders_addons/dynamic_embedding/core/utils/filebuffer.h new file mode 100644 index 000000000..7ea1f7134 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/utils/filebuffer.h @@ -0,0 +1,225 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TFRA_FILEBUFFER_H_ +#define TFRA_FILEBUFFER_H_ + +#include +#include + +#include +#include + +#if GOOGLE_CUDA +#include "cuda_runtime.h" +#endif + +namespace filebuffer { + +enum MODE { READ = 0, WRITE = 1 }; + +template +class FileBuffer { + public: + virtual void Put(const T value) {} + virtual void Flush() {} + virtual size_t Fill() { return 0; } + virtual void Clear() {} + virtual void Close() {} + virtual size_t size() { return 0; } + virtual size_t capacity() { return 0; } +}; + +template +class HostFileBuffer : public FileBuffer { + public: + HostFileBuffer(const std::string path, size_t capacity, MODE mode) + : filepath_(path), capacity_(capacity), mode_(mode) { + offset_ = 0; + buf_ = (T*)malloc(capacity_ * sizeof(T)); + if (!buf_) { + throw std::runtime_error("Failed to allocate HostFileBuffer."); + } + if (mode_ == MODE::READ) { + fp_ = fopen(filepath_.c_str(), "rb"); + } else if (mode == MODE::WRITE) { + fp_ = fopen(filepath_.c_str(), "wb"); + } else { + throw std::invalid_argument("File mode must be READ or WRITE"); + } + } + + void Close() override { + if (buf_) { + free(buf_); + buf_ = nullptr; + } + if (fp_) { + fclose(fp_); + fp_ = nullptr; + } + } + + ~HostFileBuffer() { Close(); } + + void Put(const T value) override { + buf_[offset_++] = value; + if (offset_ == capacity_) { + Flush(); + } + } + + // Must set capacity to be multiples of n. + void BatchPut(const T* value, size_t n) { + for (size_t i = 0; i < n; i++) { + buf_[offset_++] = value[i]; + } + if (offset_ == capacity_) { + Flush(); + } + } + + void Flush() override { + if (mode_ != MODE::WRITE) { + throw std::invalid_argument( + "Can only flush buffer created in WRITE mode."); + } + if (offset_ == 0) return; + size_t nwritten = fwrite(buf_, sizeof(T), offset_, fp_); + if (nwritten != offset_) { + throw std::runtime_error("write to " + filepath_ + " expecting " + + std::to_string(offset_) + " bytes, but write " + + std::to_string(nwritten) + " bytes."); + } + offset_ = 0; + } + + size_t Fill() override { + offset_ = fread(buf_, sizeof(T), capacity_ - offset_, fp_); + return offset_; + } + + void Clear() override { offset_ = 0; } + + T operator[](size_t i) { return buf_[i]; } + + size_t size() override { return offset_; } + size_t capacity() override { return capacity_; } + + void set_offset(size_t offset) { offset_ = offset; } + + private: + const std::string filepath_; + FILE* fp_; + T* buf_; + size_t capacity_; + size_t offset_; + MODE mode_; +}; + +#if GOOGLE_CUDA + +#ifndef CUDACHECK +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) +#endif + +template +class DeviceFileBuffer : public FileBuffer { + public: + DeviceFileBuffer(const std::string path, size_t size, MODE mode) + : filepath_(path), capacity_(size), mode_(mode) { + offset_ = 0; + CUDACHECK(cudaMallocHost(&buf_, capacity_ * sizeof(T))); + if (!buf_) { + throw std::runtime_error("Failed to allocate DeviceFileBuffer"); + } + if (mode_ == MODE::READ) { + fp_ = fopen(filepath_.c_str(), "rb"); + } else if (mode == MODE::WRITE) { + fp_ = fopen(filepath_.c_str(), "wb"); + } else { + throw std::invalid_argument("File mode must be READ or WRITE"); + } + } + + ~DeviceFileBuffer() { Close(); } + + void BatchPut(T* value, size_t n, cudaStream_t stream) { + CUDACHECK(cudaMemcpyAsync(buf_, value, sizeof(T) * n, + cudaMemcpyDeviceToHost, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + offset_ += n; + Flush(); + } + + void Flush() override { + if (mode_ != MODE::WRITE) { + throw std::invalid_argument( + "Can only flush buffer created in WRITE mode"); + } + if (offset_ == 0) return; + size_t nwritten = fwrite(buf_, sizeof(T), offset_, fp_); + if (nwritten != offset_) { + throw std::runtime_error("write to " + filepath_ + " expecting " + + std::to_string(offset_) + " bytes, but write " + + std::to_string(nwritten) + " bytes."); + } + offset_ = 0; + } + + size_t Fill() override { + offset_ = fread(buf_, sizeof(T), capacity_ - offset_, fp_); + return offset_; + } + + void Clear() override { offset_ = 0; } + + void Close() override { + if (buf_) { + CUDACHECK(cudaFreeHost(buf_)); + buf_ = nullptr; + } + if (fp_) { + fclose(fp_); + fp_ = nullptr; + } + } + + T* data() { return buf_; } + size_t size() override { return offset_; } + size_t capacity() override { return capacity_; } + + private: + const std::string filepath_; + FILE* fp_; + T* buf_; + size_t capacity_; + size_t offset_; + MODE mode_; +}; + +} // namespace filebuffer + +#endif // GOOGLE_CUDA + +#endif // TFRA_FILEBUFFER_H_ diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/cuckoo_hashtable_ops_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/cuckoo_hashtable_ops_test.py index c8d05cecb..311ae9536 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/cuckoo_hashtable_ops_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/cuckoo_hashtable_ops_test.py @@ -19,6 +19,8 @@ from __future__ import print_function import sys +import os +import tensorflow as tf from tensorflow_recommenders_addons import dynamic_embedding as de @@ -55,6 +57,51 @@ def test_dynamic_embedding_variable_set_init_size(self): self.assertTrue(dev_str in printed.contents()) self.assertTrue("_size={}".format(expect_size) in printed.contents()) + @test_util.run_in_graph_and_eager_modes() + def test_cuckoo_hashtable_save(self): + initializer = tf.keras.initializers.RandomNormal() + dim = 8 + + test_devices = ['/CPU:0'] + if test_util.is_gpu_available(): + test_devices.append('/GPU:0') + for idx, device in enumerate(test_devices): + var1 = de.get_variable('vmas142_' + str(idx), + key_dtype=tf.int64, + value_dtype=tf.float32, + initializer=initializer, + devices=[device], + dim=dim) + var2 = de.get_variable('lfwa031_' + str(idx), + key_dtype=tf.int64, + value_dtype=tf.float32, + initializer=initializer, + devices=[device], + dim=dim) + init_keys = tf.range(0, 10000, dtype=tf.int64) + init_values = var1.lookup(init_keys) + + sess_config = config_pb2.ConfigProto( + allow_soft_placement=True, + gpu_options=config_pb2.GPUOptions(allow_growth=True)) + with self.session(use_gpu=True, config=default_config): + self.evaluate(var1.upsert(init_keys, init_values)) + + np_keys = self.evaluate(init_keys) + np_values = self.evaluate(init_values) + + test_dir = self.get_temp_dir() + filepath = os.path.join(test_dir, 'table') + self.evaluate(var1.tables[0].save(filepath, buffer_size=4096)) + self.evaluate(var2.tables[0].load(filepath, buffer_size=4096)) + load_keys, load_values = self.evaluate(var2.export()) + sort_idx = load_keys.argsort() + load_keys = load_keys[sort_idx[::1]] + load_values = load_values[sort_idx[::1]] + + self.assertAllEqual(np_keys, load_keys) + self.assertAllEqual(np_values, load_values) + if __name__ == "__main__": test.main() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py index 7d42ca7d4..4821d54f2 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py @@ -338,6 +338,53 @@ def export(self, name=None): self.resource_handle, self._key_dtype, self._value_dtype) return keys, values + def save(self, filepath, buffer_size=4194304, name=None): + """ + Returns an operation to save the keys and values in table to + filepath. The keys and values will be stored in files with + suffix ".keys" and ".values", appended to the filepath. + + Args: + filepath: A path to save the table. + name: Name for the operation. + buffer_size: Number of kv pairs buffer write to file. + + Returns: + An operation to save the table. + """ + with ops.name_scope(name, "%s_save_table" % self.name, + [self.resource_handle]): + with ops.colocate_with(None, ignore_existing=True): + return cuckoo_ops.tfra_cuckoo_hash_table_export_to_file( + self.resource_handle, + filepath, + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, + buffer_size=buffer_size) + + def load(self, filepath, buffer_size=4194304, name=None): + """ + Returns an operation to load keys and values to table from + file. The keys and values files are generated from `save`. + + Args: + filepath: A file path stored the table keys and values. + name: Name for the operation. + buffer_size: Number of kv pairs buffer to read file. + + Returns: + An operation to load keys and values to table from file. + """ + with ops.name_scope(name, "%s_load_table" % self.name, + [self.resource_handle]): + with ops.colocate_with(None, ignore_existing=True): + return cuckoo_ops.tfra_cuckoo_hash_table_import_from_file( + self.resource_handle, + filepath, + key_dtype=self._key_dtype, + value_dtype=self._value_dtype, + buffer_size=buffer_size) + def _gather_saveables_for_checkpoint(self): """For object-based checkpointing.""" # full_name helps to figure out the name-based Saver's name for this saveable.