diff --git a/include/mscclpp/proxy_channel.hpp b/include/mscclpp/proxy_channel.hpp index d7566db0..3e4b278b 100644 --- a/include/mscclpp/proxy_channel.hpp +++ b/include/mscclpp/proxy_channel.hpp @@ -8,7 +8,6 @@ #include #include #include -#include namespace mscclpp { @@ -41,10 +40,10 @@ class ProxyService : public BaseProxyService { /// @return The ID of the semaphore. SemaphoreId addSemaphore(std::shared_ptr connection); - /// Add a pitch pair to the proxy service. - /// @param id The ID of the semaphore. + /// Add a 2D channel to the proxy service. + /// @param connection The connection associated with the channel. /// @param pitch The pitch pair. - void addPitch(SemaphoreId id, std::pair pitch); + SemaphoreId add2DChannel(std::shared_ptr connection, std::pair pitch); /// Register a memory region with the proxy service. /// @param memory The memory region to register. @@ -71,7 +70,7 @@ class ProxyService : public BaseProxyService { Communicator& communicator_; std::vector> semaphores_; std::vector memories_; - std::unordered_map> pitches_; + std::vector> pitches_; Proxy proxy_; int deviceNumaNode; diff --git a/src/proxy_channel.cc b/src/proxy_channel.cc index a06eea26..efab94a0 100644 --- a/src/proxy_channel.cc +++ b/src/proxy_channel.cc @@ -29,8 +29,13 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr pitch) { +MSCCLPP_API_CPP SemaphoreId ProxyService::add2DChannel(std::shared_ptr connection, + std::pair pitch) { + semaphores_.push_back(std::make_shared(communicator_, connection)); + SemaphoreId id = semaphores_.size() - 1; + if (id >= pitches_.size()) pitches_.resize(id + 1, std::pair(0, 0)); pitches_[id] = pitch; + return id; } MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) { diff --git a/test/mp_unit/proxy_channel_tests.cu b/test/mp_unit/proxy_channel_tests.cu index cd660210..7deee464 100644 --- a/test/mp_unit/proxy_channel_tests.cu +++ b/test/mp_unit/proxy_channel_tests.cu @@ -58,8 +58,7 @@ void ProxyChannelOneToOneTest::setupMeshConnections( communicator->setup(); - mscclpp::SemaphoreId cid = channelService->addSemaphore(conn); - channelService->addPitch(cid, std::pair(pitch, pitch)); + mscclpp::SemaphoreId cid = channelService->add2DChannel(conn, std::pair(pitch, pitch)); communicator->setup(); proxyChannels.emplace_back(mscclpp::deviceHandle( @@ -77,13 +76,13 @@ __device__ size_t getTileElementOffset(int elementId, int width, int rowIndex, i } __global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowIndex, int colIndex, int width, - int hight, int* ret) { + int height, int* ret) { DeviceHandle& proxyChan = gChannelOneToOneTestConstProxyChans; volatile int* sendBuff = (volatile int*)buff; int nTries = 1000; int flusher = 0; size_t offset = rowIndex * pitch + colIndex * sizeof(int); - size_t nElem = width * hight; + size_t nElem = width * height; size_t nElemPerPitch = pitch / sizeof(int); for (int i = 0; i < nTries; i++) { if (rank == 0) { @@ -105,7 +104,7 @@ __global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowI } __syncthreads(); // __threadfence_system(); // not necessary if we make sendBuff volatile - if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), hight); + if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), height); } if (rank == 1) { if (threadIdx.x == 0) proxyChan.wait(); @@ -125,7 +124,7 @@ __global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowI } __syncthreads(); // __threadfence_system(); // not necessary if we make sendBuff volatile - if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), hight); + if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), height); } } flusher++;