-
Notifications
You must be signed in to change notification settings - Fork 27
/
LanguageModel.cs
169 lines (150 loc) · 5.37 KB
/
LanguageModel.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
// Copyright (c) Microsoft. All rights reserved.
namespace Microsoft.TypeChat;
/// <summary>
/// A lightweight ILanguageModel implementation over OpenAI or Azure OpenAI Chat Completion REST API endpoint
/// </summary>
public class LanguageModel : ILanguageModel, IDisposable
{
static TranslationSettings s_defaultSettings = new TranslationSettings();
OpenAIConfig _config;
ModelInfo _model;
HttpClient _client;
string _endPoint;
/// <summary>
/// Create an OpenAILanguageModel object using the given OpenAIConfig
/// config.EndPoint must support the Chat Completion API
/// </summary>
/// <param name="config">configuration to use</param>
/// <param name="model">information about the target model</param>
/// <param name="client">http client to use</param>
public LanguageModel(OpenAIConfig config, ModelInfo? model = null, HttpClient? client = null)
{
ArgumentVerify.ThrowIfNull(config, nameof(config));
config.Validate();
_config = config;
_model = model ?? _config.Model;
_client = client ?? new HttpClient();
ConfigureClient();
}
/// <summary>
/// Information about the language model
/// </summary>
public ModelInfo ModelInfo => _model;
/// <summary>
/// Get a completion for the given prompt
/// </summary>
/// <param name="prompt">prompt</param>
/// <param name="settings">translation settings such as temperature</param>
/// <param name="cancelToken">cancellation token</param>
/// <returns></returns>
public async Task<string> CompleteAsync(Prompt prompt, TranslationSettings? settings = null, CancellationToken cancelToken = default)
{
ArgumentVerify.ThrowIfNullOrEmpty<IPromptSection>(prompt, nameof(prompt));
var request = CreateRequest(prompt, settings);
string apiToken = _config.HasTokenProvider ? await _config.ApiTokenProvider.GetAccessTokenAsync(cancelToken) : null;
var response = await _client.GetJsonResponseAsync<Request, Response>(_endPoint, request, _config.MaxRetries, _config.MaxPauseMs, apiToken).ConfigureAwait(false);
return response.GetText();
}
Request CreateRequest(Prompt prompt, TranslationSettings? settings = null)
{
var request = Request.Create(prompt, settings ?? s_defaultSettings);
if (!_config.Azure)
{
request.model = _model.Name;
}
return request;
}
void ConfigureClient()
{
if (_config.Azure)
{
if (_config.Endpoint.IndexOf(@"chat/completions", StringComparison.OrdinalIgnoreCase) >= 0)
{
_endPoint = _config.Endpoint;
}
else
{
string path = $"openai/deployments/{_model.Name}/chat/completions?api-version={_config.ApiVersion}";
_endPoint = new Uri(new Uri(_config.Endpoint), path).AbsoluteUri;
}
if (!_config.HasTokenProvider)
{
_client.DefaultRequestHeaders.Add("api-key", _config.ApiKey);
}
}
else
{
_endPoint = _config.Endpoint;
_client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", _config.ApiKey);
if (!string.IsNullOrEmpty(_config.Organization))
{
_client.DefaultRequestHeaders.Add("OpenAI-Organization", _config.Organization);
}
}
if (_config.TimeoutMs > 0)
{
_client.Timeout = TimeSpan.FromMilliseconds(_config.TimeoutMs);
}
}
struct Request
{
public string? model { get; set; }
public Message[] messages { get; set; }
public double? temperature { get; set; }
public int? max_tokens { get; set; }
public static Request Create(Prompt prompt, TranslationSettings settings)
{
return new Request
{
messages = Message.Create(prompt),
temperature = (settings.Temperature > 0) ? settings.Temperature : 0,
max_tokens = (settings.MaxTokens > 0) ? settings.MaxTokens : null
};
}
}
struct Response
{
public Choice[] choices { get; set; }
public string GetText()
{
string response = null;
if (choices is not null && choices.Length > 0)
{
response = choices[0].message.content;
}
return response ?? string.Empty;
}
}
struct Message
{
public string role { get; set; }
public string content { get; set; }
public static Message[] Create(Prompt prompt)
{
Message[] messages = new Message[prompt.Count];
for (int i = 0; i < prompt.Count; ++i)
{
messages[i] = new Message { role = prompt[i].Source, content = prompt[i].GetText() };
}
return messages;
}
}
struct Choice
{
public Message message { get; set; }
}
protected virtual void Dispose(bool disposing)
{
if (disposing)
{
_client?.Dispose();
_client = null;
}
}
public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}