Add support for auto-split

This commit is contained in:
Glenn Y. Rolland 2023-04-22 16:19:31 +02:00
parent 3bd2f9343f
commit 01efd5e597
6 changed files with 102 additions and 52 deletions

View file

@ -4,3 +4,7 @@ shards:
git: https://github.com/lancecarlson/openai.cr.git git: https://github.com/lancecarlson/openai.cr.git
version: 0.1.0+git.commit.852bcd9b37d8472a4c72a2498ebb90351048fa68 version: 0.1.0+git.commit.852bcd9b37d8472a4c72a2498ebb90351048fa68
spinner:
git: https://github.com/askn/spinner.git
version: 0.1.1

View file

@ -12,6 +12,8 @@ targets:
# Short description of ai-storyteller # Short description of ai-storyteller
dependencies: dependencies:
spinner:
github: askn/spinner
openai: openai:
github: lancecarlson/openai.cr github: lancecarlson/openai.cr

View file

@ -17,21 +17,25 @@ module Builder
def build(prompt : Prompt) : Chat def build(prompt : Prompt) : Chat
chat = [] of Message chat = [] of Message
token_limit = 2_900 token_limit = 2_700
mandatory_token_count = ( mandatory_token_count = (
prompt.system_zone.token_count + prompt.system_zone.token_count +
prompt.present_zone.token_count prompt.present_zone.token_count
) )
## Build mandatory system messages ## Build mandatory system messages
prompt.system_zone.each do |content| # prompt.system_zone.each do |content|
chat << { role: "system", content: content } # next if content.empty?
end # chat << { role: "system", content: content }
# end
chat << { role: "system", content: prompt.system_zone.content.join("\n") }
## Build mandatory system messages ## Build mandatory system messages
tmp_chat = [] of Message tmp_chat = [] of Message
tmp_token_count = 0 tmp_token_count = 0
prompt.past_zone.reverse_each do |content| prompt.past_zone.reverse_each do |content|
next if content.empty?
estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count
pp tmp_chat.reverse if @verbose pp tmp_chat.reverse if @verbose
puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose

View file

@ -5,6 +5,9 @@ require "./generic"
module Builder module Builder
class PromptString < PromptGeneric class PromptString < PromptGeneric
SEPARATOR = "\n\n"
getter use_color : Bool getter use_color : Bool
def initialize(@use_color) def initialize(@use_color)
Colorize.enabled = @use_color Colorize.enabled = @use_color
@ -13,28 +16,26 @@ module Builder
def build(prompt : Prompt) def build(prompt : Prompt)
str = "" str = ""
prompt.prelude_zone.each do |content| str += prompt.prelude_zone.content.join(SEPARATOR)
str += content
end
prompt.system_zone.each do |content| if ! prompt.system_zone.content.empty?
str += "@@system".colorize(:yellow).to_s str += "@@system".colorize(:yellow).to_s
str += content str += prompt.system_zone.content.join(SEPARATOR)
end end
prompt.past_zone.each do |content| if ! prompt.past_zone.content.empty?
str += "@@before".colorize(:yellow).to_s str += "@@before".colorize(:yellow).to_s
str += content str += prompt.past_zone.content.join(SEPARATOR)
end end
prompt.present_zone.each do |content| if ! prompt.present_zone.content.empty?
str += "@@current".colorize(:yellow).to_s str += "@@current".colorize(:yellow).to_s
str += content.colorize(:light_cyan).to_s str += prompt.present_zone.content.join(SEPARATOR).colorize(:light_cyan).to_s
end end
prompt.future_zone.each do |content| if ! prompt.future_zone.content.empty?
str += "@@after".colorize(:yellow).to_s str += "@@after".colorize(:yellow).to_s
str += content str += prompt.future_zone.content.join("\n")
end end
str str

View file

