Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hooks for improved customization #106

Draft
wants to merge 7 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diart/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .blocks import OnlineSpeakerDiarization, PipelineConfig
from .blocks import OnlineSpeakerDiarization, PipelineConfig, OnlineSpeakerDiarizationHook
2 changes: 1 addition & 1 deletion src/diart/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
OverlapAwareSpeakerEmbedding,
)
from .segmentation import SpeakerSegmentation
from .diarization import OnlineSpeakerDiarization, PipelineConfig
from .diarization import OnlineSpeakerDiarization, PipelineConfig, OnlineSpeakerDiarizationHook
from .utils import Binarize, Resample, AdjustVolume
13 changes: 11 additions & 2 deletions src/diart/blocks/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,20 @@ def num_blocked_speakers(self) -> int:
@property
def inactive_centers(self) -> List[int]:
return [
c
for c in range(self.max_speakers)
c for c in range(self.max_speakers)
if c not in self.active_centers or c in self.blocked_centers
]

@property
def center_matrix(self) -> Optional[np.ndarray]:
if self.centers is None:
return None
active = np.array([
c for c in range(self.max_speakers)
if c in self.active_centers and c not in self.blocked_centers
], dtype=np.int)
return self.centers[active]

def get_next_center_position(self) -> Optional[int]:
for center in range(self.max_speakers):
if center not in self.active_centers and center not in self.blocked_centers:
Expand Down
114 changes: 110 additions & 4 deletions src/diart/blocks/diarization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Optional, Any, Union, Tuple, Sequence

import numpy as np
Expand All @@ -11,6 +13,7 @@
from .segmentation import SpeakerSegmentation
from .utils import Binarize
from .. import models as m
from ..features import TemporalFeatures


class PipelineConfig:
Expand Down Expand Up @@ -85,9 +88,86 @@ def from_namespace(args: Any) -> 'PipelineConfig':
)


class OnlineSpeakerDiarizationHook:
def on_local_segmentation_batch(
self,
pipeline: OnlineSpeakerDiarization,
audio_batch: torch.Tensor,
segmentation_batch: TemporalFeatures
):
pass

def on_embedding_batch(
self,
pipeline: OnlineSpeakerDiarization,
audio_batch: torch.Tensor,
embedding_batch: torch.Tensor
):
pass

def on_local_segmentation(
self,
pipeline: OnlineSpeakerDiarization,
waveform: SlidingWindowFeature,
segmentation: SlidingWindowFeature
):
pass

def on_embeddings(
self,
pipeline: OnlineSpeakerDiarization,
waveform: SlidingWindowFeature,
embeddings: torch.Tensor
):
pass

def on_before_clustering(
self,
pipeline: OnlineSpeakerDiarization,
waveform: SlidingWindowFeature
):
pass

def on_after_clustering(
self,
pipeline: OnlineSpeakerDiarization,
waveform: SlidingWindowFeature,
clustering: OnlineSpeakerClustering,
segmentation: SlidingWindowFeature
):
pass

def on_soft_prediction(
self,
pipeline: OnlineSpeakerDiarization,
waveform: SlidingWindowFeature,
segmentation: SlidingWindowFeature
):
pass

def on_binary_prediction(
self,
pipeline: OnlineSpeakerDiarization,
waveform: SlidingWindowFeature,
diarization: Annotation
):
pass

def on_before_reset(self, pipeline: OnlineSpeakerDiarization):
pass

def on_after_reset(self, pipeline: OnlineSpeakerDiarization,):
pass


class OnlineSpeakerDiarization:
def __init__(self, config: Optional[PipelineConfig] = None):
def __init__(
self,
config: Optional[PipelineConfig] = None,
hooks: Optional[Sequence[OnlineSpeakerDiarizationHook]] = None,
):
self.config = PipelineConfig() if config is None else config
self.hooks = [] if hooks is None else hooks

msg = f"Latency should be in the range [{self.config.step}, {self.config.duration}]"
assert self.config.step <= self.config.latency <= self.config.duration, msg
Expand All @@ -111,11 +191,14 @@ def __init__(self, config: Optional[PipelineConfig] = None):
self.binarize = Binarize(self.config.tau_active)

# Internal state, handle with care
self.clustering = None
self.clustering: Optional[OnlineSpeakerClustering] = None
self.chunk_buffer, self.pred_buffer = [], []
self.reset()

def reset(self):
for hook in self.hooks:
hook.on_before_reset(self)

self.clustering = OnlineSpeakerClustering(
self.config.tau_active,
self.config.rho_update,
Expand All @@ -125,10 +208,13 @@ def reset(self):
)
self.chunk_buffer, self.pred_buffer = [], []

for hook in self.hooks:
hook.on_after_reset(self)

def __call__(
self,
waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[Optional[Tuple[Annotation, SlidingWindowFeature]]]:
) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
batch_size = len(waveforms)
msg = "Pipeline expected at least 1 input"
assert batch_size >= 1, msg
Expand All @@ -142,7 +228,12 @@ def __call__(

# Extract segmentation and embeddings
segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
for hook in self.hooks:
hook.on_local_segmentation_batch(self, batch, segmentations)

embeddings = self.embedding(batch, segmentations) # shape (batch, speakers, emb_dim)
for hook in self.hooks:
hook.on_embedding_batch(self, batch, embeddings)

seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]

Expand All @@ -156,8 +247,17 @@ def __call__(
)
seg = SlidingWindowFeature(seg.cpu().numpy(), sw)

for hook in self.hooks:
hook.on_local_segmentation(self, wav, seg)
for hook in self.hooks:
hook.on_embeddings(self, wav, emb)
for hook in self.hooks:
hook.on_before_clustering(self, wav)

# Update clustering state and permute segmentation
permuted_seg = self.clustering(seg, emb)
for hook in self.hooks:
hook.on_after_clustering(self, wav, self.clustering, permuted_seg)

# Update sliding buffer
self.chunk_buffer.append(wav)
Expand All @@ -166,7 +266,13 @@ def __call__(
# Aggregate buffer outputs for this time step
agg_waveform = self.audio_aggregation(self.chunk_buffer)
agg_prediction = self.pred_aggregation(self.pred_buffer)
outputs.append((self.binarize(agg_prediction), agg_waveform))
for hook in self.hooks:
hook.on_soft_prediction(self, agg_waveform, agg_prediction)

bin_prediction = self.binarize(agg_prediction)
outputs.append((bin_prediction, agg_waveform))
for hook in self.hooks:
hook.on_binary_prediction(self, agg_waveform, bin_prediction)

# Make place for new chunks in buffer if required
if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows:
Expand Down
Loading