Add text parsing & request
This commit is contained in:
parent
13f1a784a1
commit
65d80b090c
10 changed files with 267 additions and 150 deletions
|
@ -1,3 +1,6 @@
|
||||||
abstract class PromptGenericBuilder
|
|
||||||
abstract def build(prompt : Prompt)
|
module Builder
|
||||||
|
abstract class PromptGeneric
|
||||||
|
abstract def build(prompt : Prompt)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -2,52 +2,53 @@
|
||||||
require "pretty_print"
|
require "pretty_print"
|
||||||
require "colorize"
|
require "colorize"
|
||||||
|
|
||||||
class OpenAIChatBuilder
|
module Builder
|
||||||
alias OpenAIMessage = NamedTuple(role: String, content: String)
|
class OpenAIChat
|
||||||
alias OpenAIChat = Array(OpenAIMessage)
|
alias Message = NamedTuple(role: String, content: String)
|
||||||
|
alias Chat = Array(Message)
|
||||||
|
|
||||||
getter verbose : Bool
|
getter verbose : Bool
|
||||||
def initialize(@verbose)
|
def initialize(@verbose)
|
||||||
|
|
||||||
end
|
|
||||||
|
|
||||||
# skip prelude_zone
|
|
||||||
# skip future_zone
|
|
||||||
def build(prompt : Prompt) : OpenAIChat
|
|
||||||
chat = [] of OpenAIMessage
|
|
||||||
|
|
||||||
token_limit = 2_900
|
|
||||||
mandatory_token_count = (
|
|
||||||
prompt.system_zone.token_count +
|
|
||||||
prompt.present_zone.token_count
|
|
||||||
)
|
|
||||||
|
|
||||||
## Build mandatory system messages
|
|
||||||
prompt.system_zone.each do |content|
|
|
||||||
chat << { role: "system", content: content }
|
|
||||||
end
|
end
|
||||||
|
|
||||||
## Build mandatory system messages
|
# skip prelude_zone
|
||||||
tmp_chat = [] of OpenAIMessage
|
# skip future_zone
|
||||||
tmp_token_count = 0
|
def build(prompt : Prompt) : Chat
|
||||||
prompt.past_zone.reverse_each do |content|
|
chat = [] of Message
|
||||||
estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count
|
|
||||||
pp tmp_chat.reverse if @verbose
|
|
||||||
puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose
|
|
||||||
|
|
||||||
break if estimated_token_count >= token_limit
|
token_limit = 2_900
|
||||||
|
mandatory_token_count = (
|
||||||
|
prompt.system_zone.token_count +
|
||||||
|
prompt.present_zone.token_count
|
||||||
|
)
|
||||||
|
|
||||||
tmp_chat << { role: "user", content: content }
|
## Build mandatory system messages
|
||||||
tmp_token_count += (content.size / 4)
|
prompt.system_zone.each do |content|
|
||||||
|
chat << { role: "system", content: content }
|
||||||
|
end
|
||||||
|
|
||||||
|
## Build mandatory system messages
|
||||||
|
tmp_chat = [] of Message
|
||||||
|
tmp_token_count = 0
|
||||||
|
prompt.past_zone.reverse_each do |content|
|
||||||
|
estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count
|
||||||
|
pp tmp_chat.reverse if @verbose
|
||||||
|
puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose
|
||||||
|
|
||||||
|
break if estimated_token_count >= token_limit
|
||||||
|
|
||||||
|
tmp_chat << { role: "user", content: content }
|
||||||
|
tmp_token_count += (content.size / 4)
|
||||||
|
end
|
||||||
|
chat.concat(tmp_chat.reverse)
|
||||||
|
|
||||||
|
prompt.present_zone.each do |content|
|
||||||
|
chat << { role: "user", content: content }
|
||||||
|
end
|
||||||
|
|
||||||
|
# pp chat
|
||||||
|
chat
|
||||||
end
|
end
|
||||||
chat.concat(tmp_chat.reverse)
|
|
||||||
|
|
||||||
prompt.present_zone.each do |content|
|
|
||||||
chat << { role: "user", content: content }
|
|
||||||
end
|
|
||||||
|
|
||||||
# pp chat
|
|
||||||
chat
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
59
src/builders/openai_insert.cr
Normal file
59
src/builders/openai_insert.cr
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
|
||||||
|
require "pretty_print"
|
||||||
|
require "colorize"
|
||||||
|
|
||||||
|
module Builder
|
||||||
|
class OpenAIInsert
|
||||||
|
alias Message = NamedTuple(role: String, content: String)
|
||||||
|
alias Chat = Array(Message)
|
||||||
|
|
||||||
|
getter verbose : Bool
|
||||||
|
def initialize(@verbose)
|
||||||
|
|
||||||
|
end
|
||||||
|
|
||||||
|
# skip prelude_zone
|
||||||
|
# skip future_zone
|
||||||
|
def build(prompt : Prompt) : Chat
|
||||||
|
chat = [] of Message
|
||||||
|
|
||||||
|
token_limit = 2_900
|
||||||
|
mandatory_token_count = (
|
||||||
|
prompt.system_zone.token_count +
|
||||||
|
prompt.present_zone.token_count
|
||||||
|
)
|
||||||
|
|
||||||
|
## Build mandatory system messages
|
||||||
|
prompt.system_zone.each do |content|
|
||||||
|
chat << { role: "system", content: content }
|
||||||
|
end
|
||||||
|
|
||||||
|
## Build mandatory system messages
|
||||||
|
tmp_chat = [] of Message
|
||||||
|
tmp_token_count = 0
|
||||||
|
prompt.past_zone.reverse_each do |content|
|
||||||
|
estimated_token_count = (content.size/4) + tmp_token_count + mandatory_token_count
|
||||||
|
pp tmp_chat.reverse if @verbose
|
||||||
|
puts "ESTIMATE: #{estimated_token_count} (limit=#{token_limit})".colorize(:yellow).to_s if @verbose
|
||||||
|
|
||||||
|
break if estimated_token_count >= token_limit
|
||||||
|
|
||||||
|
tmp_chat << { role: "user", content: content }
|
||||||
|
tmp_token_count += (content.size / 4)
|
||||||
|
end
|
||||||
|
chat.concat(tmp_chat.reverse)
|
||||||
|
|
||||||
|
prompt.present_zone.each do |content|
|
||||||
|
chat << { role: "user", content: content }
|
||||||
|
end
|
||||||
|
|
||||||
|
prompt.future_zone.each do |content|
|
||||||
|
chat << { role: "user", content: content }
|
||||||
|
end
|
||||||
|
|
||||||
|
# pp chat
|
||||||
|
chat
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
43
src/builders/prompt_string.cr
Normal file
43
src/builders/prompt_string.cr
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
|
||||||
|
require "colorize"
|
||||||
|
|
||||||
|
require "./generic"
|
||||||
|
|
||||||
|
module Builder
|
||||||
|
class PromptString < PromptGeneric
|
||||||
|
getter use_color : Bool
|
||||||
|
def initialize(@use_color)
|
||||||
|
Colorize.enabled = @use_color
|
||||||
|
end
|
||||||
|
|
||||||
|
def build(prompt : Prompt)
|
||||||
|
str = ""
|
||||||
|
|
||||||
|
prompt.prelude_zone.each do |content|
|
||||||
|
str += content
|
||||||
|
end
|
||||||
|
|
||||||
|
prompt.system_zone.each do |content|
|
||||||
|
str += "@@system".colorize(:yellow).to_s
|
||||||
|
str += content
|
||||||
|
end
|
||||||
|
|
||||||
|
prompt.past_zone.each do |content|
|
||||||
|
str += "@@before".colorize(:yellow).to_s
|
||||||
|
str += content
|
||||||
|
end
|
||||||
|
|
||||||
|
prompt.present_zone.each do |content|
|
||||||
|
str += "@@current".colorize(:yellow).to_s
|
||||||
|
str += content.colorize(:light_cyan).to_s
|
||||||
|
end
|
||||||
|
|
||||||
|
prompt.future_zone.each do |content|
|
||||||
|
str += "@@after".colorize(:yellow).to_s
|
||||||
|
str += content
|
||||||
|
end
|
||||||
|
|
||||||
|
str
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,41 +0,0 @@
|
||||||
|
|
||||||
require "colorize"
|
|
||||||
|
|
||||||
require "./generic"
|
|
||||||
|
|
||||||
class StringBuilder < PromptGenericBuilder
|
|
||||||
getter use_color : Bool
|
|
||||||
def initialize(@use_color)
|
|
||||||
Colorize.enabled = @use_color
|
|
||||||
end
|
|
||||||
|
|
||||||
def build(prompt : Prompt)
|
|
||||||
str = ""
|
|
||||||
|
|
||||||
prompt.prelude_zone.each do |content|
|
|
||||||
str += content
|
|
||||||
end
|
|
||||||
|
|
||||||
prompt.system_zone.each do |content|
|
|
||||||
str += "@@system".colorize(:yellow).to_s
|
|
||||||
str += content
|
|
||||||
end
|
|
||||||
|
|
||||||
prompt.past_zone.each do |content|
|
|
||||||
str += "@@before".colorize(:yellow).to_s
|
|
||||||
str += content
|
|
||||||
end
|
|
||||||
|
|
||||||
prompt.present_zone.each do |content|
|
|
||||||
str += "@@current".colorize(:yellow).to_s
|
|
||||||
str += content.colorize(:light_cyan).to_s
|
|
||||||
end
|
|
||||||
|
|
||||||
prompt.future_zone.each do |content|
|
|
||||||
str += "@@after".colorize(:yellow).to_s
|
|
||||||
str += content
|
|
||||||
end
|
|
||||||
|
|
||||||
str
|
|
||||||
end
|
|
||||||
end
|
|
71
src/main.cr
71
src/main.cr
|
@ -3,12 +3,25 @@ require "pretty_print"
|
||||||
require "openai"
|
require "openai"
|
||||||
|
|
||||||
require "./zone"
|
require "./zone"
|
||||||
require "./parsers/string"
|
require "./parsers/prompt_string"
|
||||||
require "./builders/string"
|
require "./builders/prompt_string"
|
||||||
require "./builders/openai_chat"
|
require "./builders/openai_chat"
|
||||||
|
require "./builders/openai_insert"
|
||||||
|
|
||||||
class Storyteller
|
class Storyteller
|
||||||
|
|
||||||
|
enum OpenAIMode
|
||||||
|
Chat
|
||||||
|
Complete
|
||||||
|
Insert
|
||||||
|
|
||||||
|
def self.from_s(str_mode)
|
||||||
|
return OpenAIMode.values.find do |openai_mode|
|
||||||
|
openai_mode.to_s.downcase == str_mode.downcase
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def initialize()
|
def initialize()
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -20,11 +33,21 @@ class Storyteller
|
||||||
past_characters_limit = 1000
|
past_characters_limit = 1000
|
||||||
use_color = true
|
use_color = true
|
||||||
make_request = true
|
make_request = true
|
||||||
|
gpt_mode = OpenAIMode::Chat
|
||||||
verbose = false
|
verbose = false
|
||||||
|
|
||||||
parser = OptionParser.parse do |parser|
|
parser = OptionParser.parse do |parser|
|
||||||
parser.banner = "Usage: storyteller [options]"
|
parser.banner = "Usage: storyteller [options]"
|
||||||
|
|
||||||
|
parser.on("-m MODE", "--mode=MODE", "GPT mode (chat,insert,complete) (default: chat)") do |chosen_mode|
|
||||||
|
result_mode = OpenAIMode.from_s(chosen_mode.downcase)
|
||||||
|
if result_mode.nil?
|
||||||
|
STDERR.puts "ERROR: unknown mode #{chosen_mode}"
|
||||||
|
exit 1
|
||||||
|
end
|
||||||
|
gpt_mode = result_mode unless result_mode.nil?
|
||||||
|
end
|
||||||
|
|
||||||
parser.on("-i FILE", "--input=FILE", "Path to input file") do |file|
|
parser.on("-i FILE", "--input=FILE", "Path to input file") do |file|
|
||||||
input_file_path = file
|
input_file_path = file
|
||||||
end
|
end
|
||||||
|
@ -65,6 +88,7 @@ class Storyteller
|
||||||
|
|
||||||
# Build GPT-3 request
|
# Build GPT-3 request
|
||||||
prompt = storyteller.complete(prompt, make_request, verbose)
|
prompt = storyteller.complete(prompt, make_request, verbose)
|
||||||
|
pp prompt if verbose
|
||||||
exit 0 if !make_request
|
exit 0 if !make_request
|
||||||
|
|
||||||
if !output_file_path.empty?
|
if !output_file_path.empty?
|
||||||
|
@ -76,23 +100,46 @@ class Storyteller
|
||||||
end
|
end
|
||||||
|
|
||||||
def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
|
def complete(prompt : Prompt, make_request : Bool, verbose : Bool)
|
||||||
builder = OpenAIChatBuilder.new(verbose: verbose)
|
builder = Builder::OpenAIChat.new(verbose: verbose)
|
||||||
messages = builder.build(prompt)
|
messages = builder.build(prompt)
|
||||||
|
|
||||||
return prompt if !make_request
|
return prompt if !make_request
|
||||||
|
|
||||||
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
|
channel_ready = Channel(Bool).new
|
||||||
result = openai.chat(
|
channel_tick = Channel(Bool).new
|
||||||
"gpt-3.5-turbo",
|
spawn do
|
||||||
messages,
|
openai = OpenAI::Client.new(access_token: ENV.fetch("OPENAI_API_KEY"))
|
||||||
{
|
result = openai.chat(
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
messages,
|
||||||
|
{
|
||||||
"temperature" => 0.82,
|
"temperature" => 0.82,
|
||||||
"presence_penalty" => 1,
|
"presence_penalty" => 1,
|
||||||
"frequency_penalty" => 1,
|
"frequency_penalty" => 1,
|
||||||
"max_tokens" => 256
|
"max_tokens" => 256
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n"
|
prompt.present_zone.content << "\n" + result.choices.first["message"]["content"] + "\n"
|
||||||
|
channel_ready.send(true)
|
||||||
|
end
|
||||||
|
|
||||||
|
spawn do
|
||||||
|
loop do
|
||||||
|
channel_tick.send(true)
|
||||||
|
sleep 1.seconds
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
spawn do
|
||||||
|
while channel_tick.receive?
|
||||||
|
print "."
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
channel_ready.receive
|
||||||
|
channel_ready.close
|
||||||
|
channel_tick.close
|
||||||
|
|
||||||
prompt
|
prompt
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -100,7 +147,7 @@ class Storyteller
|
||||||
content = input_file.gets_to_end
|
content = input_file.gets_to_end
|
||||||
|
|
||||||
# puts "d: building parser"
|
# puts "d: building parser"
|
||||||
parser = StringParser.new
|
parser = Parser::PromptString.new
|
||||||
# puts "d: parsing"
|
# puts "d: parsing"
|
||||||
prompt = parser.parse(content)
|
prompt = parser.parse(content)
|
||||||
# pp prompt
|
# pp prompt
|
||||||
|
@ -108,7 +155,7 @@ class Storyteller
|
||||||
|
|
||||||
def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool)
|
def write_file(output_file : IO::FileDescriptor, prompt : Prompt, use_color : Bool)
|
||||||
# STDERR.puts "d: building builder"
|
# STDERR.puts "d: building builder"
|
||||||
builder = StringBuilder.new(use_color)
|
builder = Builder::PromptString.new(use_color)
|
||||||
# STDERR.puts "d: building"
|
# STDERR.puts "d: building"
|
||||||
text = builder.build(prompt)
|
text = builder.build(prompt)
|
||||||
output_file.write_string(text.to_slice)
|
output_file.write_string(text.to_slice)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
abstract class PromptGenericParser
|
module Parser
|
||||||
abstract def parse : Prompt
|
abstract class PromptGenericParser
|
||||||
|
abstract def parse : Prompt
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
51
src/parsers/prompt_string.cr
Normal file
51
src/parsers/prompt_string.cr
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
|
||||||
|
require "../prompt"
|
||||||
|
|
||||||
|
module Parser
|
||||||
|
class PromptString
|
||||||
|
def parse(current : String) : Prompt
|
||||||
|
prompt = Prompt.new
|
||||||
|
remaining = read_string(prompt.zones, current)
|
||||||
|
prompt.prelude_zone.content << remaining
|
||||||
|
|
||||||
|
prompt.system_zone.content.reverse!
|
||||||
|
prompt.past_zone.content.reverse!
|
||||||
|
prompt.present_zone.content.reverse!
|
||||||
|
prompt.future_zone.content.reverse!
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
end
|
||||||
|
|
||||||
|
def read_string(zone_list : Array(Zone), current : String)
|
||||||
|
# puts "== read_string(current=#{current})"
|
||||||
|
|
||||||
|
pos = current.index("@@")
|
||||||
|
|
||||||
|
## If there is no remaining @@, then return current
|
||||||
|
if pos.nil?
|
||||||
|
# puts "-- no remaining @@, returning"
|
||||||
|
return current
|
||||||
|
end
|
||||||
|
|
||||||
|
## If @@ is not at position 0, then parse its content first and return remains
|
||||||
|
if pos > 0
|
||||||
|
return current[0..(pos-1)] + read_string(zone_list, current[pos..])
|
||||||
|
end
|
||||||
|
|
||||||
|
## If @@ is at position 0, try detecting tag
|
||||||
|
zone = zone_list.find { |zone| current.starts_with?("@@" + zone.tag) }
|
||||||
|
|
||||||
|
## When there is not recognizable tag, skip fake tag and parse
|
||||||
|
if zone.nil?
|
||||||
|
# puts "-- no recognizable tag, returning as is"
|
||||||
|
return "@@" + read_string(zone_list, current[2..])
|
||||||
|
end
|
||||||
|
|
||||||
|
## Handle recognized tag, skip tag & parse & add remains
|
||||||
|
# puts "-- found tag #{zone.tag}"
|
||||||
|
remaining = read_string(zone_list, current[(2+zone.tag.size)..])
|
||||||
|
zone.content << remaining
|
||||||
|
return ""
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -1,7 +1,9 @@
|
||||||
|
|
||||||
class PromptRequestParser
|
module Parser
|
||||||
getter prompt
|
class PromptRequest
|
||||||
|
getter prompt
|
||||||
|
|
||||||
def initialize(@prompt = prompt)
|
def initialize(@prompt = prompt)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,50 +0,0 @@
|
||||||
|
|
||||||
require "../prompt"
|
|
||||||
|
|
||||||
class StringParser
|
|
||||||
def parse(current : String) : Prompt
|
|
||||||
prompt = Prompt.new
|
|
||||||
remaining = read_string(prompt.zones, current)
|
|
||||||
prompt.prelude_zone.content << remaining
|
|
||||||
|
|
||||||
prompt.system_zone.content.reverse!
|
|
||||||
prompt.past_zone.content.reverse!
|
|
||||||
prompt.present_zone.content.reverse!
|
|
||||||
prompt.future_zone.content.reverse!
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
end
|
|
||||||
|
|
||||||
def read_string(zone_list : Array(Zone), current : String)
|
|
||||||
# puts "== read_string(current=#{current})"
|
|
||||||
|
|
||||||
pos = current.index("@@")
|
|
||||||
|
|
||||||
## If there is no remaining @@, then return current
|
|
||||||
if pos.nil?
|
|
||||||
# puts "-- no remaining @@, returning"
|
|
||||||
return current
|
|
||||||
end
|
|
||||||
|
|
||||||
## If @@ is not at position 0, then parse its content first and return remains
|
|
||||||
if pos > 0
|
|
||||||
return current[0..(pos-1)] + read_string(zone_list, current[pos..])
|
|
||||||
end
|
|
||||||
|
|
||||||
## If @@ is at position 0, try detecting tag
|
|
||||||
zone = zone_list.find { |zone| current.starts_with?("@@" + zone.tag) }
|
|
||||||
|
|
||||||
## When there is not recognizable tag, skip fake tag and parse
|
|
||||||
if zone.nil?
|
|
||||||
# puts "-- no recognizable tag, returning as is"
|
|
||||||
return "@@" + read_string(zone_list, current[2..])
|
|
||||||
end
|
|
||||||
|
|
||||||
## Handle recognized tag, skip tag & parse & add remains
|
|
||||||
# puts "-- found tag #{zone.tag}"
|
|
||||||
remaining = read_string(zone_list, current[(2+zone.tag.size)..])
|
|
||||||
zone.content << remaining
|
|
||||||
return ""
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
Loading…
Reference in a new issue