Add support for auto-split
This commit is contained in:
parent
3bd2f9343f
commit
01efd5e597
6 changed files with 102 additions and 52 deletions
|
@ -4,3 +4,7 @@ shards:
|
||||||
git: https://github.com/lancecarlson/openai.cr.git
|
git: https://github.com/lancecarlson/openai.cr.git
|
||||||
version: 0.1.0+git.commit.852bcd9b37d8472a4c72a2498ebb90351048fa68
|
version: 0.1.0+git.commit.852bcd9b37d8472a4c72a2498ebb90351048fa68
|
||||||
|
|
||||||
|
spinner:
|
||||||
|
git: https://github.com/askn/spinner.git
|
||||||
|
version: 0.1.1
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,8 @@ targets:
|
||||||
# Short description of ai-storyteller
|
# Short description of ai-storyteller
|
||||||
|
|
||||||
dependencies:
|
dependencies:
|
||||||
|
spinner:
|
||||||
|
github: askn/spinner
|
||||||
openai:
|
openai:
|
||||||
github: lancecarlson/openai.cr
|
github: lancecarlson/openai.cr
|
||||||
|
|
||||||
|
|
|
@ -17,21 +17,25 @@ module Builder
|
||||||
def build(prompt : Prompt) : Chat
|
def build(prompt : Prompt) : Chat
|
||||||
chat = [] of Message
|
chat = [] of Message
|
||||||
|
|
||||||
token_limit = 2_900
|
token_limit = 2_700
|
||||||
mandatory_token_count = (
|
mandatory_token_count = (
|
||||||
prompt.system_zone.token_count +
|
prompt.system_zone.token_count +
|
||||||
prompt.present_zone.token_count
|
prompt.present_zone.token_count
|
||||||
)
|
)
|
||||||
|
|
||||||
## Build mandatory system messages
|
## Build mandatory system messages
|
||||||
prompt.system_zone.each do |content|
|
# prompt.system_zone.each do |content|
|
||||||
chat << { role: "system", content: content }
|
# next if content.empty?
|
||||||
end
|
# chat << { role: "system", content: content }
|
||||||
|
# end
|
||||||
|
chat << { role: "system", content: prompt.system_zone.content.join("\n") }
|
||||||
|
|
||||||
## Build mandatory system messages
|
## Build mandatory system messages
|
||||||
tmp_chat = [] of Message
|
tmp_chat = [] of Message
|
||||||
tmp_token_count = 0
|
tmp_token_count = 0
|
||||||
prompt.past_zone.reverse_each do |content|
|
prompt.past_zone.reverse_each do |content|
|
||||||
|
next if content.empty?
|
||||||
|
|
||||||
estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count
|
estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count
|
||||||
pp tmp_chat.reverse if @verbose
|
pp tmp_chat.reverse if @verbose
|
||||||
puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose
|
puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose
|
||||||
|
|
|
@ -5,6 +5,9 @@ require "./generic"
|
||||||
|
|
||||||
module Builder
|
module Builder
|
||||||
class PromptString < PromptGeneric
|
class PromptString < PromptGeneric
|
||||||
|
|
||||||
|
SEPARATOR = "\n\n"
|
||||||
|
|
||||||
getter use_color : Bool
|
getter use_color : Bool
|
||||||
def initialize(@use_color)
|
def initialize(@use_color)
|
||||||
Colorize.enabled = @use_color
|
Colorize.enabled = @use_color
|
||||||
|
@ -13,28 +16,26 @@ module Builder
|
||||||
def build(prompt : Prompt)
|
def build(prompt : Prompt)
|
||||||
str = ""
|
str = ""
|
||||||
|
|
||||||
prompt.prelude_zone.each do |content|
|
str += prompt.prelude_zone.content.join(SEPARATOR)
|
||||||
str += content
|
|
||||||
end
|
|
||||||
|
|
||||||
prompt.system_zone.each do |content|
|
if ! prompt.system_zone.content.empty?
|
||||||
str += "@@system".colorize(:yellow).to_s
|
str += "@@system".colorize(:yellow).to_s
|
||||||
str += content
|
str += prompt.system_zone.content.join(SEPARATOR)
|
||||||
end
|
end
|
||||||
|
|
||||||
prompt.past_zone.each do |content|
|
if ! prompt.past_zone.content.empty?
|
||||||
str += "@@before".colorize(:yellow).to_s
|
str += "@@before".colorize(:yellow).to_s
|
||||||
str += content
|
str += prompt.past_zone.content.join(SEPARATOR)
|
||||||
end
|
end
|
||||||
|
|
||||||
prompt.present_zone.each do |content|
|
if ! prompt.present_zone.content.empty?
|
||||||
str += "@@current".colorize(:yellow).to_s
|
str += "@@current".colorize(:yellow).to_s
|
||||||
str += content.colorize(:light_cyan).to_s
|
str += prompt.present_zone.content.join(SEPARATOR).colorize(:light_cyan).to_s
|
||||||
end
|
end
|
||||||
|
|
||||||
prompt.future_zone.each do |content|
|
if ! prompt.future_zone.content.empty?
|
||||||
str += "@@after".colorize(:yellow).to_s
|
str += "@@after".colorize(:yellow).to_s
|
||||||
str += content
|
str += prompt.future_zone.content.join("\n")
|
||||||
end
|
end
|
||||||
|
|
||||||
str
|
str
|
||||||
|
|
101
src/main.cr
101
src/main.cr
|
@ -1,6 +1,7 @@
|
||||||
require "option_parser"
|
require "option_parser"
|
||||||
require "pretty_print"
|
require "pretty_print"
|
||||||
require "openai"
|
require "openai"
|
||||||
|
require "spinner"
|
||||||
|
|
||||||
require "./zone"
|
require "./zone"
|
||||||
require "./parsers/prompt_string"
|
require "./parsers/prompt_string"
|
||||||
|
@ -35,30 +36,26 @@ class Storyteller
|
||||||
make_request = true
|
make_request = true
|
||||||
gpt_mode = OpenAIMode::Chat
|
gpt_mode = OpenAIMode::Chat
|
||||||
verbose = false
|
verbose = false
|
||||||
|
gpt_temperature = 0.82
|
||||||
|
gpt_presence_penalty = 1
|
||||||
|
gpt_frequency_penalty = 1
|
||||||
|
gpt_max_tokens = 256
|
||||||
|
|
||||||
parser = OptionParser.parse do |parser|
|
parser = OptionParser.new do |parser|
|
||||||
parser.banner = "Usage: storyteller [options]"
|
parser.banner = "Usage: storyteller [options]"
|
||||||
|
|
||||||
parser.on("-m MODE", "--mode=MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
|
parser.separator("Options:")
|
||||||
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
|
|
||||||
if result_mode.nil?
|
|
||||||
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
|
|
||||||
exit 1
|
|
||||||
end
|
|
||||||
gpt_mode = result_mode unless result_mode.nil?
|
|
||||||
end
|
|
||||||
|
|
||||||
parser.on("-i FILE", "--input=FILE", "Path to input file") do |file|
|
parser.on("-i FILE", "--input=FILE", "Path to input file") do |file|
|
||||||
input_file_path = file
|
input_file_path = file
|
||||||
end
|
end
|
||||||
|
|
||||||
parser.on("-v", "--verbose", "Be verbose (cumulative)") do
|
parser.on("-h", "--help", "Show this help") do
|
||||||
verbose = true
|
puts parser
|
||||||
|
exit
|
||||||
end
|
end
|
||||||
|
|
||||||
parser.on("--dry-run", "Don't call the API") do
|
|
||||||
make_request = false
|
|
||||||
end
|
|
||||||
|
|
||||||
parser.on("-n", "--no-color", "Disable color output") do
|
parser.on("-n", "--no-color", "Disable color output") do
|
||||||
use_color = false
|
use_color = false
|
||||||
|
@ -69,12 +66,42 @@ class Storyteller
|
||||||
output_file_path = file
|
output_file_path = file
|
||||||
end
|
end
|
||||||
|
|
||||||
parser.on("-h", "--help", "Show this help") do
|
parser.on("-v", "--verbose", "Be verbose (cumulative)") do
|
||||||
puts parser
|
verbose = true
|
||||||
exit
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
parser.on("--dry-run", "Don't call the API") do
|
||||||
|
make_request = false
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
parser.separator("GPT options")
|
||||||
|
|
||||||
|
parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
|
||||||
|
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
|
||||||
|
if result_mode.nil?
|
||||||
|
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
|
||||||
|
exit 1
|
||||||
|
end
|
||||||
|
gpt_mode = result_mode unless result_mode.nil?
|
||||||
|
end
|
||||||
|
|
||||||
|
parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature") do |temperature|
|
||||||
|
gpt_temperature = temperature
|
||||||
|
end
|
||||||
|
parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty") do |presence_penalty|
|
||||||
|
gpt_presence_penalty = presence_penalty
|
||||||
|
end
|
||||||
|
parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty") do |frequency_penalty|
|
||||||
|
gpt_frequency_penalty = frequency_penalty
|
||||||
|
end
|
||||||
|
parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens") do |max_tokens|
|
||||||
|
gpt_max_tokens = max_tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
end
|
||||||
|
parser.parse(ARGV)
|
||||||
|
|
||||||
# Create Storyteller instance
|
# Create Storyteller instance
|
||||||
storyteller = Storyteller.new()
|
storyteller = Storyteller.new()
|
||||||
|
|
||||||
|
@ -97,16 +124,38 @@ class Storyteller
|
||||||
end
|
end
|
||||||
storyteller.write_file(output_file, prompt, use_color)
|
storyteller.write_file(output_file, prompt, use_color)
|
||||||
output_file.close
|
output_file.close
|
||||||
|
rescue ex : OptionParser::InvalidOption
|
||||||
|
STDERR.puts parser
|
||||||
|
STDERR.puts "\nERROR: #{ex.message}"
|
||||||
|
exit 1
|
||||||
end
|
end
|
||||||
|
|
||||||
def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
|
def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
|
||||||
builder = Builder::OpenAIChat.new(verbose: verbose)
|
builder = Builder::OpenAIChat.new(verbose: verbose)
|
||||||
messages = builder.build(prompt)
|
messages = builder.build(prompt)
|
||||||
|
|
||||||
|
STDERR.puts messages if verbose
|
||||||
|
|
||||||
return prompt if !make_request
|
return prompt if !make_request
|
||||||
|
|
||||||
channel_ready = Channel(Bool).new
|
channel_ready = Channel(Bool).new
|
||||||
channel_tick = Channel(Bool).new
|
channel_tick = Channel(Bool).new
|
||||||
|
# sp = Spin.new(0.5, Spinner::Charset[:progress])
|
||||||
|
|
||||||
|
spawn do
|
||||||
|
tick = 0
|
||||||
|
loop do
|
||||||
|
tick += 1
|
||||||
|
print "."
|
||||||
|
if tick > (3 * 60)
|
||||||
|
print "(timeout)"
|
||||||
|
exit 1
|
||||||
|
end
|
||||||
|
channel_tick.send(true)
|
||||||
|
sleep 1.seconds
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
spawn do
|
spawn do
|
||||||
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
|
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
|
||||||
result = openai.chat(
|
result = openai.chat(
|
||||||
|
@ -119,26 +168,14 @@ class Storyteller
|
||||||
"max_tokens" => 256
|
"max_tokens" => 256
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n"
|
prompt.past_zone.content << "\n" + result.choices.first["message"]["content"] + "\n"
|
||||||
channel_ready.send(true)
|
channel_ready.send(true)
|
||||||
end
|
end
|
||||||
|
|
||||||
spawn do
|
# sp.start
|
||||||
loop do
|
|
||||||
channel_tick.send(true)
|
|
||||||
sleep 1.seconds
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
spawn do
|
|
||||||
while channel_tick.receive?
|
|
||||||
print "."
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
channel_ready.receive
|
channel_ready.receive
|
||||||
channel_ready.close
|
channel_ready.close
|
||||||
channel_tick.close
|
# sp.stop
|
||||||
|
|
||||||
prompt
|
prompt
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,6 +3,8 @@ require "../prompt"
|
||||||
|
|
||||||
module Parser
|
module Parser
|
||||||
class PromptString
|
class PromptString
|
||||||
|
SEPARATOR = "\n\n"
|
||||||
|
|
||||||
def parse(current : String) : Prompt
|
def parse(current : String) : Prompt
|
||||||
prompt = Prompt.new
|
prompt = Prompt.new
|
||||||
remaining = read_string(prompt.zones, current)
|
remaining = read_string(prompt.zones, current)
|
||||||
|
@ -44,7 +46,7 @@ module Parser
|
||||||
## Handle recognized tag, skip tag & parse & add remains
|
## Handle recognized tag, skip tag & parse & add remains
|
||||||
# puts "-- found tag #{zone.tag}"
|
# puts "-- found tag #{zone.tag}"
|
||||||
remaining = read_string(zone_list, current[(2+zone.tag.size)..])
|
remaining = read_string(zone_list, current[(2+zone.tag.size)..])
|
||||||
zone.content << remaining
|
zone.content.concat remaining.split(SEPARATOR).reverse!
|
||||||
return ""
|
return ""
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in a new issue