src: refactor + clean code using better types
This commit is contained in:
parent
bb26268e3b
commit
4e3e6aa8f8
19 changed files with 251 additions and 84 deletions
|
@ -1,6 +0,0 @@
|
||||||
|
|
||||||
module Builder
|
|
||||||
abstract class PromptGeneric
|
|
||||||
abstract def build(prompt : Prompt)
|
|
||||||
end
|
|
||||||
end
|
|
14
src/config.cr
Normal file
14
src/config.cr
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
|
||||||
|
require "./openai_config"
|
||||||
|
|
||||||
|
class AppConfig
|
||||||
|
property input_file = STDIN
|
||||||
|
property input_file_path = ""
|
||||||
|
property output_file = STDOUT
|
||||||
|
property output_file_path = ""
|
||||||
|
property past_characters_limit = 1000
|
||||||
|
property use_color = true
|
||||||
|
property make_request = true
|
||||||
|
property verbose = false
|
||||||
|
property gpt_config = AIUtils::GptConfig.new
|
||||||
|
end
|
6
src/context_builders/generic_builder.cr
Normal file
6
src/context_builders/generic_builder.cr
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
|
||||||
|
module ContextBuilder
|
||||||
|
abstract class GenericContextBuilder
|
||||||
|
abstract def build(prompt : Prompt)
|
||||||
|
end
|
||||||
|
end
|
|
@ -14,10 +14,12 @@ module Builder
|
||||||
|
|
||||||
# skip prelude_zone
|
# skip prelude_zone
|
||||||
# skip future_zone
|
# skip future_zone
|
||||||
def build(prompt : Prompt) : Chat
|
def build(prompt : Prompt, context_token_limit : UInt32) : Chat
|
||||||
chat = [] of Message
|
chat = [] of Message
|
||||||
|
|
||||||
token_limit = 2_700
|
# token_limit = 10_000
|
||||||
|
token_limit = (0.65 * context_token_limit).to_i
|
||||||
|
STDERR.puts "messages_token_limit = #{token_limit}"
|
||||||
mandatory_token_count = (
|
mandatory_token_count = (
|
||||||
prompt.system_zone.token_count +
|
prompt.system_zone.token_count +
|
||||||
prompt.present_zone.token_count
|
prompt.present_zone.token_count
|
45
src/context_builders/prompt_dir.cr
Normal file
45
src/context_builders/prompt_dir.cr
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
|
||||||
|
require "colorize"
|
||||||
|
|
||||||
|
require "../config"
|
||||||
|
require "./generic_builder"
|
||||||
|
|
||||||
|
module Builder
|
||||||
|
class PromptDir < PromptGeneric
|
||||||
|
|
||||||
|
SEPARATOR = "\n\n"
|
||||||
|
|
||||||
|
getter use_color : Bool
|
||||||
|
def initialize(@use_color)
|
||||||
|
Colorize.enabled = @use_color
|
||||||
|
end
|
||||||
|
|
||||||
|
def build(prompt : Prompt)
|
||||||
|
str = ""
|
||||||
|
|
||||||
|
str += prompt.prelude_zone.content.join(SEPARATOR)
|
||||||
|
|
||||||
|
if ! prompt.system_zone.content.empty?
|
||||||
|
str += "@@system".colorize(:yellow).to_s
|
||||||
|
str += prompt.system_zone.content.join(SEPARATOR)
|
||||||
|
end
|
||||||
|
|
||||||
|
if ! prompt.past_zone.content.empty?
|
||||||
|
str += "@@before".colorize(:yellow).to_s
|
||||||
|
str += prompt.past_zone.content.join(SEPARATOR)
|
||||||
|
end
|
||||||
|
|
||||||
|
if ! prompt.present_zone.content.empty?
|
||||||
|
# str += "@@current".colorize(:yellow).to_s
|
||||||
|
str += prompt.present_zone.content.join(SEPARATOR).colorize(:light_cyan).to_s
|
||||||
|
end
|
||||||
|
|
||||||
|
if ! prompt.future_zone.content.empty?
|
||||||
|
str += "@@after".colorize(:yellow).to_s
|
||||||
|
str += prompt.future_zone.content.join("\n")
|
||||||
|
end
|
||||||
|
|
||||||
|
str
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,10 +1,10 @@
|
||||||
|
|
||||||
require "colorize"
|
require "colorize"
|
||||||
|
|
||||||
require "./generic"
|
require "./generic_builder"
|
||||||
|
|
||||||
module Builder
|
module ContextBuilder
|
||||||
class PromptString < PromptGeneric
|
class PromptString < GenericContextBuilder
|
||||||
|
|
||||||
SEPARATOR = "\n\n"
|
SEPARATOR = "\n\n"
|
||||||
|
|
10
src/main.cr
10
src/main.cr
|
@ -3,15 +3,15 @@ require "pretty_print"
|
||||||
require "openai"
|
require "openai"
|
||||||
require "spinner"
|
require "spinner"
|
||||||
|
|
||||||
require "./zone"
|
require "./prompts/zone"
|
||||||
require "./parsers/prompt_string"
|
require "./parsers/prompt_string"
|
||||||
require "./builders/prompt_string"
|
require "./context_builders/prompt_string"
|
||||||
require "./builders/openai_chat"
|
# require "./context_builders/prompt_dir"
|
||||||
require "./builders/openai_insert"
|
require "./context_builders/openai_chat"
|
||||||
|
require "./context_builders/openai_insert"
|
||||||
require "./storyteller"
|
require "./storyteller"
|
||||||
|
|
||||||
storyteller = Storyteller.new()
|
storyteller = Storyteller.new()
|
||||||
storyteller.prepare()
|
storyteller.prepare()
|
||||||
storyteller.run()
|
storyteller.run()
|
||||||
|
|
||||||
|
|
||||||
|
|
11
src/modifiers/readme.md
Normal file
11
src/modifiers/readme.md
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
|
||||||
|
Des classes qui modifient un prompt d'apres des contraintes
|
||||||
|
|
||||||
|
truncate(max_tokens) # le prompt aux X derniers tokens
|
||||||
|
|
||||||
|
summarize() # réduis la longueur du prompt
|
||||||
|
|
||||||
|
worldinfo() # introduis les éléments du monde
|
||||||
|
|
||||||
|
|
||||||
|
|
64
src/openai_config.cr
Normal file
64
src/openai_config.cr
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
|
||||||
|
|
||||||
|
module AIUtils
|
||||||
|
enum OpenAIModel
|
||||||
|
Gpt_4
|
||||||
|
Gpt_3_5_Turbo
|
||||||
|
Gpt_3_5_Turbo_16k
|
||||||
|
|
||||||
|
def to_s
|
||||||
|
vals = [
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-16k"
|
||||||
|
]
|
||||||
|
vals[self.value]
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.from_s?(str_model : String) : OpenAIModel?
|
||||||
|
return OpenAIModel.values.find do |openai_model|
|
||||||
|
openai_model.to_s.downcase == str_model.downcase
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
enum OpenAIMode
|
||||||
|
Chat
|
||||||
|
Complete
|
||||||
|
Insert
|
||||||
|
|
||||||
|
def self.from_s?(str_mode) : OpenAIMode?
|
||||||
|
return OpenAIMode.parse?(str_mode)
|
||||||
|
# return OpenAIMode.values.find do |openai_mode|
|
||||||
|
# openai_mode.to_s.downcase == str_mode.downcase
|
||||||
|
# end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.context_token_limit(model : OpenAIModel) : UInt32
|
||||||
|
hash = {
|
||||||
|
OpenAIModel::Gpt_3_5_Turbo => 1024_u32 * 4_u32,
|
||||||
|
OpenAIModel::Gpt_3_5_Turbo_16k => 1024_u32 * 16_u32,
|
||||||
|
OpenAIModel::Gpt_4 => 1024_u32 * 8_u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0_u32 unless hash[model]?
|
||||||
|
return hash[model]
|
||||||
|
end
|
||||||
|
|
||||||
|
class GptConfig
|
||||||
|
# defaults
|
||||||
|
# property temperature = 0.82
|
||||||
|
property model : OpenAIModel = OpenAIModel::Gpt_3_5_Turbo
|
||||||
|
# property model = "gpt-3.5-turbo-16k"
|
||||||
|
# property prompt_size = 10_000
|
||||||
|
property prompt_size = 2_700
|
||||||
|
property temperature = 0.8
|
||||||
|
property presence_penalty = 1.0
|
||||||
|
property frequency_penalty = 1.0
|
||||||
|
property max_tokens = 384
|
||||||
|
property mode = OpenAIMode::Chat
|
||||||
|
property openai_key : String = ENV.fetch("OPENAI_API_KEY")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
|
||||||
require "../prompt"
|
require "../prompts/generic_prompt"
|
||||||
|
|
||||||
module Parser
|
module Parser
|
||||||
class PromptString
|
class PromptString
|
||||||
|
|
7
src/requests/chat.cr
Normal file
7
src/requests/chat.cr
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
require "./generic_request"
|
||||||
|
|
||||||
|
module Requests
|
||||||
|
class ChatRequest < GenericRequest
|
||||||
|
end
|
||||||
|
end
|
8
src/requests/complete.cr
Normal file
8
src/requests/complete.cr
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
|
||||||
|
|
||||||
|
require "./generic_request"
|
||||||
|
|
||||||
|
module Requests
|
||||||
|
class CompleteRequest < GenericRequest
|
||||||
|
end
|
||||||
|
end
|
23
src/requests/generic_request.cr
Normal file
23
src/requests/generic_request.cr
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
|
||||||
|
module Requests
|
||||||
|
class GenericRequest
|
||||||
|
getter context_size : Int32 = 0
|
||||||
|
|
||||||
|
def initialize(@config : AIUtils::GptConfig)
|
||||||
|
end
|
||||||
|
|
||||||
|
def perform(messages)
|
||||||
|
openai = OpenAI::Client.new(access_token: @config.openai_key)
|
||||||
|
result = openai.chat(
|
||||||
|
@config.model.to_s,
|
||||||
|
messages,
|
||||||
|
{
|
||||||
|
"temperature" => @config.temperature,
|
||||||
|
"presence_penalty" => @config.presence_penalty,
|
||||||
|
"frequency_penalty" => @config.frequency_penalty,
|
||||||
|
"max_tokens" => @config.max_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
8
src/requests/insert.cr
Normal file
8
src/requests/insert.cr
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
|
||||||
|
|
||||||
|
require "./generic_request"
|
||||||
|
|
||||||
|
module Requests
|
||||||
|
class InsertRequest < GenericRequest
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,35 +1,6 @@
|
||||||
enum OpenAIMode
|
|
||||||
Chat
|
|
||||||
Complete
|
|
||||||
Insert
|
|
||||||
|
|
||||||
def self.from_s(str_mode)
|
require "./config"
|
||||||
return OpenAIMode.values.find do |openai_mode|
|
require "./requests/chat"
|
||||||
openai_mode.to_s.downcase == str_mode.downcase
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class GptConfig
|
|
||||||
# defaults
|
|
||||||
property gpt_temperature = 0.82
|
|
||||||
property gpt_presence_penalty = 1
|
|
||||||
property gpt_frequency_penalty = 1
|
|
||||||
property gpt_max_tokens = 256
|
|
||||||
property gpt_mode = OpenAIMode::Chat
|
|
||||||
end
|
|
||||||
|
|
||||||
class AppConfig
|
|
||||||
property input_file = STDIN
|
|
||||||
property input_file_path = ""
|
|
||||||
property output_file = STDOUT
|
|
||||||
property output_file_path = ""
|
|
||||||
property past_characters_limit = 1000
|
|
||||||
property use_color = true
|
|
||||||
property make_request = true
|
|
||||||
property verbose = false
|
|
||||||
property gpt_config = GptConfig.new
|
|
||||||
end
|
|
||||||
|
|
||||||
class Storyteller
|
class Storyteller
|
||||||
property config = AppConfig.new
|
property config = AppConfig.new
|
||||||
|
@ -79,26 +50,35 @@ class Storyteller
|
||||||
|
|
||||||
parser.separator("GPT options")
|
parser.separator("GPT options")
|
||||||
|
|
||||||
parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
|
parser.on("--gpt-model MODEL", "GPT model (...) (default: gpt-3.5-turbo)") do |chosen_model|
|
||||||
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
|
result_model = AIUtils::OpenAIModel.from_s?(chosen_model.downcase)
|
||||||
if result_mode.nil?
|
if result_model.nil?
|
||||||
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
|
STDERR.puts "ERROR: unknown model #{chosen_model}"
|
||||||
exit 1
|
exit 1
|
||||||
end
|
end
|
||||||
gpt_mode = result_mode unless result_mode.nil?
|
@config.gpt_config.model = result_model unless result_model.nil?
|
||||||
end
|
end
|
||||||
|
|
||||||
parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature") do |temperature|
|
# parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
|
||||||
gpt_temperature = temperature
|
# result_mode = AIUtils::OpenAIMode.from_s(chosen_mode.downcase)
|
||||||
|
# if result_mode.nil?
|
||||||
|
# STDERR.puts "ERROR: unknown mode #{chosen_mode}"
|
||||||
|
# exit 1
|
||||||
|
# end
|
||||||
|
# @config.gpt_config.mode = result_mode unless result_mode.nil?
|
||||||
|
# end
|
||||||
|
|
||||||
|
parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature (default #{@config.gpt_config.temperature})") do |temperature|
|
||||||
|
@config.gpt_config.temperature = temperature.to_f
|
||||||
end
|
end
|
||||||
parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty") do |presence_penalty|
|
parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty (default #{@config.gpt_config.presence_penalty})") do |presence_penalty|
|
||||||
gpt_presence_penalty = presence_penalty
|
@config.gpt_config.presence_penalty = presence_penalty.to_f
|
||||||
end
|
end
|
||||||
parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty") do |frequency_penalty|
|
parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty (default #{@config.gpt_config.frequency_penalty})") do |frequency_penalty|
|
||||||
gpt_frequency_penalty = frequency_penalty
|
@config.gpt_config.frequency_penalty = frequency_penalty.to_f
|
||||||
end
|
end
|
||||||
parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens") do |max_tokens|
|
parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens (default #{@config.gpt_config.max_tokens})") do |max_tokens|
|
||||||
gpt_max_tokens = max_tokens
|
@config.gpt_config.max_tokens = max_tokens.to_i
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -130,21 +110,15 @@ class Storyteller
|
||||||
end
|
end
|
||||||
self.write_file(@config.output_file, prompt, @config.use_color)
|
self.write_file(@config.output_file, prompt, @config.use_color)
|
||||||
@config.output_file.close
|
@config.output_file.close
|
||||||
|
|
||||||
# rescue ex : OptionParser::InvalidOption
|
|
||||||
# STDERR.puts @parser
|
|
||||||
# STDERR.puts "\nERROR: #{ex.message}"
|
|
||||||
# exit 1
|
|
||||||
|
|
||||||
# rescue ex: OpenAI::Client::ClientError
|
|
||||||
# STDERR.puts "ERROR: #{ex.message}"
|
|
||||||
# exit 1
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
|
def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
|
||||||
builder = Builder::OpenAIChat.new(verbose: verbose)
|
builder = Builder::OpenAIChat.new(verbose: verbose)
|
||||||
messages = builder.build(prompt)
|
context_token_limit = AIUtils.context_token_limit(@config.gpt_config.model)
|
||||||
|
messages = builder.build(prompt, context_token_limit)
|
||||||
|
|
||||||
|
STDERR.puts "model = #{@config.gpt_config.model}"
|
||||||
|
STDERR.puts "context_token_limit = #{context_token_limit}"
|
||||||
STDERR.puts messages if verbose
|
STDERR.puts messages if verbose
|
||||||
|
|
||||||
return prompt if !make_request
|
return prompt if !make_request
|
||||||
|
@ -158,7 +132,7 @@ class Storyteller
|
||||||
loop do
|
loop do
|
||||||
tick += 1
|
tick += 1
|
||||||
print "."
|
print "."
|
||||||
if tick > (3 * 60)
|
if tick > (2 * 60)
|
||||||
print "(timeout)"
|
print "(timeout)"
|
||||||
exit 1
|
exit 1
|
||||||
end
|
end
|
||||||
|
@ -167,20 +141,31 @@ class Storyteller
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
# request = Request::Chat.new(config: @config.gpt_config)
|
||||||
|
# result = request.perform(message)
|
||||||
|
|
||||||
spawn do
|
spawn do
|
||||||
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
|
request = Requests::ChatRequest.new(config: @config.gpt_config)
|
||||||
result = openai.chat(
|
result = request.perform(messages)
|
||||||
"gpt-3.5-turbo",
|
|
||||||
messages,
|
pp result.choices if @config.verbose
|
||||||
{
|
|
||||||
"temperature" => 0.82,
|
|
||||||
"presence_penalty" => 1,
|
|
||||||
"frequency_penalty" => 1,
|
|
||||||
"max_tokens" => 256
|
|
||||||
}
|
|
||||||
)
|
|
||||||
prompt.present_zone.content << result.choices.first["message"]["content"]
|
prompt.present_zone.content << result.choices.first["message"]["content"]
|
||||||
channel_ready.send(true)
|
channel_ready.send(true)
|
||||||
|
rescue ex: KeyError
|
||||||
|
puts "(openai error)"
|
||||||
|
STDERR.puts "ERROR: #{ex.message}"
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
rescue ex: IO::Error
|
||||||
|
puts "(network error)"
|
||||||
|
STDERR.puts "ERROR: #{ex.message}"
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
rescue ex: Socket::Addrinfo::Error
|
||||||
|
puts "(network error)"
|
||||||
|
STDERR.puts "ERROR: #{ex.message}"
|
||||||
|
exit 1
|
||||||
end
|
end
|
||||||
|
|
||||||
# sp.start
|
# sp.start
|
||||||
|
@ -203,7 +188,7 @@ class Storyteller
|
||||||
|
|
||||||
def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool)
|
def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool)
|
||||||
# STDERR.puts "d: building builder"
|
# STDERR.puts "d: building builder"
|
||||||
builder = Builder::PromptString.new(use_color)
|
builder = ContextBuilder::PromptString.new(use_color)
|
||||||
# STDERR.puts "d: building"
|
# STDERR.puts "d: building"
|
||||||
text = builder.build(prompt)
|
text = builder.build(prompt)
|
||||||
output_file.write_string(text.to_slice)
|
output_file.write_string(text.to_slice)
|
||||||
|
|
Loading…
Reference in a new issue