src: refactor + clean code using better types
This commit is contained in:
parent
bb26268e3b
commit
4e3e6aa8f8
19 changed files with 251 additions and 84 deletions
|
@ -1,6 +0,0 @@
|
|||
|
||||
module Builder
|
||||
abstract class PromptGeneric
|
||||
abstract def build(prompt : Prompt)
|
||||
end
|
||||
end
|
14
src/config.cr
Normal file
14
src/config.cr
Normal file
|
@ -0,0 +1,14 @@
|
|||
|
||||
require "./openai_config"
|
||||
|
||||
class AppConfig
|
||||
property input_file = STDIN
|
||||
property input_file_path = ""
|
||||
property output_file = STDOUT
|
||||
property output_file_path = ""
|
||||
property past_characters_limit = 1000
|
||||
property use_color = true
|
||||
property make_request = true
|
||||
property verbose = false
|
||||
property gpt_config = AIUtils::GptConfig.new
|
||||
end
|
6
src/context_builders/generic_builder.cr
Normal file
6
src/context_builders/generic_builder.cr
Normal file
|
@ -0,0 +1,6 @@
|
|||
|
||||
module ContextBuilder
|
||||
abstract class GenericContextBuilder
|
||||
abstract def build(prompt : Prompt)
|
||||
end
|
||||
end
|
|
@ -14,10 +14,12 @@ module Builder
|
|||
|
||||
# skip prelude_zone
|
||||
# skip future_zone
|
||||
def build(prompt : Prompt) : Chat
|
||||
def build(prompt : Prompt, context_token_limit : UInt32) : Chat
|
||||
chat = [] of Message
|
||||
|
||||
token_limit = 2_700
|
||||
# token_limit = 10_000
|
||||
token_limit = (0.65 * context_token_limit).to_i
|
||||
STDERR.puts "messages_token_limit = #{token_limit}"
|
||||
mandatory_token_count = (
|
||||
prompt.system_zone.token_count +
|
||||
prompt.present_zone.token_count
|
45
src/context_builders/prompt_dir.cr
Normal file
45
src/context_builders/prompt_dir.cr
Normal file
|
@ -0,0 +1,45 @@
|
|||
|
||||
require "colorize"
|
||||
|
||||
require "../config"
|
||||
require "./generic_builder"
|
||||
|
||||
module Builder
|
||||
class PromptDir < PromptGeneric
|
||||
|
||||
SEPARATOR = "\n\n"
|
||||
|
||||
getter use_color : Bool
|
||||
def initialize(@use_color)
|
||||
Colorize.enabled = @use_color
|
||||
end
|
||||
|
||||
def build(prompt : Prompt)
|
||||
str = ""
|
||||
|
||||
str += prompt.prelude_zone.content.join(SEPARATOR)
|
||||
|
||||
if ! prompt.system_zone.content.empty?
|
||||
str += "@@system".colorize(:yellow).to_s
|
||||
str += prompt.system_zone.content.join(SEPARATOR)
|
||||
end
|
||||
|
||||
if ! prompt.past_zone.content.empty?
|
||||
str += "@@before".colorize(:yellow).to_s
|
||||
str += prompt.past_zone.content.join(SEPARATOR)
|
||||
end
|
||||
|
||||
if ! prompt.present_zone.content.empty?
|
||||
# str += "@@current".colorize(:yellow).to_s
|
||||
str += prompt.present_zone.content.join(SEPARATOR).colorize(:light_cyan).to_s
|
||||
end
|
||||
|
||||
if ! prompt.future_zone.content.empty?
|
||||
str += "@@after".colorize(:yellow).to_s
|
||||
str += prompt.future_zone.content.join("\n")
|
||||
end
|
||||
|
||||
str
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,10 +1,10 @@
|
|||
|
||||
require "colorize"
|
||||
|
||||
require "./generic"
|
||||
require "./generic_builder"
|
||||
|
||||
module Builder
|
||||
class PromptString < PromptGeneric
|
||||
module ContextBuilder
|
||||
class PromptString < GenericContextBuilder
|
||||
|
||||
SEPARATOR = "\n\n"
|
||||
|
10
src/main.cr
10
src/main.cr
|
@ -3,15 +3,15 @@ require "pretty_print"
|
|||
require "openai"
|
||||
require "spinner"
|
||||
|
||||
require "./zone"
|
||||
require "./prompts/zone"
|
||||
require "./parsers/prompt_string"
|
||||
require "./builders/prompt_string"
|
||||
require "./builders/openai_chat"
|
||||
require "./builders/openai_insert"
|
||||
require "./context_builders/prompt_string"
|
||||
# require "./context_builders/prompt_dir"
|
||||
require "./context_builders/openai_chat"
|
||||
require "./context_builders/openai_insert"
|
||||
require "./storyteller"
|
||||
|
||||
storyteller = Storyteller.new()
|
||||
storyteller.prepare()
|
||||
storyteller.run()
|
||||
|
||||
|
||||
|
|
11
src/modifiers/readme.md
Normal file
11
src/modifiers/readme.md
Normal file
|
@ -0,0 +1,11 @@
|
|||
|
||||
Des classes qui modifient un prompt d'apres des contraintes
|
||||
|
||||
truncate(max_tokens) # le prompt aux X derniers tokens
|
||||
|
||||
summarize() # réduis la longueur du prompt
|
||||
|
||||
worldinfo() # introduis les éléments du monde
|
||||
|
||||
|
||||
|
64
src/openai_config.cr
Normal file
64
src/openai_config.cr
Normal file
|
@ -0,0 +1,64 @@
|
|||
|
||||
|
||||
module AIUtils
|
||||
enum OpenAIModel
|
||||
Gpt_4
|
||||
Gpt_3_5_Turbo
|
||||
Gpt_3_5_Turbo_16k
|
||||
|
||||
def to_s
|
||||
vals = [
|
||||
"gpt-4",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k"
|
||||
]
|
||||
vals[self.value]
|
||||
end
|
||||
|
||||
def self.from_s?(str_model : String) : OpenAIModel?
|
||||
return OpenAIModel.values.find do |openai_model|
|
||||
openai_model.to_s.downcase == str_model.downcase
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
enum OpenAIMode
|
||||
Chat
|
||||
Complete
|
||||
Insert
|
||||
|
||||
def self.from_s?(str_mode) : OpenAIMode?
|
||||
return OpenAIMode.parse?(str_mode)
|
||||
# return OpenAIMode.values.find do |openai_mode|
|
||||
# openai_mode.to_s.downcase == str_mode.downcase
|
||||
# end
|
||||
end
|
||||
end
|
||||
|
||||
def self.context_token_limit(model : OpenAIModel) : UInt32
|
||||
hash = {
|
||||
OpenAIModel::Gpt_3_5_Turbo => 1024_u32 * 4_u32,
|
||||
OpenAIModel::Gpt_3_5_Turbo_16k => 1024_u32 * 16_u32,
|
||||
OpenAIModel::Gpt_4 => 1024_u32 * 8_u32,
|
||||
}
|
||||
|
||||
return 0_u32 unless hash[model]?
|
||||
return hash[model]
|
||||
end
|
||||
|
||||
class GptConfig
|
||||
# defaults
|
||||
# property temperature = 0.82
|
||||
property model : OpenAIModel = OpenAIModel::Gpt_3_5_Turbo
|
||||
# property model = "gpt-3.5-turbo-16k"
|
||||
# property prompt_size = 10_000
|
||||
property prompt_size = 2_700
|
||||
property temperature = 0.8
|
||||
property presence_penalty = 1.0
|
||||
property frequency_penalty = 1.0
|
||||
property max_tokens = 384
|
||||
property mode = OpenAIMode::Chat
|
||||
property openai_key : String = ENV.fetch("OPENAI_API_KEY")
|
||||
end
|
||||
end
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
require "../prompt"
|
||||
require "../prompts/generic_prompt"
|
||||
|
||||
module Parser
|
||||
class PromptString
|
||||
|
|
7
src/requests/chat.cr
Normal file
7
src/requests/chat.cr
Normal file
|
@ -0,0 +1,7 @@
|
|||
|
||||
require "./generic_request"
|
||||
|
||||
module Requests
|
||||
class ChatRequest < GenericRequest
|
||||
end
|
||||
end
|
8
src/requests/complete.cr
Normal file
8
src/requests/complete.cr
Normal file
|
@ -0,0 +1,8 @@
|
|||
|
||||
|
||||
require "./generic_request"
|
||||
|
||||
module Requests
|
||||
class CompleteRequest < GenericRequest
|
||||
end
|
||||
end
|
23
src/requests/generic_request.cr
Normal file
23
src/requests/generic_request.cr
Normal file
|
@ -0,0 +1,23 @@
|
|||
|
||||
module Requests
|
||||
class GenericRequest
|
||||
getter context_size : Int32 = 0
|
||||
|
||||
def initialize(@config : AIUtils::GptConfig)
|
||||
end
|
||||
|
||||
def perform(messages)
|
||||
openai = OpenAI::Client.new(access_token: @config.openai_key)
|
||||
result = openai.chat(
|
||||
@config.model.to_s,
|
||||
messages,
|
||||
{
|
||||
"temperature" => @config.temperature,
|
||||
"presence_penalty" => @config.presence_penalty,
|
||||
"frequency_penalty" => @config.frequency_penalty,
|
||||
"max_tokens" => @config.max_tokens,
|
||||
}
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
8
src/requests/insert.cr
Normal file
8
src/requests/insert.cr
Normal file
|
@ -0,0 +1,8 @@
|
|||
|
||||
|
||||
require "./generic_request"
|
||||
|
||||
module Requests
|
||||
class InsertRequest < GenericRequest
|
||||
end
|
||||
end
|
|
@ -1,35 +1,6 @@
|
|||
enum OpenAIMode
|
||||
Chat
|
||||
Complete
|
||||
Insert
|
||||
|
||||
def self.from_s(str_mode)
|
||||
return OpenAIMode.values.find do |openai_mode|
|
||||
openai_mode.to_s.downcase == str_mode.downcase
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
class GptConfig
|
||||
# defaults
|
||||
property gpt_temperature = 0.82
|
||||
property gpt_presence_penalty = 1
|
||||
property gpt_frequency_penalty = 1
|
||||
property gpt_max_tokens = 256
|
||||
property gpt_mode = OpenAIMode::Chat
|
||||
end
|
||||
|
||||
class AppConfig
|
||||
property input_file = STDIN
|
||||
property input_file_path = ""
|
||||
property output_file = STDOUT
|
||||
property output_file_path = ""
|
||||
property past_characters_limit = 1000
|
||||
property use_color = true
|
||||
property make_request = true
|
||||
property verbose = false
|
||||
property gpt_config = GptConfig.new
|
||||
end
|
||||
require "./config"
|
||||
require "./requests/chat"
|
||||
|
||||
class Storyteller
|
||||
property config = AppConfig.new
|
||||
|
@ -79,26 +50,35 @@ class Storyteller
|
|||
|
||||
parser.separator("GPT options")
|
||||
|
||||
parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
|
||||
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
|
||||
if result_mode.nil?
|
||||
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
|
||||
parser.on("--gpt-model MODEL", "GPT model (...) (default: gpt-3.5-turbo)") do |chosen_model|
|
||||
result_model = AIUtils::OpenAIModel.from_s?(chosen_model.downcase)
|
||||
if result_model.nil?
|
||||
STDERR.puts "ERROR: unknown model #{chosen_model}"
|
||||
exit 1
|
||||
end
|
||||
gpt_mode = result_mode unless result_mode.nil?
|
||||
@config.gpt_config.model = result_model unless result_model.nil?
|
||||
end
|
||||
|
||||
parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature") do |temperature|
|
||||
gpt_temperature = temperature
|
||||
# parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
|
||||
# result_mode = AIUtils::OpenAIMode.from_s(chosen_mode.downcase)
|
||||
# if result_mode.nil?
|
||||
# STDERR.puts "ERROR: unknown mode #{chosen_mode}"
|
||||
# exit 1
|
||||
# end
|
||||
# @config.gpt_config.mode = result_mode unless result_mode.nil?
|
||||
# end
|
||||
|
||||
parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature (default #{@config.gpt_config.temperature})") do |temperature|
|
||||
@config.gpt_config.temperature = temperature.to_f
|
||||
end
|
||||
parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty") do |presence_penalty|
|
||||
gpt_presence_penalty = presence_penalty
|
||||
parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty (default #{@config.gpt_config.presence_penalty})") do |presence_penalty|
|
||||
@config.gpt_config.presence_penalty = presence_penalty.to_f
|
||||
end
|
||||
parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty") do |frequency_penalty|
|
||||
gpt_frequency_penalty = frequency_penalty
|
||||
parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty (default #{@config.gpt_config.frequency_penalty})") do |frequency_penalty|
|
||||
@config.gpt_config.frequency_penalty = frequency_penalty.to_f
|
||||
end
|
||||
parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens") do |max_tokens|
|
||||
gpt_max_tokens = max_tokens
|
||||
parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens (default #{@config.gpt_config.max_tokens})") do |max_tokens|
|
||||
@config.gpt_config.max_tokens = max_tokens.to_i
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -130,21 +110,15 @@ class Storyteller
|
|||
end
|
||||
self.write_file(@config.output_file, prompt, @config.use_color)
|
||||
@config.output_file.close
|
||||
|
||||
# rescue ex : OptionParser::InvalidOption
|
||||
# STDERR.puts @parser
|
||||
# STDERR.puts "\nERROR: #{ex.message}"
|
||||
# exit 1
|
||||
|
||||
# rescue ex: OpenAI::Client::ClientError
|
||||
# STDERR.puts "ERROR: #{ex.message}"
|
||||
# exit 1
|
||||
end
|
||||
|
||||
def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
|
||||
builder = Builder::OpenAIChat.new(verbose: verbose)
|
||||
messages = builder.build(prompt)
|
||||
context_token_limit = AIUtils.context_token_limit(@config.gpt_config.model)
|
||||
messages = builder.build(prompt, context_token_limit)
|
||||
|
||||
STDERR.puts "model = #{@config.gpt_config.model}"
|
||||
STDERR.puts "context_token_limit = #{context_token_limit}"
|
||||
STDERR.puts messages if verbose
|
||||
|
||||
return prompt if !make_request
|
||||
|
@ -158,7 +132,7 @@ class Storyteller
|
|||
loop do
|
||||
tick += 1
|
||||
print "."
|
||||
if tick > (3 * 60)
|
||||
if tick > (2 * 60)
|
||||
print "(timeout)"
|
||||
exit 1
|
||||
end
|
||||
|
@ -167,20 +141,31 @@ class Storyteller
|
|||
end
|
||||
end
|
||||
|
||||
# TODO:
|
||||
# request = Request::Chat.new(config: @config.gpt_config)
|
||||
# result = request.perform(message)
|
||||
|
||||
spawn do
|
||||
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
|
||||
result = openai.chat(
|
||||
"gpt-3.5-turbo",
|
||||
messages,
|
||||
{
|
||||
"temperature" => 0.82,
|
||||
"presence_penalty" => 1,
|
||||
"frequency_penalty" => 1,
|
||||
"max_tokens" => 256
|
||||
}
|
||||
)
|
||||
request = Requests::ChatRequest.new(config: @config.gpt_config)
|
||||
result = request.perform(messages)
|
||||
|
||||
pp result.choices if @config.verbose
|
||||
prompt.present_zone.content << result.choices.first["message"]["content"]
|
||||
channel_ready.send(true)
|
||||
rescue ex: KeyError
|
||||
puts "(openai error)"
|
||||
STDERR.puts "ERROR: #{ex.message}"
|
||||
exit 1
|
||||
|
||||
rescue ex: IO::Error
|
||||
puts "(network error)"
|
||||
STDERR.puts "ERROR: #{ex.message}"
|
||||
exit 1
|
||||
|
||||
rescue ex: Socket::Addrinfo::Error
|
||||
puts "(network error)"
|
||||
STDERR.puts "ERROR: #{ex.message}"
|
||||
exit 1
|
||||
end
|
||||
|
||||
# sp.start
|
||||
|
@ -203,7 +188,7 @@ class Storyteller
|
|||
|
||||
def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool)
|
||||
# STDERR.puts "d: building builder"
|
||||
builder = Builder::PromptString.new(use_color)
|
||||
builder = ContextBuilder::PromptString.new(use_color)
|
||||
# STDERR.puts "d: building"
|
||||
text = builder.build(prompt)
|
||||
output_file.write_string(text.to_slice)
|
||||
|
|
Loading…
Reference in a new issue