@ -1,6 +1,7 @@
require "option_parser" require "option_parser"
require "pretty_print" require "pretty_print"
require "openai" require "openai"
require "spinner"
require "./zone" require "./zone"
require "./parsers/prompt_string" require "./parsers/prompt_string"
@ -35,30 +36,26 @@ class Storyteller
make_request = true make_request = true
gpt_mode = OpenAIMode::Chat gpt_mode = OpenAIMode::Chat
verbose = false verbose = false
gpt_temperature = 0.82
gpt_presence_penalty = 1
gpt_frequency_penalty = 1
gpt_max_tokens = 256
parser = OptionParser.parse do |parser| parser = OptionParser.new do |parser|
parser.banner = "Usage: storyteller [options]" parser.banner = "Usage: storyteller [options]"
parser.on("-m MODE", "--mode=MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode| parser.separator("Options:")
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
if result_mode.nil?
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
exit 1
end
gpt_mode = result_mode unless result_mode.nil?
end
parser.on("-i FILE", "--input=FILE", "Path to input file") do |file| parser.on("-i FILE", "--input=FILE", "Path to input file") do |file|
input_file_path = file input_file_path = file
end end
parser.on("-v", "--verbose", "Be verbose (cumulative)") do parser.on("-h", "--help", "Show this help") do
verbose = true puts parser
exit
end end
parser.on("--dry-run", "Don't call the API") do
make_request = false
end
parser.on("-n", "--no-color", "Disable color output") do parser.on("-n", "--no-color", "Disable color output") do
use_color = false use_color = false
@ -69,12 +66,42 @@ class Storyteller
output_file_path = file output_file_path = file
end end
parser.on("-h", "--help", "Show this help") do parser.on("-v", "--verbose", "Be verbose (cumulative)") do
puts parser verbose = true
exit
end end
parser.on("--dry-run", "Don't call the API") do
make_request = false
end end
parser.separator("GPT options")
parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
if result_mode.nil?
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
exit 1
end
gpt_mode = result_mode unless result_mode.nil?
end
parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature") do |temperature|
gpt_temperature = temperature
end
parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty") do |presence_penalty|
gpt_presence_penalty = presence_penalty
end
parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty") do |frequency_penalty|
gpt_frequency_penalty = frequency_penalty
end
parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens") do |max_tokens|
gpt_max_tokens = max_tokens
end
end
parser.parse(ARGV)
# Create Storyteller instance # Create Storyteller instance
storyteller = Storyteller.new() storyteller = Storyteller.new()
@ -97,16 +124,38 @@ class Storyteller
end end
storyteller.write_file(output_file, prompt, use_color) storyteller.write_file(output_file, prompt, use_color)
output_file.close output_file.close
rescue ex : OptionParser::InvalidOption
STDERR.puts parser
STDERR.puts "\nERROR: #{ex.message}"
exit 1
end end
def complete(prompt : Prompt, make_request : Bool, verbose : Bool) def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
builder = Builder::OpenAIChat.new(verbose: verbose) builder = Builder::OpenAIChat.new(verbose: verbose)
messages = builder.build(prompt) messages = builder.build(prompt)
STDERR.puts messages if verbose
return prompt if !make_request return prompt if !make_request
channel_ready = Channel(Bool).new channel_ready = Channel(Bool).new
channel_tick = Channel(Bool).new channel_tick = Channel(Bool).new
# sp = Spin.new(0.5, Spinner::Charset[:progress])
spawn do
tick = 0
loop do
tick += 1
print "."
if tick > (3 * 60)
print "(timeout)"
exit 1
end
channel_tick.send(true)
sleep 1.seconds
end
end
spawn do spawn do
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY")) openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
result = openai.chat( result = openai.chat(
@ -119,26 +168,14 @@ class Storyteller
"max_tokens" => 256 "max_tokens" => 256
} }
) )
prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n" prompt.past_zone.content << "\n" + result.choices.first["message"]["content"] + "\n"
channel_ready.send(true) channel_ready.send(true)
end end
spawn do # sp.start
loop do
channel_tick.send(true)
sleep 1.seconds
end
end
spawn do
while channel_tick.receive?
print "."
end
end
channel_ready.receive channel_ready.receive
channel_ready.close channel_ready.close
channel_tick.close # sp.stop
prompt prompt
end end

View file

@ -3,6 +3,8 @@ require "../prompt"
module Parser module Parser
class PromptString class PromptString
SEPARATOR = "\n\n"
def parse(current : String) : Prompt def parse(current : String) : Prompt
prompt = Prompt.new prompt = Prompt.new
remaining = read_string(prompt.zones, current) remaining = read_string(prompt.zones, current)
@ -44,7 +46,7 @@ module Parser
## Handle recognized tag, skip tag & parse & add remains ## Handle recognized tag, skip tag & parse & add remains
# puts "-- found tag #{zone.tag}" # puts "-- found tag #{zone.tag}"
remaining = read_string(zone_list, current[(2+zone.tag.size)..]) remaining = read_string(zone_list, current[(2+zone.tag.size)..])
zone.content << remaining zone.content.concat remaining.split(SEPARATOR).reverse!
return "" return ""
end end
end end