Add support for auto-split
This commit is contained in:
parent
3bd2f9343f
commit
01efd5e597
6 changed files with 102 additions and 52 deletions
|
@ -4,3 +4,7 @@ shards:
|
|||
git: https://github.com/lancecarlson/openai.cr.git
|
||||
version: 0.1.0+git.commit.852bcd9b37d8472a4c72a2498ebb90351048fa68
|
||||
|
||||
spinner:
|
||||
git: https://github.com/askn/spinner.git
|
||||
version: 0.1.1
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ targets:
|
|||
# Short description of ai-storyteller
|
||||
|
||||
dependencies:
|
||||
spinner:
|
||||
github: askn/spinner
|
||||
openai:
|
||||
github: lancecarlson/openai.cr
|
||||
|
||||
|
|
|
@ -17,21 +17,25 @@ module Builder
|
|||
def build(prompt : Prompt) : Chat
|
||||
chat = [] of Message
|
||||
|
||||
token_limit = 2_900
|
||||
token_limit = 2_700
|
||||
mandatory_token_count = (
|
||||
prompt.system_zone.token_count +
|
||||
prompt.present_zone.token_count
|
||||
)
|
||||
|
||||
## Build mandatory system messages
|
||||
prompt.system_zone.each do |content|
|
||||
chat << { role: "system", content: content }
|
||||
end
|
||||
# prompt.system_zone.each do |content|
|
||||
# next if content.empty?
|
||||
# chat << { role: "system", content: content }
|
||||
# end
|
||||
chat << { role: "system", content: prompt.system_zone.content.join("\n") }
|
||||
|
||||
## Build mandatory system messages
|
||||
tmp_chat = [] of Message
|
||||
tmp_token_count = 0
|
||||
prompt.past_zone.reverse_each do |content|
|
||||
next if content.empty?
|
||||
|
||||
estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count
|
||||
pp tmp_chat.reverse if @verbose
|
||||
puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose
|
||||
|
|
|
@ -5,6 +5,9 @@ require "./generic"
|
|||
|
||||
module Builder
|
||||
class PromptString < PromptGeneric
|
||||
|
||||
SEPARATOR = "\n\n"
|
||||
|
||||
getter use_color : Bool
|
||||
def initialize(@use_color)
|
||||
Colorize.enabled = @use_color
|
||||
|
@ -13,28 +16,26 @@ module Builder
|
|||
def build(prompt : Prompt)
|
||||
str = ""
|
||||
|
||||
prompt.prelude_zone.each do |content|
|
||||
str += content
|
||||
end
|
||||
str += prompt.prelude_zone.content.join(SEPARATOR)
|
||||
|
||||
prompt.system_zone.each do |content|
|
||||
if ! prompt.system_zone.content.empty?
|
||||
str += "@@system".colorize(:yellow).to_s
|
||||
str += content
|
||||
str += prompt.system_zone.content.join(SEPARATOR)
|
||||
end
|
||||
|
||||
prompt.past_zone.each do |content|
|
||||
if ! prompt.past_zone.content.empty?
|
||||
str += "@@before".colorize(:yellow).to_s
|
||||
str += content
|
||||
str += prompt.past_zone.content.join(SEPARATOR)
|
||||
end
|
||||
|
||||
prompt.present_zone.each do |content|
|
||||
if ! prompt.present_zone.content.empty?
|
||||
str += "@@current".colorize(:yellow).to_s
|
||||
str += content.colorize(:light_cyan).to_s
|
||||
str += prompt.present_zone.content.join(SEPARATOR).colorize(:light_cyan).to_s
|
||||
end
|
||||
|
||||
prompt.future_zone.each do |content|
|
||||
if ! prompt.future_zone.content.empty?
|
||||
str += "@@after".colorize(:yellow).to_s
|
||||
str += content
|
||||
str += prompt.future_zone.content.join("\n")
|
||||
end
|
||||
|
||||
str
|
||||
|
|
109
src/main.cr
109
src/main.cr
|
@ -1,6 +1,7 @@
|
|||
require "option_parser"
|
||||
require "pretty_print"
|
||||
require "openai"
|
||||
require "spinner"
|
||||
|
||||
require "./zone"
|
||||
require "./parsers/prompt_string"
|
||||
|
@ -35,30 +36,26 @@ class Storyteller
|
|||
make_request = true
|
||||
gpt_mode = OpenAIMode::Chat
|
||||
verbose = false
|
||||
gpt_temperature = 0.82
|
||||
gpt_presence_penalty = 1
|
||||
gpt_frequency_penalty = 1
|
||||
gpt_max_tokens = 256
|
||||
|
||||
parser = OptionParser.parse do |parser|
|
||||
parser = OptionParser.new do |parser|
|
||||
parser.banner = "Usage: storyteller [options]"
|
||||
|
||||
parser.on("-m MODE", "--mode=MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
|
||||
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
|
||||
if result_mode.nil?
|
||||
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
|
||||
exit 1
|
||||
end
|
||||
gpt_mode = result_mode unless result_mode.nil?
|
||||
end
|
||||
parser.separator("Options:")
|
||||
|
||||
|
||||
parser.on("-i FILE", "--input=FILE", "Path to input file") do |file|
|
||||
input_file_path = file
|
||||
end
|
||||
|
||||
parser.on("-v", "--verbose", "Be verbose (cumulative)") do
|
||||
verbose = true
|
||||
parser.on("-h", "--help", "Show this help") do
|
||||
puts parser
|
||||
exit
|
||||
end
|
||||
|
||||
parser.on("--dry-run", "Don't call the API") do
|
||||
make_request = false
|
||||
end
|
||||
|
||||
parser.on("-n", "--no-color", "Disable color output") do
|
||||
use_color = false
|
||||
|
@ -69,11 +66,41 @@ class Storyteller
|
|||
output_file_path = file
|
||||
end
|
||||
|
||||
parser.on("-h", "--help", "Show this help") do
|
||||
puts parser
|
||||
exit
|
||||
parser.on("-v", "--verbose", "Be verbose (cumulative)") do
|
||||
verbose = true
|
||||
end
|
||||
|
||||
parser.on("--dry-run", "Don't call the API") do
|
||||
make_request = false
|
||||
end
|
||||
|
||||
|
||||
parser.separator("GPT options")
|
||||
|
||||
parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
|
||||
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
|
||||
if result_mode.nil?
|
||||
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
|
||||
exit 1
|
||||
end
|
||||
gpt_mode = result_mode unless result_mode.nil?
|
||||
end
|
||||
|
||||
parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature") do |temperature|
|
||||
gpt_temperature = temperature
|
||||
end
|
||||
parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty") do |presence_penalty|
|
||||
gpt_presence_penalty = presence_penalty
|
||||
end
|
||||
parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty") do |frequency_penalty|
|
||||
gpt_frequency_penalty = frequency_penalty
|
||||
end
|
||||
parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens") do |max_tokens|
|
||||
gpt_max_tokens = max_tokens
|
||||
end
|
||||
|
||||
end
|
||||
parser.parse(ARGV)
|
||||
|
||||
# Create Storyteller instance
|
||||
storyteller = Storyteller.new()
|
||||
|
@ -97,48 +124,58 @@ class Storyteller
|
|||
end
|
||||
storyteller.write_file(output_file, prompt, use_color)
|
||||
output_file.close
|
||||
rescue ex : OptionParser::InvalidOption
|
||||
STDERR.puts parser
|
||||
STDERR.puts "\nERROR: #{ex.message}"
|
||||
exit 1
|
||||
end
|
||||
|
||||
def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
|
||||
builder = Builder::OpenAIChat.new(verbose: verbose)
|
||||
messages = builder.build(prompt)
|
||||
|
||||
STDERR.puts messages if verbose
|
||||
|
||||
return prompt if !make_request
|
||||
|
||||
channel_ready = Channel(Bool).new
|
||||
channel_tick = Channel(Bool).new
|
||||
spawn do
|
||||
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
|
||||
result = openai.chat(
|
||||
"gpt-3.5-turbo",
|
||||
messages,
|
||||
{
|
||||
"temperature" => 0.82,
|
||||
"presence_penalty" => 1,
|
||||
"frequency_penalty" => 1,
|
||||
"max_tokens" => 256
|
||||
}
|
||||
)
|
||||
prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n"
|
||||
channel_ready.send(true)
|
||||
end
|
||||
# sp = Spin.new(0.5, Spinner::Charset[:progress])
|
||||
|
||||
spawn do
|
||||
tick = 0
|
||||
loop do
|
||||
tick += 1
|
||||
print "."
|
||||
if tick > (3 * 60)
|
||||
print "(timeout)"
|
||||
exit 1
|
||||
end
|
||||
channel_tick.send(true)
|
||||
sleep 1.seconds
|
||||
end
|
||||
end
|
||||
|
||||
spawn do
|
||||
while channel_tick.receive?
|
||||
print "."
|
||||
end
|
||||
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
|
||||
result = openai.chat(
|
||||
"gpt-3.5-turbo",
|
||||
messages,
|
||||
{
|
||||
"temperature" => 0.82,
|
||||
"presence_penalty" => 1,
|
||||
"frequency_penalty" => 1,
|
||||
"max_tokens" => 256
|
||||
}
|
||||
)
|
||||
prompt.past_zone.content << "\n" + result.choices.first["message"]["content"] + "\n"
|
||||
channel_ready.send(true)
|
||||
end
|
||||
|
||||
# sp.start
|
||||
channel_ready.receive
|
||||
channel_ready.close
|
||||
channel_tick.close
|
||||
# sp.stop
|
||||
|
||||
prompt
|
||||
end
|
||||
|
|
|
@ -3,6 +3,8 @@ require "../prompt"
|
|||
|
||||
module Parser
|
||||
class PromptString
|
||||
SEPARATOR = "\n\n"
|
||||
|
||||
def parse(current : String) : Prompt
|
||||
prompt = Prompt.new
|
||||
remaining = read_string(prompt.zones, current)
|
||||
|
@ -44,7 +46,7 @@ module Parser
|
|||
## Handle recognized tag, skip tag & parse & add remains
|
||||
# puts "-- found tag #{zone.tag}"
|
||||
remaining = read_string(zone_list, current[(2+zone.tag.size)..])
|
||||
zone.content << remaining
|
||||
zone.content.concat remaining.split(SEPARATOR).reverse!
|
||||
return ""
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue