Add text parsing & request

This commit is contained in:
Glenn Y. Rolland 2023-04-18 15:16:32 +02:00
parent 13f1a784a1
commit 65d80b090c
10 changed files with 267 additions and 150 deletions

View file

@ -1,3 +1,6 @@
abstract class PromptGenericBuilder
module Builder
abstract class PromptGeneric
abstract def build(prompt : Prompt) abstract def build(prompt : Prompt)
end
end end

View file

@ -2,9 +2,10 @@
require "pretty_print" require "pretty_print"
require "colorize" require "colorize"
class OpenAIChatBuilder module Builder
alias OpenAIMessage = NamedTuple(role: String, content: String) class OpenAIChat
alias OpenAIChat = Array(OpenAIMessage) alias Message = NamedTuple(role: String, content: String)
alias Chat = Array(Message)
getter verbose : Bool getter verbose : Bool
def initialize(@verbose) def initialize(@verbose)
@ -13,8 +14,8 @@ class OpenAIChatBuilder
# skip prelude_zone # skip prelude_zone
# skip future_zone # skip future_zone
def build(prompt : Prompt) : OpenAIChat def build(prompt : Prompt) : Chat
chat = [] of OpenAIMessage chat = [] of Message
token_limit = 2_900 token_limit = 2_900
mandatory_token_count = ( mandatory_token_count = (
@ -28,7 +29,7 @@ class OpenAIChatBuilder
end end
## Build mandatory system messages ## Build mandatory system messages
tmp_chat = [] of OpenAIMessage tmp_chat = [] of Message
tmp_token_count = 0 tmp_token_count = 0
prompt.past_zone.reverse_each do |content| prompt.past_zone.reverse_each do |content|
estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count
@ -49,5 +50,5 @@ class OpenAIChatBuilder
# pp chat # pp chat
chat chat
end end
end
end end

View file

@ -0,0 +1,59 @@
require "pretty_print"
require "colorize"
module Builder
class OpenAIInsert
alias Message = NamedTuple(role: String, content: String)
alias Chat = Array(Message)
getter verbose : Bool
def initialize(@verbose)
end
# skip prelude_zone
# skip future_zone
def build(prompt : Prompt) : Chat
chat = [] of Message
token_limit = 2_900
mandatory_token_count = (
prompt.system_zone.token_count +
prompt.present_zone.token_count
)
## Build mandatory system messages
prompt.system_zone.each do |content|
chat << { role: "system", content: content }
end
## Build mandatory system messages
tmp_chat = [] of Message
tmp_token_count = 0
prompt.past_zone.reverse_each do |content|
estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count
pp tmp_chat.reverse if @verbose
puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose
break if estimated_token_count >= token_limit
tmp_chat << { role: "user", content: content }
tmp_token_count += (content.size / 4)
end
chat.concat(tmp_chat.reverse)
prompt.present_zone.each do |content|
chat << { role: "user", content: content }
end
prompt.future_zone.each do |content|
chat << { role: "user", content: content }
end
# pp chat
chat
end
end
end

View file

@ -0,0 +1,43 @@
require "colorize"
require "./generic"
module Builder
class PromptString < PromptGeneric
getter use_color : Bool
def initialize(@use_color)
Colorize.enabled = @use_color
end
def build(prompt : Prompt)
str = ""
prompt.prelude_zone.each do |content|
str += content
end
prompt.system_zone.each do |content|
str += "@@system".colorize(:yellow).to_s
str += content
end
prompt.past_zone.each do |content|
str += "@@before".colorize(:yellow).to_s
str += content
end
prompt.present_zone.each do |content|
str += "@@current".colorize(:yellow).to_s
str += content.colorize(:light_cyan).to_s
end
prompt.future_zone.each do |content|
str += "@@after".colorize(:yellow).to_s
str += content
end
str
end
end
end

View file

@ -1,41 +0,0 @@
require "colorize"
require "./generic"
class StringBuilder < PromptGenericBuilder
getter use_color : Bool
def initialize(@use_color)
Colorize.enabled = @use_color
end
def build(prompt : Prompt)
str = ""
prompt.prelude_zone.each do |content|
str += content
end
prompt.system_zone.each do |content|
str += "@@system".colorize(:yellow).to_s
str += content
end
prompt.past_zone.each do |content|
str += "@@before".colorize(:yellow).to_s
str += content
end
prompt.present_zone.each do |content|
str += "@@current".colorize(:yellow).to_s
str += content.colorize(:light_cyan).to_s
end
prompt.future_zone.each do |content|
str += "@@after".colorize(:yellow).to_s
str += content
end
str
end
end

View file

@ -3,12 +3,25 @@ require "pretty_print"
require "openai" require "openai"
require "./zone" require "./zone"
require "./parsers/string" require "./parsers/prompt_string"
require "./builders/string" require "./builders/prompt_string"
require "./builders/openai_chat" require "./builders/openai_chat"
require "./builders/openai_insert"
class Storyteller class Storyteller
enum OpenAIMode
Chat
Complete
Insert
def self.from_s(str_mode)
return OpenAIMode.values.find do |openai_mode|
openai_mode.to_s.downcase == str_mode.downcase
end
end
end
def initialize() def initialize()
end end
@ -20,11 +33,21 @@ class Storyteller
past_characters_limit = 1000 past_characters_limit = 1000
use_color = true use_color = true
make_request = true make_request = true
gpt_mode = OpenAIMode::Chat
verbose = false verbose = false
parser = OptionParser.parse do |parser| parser = OptionParser.parse do |parser|
parser.banner = "Usage: storyteller [options]" parser.banner = "Usage: storyteller [options]"
parser.on("-m MODE", "--mode=MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
if result_mode.nil?
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
exit 1
end
gpt_mode = result_mode unless result_mode.nil?
end
parser.on("-i FILE", "--input=FILE", "Path to input file") do |file| parser.on("-i FILE", "--input=FILE", "Path to input file") do |file|
input_file_path = file input_file_path = file
end end
@ -65,6 +88,7 @@ class Storyteller
# Build GPT-3 request # Build GPT-3 request
prompt = storyteller.complete(prompt, make_request, verbose) prompt = storyteller.complete(prompt, make_request, verbose)
pp prompt if verbose
exit 0 if !make_request exit 0 if !make_request
if !output_file_path.empty? if !output_file_path.empty?
@ -76,11 +100,14 @@ class Storyteller
end end
def complete(prompt : Prompt, make_request : Bool, verbose : Bool) def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
builder = OpenAIChatBuilder.new(verbose: verbose) builder = Builder::OpenAIChat.new(verbose: verbose)
messages = builder.build(prompt) messages = builder.build(prompt)
return prompt if !make_request return prompt if !make_request
channel_ready = Channel(Bool).new
channel_tick = Channel(Bool).new
spawn do
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY")) openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
result = openai.chat( result = openai.chat(
"gpt-3.5-turbo", "gpt-3.5-turbo",
@ -93,6 +120,26 @@ class Storyteller
} }
) )
prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n" prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n"
channel_ready.send(true)
end
spawn do
loop do
channel_tick.send(true)
sleep 1.seconds
end
end
spawn do
while channel_tick.receive?
print "."
end
end
channel_ready.receive
channel_ready.close
channel_tick.close
prompt prompt
end end
@ -100,7 +147,7 @@ class Storyteller
content = input_file.gets_to_end content = input_file.gets_to_end
# puts "d: building parser" # puts "d: building parser"
parser = StringParser.new parser = Parser::PromptString.new
# puts "d: parsing" # puts "d: parsing"
prompt = parser.parse(content) prompt = parser.parse(content)
# pp prompt # pp prompt
@ -108,7 +155,7 @@ class Storyteller
def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool) def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool)
# STDERR.puts "d: building builder" # STDERR.puts "d: building builder"
builder = StringBuilder.new(use_color) builder = Builder::PromptString.new(use_color)
# STDERR.puts "d: building" # STDERR.puts "d: building"
text = builder.build(prompt) text = builder.build(prompt)
output_file.write_string(text.to_slice) output_file.write_string(text.to_slice)

