src: refactor + clean code using better types

This commit is contained in:
Glenn Y. Rolland 2023-07-07 22:09:26 +02:00
parent bb26268e3b
commit 4e3e6aa8f8
19 changed files with 251 additions and 84 deletions

View file

@ -1,6 +0,0 @@
module Builder
abstract class PromptGeneric
abstract def build(prompt : Prompt)
end
end

14
src/config.cr Normal file
View file

@ -0,0 +1,14 @@
require "./openai_config"
class AppConfig
property input_file = STDIN
property input_file_path = ""
property output_file = STDOUT
property output_file_path = ""
property past_characters_limit = 1000
property use_color = true
property make_request = true
property verbose = false
property gpt_config = AIUtils::GptConfig.new
end

View file

@ -0,0 +1,6 @@
module ContextBuilder
abstract class GenericContextBuilder
abstract def build(prompt : Prompt)
end
end

View file

@ -14,10 +14,12 @@ module Builder
# skip prelude_zone
# skip future_zone
def build(prompt : Prompt) : Chat
def build(prompt : Prompt, context_token_limit : UInt32) : Chat
chat = [] of Message
token_limit = 2_700
# token_limit = 10_000
token_limit = (0.65 * context_token_limit).to_i
STDERR.puts "messages_token_limit = #{token_limit}"
mandatory_token_count = (
prompt.system_zone.token_count +
prompt.present_zone.token_count

View file

@ -0,0 +1,45 @@
require "colorize"
require "../config"
require "./generic_builder"
module Builder
class PromptDir < PromptGeneric
SEPARATOR = "\n\n"
getter use_color : Bool
def initialize(@use_color)
Colorize.enabled = @use_color
end
def build(prompt : Prompt)
str = ""
str += prompt.prelude_zone.content.join(SEPARATOR)
if ! prompt.system_zone.content.empty?
str += "@@system".colorize(:yellow).to_s
str += prompt.system_zone.content.join(SEPARATOR)
end
if ! prompt.past_zone.content.empty?
str += "@@before".colorize(:yellow).to_s
str += prompt.past_zone.content.join(SEPARATOR)
end
if ! prompt.present_zone.content.empty?
# str += "@@current".colorize(:yellow).to_s
str += prompt.present_zone.content.join(SEPARATOR).colorize(:light_cyan).to_s
end
if ! prompt.future_zone.content.empty?
str += "@@after".colorize(:yellow).to_s
str += prompt.future_zone.content.join("\n")
end
str
end
end
end

View file

@ -1,10 +1,10 @@
require "colorize"
require "./generic"
require "./generic_builder"
module Builder
class PromptString < PromptGeneric
module ContextBuilder
class PromptString < GenericContextBuilder
SEPARATOR = "\n\n"

View file

@ -3,15 +3,15 @@ require "pretty_print"
require "openai"
require "spinner"
require "./zone"
require "./prompts/zone"
require "./parsers/prompt_string"
require "./builders/prompt_string"
require "./builders/openai_chat"
require "./builders/openai_insert"
require "./context_builders/prompt_string"
# require "./context_builders/prompt_dir"
require "./context_builders/openai_chat"
require "./context_builders/openai_insert"
require "./storyteller"
storyteller = Storyteller.new()
storyteller.prepare()
storyteller.run()

11
src/modifiers/readme.md Normal file
View file

@ -0,0 +1,11 @@
Des classes qui modifient un prompt d'apres des contraintes
truncate(max_tokens) # le prompt aux X derniers tokens
summarize() # réduis la longueur du prompt
worldinfo() # introduis les éléments du monde

64
src/openai_config.cr Normal file
View file

@ -0,0 +1,64 @@
module AIUtils
enum OpenAIModel
Gpt_4
Gpt_3_5_Turbo
Gpt_3_5_Turbo_16k
def to_s
vals = [
"gpt-4",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k"
]
vals[self.value]
end
def self.from_s?(str_model : String) : OpenAIModel?
return OpenAIModel.values.find do |openai_model|
openai_model.to_s.downcase == str_model.downcase
end
end
end
enum OpenAIMode
Chat
Complete
Insert
def self.from_s?(str_mode) : OpenAIMode?
return OpenAIMode.parse?(str_mode)
# return OpenAIMode.values.find do |openai_mode|
# openai_mode.to_s.downcase == str_mode.downcase
# end
end
end
def self.context_token_limit(model : OpenAIModel) : UInt32
hash = {
OpenAIModel::Gpt_3_5_Turbo => 1024_u32 * 4_u32,
OpenAIModel::Gpt_3_5_Turbo_16k => 1024_u32 * 16_u32,
OpenAIModel::Gpt_4 => 1024_u32 * 8_u32,
}
return 0_u32 unless hash[model]?
return hash[model]
end
class GptConfig
# defaults
# property temperature = 0.82
property model : OpenAIModel = OpenAIModel::Gpt_3_5_Turbo
# property model = "gpt-3.5-turbo-16k"
# property prompt_size = 10_000
property prompt_size = 2_700
property temperature = 0.8
property presence_penalty = 1.0
property frequency_penalty = 1.0
property max_tokens = 384
property mode = OpenAIMode::Chat
property openai_key : String = ENV.fetch("OPENAI_API_KEY")
end
end

View file

@ -1,5 +1,5 @@
require "../prompt"
require "../prompts/generic_prompt"
module Parser
class PromptString

7
src/requests/chat.cr Normal file
View file

@ -0,0 +1,7 @@
require "./generic_request"
module Requests
class ChatRequest < GenericRequest
end
end

8
src/requests/complete.cr Normal file
View file

@ -0,0 +1,8 @@
require "./generic_request"
module Requests
class CompleteRequest < GenericRequest
end
end

View file

@ -0,0 +1,23 @@
module Requests
class GenericRequest
getter context_size : Int32 = 0
def initialize(@config : AIUtils::GptConfig)
end
def perform(messages)
openai = OpenAI::Client.new(access_token: @config.openai_key)
result = openai.chat(
@config.model.to_s,
messages,
{
"temperature" => @config.temperature,
"presence_penalty" => @config.presence_penalty,
"frequency_penalty" => @config.frequency_penalty,
"max_tokens" => @config.max_tokens,
}
)
end
end
end

8
src/requests/insert.cr Normal file
View file

@ -0,0 +1,8 @@
require "./generic_request"
module Requests
class InsertRequest < GenericRequest
end
end

View file

@ -1,35 +1,6 @@
enum OpenAIMode
Chat
Complete
Insert
def self.from_s(str_mode)
return OpenAIMode.values.find do |openai_mode|
openai_mode.to_s.downcase == str_mode.downcase
end
end
end
class GptConfig
# defaults
property gpt_temperature = 0.82
property gpt_presence_penalty = 1
property gpt_frequency_penalty = 1
property gpt_max_tokens = 256
property gpt_mode = OpenAIMode::Chat
end
class AppConfig
property input_file = STDIN
property input_file_path = ""
property output_file = STDOUT
property output_file_path = ""
property past_characters_limit = 1000
property use_color = true
property make_request = true
property verbose = false
property gpt_config = GptConfig.new
end
require "./config"
require "./requests/chat"
class Storyteller
property config = AppConfig.new
@ -79,26 +50,35 @@ class Storyteller
parser.separator("GPT options")
parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
if result_mode.nil?
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
parser.on("--gpt-model MODEL", "GPT model (...) (default: gpt-3.5-turbo)") do |chosen_model|
result_model = AIUtils::OpenAIModel.from_s?(chosen_model.downcase)
if result_model.nil?
STDERR.puts "ERROR: unknown model #{chosen_model}"
exit 1
end
gpt_mode = result_mode unless result_mode.nil?
@config.gpt_config.model = result_model unless result_model.nil?
end
parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature") do |temperature|
gpt_temperature = temperature
# parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
# result_mode = AIUtils::OpenAIMode.from_s(chosen_mode.downcase)
# if result_mode.nil?
# STDERR.puts "ERROR: unknown mode #{chosen_mode}"
# exit 1
# end
# @config.gpt_config.mode = result_mode unless result_mode.nil?
# end
parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature (default #{@config.gpt_config.temperature})") do |temperature|
@config.gpt_config.temperature = temperature.to_f
end
parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty") do |presence_penalty|
gpt_presence_penalty = presence_penalty
parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty (default #{@config.gpt_config.presence_penalty})") do |presence_penalty|
@config.gpt_config.presence_penalty = presence_penalty.to_f
end
parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty") do |frequency_penalty|
gpt_frequency_penalty = frequency_penalty
parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty (default #{@config.gpt_config.frequency_penalty})") do |frequency_penalty|
@config.gpt_config.frequency_penalty = frequency_penalty.to_f
end
parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens") do |max_tokens|
gpt_max_tokens = max_tokens
parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens (default #{@config.gpt_config.max_tokens})") do |max_tokens|
@config.gpt_config.max_tokens = max_tokens.to_i
end
end
@ -130,21 +110,15 @@ class Storyteller
end
self.write_file(@config.output_file, prompt, @config.use_color)
@config.output_file.close
# rescue ex : OptionParser::InvalidOption
# STDERR.puts @parser
# STDERR.puts "\nERROR: #{ex.message}"
# exit 1
# rescue ex: OpenAI::Client::ClientError
# STDERR.puts "ERROR: #{ex.message}"
# exit 1
end
def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
builder = Builder::OpenAIChat.new(verbose: verbose)
messages = builder.build(prompt)
context_token_limit = AIUtils.context_token_limit(@config.gpt_config.model)
messages = builder.build(prompt, context_token_limit)
STDERR.puts "model = #{@config.gpt_config.model}"
STDERR.puts "context_token_limit = #{context_token_limit}"
STDERR.puts messages if verbose
return prompt if !make_request
@ -158,7 +132,7 @@ class Storyteller
loop do
tick += 1
print "."
if tick > (3 * 60)
if tick > (2 * 60)
print "(timeout)"
exit 1
end
@ -167,20 +141,31 @@ class Storyteller
end
end
# TODO:
# request = Request::Chat.new(config: @config.gpt_config)
# result = request.perform(message)
spawn do
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
result = openai.chat(
"gpt-3.5-turbo",
messages,
{
"temperature" => 0.82,
"presence_penalty" => 1,
"frequency_penalty" => 1,
"max_tokens" => 256
}
)
request = Requests::ChatRequest.new(config: @config.gpt_config)
result = request.perform(messages)
pp result.choices if @config.verbose
prompt.present_zone.content << result.choices.first["message"]["content"]
channel_ready.send(true)
rescue ex: KeyError
puts "(openai error)"
STDERR.puts "ERROR: #{ex.message}"
exit 1
rescue ex: IO::Error
puts "(network error)"
STDERR.puts "ERROR: #{ex.message}"
exit 1
rescue ex: Socket::Addrinfo::Error
puts "(network error)"
STDERR.puts "ERROR: #{ex.message}"
exit 1
end
# sp.start
@ -203,7 +188,7 @@ class Storyteller
def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool)
# STDERR.puts "d: building builder"
builder = Builder::PromptString.new(use_color)
builder = ContextBuilder::PromptString.new(use_color)
# STDERR.puts "d: building"
text = builder.build(prompt)
output_file.write_string(text.to_slice)