Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML.NET NER - Mismatched state_dict sizes: expected 60, but found 126 entries. #7350

Open
piercarlo62 opened this issue Dec 22, 2024 · 0 comments
Labels
untriaged New issue has not been triaged

Comments

@piercarlo62
Copy link

Hello,
I'm testing the NER capabilities of ML.NET and on training I'm getting following error:
Error: Mismatched state_dict sizes: expected 60, but found 126 entries.


System Information:

  • OS & Version: Windows 10
  • ML.NET Version: ML.NET v4.0.0
  • .NET Version: .NET 8.0

Description of the bug
on var transformer = estimator.Fit(dataView); -> Mismatched state_dict sizes: expected 60, but found 126 entries

Mismatched state_dict sizes: expected 60, but found 126 entries.
in TorchSharp.torch.nn.Module.load(BinaryReader reader, Boolean strict, IList`1 skip, Dictionary`2 loadedParameters)
   in TorchSharp.torch.nn.Module.load(String location, Boolean strict, IList`1 skip, Dictionary`2 loadedParameters)
   in Microsoft.ML.TorchSharp.NasBert.NasBertTrainer`2.NasBertTrainerBase.CreateModule(IChannel ch, IDataView input)
   in Microsoft.ML.TorchSharp.TorchSharpBaseTrainer`2.TrainerBase..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input, String modelUrl)
   in Microsoft.ML.TorchSharp.NasBert.NasBertTrainer`2.NasBertTrainerBase..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input, String modelUrl)
   in Microsoft.ML.TorchSharp.NasBert.NerTrainer.Trainer..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input)
   in Microsoft.ML.TorchSharp.NasBert.NerTrainer.CreateTrainer(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input)
   in Microsoft.ML.TorchSharp.TorchSharpBaseTrainer`2.Fit(IDataView input)
   in Microsoft.ML.Data.EstimatorChain`1.Fit(IDataView input)
   in Program.Main(String[] args) in C:\Users\pierc\source\repos\ML_NER_TEST\Program.cs: riga 64

Sample Projects

using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.TorchSharp;

namespace ML_NER_TEST
{
    public class Program
    {
        public static void Main(string[] args)
        {
            try
            {
                var context = new MLContext()
                {
                    FallbackToCpu = true,
                    GpuDeviceId = 0
                };

                var labels = context.Data.LoadFromEnumerable(
                    [
                            new Label { Key = "PERSON" },       // People, including fictional.
                            new Label { Key = "NORP" },         // Nationalities or religious or political groups.
                            new Label { Key = "FAC" },          // Buildings, airports, highways, bridges, etc.
                            new Label { Key = "ORG" },          // Companies, agencies, institutions, etc.
                            new Label { Key = "GPE" },          // Countries, cities, states.
                            new Label { Key = "LOC" },          // Non-GPE locations, mountain ranges, bodies of water.
                            new Label { Key = "PRODUCT" },      // Objects, vehicles, foods, etc. (Not services.)
                            new Label { Key = "EVENT" },        // Named hurricanes, battles, wars, sports events, etc.
                            new Label { Key = "WORK_OF_ART" },  // Titles of books, songs, etc.
                            new Label { Key = "LAW" },          // Named documents made into laws.
                            new Label { Key = "LANGUAGE" },     // Any named language.
                            new Label { Key = "DATE" },         // Absolute or relative dates or periods.
                            new Label { Key = "TIME" },         // Times smaller than a day.
                            new Label { Key = "PERCENT" },      // Percentage, including "%".
                            new Label { Key = "MONEY" },        // Monetary values, including unit.
                            new Label { Key = "QUANTITY" },     // Measurements, as of weight or distance.
                            new Label { Key = "ORDINAL" },      // "first", "second", etc.
                            new Label { Key = "CARDINAL" },     // Numerals that do not fall under another type.
                            new Label { Key = "OBJECT" },       // An Object, Entity might be a Spoon, or a Soccer Ball. Needs Sub Categories.
                ]);

                var dataView = context.Data.LoadFromEnumerable(
                    new List<InputTrainingData>([
                        new InputTrainingData()
                    {   
                        // Testing longer than 512 words.
                        Sentence = "Alice and Bob live in the USA",
                        Label = ["PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY"]
                    },
                    new InputTrainingData()
                     {
                        Sentence = "Frank and Alice traveled along the California coast.",
                        Label = ["PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY", "0"]
                     },
                    ]));

                var chain = new EstimatorChain<ITransformer>();

                var estimator = chain.Append(context.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
                   .Append(context.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "Predictions"))
                   .Append(context.Transforms.Conversion.MapKeyToValue("Predictions"));

                Console.WriteLine("Training the model...");

                var transformer = estimator.Fit(dataView);

                Console.WriteLine("Model trained!");

                var transformerSchema = transformer.GetOutputSchema(dataView.Schema);

                string sentence = "Alice and Bob live in the USA";
                var engine = context.Model.CreatePredictionEngine<Input, Output>(transformer);

                Console.WriteLine("Predicting...");

                Output predictions = engine.Predict(new Input { Sentence = sentence });

                Console.WriteLine($"Predictions: {sentence} - {string.Join(", ", predictions.Predictions)}");

                transformer.Dispose();
                Console.WriteLine("Success!");
                Console.ReadLine();
            }
            catch (Exception ex)
            {
                Console.WriteLine($"Error: {ex.Message}");
                Console.ReadLine();
            }
        }
        private class Input
        {
            public string Sentence;
            public string[] Label;
        }
        private class Output
        {
            public string[] Predictions;
        }
        public class Label
        {
            public string Key { get; set; }
        }
        private class InputTrainingData
        {
            public string Sentence;
            public string[] Label;
        }
    }
}

Additional context

<Project Sdk="Microsoft.NET.Sdk">

  <PropertyGroup>
    <OutputType>Exe</OutputType>
    <TargetFramework>net8.0</TargetFramework>
    <ImplicitUsings>enable</ImplicitUsings>
    <Nullable>disable</Nullable>
  </PropertyGroup>

  <ItemGroup>
    <PackageReference Include="libtorch-cpu-win-x64" Version="2.5.1" />
    <PackageReference Include="Microsoft.ML" Version="4.0.0" />
    <PackageReference Include="Microsoft.ML.Tokenizers" Version="1.0.0" />
    <PackageReference Include="Microsoft.ML.TorchSharp" Version="0.22.0" />
    <PackageReference Include="TorchSharp" Version="0.105.0" />
  </ItemGroup>

</Project>
@dotnet-policy-service dotnet-policy-service bot added the untriaged New issue has not been triaged label Dec 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
untriaged New issue has not been triaged
Projects
None yet
Development

No branches or pull requests

1 participant