From 01efd5e597bf403bcb2e43427feb69881f4e1f07 Mon Sep 17 00:00:00 2001 From: Glenn Date: Sat, 22 Apr 2023 16:19:31 +0200 Subject: [PATCH] Add support for auto-split --- shard.lock | 4 ++ shard.yml | 2 + src/builders/openai_chat.cr | 12 ++-- src/builders/prompt_string.cr | 23 +++---- src/main.cr | 109 +++++++++++++++++++++++----------- src/parsers/prompt_string.cr | 4 +- 6 files changed, 102 insertions(+), 52 deletions(-) diff --git a/shard.lock b/shard.lock index 980beb4..fb64305 100644 --- a/shard.lock +++ b/shard.lock @@ -4,3 +4,7 @@ shards: git: https://github.com/lancecarlson/openai.cr.git version: 0.1.0+git.commit.852bcd9b37d8472a4c72a2498ebb90351048fa68 + spinner: + git: https://github.com/askn/spinner.git + version: 0.1.1 + diff --git a/shard.yml b/shard.yml index ed456d4..8c86777 100644 --- a/shard.yml +++ b/shard.yml @@ -12,6 +12,8 @@ targets: # Short description of ai-storyteller dependencies: + spinner: + github: askn/spinner openai: github: lancecarlson/openai.cr diff --git a/src/builders/openai_chat.cr b/src/builders/openai_chat.cr index ec1ff2f..7ad7e66 100644 --- a/src/builders/openai_chat.cr +++ b/src/builders/openai_chat.cr @@ -17,21 +17,25 @@ module Builder def build(prompt : Prompt) : Chat chat = [] of Message - token_limit = 2_900 + token_limit = 2_700 mandatory_token_count = ( prompt.system_zone.token_count + prompt.present_zone.token_count ) ## Build mandatory system messages - prompt.system_zone.each do |content| - chat << { role: "system", content: content } - end + # prompt.system_zone.each do |content| + # next if content.empty? + # chat << { role: "system", content: content } + # end + chat << { role: "system", content: prompt.system_zone.content.join("\n") } ## Build mandatory system messages tmp_chat = [] of Message tmp_token_count = 0 prompt.past_zone.reverse_each do |content| + next if content.empty? + estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count pp tmp_chat.reverse if @verbose puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose diff --git a/src/builders/prompt_string.cr b/src/builders/prompt_string.cr index 2a81865..f1a7284 100644 --- a/src/builders/prompt_string.cr +++ b/src/builders/prompt_string.cr @@ -5,6 +5,9 @@ require "./generic" module Builder class PromptString < PromptGeneric + + SEPARATOR = "\n\n" + getter use_color : Bool def initialize(@use_color) Colorize.enabled = @use_color @@ -13,28 +16,26 @@ module Builder def build(prompt : Prompt) str = "" - prompt.prelude_zone.each do |content| - str += content - end + str += prompt.prelude_zone.content.join(SEPARATOR) - prompt.system_zone.each do |content| + if ! prompt.system_zone.content.empty? str += "@@system".colorize(:yellow).to_s - str += content + str += prompt.system_zone.content.join(SEPARATOR) end - prompt.past_zone.each do |content| + if ! prompt.past_zone.content.empty? str += "@@before".colorize(:yellow).to_s - str += content + str += prompt.past_zone.content.join(SEPARATOR) end - prompt.present_zone.each do |content| + if ! prompt.present_zone.content.empty? str += "@@current".colorize(:yellow).to_s - str += content.colorize(:light_cyan).to_s + str += prompt.present_zone.content.join(SEPARATOR).colorize(:light_cyan).to_s end - prompt.future_zone.each do |content| + if ! prompt.future_zone.content.empty? str += "@@after".colorize(:yellow).to_s - str += content + str += prompt.future_zone.content.join("\n") end str diff --git a/src/main.cr b/src/main.cr index cd99784..483fa43 100644 --- a/src/main.cr +++ b/src/main.cr @@ -1,6 +1,7 @@ require "option_parser" require "pretty_print" require "openai" +require "spinner" require "./zone" require "./parsers/prompt_string" @@ -35,30 +36,26 @@ class Storyteller make_request = true gpt_mode = OpenAIMode::Chat verbose = false + gpt_temperature = 0.82 + gpt_presence_penalty = 1 + gpt_frequency_penalty = 1 + gpt_max_tokens = 256 - parser = OptionParser.parse do |parser| + parser = OptionParser.new do |parser| parser.banner = "Usage: storyteller [options]" - parser.on("-m MODE", "--mode=MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode| - result_mode = OpenAIMode.from_s(chosen_mode.downcase) - if result_mode.nil? - STDERR.puts "ERROR: unknown mode #{chosen_mode}" - exit 1 - end - gpt_mode = result_mode unless result_mode.nil? - end + parser.separator("Options:") + parser.on("-i FILE", "--input=FILE", "Path to input file") do |file| input_file_path = file end - parser.on("-v", "--verbose", "Be verbose (cumulative)") do - verbose = true + parser.on("-h", "--help", "Show this help") do + puts parser + exit end - parser.on("--dry-run", "Don't call the API") do - make_request = false - end parser.on("-n", "--no-color", "Disable color output") do use_color = false @@ -69,11 +66,41 @@ class Storyteller output_file_path = file end - parser.on("-h", "--help", "Show this help") do - puts parser - exit + parser.on("-v", "--verbose", "Be verbose (cumulative)") do + verbose = true end + + parser.on("--dry-run", "Don't call the API") do + make_request = false + end + + + parser.separator("GPT options") + + parser.on("--gpt-mode MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode| + result_mode = OpenAIMode.from_s(chosen_mode.downcase) + if result_mode.nil? + STDERR.puts "ERROR: unknown mode #{chosen_mode}" + exit 1 + end + gpt_mode = result_mode unless result_mode.nil? + end + + parser.on("--gpt-temperature TEMPERATURE", "GPT Temperature") do |temperature| + gpt_temperature = temperature + end + parser.on("--gpt-presence-penalty PENALTY", "GPT Presence Penalty") do |presence_penalty| + gpt_presence_penalty = presence_penalty + end + parser.on("--gpt-frequency-penalty PENALTY", "GPT Frequency Penalty") do |frequency_penalty| + gpt_frequency_penalty = frequency_penalty + end + parser.on("--gpt-max-tokens TOKENS", "GPT Max Tokens") do |max_tokens| + gpt_max_tokens = max_tokens + end + end + parser.parse(ARGV) # Create Storyteller instance storyteller = Storyteller.new() @@ -97,48 +124,58 @@ class Storyteller end storyteller.write_file(output_file, prompt, use_color) output_file.close + rescue ex : OptionParser::InvalidOption + STDERR.puts parser + STDERR.puts "\nERROR: #{ex.message}" + exit 1 end def complete(prompt : Prompt, make_request : Bool, verbose : Bool) builder = Builder::OpenAIChat.new(verbose: verbose) messages = builder.build(prompt) + STDERR.puts messages if verbose + return prompt if !make_request channel_ready = Channel(Bool).new channel_tick = Channel(Bool).new - spawn do - openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY")) - result = openai.chat( - "gpt-3.5-turbo", - messages, - { - "temperature" => 0.82, - "presence_penalty" => 1, - "frequency_penalty" => 1, - "max_tokens" => 256 - } - ) - prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n" - channel_ready.send(true) - end + # sp = Spin.new(0.5, Spinner::Charset[:progress]) spawn do + tick = 0 loop do + tick += 1 + print "." + if tick > (3 * 60) + print "(timeout)" + exit 1 + end channel_tick.send(true) sleep 1.seconds end end spawn do - while channel_tick.receive? - print "." - end + openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY")) + result = openai.chat( + "gpt-3.5-turbo", + messages, + { + "temperature" => 0.82, + "presence_penalty" => 1, + "frequency_penalty" => 1, + "max_tokens" => 256 + } + ) + prompt.past_zone.content << "\n" + result.choices.first["message"]["content"] + "\n" + channel_ready.send(true) end + # sp.start channel_ready.receive channel_ready.close - channel_tick.close + # sp.stop prompt end diff --git a/src/parsers/prompt_string.cr b/src/parsers/prompt_string.cr index 26f6690..237762e 100644 --- a/src/parsers/prompt_string.cr +++ b/src/parsers/prompt_string.cr @@ -3,6 +3,8 @@ require "../prompt" module Parser class PromptString + SEPARATOR = "\n\n" + def parse(current : String) : Prompt prompt = Prompt.new remaining = read_string(prompt.zones, current) @@ -44,7 +46,7 @@ module Parser ## Handle recognized tag, skip tag & parse & add remains # puts "-- found tag #{zone.tag}" remaining = read_string(zone_list, current[(2+zone.tag.size)..]) - zone.content << remaining + zone.content.concat remaining.split(SEPARATOR).reverse! return "" end end