From 65d80b090c6d6134311f87a12a273e150758077f Mon Sep 17 00:00:00 2001 From: Glenn Date: Tue, 18 Apr 2023 15:16:32 +0200 Subject: [PATCH] Add text parsing & request --- src/builders/generic.cr | 7 ++- src/builders/openai_chat.cr | 81 ++++++++++++++++++----------------- src/builders/openai_insert.cr | 59 +++++++++++++++++++++++++ src/builders/prompt_string.cr | 43 +++++++++++++++++++ src/builders/string.cr | 41 ------------------ src/main.cr | 71 ++++++++++++++++++++++++------ src/parsers/generic.cr | 6 ++- src/parsers/prompt_string.cr | 51 ++++++++++++++++++++++ src/parsers/request.cr | 8 ++-- src/parsers/string.cr | 50 --------------------- 10 files changed, 267 insertions(+), 150 deletions(-) create mode 100644 src/builders/openai_insert.cr create mode 100644 src/builders/prompt_string.cr delete mode 100644 src/builders/string.cr create mode 100644 src/parsers/prompt_string.cr delete mode 100644 src/parsers/string.cr diff --git a/src/builders/generic.cr b/src/builders/generic.cr index a6b300c..259eca6 100644 --- a/src/builders/generic.cr +++ b/src/builders/generic.cr @@ -1,3 +1,6 @@ -abstract class PromptGenericBuilder - abstract def build(prompt : Prompt) + +module Builder + abstract class PromptGeneric + abstract def build(prompt : Prompt) + end end diff --git a/src/builders/openai_chat.cr b/src/builders/openai_chat.cr index 5eb7491..ec1ff2f 100644 --- a/src/builders/openai_chat.cr +++ b/src/builders/openai_chat.cr @@ -2,52 +2,53 @@ require "pretty_print" require "colorize" -class OpenAIChatBuilder - alias OpenAIMessage = NamedTuple(role: String, content: String) - alias OpenAIChat = Array(OpenAIMessage) +module Builder + class OpenAIChat + alias Message = NamedTuple(role: String, content: String) + alias Chat = Array(Message) - getter verbose : Bool - def initialize(@verbose) + getter verbose : Bool + def initialize(@verbose) - end - - # skip prelude_zone - # skip future_zone - def build(prompt : Prompt) : OpenAIChat - chat = [] of OpenAIMessage - - token_limit = 2_900 - mandatory_token_count = ( - prompt.system_zone.token_count + - prompt.present_zone.token_count - ) - - ## Build mandatory system messages - prompt.system_zone.each do |content| - chat << { role: "system", content: content } end - ## Build mandatory system messages - tmp_chat = [] of OpenAIMessage - tmp_token_count = 0 - prompt.past_zone.reverse_each do |content| - estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count - pp tmp_chat.reverse if @verbose - puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose + # skip prelude_zone + # skip future_zone + def build(prompt : Prompt) : Chat + chat = [] of Message - break if estimated_token_count >= token_limit + token_limit = 2_900 + mandatory_token_count = ( + prompt.system_zone.token_count + + prompt.present_zone.token_count + ) - tmp_chat << { role: "user", content: content } - tmp_token_count += (content.size / 4) + ## Build mandatory system messages + prompt.system_zone.each do |content| + chat << { role: "system", content: content } + end + + ## Build mandatory system messages + tmp_chat = [] of Message + tmp_token_count = 0 + prompt.past_zone.reverse_each do |content| + estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count + pp tmp_chat.reverse if @verbose + puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose + + break if estimated_token_count >= token_limit + + tmp_chat << { role: "user", content: content } + tmp_token_count += (content.size / 4) + end + chat.concat(tmp_chat.reverse) + + prompt.present_zone.each do |content| + chat << { role: "user", content: content } + end + + # pp chat + chat end - chat.concat(tmp_chat.reverse) - - prompt.present_zone.each do |content| - chat << { role: "user", content: content } - end - - # pp chat - chat end end - diff --git a/src/builders/openai_insert.cr b/src/builders/openai_insert.cr new file mode 100644 index 0000000..a610b74 --- /dev/null +++ b/src/builders/openai_insert.cr @@ -0,0 +1,59 @@ + +require "pretty_print" +require "colorize" + +module Builder + class OpenAIInsert + alias Message = NamedTuple(role: String, content: String) + alias Chat = Array(Message) + + getter verbose : Bool + def initialize(@verbose) + + end + + # skip prelude_zone + # skip future_zone + def build(prompt : Prompt) : Chat + chat = [] of Message + + token_limit = 2_900 + mandatory_token_count = ( + prompt.system_zone.token_count + + prompt.present_zone.token_count + ) + + ## Build mandatory system messages + prompt.system_zone.each do |content| + chat << { role: "system", content: content } + end + + ## Build mandatory system messages + tmp_chat = [] of Message + tmp_token_count = 0 + prompt.past_zone.reverse_each do |content| + estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count + pp tmp_chat.reverse if @verbose + puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose + + break if estimated_token_count >= token_limit + + tmp_chat << { role: "user", content: content } + tmp_token_count += (content.size / 4) + end + chat.concat(tmp_chat.reverse) + + prompt.present_zone.each do |content| + chat << { role: "user", content: content } + end + + prompt.future_zone.each do |content| + chat << { role: "user", content: content } + end + + # pp chat + chat + end + end +end + diff --git a/src/builders/prompt_string.cr b/src/builders/prompt_string.cr new file mode 100644 index 0000000..2a81865 --- /dev/null +++ b/src/builders/prompt_string.cr @@ -0,0 +1,43 @@ + +require "colorize" + +require "./generic" + +module Builder + class PromptString < PromptGeneric + getter use_color : Bool + def initialize(@use_color) + Colorize.enabled = @use_color + end + + def build(prompt : Prompt) + str = "" + + prompt.prelude_zone.each do |content| + str += content + end + + prompt.system_zone.each do |content| + str += "@@system".colorize(:yellow).to_s + str += content + end + + prompt.past_zone.each do |content| + str += "@@before".colorize(:yellow).to_s + str += content + end + + prompt.present_zone.each do |content| + str += "@@current".colorize(:yellow).to_s + str += content.colorize(:light_cyan).to_s + end + + prompt.future_zone.each do |content| + str += "@@after".colorize(:yellow).to_s + str += content + end + + str + end + end +end diff --git a/src/builders/string.cr b/src/builders/string.cr deleted file mode 100644 index febcc99..0000000 --- a/src/builders/string.cr +++ /dev/null @@ -1,41 +0,0 @@ - -require "colorize" - -require "./generic" - -class StringBuilder < PromptGenericBuilder - getter use_color : Bool - def initialize(@use_color) - Colorize.enabled = @use_color - end - - def build(prompt : Prompt) - str = "" - - prompt.prelude_zone.each do |content| - str += content - end - - prompt.system_zone.each do |content| - str += "@@system".colorize(:yellow).to_s - str += content - end - - prompt.past_zone.each do |content| - str += "@@before".colorize(:yellow).to_s - str += content - end - - prompt.present_zone.each do |content| - str += "@@current".colorize(:yellow).to_s - str += content.colorize(:light_cyan).to_s - end - - prompt.future_zone.each do |content| - str += "@@after".colorize(:yellow).to_s - str += content - end - - str - end -end diff --git a/src/main.cr b/src/main.cr index 72d47a5..cd99784 100644 --- a/src/main.cr +++ b/src/main.cr @@ -3,12 +3,25 @@ require "pretty_print" require "openai" require "./zone" -require "./parsers/string" -require "./builders/string" +require "./parsers/prompt_string" +require "./builders/prompt_string" require "./builders/openai_chat" +require "./builders/openai_insert" class Storyteller + enum OpenAIMode + Chat + Complete + Insert + + def self.from_s(str_mode) + return OpenAIMode.values.find do |openai_mode| + openai_mode.to_s.downcase == str_mode.downcase + end + end + end + def initialize() end @@ -20,11 +33,21 @@ class Storyteller past_characters_limit = 1000 use_color = true make_request = true + gpt_mode = OpenAIMode::Chat verbose = false parser = OptionParser.parse do |parser| parser.banner = "Usage: storyteller [options]" + parser.on("-m MODE", "--mode=MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode| + result_mode = OpenAIMode.from_s(chosen_mode.downcase) + if result_mode.nil? + STDERR.puts "ERROR: unknown mode #{chosen_mode}" + exit 1 + end + gpt_mode = result_mode unless result_mode.nil? + end + parser.on("-i FILE", "--input=FILE", "Path to input file") do |file| input_file_path = file end @@ -65,6 +88,7 @@ class Storyteller # Build GPT-3 request prompt = storyteller.complete(prompt, make_request, verbose) + pp prompt if verbose exit 0 if !make_request if !output_file_path.empty? @@ -76,23 +100,46 @@ class Storyteller end def complete(prompt : Prompt, make_request : Bool, verbose : Bool) - builder = OpenAIChatBuilder.new(verbose: verbose) + builder = Builder::OpenAIChat.new(verbose: verbose) messages = builder.build(prompt) return prompt if !make_request - openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY")) - result = openai.chat( - "gpt-3.5-turbo", - messages, - { + channel_ready = Channel(Bool).new + channel_tick = Channel(Bool).new + spawn do + openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY")) + result = openai.chat( + "gpt-3.5-turbo", + messages, + { "temperature" => 0.82, "presence_penalty" => 1, "frequency_penalty" => 1, "max_tokens" => 256 } - ) - prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n" + ) + prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n" + channel_ready.send(true) + end + + spawn do + loop do + channel_tick.send(true) + sleep 1.seconds + end + end + + spawn do + while channel_tick.receive? + print "." + end + end + + channel_ready.receive + channel_ready.close + channel_tick.close + prompt end @@ -100,7 +147,7 @@ class Storyteller content = input_file.gets_to_end # puts "d: building parser" - parser = StringParser.new + parser = Parser::PromptString.new # puts "d: parsing" prompt = parser.parse(content) # pp prompt @@ -108,7 +155,7 @@ class Storyteller def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool) # STDERR.puts "d: building builder" - builder = StringBuilder.new(use_color) + builder = Builder::PromptString.new(use_color) # STDERR.puts "d: building" text = builder.build(prompt) output_file.write_string(text.to_slice) diff --git a/src/parsers/generic.cr b/src/parsers/generic.cr index bde35a3..7d5526c 100644 --- a/src/parsers/generic.cr +++ b/src/parsers/generic.cr @@ -1,3 +1,5 @@ -abstract class PromptGenericParser - abstract def parse : Prompt +module Parser + abstract class PromptGenericParser + abstract def parse : Prompt + end end diff --git a/src/parsers/prompt_string.cr b/src/parsers/prompt_string.cr new file mode 100644 index 0000000..26f6690 --- /dev/null +++ b/src/parsers/prompt_string.cr @@ -0,0 +1,51 @@ + +require "../prompt" + +module Parser + class PromptString + def parse(current : String) : Prompt + prompt = Prompt.new + remaining = read_string(prompt.zones, current) + prompt.prelude_zone.content << remaining + + prompt.system_zone.content.reverse! + prompt.past_zone.content.reverse! + prompt.present_zone.content.reverse! + prompt.future_zone.content.reverse! + + return prompt + end + + def read_string(zone_list : Array(Zone), current : String) + # puts "== read_string(current=#{current})" + + pos = current.index("@@") + + ## If there is no remaining @@, then return current + if pos.nil? + # puts "-- no remaining @@, returning" + return current + end + + ## If @@ is not at position 0, then parse its content first and return remains + if pos > 0 + return current[0..(pos-1)] + read_string(zone_list, current[pos..]) + end + + ## If @@ is at position 0, try detecting tag + zone = zone_list.find { |zone| current.starts_with?("@@" + zone.tag) } + + ## When there is not recognizable tag, skip fake tag and parse + if zone.nil? + # puts "-- no recognizable tag, returning as is" + return "@@" + read_string(zone_list, current[2..]) + end + + ## Handle recognized tag, skip tag & parse & add remains + # puts "-- found tag #{zone.tag}" + remaining = read_string(zone_list, current[(2+zone.tag.size)..]) + zone.content << remaining + return "" + end + end +end diff --git a/src/parsers/request.cr b/src/parsers/request.cr index 43555d8..10e1b2d 100644 --- a/src/parsers/request.cr +++ b/src/parsers/request.cr @@ -1,7 +1,9 @@ -class PromptRequestParser - getter prompt +module Parser + class PromptRequest + getter prompt - def initialize(@prompt = prompt) + def initialize(@prompt = prompt) + end end end diff --git a/src/parsers/string.cr b/src/parsers/string.cr deleted file mode 100644 index 8fa210c..0000000 --- a/src/parsers/string.cr +++ /dev/null @@ -1,50 +0,0 @@ - -require "../prompt" - -class StringParser - def parse(current : String) : Prompt - prompt = Prompt.new - remaining = read_string(prompt.zones, current) - prompt.prelude_zone.content << remaining - - prompt.system_zone.content.reverse! - prompt.past_zone.content.reverse! - prompt.present_zone.content.reverse! - prompt.future_zone.content.reverse! - - return prompt - end - - def read_string(zone_list : Array(Zone), current : String) - # puts "== read_string(current=#{current})" - - pos = current.index("@@") - - ## If there is no remaining @@, then return current - if pos.nil? - # puts "-- no remaining @@, returning" - return current - end - - ## If @@ is not at position 0, then parse its content first and return remains - if pos > 0 - return current[0..(pos-1)] + read_string(zone_list, current[pos..]) - end - - ## If @@ is at position 0, try detecting tag - zone = zone_list.find { |zone| current.starts_with?("@@" + zone.tag) } - - ## When there is not recognizable tag, skip fake tag and parse - if zone.nil? - # puts "-- no recognizable tag, returning as is" - return "@@" + read_string(zone_list, current[2..]) - end - - ## Handle recognized tag, skip tag & parse & add remains - # puts "-- found tag #{zone.tag}" - remaining = read_string(zone_list, current[(2+zone.tag.size)..]) - zone.content << remaining - return "" - end -end -