Skip to content

Commit

Permalink
Make Microsoft.ML.OnnxRuntimeGenAI.Tokenizer a Microsoft.ML.Tokenizer…
Browse files Browse the repository at this point in the history
…s.Tokenizer

This enables an ONNX Runtime GenAI tokenizer instance to be used anywhere a Microsoft.ML.Tokenizers tokenizer is accepted. If we'd prefer, rather than having Tokenizer be a base class for the ONNX Runtime one, we could instead expose some sort of `public Microsoft.ML.Tokenizer.Tokenizer AsTokenizer()` conversion method that returns a wrapper object (though that's a bit confusing given the names of the type are the same, just different namespaces).
  • Loading branch information
stephentoub committed Oct 16, 2024
1 parent 7998f13 commit 0b759a7
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 69 deletions.
1 change: 0 additions & 1 deletion src/csharp/Exceptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down
9 changes: 3 additions & 6 deletions src/csharp/GeneratorParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand All @@ -21,12 +18,12 @@ public GeneratorParams(Model model)

public void SetSearchOption(string searchOption, double value)
{
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchNumber(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value));
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchNumber(_generatorParamsHandle, StringUtils.ToNullTerminatedUtf8(searchOption), value));
}

public void SetSearchOption(string searchOption, bool value)
{
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value));
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToNullTerminatedUtf8(searchOption), value));
}

public void TryGraphCaptureWithMaxBatchSize(int maxBatchSize)
Expand All @@ -52,7 +49,7 @@ public void SetInputSequences(Sequences sequences)

public void SetModelInput(string name, Tensor value)
{
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetModelInput(_generatorParamsHandle, StringUtils.ToUtf8(name), value.Handle));
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetModelInput(_generatorParamsHandle, StringUtils.ToNullTerminatedUtf8(name), value.Handle));
}

public void SetInputs(NamedTensors namedTensors)
Expand Down
3 changes: 1 addition & 2 deletions src/csharp/Images.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand All @@ -23,7 +22,7 @@ public static Images Load(string[] imagePaths)
Result.VerifySuccess(NativeMethods.OgaCreateStringArray(out IntPtr stringArray));
foreach (string imagePath in imagePaths)
{
Result.VerifySuccess(NativeMethods.OgaStringArrayAddString(stringArray, StringUtils.ToUtf8(imagePath)));
Result.VerifySuccess(NativeMethods.OgaStringArrayAddString(stringArray, StringUtils.ToNullTerminatedUtf8(imagePath)));
}
Result.VerifySuccess(NativeMethods.OgaLoadImages(stringArray, out IntPtr imagesHandle));
NativeMethods.OgaDestroyStringArray(stringArray);
Expand Down
4 changes: 4 additions & 0 deletions src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,8 @@
<PackageReference Include="System.Memory" Version="4.5.5" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.ML.Tokenizers" Version="0.22.0-preview.24378.1" />
</ItemGroup>

</Project>
3 changes: 1 addition & 2 deletions src/csharp/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand All @@ -13,7 +12,7 @@ public class Model : IDisposable

public Model(string modelPath)
{
Result.VerifySuccess(NativeMethods.OgaCreateModel(StringUtils.ToUtf8(modelPath), out _modelHandle));
Result.VerifySuccess(NativeMethods.OgaCreateModel(StringUtils.ToNullTerminatedUtf8(modelPath), out _modelHandle));
}

internal IntPtr Handle { get { return _modelHandle; } }
Expand Down
5 changes: 2 additions & 3 deletions src/csharp/MultiModalProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand All @@ -21,7 +20,7 @@ public MultiModalProcessor(Model model)
public NamedTensors ProcessImages(string prompt, Images images)
{
IntPtr imagesHandle = images == null ? IntPtr.Zero : images.Handle;
Result.VerifySuccess(NativeMethods.OgaProcessorProcessImages(_processorHandle, StringUtils.ToUtf8(prompt),
Result.VerifySuccess(NativeMethods.OgaProcessorProcessImages(_processorHandle, StringUtils.ToNullTerminatedUtf8(prompt),
imagesHandle, out IntPtr namedTensorsHandle));
return new NamedTensors(namedTensorsHandle);
}
Expand All @@ -38,7 +37,7 @@ public string Decode(ReadOnlySpan<int> sequence)
}
try
{
return StringUtils.FromUtf8(outStr);
return StringUtils.FromNullTerminatedUtf8(outStr);
}
finally
{
Expand Down
13 changes: 8 additions & 5 deletions src/csharp/Result.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@

using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
class Result
internal static class Result
{
private static string GetErrorMessage(IntPtr nativeResult)
internal static string GetErrorMessage(IntPtr nativeResult)
{

return StringUtils.FromUtf8(NativeMethods.OgaResultGetError(nativeResult));
return StringUtils.FromNullTerminatedUtf8(NativeMethods.OgaResultGetError(nativeResult));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void VerifySuccess(IntPtr nativeResult)
{
if (nativeResult != IntPtr.Zero)
{
Throw(nativeResult);
}

static void Throw(IntPtr nativeResult)
{
try
{
Expand Down
1 change: 0 additions & 1 deletion src/csharp/Sequences.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down
1 change: 0 additions & 1 deletion src/csharp/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down
Loading

0 comments on commit 0b759a7

Please sign in to comment.