|
|
import copy |
|
|
import json |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import re |
|
|
import warnings |
|
|
from datetime import timedelta |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import av |
|
|
import numpy as np |
|
|
import PIL |
|
|
import torch |
|
|
import transformers |
|
|
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs |
|
|
from accelerate.state import AcceleratorState |
|
|
from decord import VideoReader, cpu |
|
|
from moviepy.video.io.VideoFileClip import VideoFileClip |
|
|
from packaging import version |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoConfig |
|
|
|
|
|
from lmms_eval import utils |
|
|
from lmms_eval.api.instance import Instance |
|
|
from lmms_eval.api.model import lmms |
|
|
from lmms_eval.api.registry import register_model |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
eval_logger = logging.getLogger("lmms-eval") |
|
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
|
|
|
try: |
|
|
from llava.constants import ( |
|
|
DEFAULT_IM_END_TOKEN, |
|
|
DEFAULT_IM_START_TOKEN, |
|
|
DEFAULT_IMAGE_TOKEN, |
|
|
IGNORE_INDEX, |
|
|
IMAGE_TOKEN_INDEX, |
|
|
) |
|
|
from llava.conversation import SeparatorStyle, conv_templates |
|
|
from llava.mm_utils import ( |
|
|
KeywordsStoppingCriteria, |
|
|
get_model_name_from_path, |
|
|
process_images, |
|
|
tokenizer_image_token, |
|
|
) |
|
|
from llava.model.builder import load_pretrained_model |
|
|
except ImportError as e: |
|
|
eval_logger.debug(f"LLaVA_NeXT is not installed. Please install llava from `https://github.com/rese1f/MovieChat.git` to use this model.\nError: {e}") |
|
|
|
|
|
|
|
|
|
|
|
if version.parse(torch.__version__) >= version.parse("2.1.2"): |
|
|
best_fit_attn_implementation = "sdpa" |
|
|
else: |
|
|
best_fit_attn_implementation = "eager" |
|
|
|
|
|
|
|
|
|
|
|
@register_model("llava_onevision_moviechat") |
|
|
class Llava_OneVision_MovieChat(lmms): |
|
|
""" |
|
|
Llava Model |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pretrained: str = "lmms-lab/llava-onevision-qwen2-7b-ov", |
|
|
truncation: Optional[bool] = True, |
|
|
device: Optional[str] = "cuda:0", |
|
|
batch_size: Optional[Union[int, str]] = 1, |
|
|
model_name: str = "llava_qwen", |
|
|
attn_implementation: Optional[str] = best_fit_attn_implementation, |
|
|
device_map: Optional[str] = "cuda:0", |
|
|
conv_template: Optional[str] = "qwen_1_5", |
|
|
use_cache: Optional[bool] = True, |
|
|
truncate_context: Optional[bool] = False, |
|
|
customized_config: Optional[str] = None, |
|
|
short_memory_length: Optional[int] = 18, |
|
|
long_memory_length: Optional[int] = 64, |
|
|
sliding_window_length: Optional[int] = 8, |
|
|
merge_frame_length: Optional[int] = 2, |
|
|
tmp_folder: Optional[str] = "tmp/", |
|
|
mm_spatial_pool_stride: Optional[int] = 2, |
|
|
mm_spatial_pool_mode: Optional[str] = "bilinear", |
|
|
token_strategy: Optional[str] = "single", |
|
|
video_decode_backend: str = "decord", |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
assert kwargs == {}, f"Unexpected kwargs: {kwargs}" |
|
|
|
|
|
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) |
|
|
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) |
|
|
if accelerator.num_processes > 1: |
|
|
self._device = torch.device(f"cuda:{accelerator.local_process_index}") |
|
|
self.device_map = f"cuda:{accelerator.local_process_index}" |
|
|
elif accelerator.num_processes == 1 and device_map == "auto": |
|
|
self._device = torch.device(device) |
|
|
self.device_map = device_map |
|
|
else: |
|
|
self._device = torch.device(f"cuda:{accelerator.local_process_index}") |
|
|
self.device_map = f"cuda:{accelerator.local_process_index}" |
|
|
|
|
|
llava_model_args = { |
|
|
"multimodal": True, |
|
|
} |
|
|
if customized_config is not None: |
|
|
llava_model_args["customized_config"] = customized_config |
|
|
if attn_implementation is not None: |
|
|
llava_model_args["attn_implementation"] = attn_implementation |
|
|
if "use_flash_attention_2" in kwargs: |
|
|
llava_model_args["use_flash_attention_2"] = kwargs["use_flash_attention_2"] |
|
|
model_name = model_name if model_name is not None else get_model_name_from_path(pretrained) |
|
|
|
|
|
self.pretrained = pretrained |
|
|
self.token_strategy = token_strategy |
|
|
self.mm_spatial_pool_stride = mm_spatial_pool_stride |
|
|
self.mm_spatial_pool_mode = mm_spatial_pool_mode |
|
|
self.video_decode_backend = video_decode_backend |
|
|
|
|
|
self.short_memory_length = short_memory_length |
|
|
self.long_memory_length = long_memory_length |
|
|
self.merge_frame_length = merge_frame_length |
|
|
self.sliding_window_length = sliding_window_length |
|
|
self.num_clips = (self.long_memory_length // self.merge_frame_length) * ((self.short_memory_length - self.merge_frame_length) // self.sliding_window_length) |
|
|
self.tmp_folder = tmp_folder |
|
|
|
|
|
overwrite_config = {} |
|
|
overwrite_config["mm_spatial_pool_stride"] = self.mm_spatial_pool_stride |
|
|
overwrite_config["mm_spatial_pool_mode"] = self.mm_spatial_pool_mode |
|
|
cfg_pretrained = AutoConfig.from_pretrained(self.pretrained) |
|
|
|
|
|
if cfg_pretrained.architectures[0] == "LlavaLlamaForCausalLM": |
|
|
if "224" in cfg_pretrained.mm_vision_tower: |
|
|
least_token_number = self.max_frames_num * (16 // self.mm_spatial_pool_stride) ** 2 + 1000 |
|
|
else: |
|
|
least_token_number = self.max_frames_num * (24 // self.mm_spatial_pool_stride) ** 2 + 1000 |
|
|
|
|
|
scaling_factor = math.ceil(least_token_number / 4096) |
|
|
if scaling_factor >= 2: |
|
|
overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"} |
|
|
overwrite_config["max_sequence_length"] = 4096 * scaling_factor |
|
|
overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor |
|
|
|
|
|
llava_model_args["overwrite_config"] = overwrite_config |
|
|
from LLaVA_NeXT.llava.model.builder import load_pretrained_model |
|
|
|
|
|
try: |
|
|
|
|
|
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args) |
|
|
except TypeError: |
|
|
|
|
|
llava_model_args.pop("multimodal", None) |
|
|
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args) |
|
|
|
|
|
self._config = self._model.config |
|
|
self.model.eval() |
|
|
self.truncation = truncation |
|
|
self.batch_size_per_gpu = int(batch_size) |
|
|
self.conv_template = conv_template |
|
|
self.use_cache = use_cache |
|
|
self.truncate_context = truncate_context |
|
|
assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." |
|
|
|
|
|
if accelerator.num_processes > 1: |
|
|
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." |
|
|
|
|
|
|
|
|
|
|
|
if accelerator.distributed_type == DistributedType.DEEPSPEED: |
|
|
kwargs = { |
|
|
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu, |
|
|
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, |
|
|
} |
|
|
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) |
|
|
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") |
|
|
|
|
|
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: |
|
|
self._model = accelerator.prepare(self.model) |
|
|
else: |
|
|
self._model = accelerator.prepare_model(self.model, evaluation_mode=True) |
|
|
self.accelerator = accelerator |
|
|
if self.accelerator.is_local_main_process: |
|
|
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") |
|
|
self._rank = self.accelerator.local_process_index |
|
|
self._world_size = self.accelerator.num_processes |
|
|
|
|
|
elif accelerator.num_processes == 1 and device_map == "auto": |
|
|
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") |
|
|
self._rank = 0 |
|
|
self._world_size = 1 |
|
|
|
|
|
else: |
|
|
eval_logger.info(f"Using single device: {self._device}") |
|
|
self.model.to(self._device) |
|
|
self._rank = 0 |
|
|
self._world_size = 1 |
|
|
|
|
|
@property |
|
|
def config(self): |
|
|
|
|
|
return self._config |
|
|
|
|
|
@property |
|
|
def tokenizer(self): |
|
|
return self._tokenizer |
|
|
|
|
|
@property |
|
|
def model(self): |
|
|
|
|
|
if hasattr(self, "accelerator"): |
|
|
return self.accelerator.unwrap_model(self._model) |
|
|
else: |
|
|
return self._model |
|
|
|
|
|
@property |
|
|
def eot_token_id(self): |
|
|
|
|
|
return self.tokenizer.eos_token_id |
|
|
|
|
|
@property |
|
|
def max_length(self): |
|
|
return self._max_length |
|
|
|
|
|
def pad_sequence(self, input_ids, batch_first, padding_value): |
|
|
if self.tokenizer.padding_side == "left": |
|
|
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] |
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) |
|
|
if self.tokenizer.padding_side == "left": |
|
|
input_ids = torch.flip(input_ids, [1]) |
|
|
return input_ids |
|
|
|
|
|
@property |
|
|
def batch_size(self): |
|
|
return self.batch_size_per_gpu |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self._device |
|
|
|
|
|
@property |
|
|
def rank(self): |
|
|
return self._rank |
|
|
|
|
|
@property |
|
|
def world_size(self): |
|
|
return self._world_size |
|
|
|
|
|
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]: |
|
|
""" """ |
|
|
add_special_tokens = False if add_special_tokens is None else add_special_tokens |
|
|
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) |
|
|
|
|
|
if left_truncate_len: |
|
|
encoding = encoding[-left_truncate_len:] |
|
|
return encoding |
|
|
|
|
|
def tok_decode(self, tokens): |
|
|
try: |
|
|
return self.tokenizer.decode(tokens) |
|
|
except: |
|
|
return self.tokenizer.decode([tokens]) |
|
|
|
|
|
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: |
|
|
|
|
|
raise NotImplementedError("MovieChat only supports generation.") |
|
|
|
|
|
def flatten(self, input): |
|
|
new_list = [] |
|
|
for i in input: |
|
|
for j in i: |
|
|
new_list.append(j) |
|
|
return new_list |
|
|
|
|
|
def load_video(self, video_path, max_frames_num): |
|
|
if type(video_path) == str: |
|
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
|
else: |
|
|
vr = VideoReader(video_path[0], ctx=cpu(0)) |
|
|
total_frame_num = len(vr) |
|
|
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) |
|
|
frame_idx = uniform_sampled_frames.tolist() |
|
|
spare_frames = vr.get_batch(frame_idx).numpy() |
|
|
return spare_frames |
|
|
|
|
|
def extract_keyframes(self, video_path, keyframes): |
|
|
container = av.open(video_path) |
|
|
video_stream = container.streams.video[0] |
|
|
fps = video_stream.average_rate |
|
|
time_base = video_stream.time_base |
|
|
frames = [] |
|
|
|
|
|
for keyframe in keyframes: |
|
|
keyframe_time = float(keyframe) |
|
|
frame_number = int(keyframe_time * fps) |
|
|
container.seek(int(keyframe_time / time_base)) |
|
|
found = False |
|
|
for packet in container.demux(video=0): |
|
|
for frame in packet.decode(): |
|
|
if frame.index >= frame_number: |
|
|
frames.append(frame) |
|
|
found = True |
|
|
break |
|
|
if found: |
|
|
break |
|
|
|
|
|
if not found: |
|
|
container.seek(-1, any_frame=False) |
|
|
for packet in container.demux(video=0): |
|
|
for frame in packet.decode(): |
|
|
pass |
|
|
frames.append(frame) |
|
|
|
|
|
video = [x.to_ndarray(format="rgb24") for x in frames] |
|
|
video_frames = [Image.fromarray(x) for x in video] |
|
|
return video_frames |
|
|
|
|
|
def generate_until(self, requests: List[Instance]) -> List[str]: |
|
|
res = [] |
|
|
|
|
|
def _collate(x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toks = self.tok_encode(x[0]) |
|
|
return -len(toks), x[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metadata = requests[0].metadata |
|
|
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) |
|
|
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) |
|
|
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 |
|
|
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") |
|
|
for chunk in chunks: |
|
|
batched_contexts, all_gen_kwargs, batched_doc_to_visual, batched_doc_id, batched_task, batched_split = zip(*chunk) |
|
|
task = batched_task[0] |
|
|
split = batched_split[0] |
|
|
batched_visuals = [batched_doc_to_visual[0](self.task_dict[task][split][ids]) for ids in batched_doc_id] |
|
|
assert len(batched_visuals) == 1 |
|
|
|
|
|
|
|
|
|
|
|
gen_kwargs = all_gen_kwargs[0] |
|
|
if "until" in gen_kwargs: |
|
|
gen_kwargs.pop("until") |
|
|
|
|
|
question_input = [] |
|
|
|
|
|
for visual, context in zip(batched_visuals, batched_contexts): |
|
|
if len(visual) > 1 or "image_aspect_ratio" not in self._config.__dict__: |
|
|
self._config.image_aspect_ratio = getattr(gen_kwargs, "image_aspect_ratio", "pad") |
|
|
eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}") |
|
|
if type(visual[0]) == PIL.Image.Image and "task_type" not in metadata and "sample_frames" not in metadata: |
|
|
raise NotImplementedError("MovieChat only supports video inputs.") |
|
|
|
|
|
elif "task_type" in metadata and metadata["task_type"] == "video" and "sample_frames" in metadata: |
|
|
raise NotImplementedError("MovieChat only supports video inputs.") |
|
|
|
|
|
elif type(visual[0]) == str: |
|
|
try: |
|
|
self.short_memory_buffer = [] |
|
|
self.long_memory_buffer = [] |
|
|
|
|
|
os.makedirs(self.tmp_folder, exist_ok=True) |
|
|
|
|
|
video = VideoFileClip(visual[0]) |
|
|
clip_duration = video.duration / self.num_clips |
|
|
|
|
|
cur_frame = 0 |
|
|
for i in range(self.num_clips): |
|
|
start_time = i * clip_duration |
|
|
end_time = start_time + clip_duration |
|
|
|
|
|
frames = list(video.subclip(start_time, end_time).iter_frames(fps=self.sliding_window_length / clip_duration))[: self.sliding_window_length] |
|
|
frames = [Image.fromarray(frame).convert("RGB") for frame in frames] |
|
|
preprocess_frames = self._image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].half().cuda() |
|
|
encoded_window = self.model.encode_images(preprocess_frames) |
|
|
|
|
|
for frame in encoded_window: |
|
|
if cur_frame < (self.short_memory_length - self.merge_frame_length): |
|
|
if len(self.short_memory_buffer) == self.short_memory_length: |
|
|
self.short_memory_buffer.pop(0) |
|
|
self.short_memory_buffer.append(frame) |
|
|
cur_frame += 1 |
|
|
|
|
|
if cur_frame == (self.short_memory_length - self.merge_frame_length): |
|
|
cur_frame = 0 |
|
|
|
|
|
|
|
|
similar_list = [] |
|
|
for frame_i in range(len(self.short_memory_buffer) - 1): |
|
|
scores = self.short_memory_buffer[frame_i] @ self.short_memory_buffer[frame_i + 1].transpose(-1, -2) |
|
|
frame_silimar = torch.mean(scores) |
|
|
similar_list.append(frame_silimar) |
|
|
|
|
|
while len(self.short_memory_buffer) > self.merge_frame_length: |
|
|
max_value = max(similar_list) |
|
|
max_index = similar_list.index(max_value) |
|
|
new_frame_feature = (self.short_memory_buffer[max_index].cpu() + self.short_memory_buffer[max_index + 1].cpu()) / 2 |
|
|
self.short_memory_buffer[max_index] = new_frame_feature.cuda() |
|
|
del self.short_memory_buffer[max_index + 1] |
|
|
similar_list = [] |
|
|
for frame_i in range(len(self.short_memory_buffer) - 1): |
|
|
scores = self.short_memory_buffer[frame_i] @ self.short_memory_buffer[frame_i + 1].transpose(-1, -2) |
|
|
frame_silimar = torch.mean(scores) |
|
|
similar_list.append(frame_silimar) |
|
|
|
|
|
for frame in self.short_memory_buffer: |
|
|
self.long_memory_buffer.append(frame) |
|
|
|
|
|
image_features = torch.stack(self.long_memory_buffer) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
eval_logger.error(f"Error {e} in loading video") |
|
|
image_features = None |
|
|
|
|
|
task_type = "video" |
|
|
placeholder_count = len(frames) if self.token_strategy == "multiple" else 1 |
|
|
|
|
|
if image_features is not None and DEFAULT_IMAGE_TOKEN not in context: |
|
|
""" |
|
|
Three senarios: |
|
|
1. No image, and there for, no image token should be added. |
|
|
2. image token is already specified in the context, so we don't need to add it. |
|
|
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line. |
|
|
4. For video tasks, we could add a <image> token or multiple <image> tokens for each frame in the context. This depends on the training strategy and should balance in test to decide which is better |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_tokens = [DEFAULT_IMAGE_TOKEN] * placeholder_count |
|
|
image_tokens = " ".join(image_tokens) |
|
|
question = image_tokens + "\n" + context |
|
|
else: |
|
|
question = context |
|
|
|
|
|
|
|
|
if "llama_3" in self.conv_template: |
|
|
conv = copy.deepcopy(conv_templates[self.conv_template]) |
|
|
else: |
|
|
conv = conv_templates[self.conv_template].copy() |
|
|
|
|
|
if utils.is_json(question): |
|
|
question = json.loads(question) |
|
|
for idx, item in enumerate(question): |
|
|
role = conv.roles[idx % 2] |
|
|
message = item["value"] |
|
|
conv.append_message(role, message) |
|
|
|
|
|
assert len(conv.messages) % 2 == 1 |
|
|
conv.append_message(conv.roles[1], None) |
|
|
prompt_question = conv.get_prompt() |
|
|
question_input.append(prompt_question) |
|
|
else: |
|
|
conv.append_message(conv.roles[0], question) |
|
|
conv.append_message(conv.roles[1], None) |
|
|
prompt_question = conv.get_prompt() |
|
|
question_input.append(prompt_question) |
|
|
|
|
|
|
|
|
if "max_new_tokens" not in gen_kwargs: |
|
|
gen_kwargs["max_new_tokens"] = 1024 |
|
|
if "temperature" not in gen_kwargs: |
|
|
gen_kwargs["temperature"] = 0 |
|
|
if "do_sample" not in gen_kwargs: |
|
|
gen_kwargs["do_sample"] = False |
|
|
if "top_p" not in gen_kwargs: |
|
|
gen_kwargs["top_p"] = None |
|
|
if "num_beams" not in gen_kwargs: |
|
|
gen_kwargs["num_beams"] = 1 |
|
|
|
|
|
input_ids_list = [tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in question_input] |
|
|
pad_token_ids = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id |
|
|
input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_ids).to(self.device) |
|
|
attention_masks = input_ids.ne(pad_token_ids).to(self.device) |
|
|
|
|
|
if task_type == "image": |
|
|
gen_kwargs["image_sizes"] = [batched_visuals[0][idx].size for idx in range(len(batched_visuals[0]))] |
|
|
elif task_type == "video": |
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
|
|
keywords = [stop_str] |
|
|
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) |
|
|
gen_kwargs["modalities"] = ["video"] |
|
|
gen_kwargs["stopping_criteria"] = [stopping_criteria] |
|
|
self._config.mm_spatial_pool_stride = self.mm_spatial_pool_stride |
|
|
self._config.mm_spatial_pool_mode = self.mm_spatial_pool_mode |
|
|
|
|
|
|
|
|
|
|
|
if "image_aspect_ratio" in gen_kwargs.keys(): |
|
|
gen_kwargs.pop("image_aspect_ratio") |
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
gen_kwargs.pop("modalities") |
|
|
cont = self.model.generate_moviechat(input_ids, attention_mask=attention_masks, pad_token_id=pad_token_ids, image_features=image_features, use_cache=self.use_cache, **gen_kwargs) |
|
|
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) |
|
|
text_outputs = [response.strip() for response in text_outputs] |
|
|
except Exception as e: |
|
|
print(e) |
|
|
text_outputs = "Can not infer the answer." |
|
|
|
|
|
res.extend(text_outputs) |
|
|
print(text_outputs) |
|
|
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs) |
|
|
pbar.update(1) |
|
|
|
|
|
res = re_ords.get_original(res) |
|
|
|
|
|
pbar.close() |
|
|
return res |
|
|
|