View file

@ -1,3 +1,5 @@
abstract class PromptGenericParser module Parser
abstract class PromptGenericParser
abstract def parse : Prompt abstract def parse : Prompt
end
end end

View file

@ -0,0 +1,51 @@
require "../prompt"
module Parser
class PromptString
def parse(current : String) : Prompt
prompt = Prompt.new
remaining = read_string(prompt.zones, current)
prompt.prelude_zone.content << remaining
prompt.system_zone.content.reverse!
prompt.past_zone.content.reverse!
prompt.present_zone.content.reverse!
prompt.future_zone.content.reverse!
return prompt
end
def read_string(zone_list : Array(Zone), current : String)
# puts "== read_string(current=#{current})"
pos = current.index("@@")
## If there is no remaining @@, then return current
if pos.nil?
# puts "-- no remaining @@, returning"
return current
end
## If @@ is not at position 0, then parse its content first and return remains
if pos > 0
return current[0..(pos-1)] + read_string(zone_list, current[pos..])
end
## If @@ is at position 0, try detecting tag
zone = zone_list.find { |zone| current.starts_with?("@@" + zone.tag) }
## When there is not recognizable tag, skip fake tag and parse
if zone.nil?
# puts "-- no recognizable tag, returning as is"
return "@@" + read_string(zone_list, current[2..])
end
## Handle recognized tag, skip tag & parse & add remains
# puts "-- found tag #{zone.tag}"
remaining = read_string(zone_list, current[(2+zone.tag.size)..])
zone.content << remaining
return ""
end
end
end

View file

@ -1,7 +1,9 @@
class PromptRequestParser module Parser
class PromptRequest
getter prompt getter prompt
def initialize(@prompt = prompt) def initialize(@prompt = prompt)
end end
end
end end

View file

@ -1,50 +0,0 @@
require "../prompt"
class StringParser
def parse(current : String) : Prompt
prompt = Prompt.new
remaining = read_string(prompt.zones, current)
prompt.prelude_zone.content << remaining
prompt.system_zone.content.reverse!
prompt.past_zone.content.reverse!
prompt.present_zone.content.reverse!
prompt.future_zone.content.reverse!
return prompt
end
def read_string(zone_list : Array(Zone), current : String)
# puts "== read_string(current=#{current})"
pos = current.index("@@")
## If there is no remaining @@, then return current
if pos.nil?
# puts "-- no remaining @@, returning"
return current
end
## If @@ is not at position 0, then parse its content first and return remains
if pos > 0
return current[0..(pos-1)] + read_string(zone_list, current[pos..])
end
## If @@ is at position 0, try detecting tag
zone = zone_list.find { |zone| current.starts_with?("@@" + zone.tag) }
## When there is not recognizable tag, skip fake tag and parse
if zone.nil?
# puts "-- no recognizable tag, returning as is"
return "@@" + read_string(zone_list, current[2..])
end
## Handle recognized tag, skip tag & parse & add remains
# puts "-- found tag #{zone.tag}"
remaining = read_string(zone_list, current[(2+zone.tag.size)..])
zone.content << remaining
return ""
end
end