diff --git a/UsbDk/ControlDevice.cpp b/UsbDk/ControlDevice.cpp index a244a31..ab2b7de 100644 --- a/UsbDk/ControlDevice.cpp +++ b/UsbDk/ControlDevice.cpp @@ -1025,9 +1025,9 @@ NTSTATUS CUsbDkControlDevice::AddRedirectionToSet(const USB_DK_DEVICE_ID &Device return STATUS_SUCCESS; } -NTSTATUS CUsbDkControlDevice::RemoveRedirect(const USB_DK_DEVICE_ID &DeviceId) +NTSTATUS CUsbDkControlDevice::RemoveRedirect(const USB_DK_DEVICE_ID &DeviceId, ULONG pid) { - if (NotifyRedirectorRemovalStarted(DeviceId)) + if (NotifyRedirectorRemovalStarted(DeviceId, pid)) { auto res = ResetUsbDevice(DeviceId, false); if (NT_SUCCESS(res)) @@ -1079,9 +1079,8 @@ bool CUsbDkControlDevice::NotifyRedirectorAttached(CRegText *DeviceID, CRegText return m_Redirections.ModifyOne(&ID, [RedirectorDevice](CUsbDkRedirection *R){ R->NotifyRedirectorCreated(RedirectorDevice); }); } -bool CUsbDkControlDevice::NotifyRedirectorRemovalStarted(const USB_DK_DEVICE_ID &ID) +bool CUsbDkControlDevice::NotifyRedirectorRemovalStarted(const USB_DK_DEVICE_ID &ID, ULONG pid) { - ULONG pid = (ULONG)(ULONG_PTR)PsGetCurrentProcessId(); return m_Redirections.ModifyOne(&ID, [](CUsbDkRedirection *R){ R->NotifyRedirectionRemovalStarted(); }, pid); } diff --git a/UsbDk/ControlDevice.h b/UsbDk/ControlDevice.h index 9223036..1ed72a7 100644 --- a/UsbDk/ControlDevice.h +++ b/UsbDk/ControlDevice.h @@ -274,7 +274,7 @@ class CUsbDkControlDevice : private CWdfControlDevice, public CAllocatableUsbDkFilter->OnFileCreate(Request); + auto filter = UsbDkFilterGetContext(Device)->UsbDkFilter; + filter->m_open_count.AddRef(); + filter->OnFileCreate(Request); }, [](_In_ WDFFILEOBJECT FileObject) { WDFDEVICE Device = WdfFileObjectGetDevice(FileObject); - Strategy(Device)->OnClose(); + auto filter = UsbDkFilterGetContext(Device)->UsbDkFilter; + ULONG pid = 0; // zero means always match + + // Check PID only if there are multiple open references to the file. + // If this was the last reference, always close the redirection. + // + // This callback function might run in a different process-context + // than the initiator process, therefore the 'current process ID' + // isn't always the ID of the 'owning' process. + // + // In the worst case, the USB redirection will be kept until the last + // open file handle to the device is closed. + // + // From KMDF 1.21, there's a new method that should give us the expected ID: + // WdfFileObjectGetInitiatorProcessId(FileObject) + + if (filter->m_open_count.Release()) { + pid = (ULONG)(ULONG_PTR)PsGetCurrentProcessId(); + } + Strategy(Device)->OnClose(pid); }, WDF_NO_EVENT_CALLBACK); diff --git a/UsbDk/FilterDevice.h b/UsbDk/FilterDevice.h index 76b38ca..1de5597 100644 --- a/UsbDk/FilterDevice.h +++ b/UsbDk/FilterDevice.h @@ -176,6 +176,9 @@ class CUsbDkFilterDevice : public CWdfDevice, { m_SerialNumber = Number; } void OnFileCreate(WDFREQUEST Request); + + CWdmRefCounter m_open_count; + private: ~CUsbDkFilterDevice() { diff --git a/UsbDk/FilterStrategy.h b/UsbDk/FilterStrategy.h index 82a664c..fdab0b0 100644 --- a/UsbDk/FilterStrategy.h +++ b/UsbDk/FilterStrategy.h @@ -62,7 +62,8 @@ class CUsbDkFilterStrategy CUsbDkControlDevice* GetControlDevice() { return m_ControlDevice; } - virtual void OnClose(){} + virtual void OnClose(ULONG pid) + { UNREFERENCED_PARAMETER(pid); } protected: CUsbDkFilterDevice *m_Owner = nullptr; diff --git a/UsbDk/RedirectorStrategy.cpp b/UsbDk/RedirectorStrategy.cpp index f8755f3..262db43 100644 --- a/UsbDk/RedirectorStrategy.cpp +++ b/UsbDk/RedirectorStrategy.cpp @@ -83,7 +83,7 @@ NTSTATUS CUsbDkRedirectorStrategy::Create(CUsbDkFilterDevice *Owner) return status; } -using USBDK_REDIRECTOR_REQUEST_CONTEXT = struct : public USBDK_TARGET_REQUEST_CONTEXT +struct USBDK_REDIRECTOR_REQUEST_CONTEXT : public USBDK_TARGET_REQUEST_CONTEXT { bool PreprocessingDone; @@ -626,13 +626,13 @@ size_t CUsbDkRedirectorStrategy::GetRequestContextSize() return sizeof(USBDK_REDIRECTOR_REQUEST_CONTEXT); } -void CUsbDkRedirectorStrategy::OnClose() +void CUsbDkRedirectorStrategy::OnClose(ULONG pid) { USB_DK_DEVICE_ID ID; UsbDkFillIDStruct(&ID, *m_DeviceID->begin(), *m_InstanceID->begin()); TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_REDIRECTOR, "%!FUNC!"); - auto status = m_ControlDevice->RemoveRedirect(ID); + auto status = m_ControlDevice->RemoveRedirect(ID, pid); if (!NT_SUCCESS(status)) { TraceEvents(TRACE_LEVEL_ERROR, TRACE_REDIRECTOR, "%!FUNC! RemoveRedirect failed: %!STATUS!", status); diff --git a/UsbDk/RedirectorStrategy.h b/UsbDk/RedirectorStrategy.h index 31d07c0..d78c05d 100644 --- a/UsbDk/RedirectorStrategy.h +++ b/UsbDk/RedirectorStrategy.h @@ -75,7 +75,7 @@ class CUsbDkRedirectorStrategy : public CUsbDkHiderStrategy size_t InputBufferLength, ULONG IoControlCode) override; - virtual void OnClose() override; + virtual void OnClose(ULONG pid) override; void SetDeviceID(CRegText *DevID) { m_DeviceID = DevID; } diff --git a/UsbDk/UsbTarget.h b/UsbDk/UsbTarget.h index 9c2a3fa..eec896a 100644 --- a/UsbDk/UsbTarget.h +++ b/UsbDk/UsbTarget.h @@ -28,7 +28,7 @@ #include "Urb.h" #include "WdfRequest.h" -using USBDK_TARGET_REQUEST_CONTEXT = struct : public WDF_REQUEST_CONTEXT +struct USBDK_TARGET_REQUEST_CONTEXT : public WDF_REQUEST_CONTEXT { ULONG64 RequestId; };